Удивительное поведение numpy float16 при проверке равенства

Я передаю различные биты данных в функцию, которая вычисляет дисперсию по первому измерению. Иногда дисперсия равна нулю, и это нормально, но тогда происходит следующая странная вещь:

>> sigma = data.var(axis=0) + 1e-7 # data has zero variance so all entries should equal 1e-7
>> sigma
array([1.e-07, 1.e-07, 1.e-07, ..., 1.e-07, 1.e-07, 1.e-07], dtype=float16)
>> (sigma==1e-7).all()
True
>> sigma[0]==1e-7
False

Сама по себе четвертая строка объясняется 16-битной точностью, и действительно

>> np.float16(1e-7)==1e-7
False

Но это, кажется, противоречит третьей строке, в которой говорится, что они равны. Это вызвало ошибку в моем коде. Я могу изменить дизайн вокруг этого, но я хочу понять, почему numpy делает это, чтобы меня снова не поймали в будущем.

Используйте all_close при сравнении чисел с плавающей запятой.

hpaulj 06.06.2024 17:36
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
4
1
98
2
Перейти к ответу Данный вопрос помечен как решенный

Ответы 2

Насколько я понимаю, 1.e-7 — это 0.0000001

Но float16 не обладает достаточной точностью, чтобы точно представить 0.0000001.

Когда 1.e-7 приводится к float16, оно округляется до ближайшего представимого значения, которое равно 6.6e-7.

В float32 1.e-7 будет 6.9999998e-7

Только использование float64 позволит сохранить это точно.

Ответ принят как подходящий

Это связано с тем, что продвижение типов numpy по-разному обрабатывает скаляры и массивы. Вы можете увидеть это с помощью np.result_type:

>>> np.result_type(sigma, 1E-7)
dtype('float16')

>>> np.result_type(sigma[0], 1E-7)
dtype('float64')

По сути, когда значение массива сравнивается со скалярным значением (первый случай), dtype значения массива имеет приоритет. При сравнении двух скаляров или двух массивов (второй случай) приоритет имеет наивысшая точность.

Это означает, что когда вы оцениваете (sigma == 1E-7), обе стороны сначала приводятся к float16 перед сравнением, тогда как когда вы оцениваете sigma[0] == 1E-7, обе стороны сначала приводятся к float64 перед сравнением.

Поскольку float16 не может идеально представлять значение 1E-7, это приводит к расхождению в случае скалярного сравнения, когда оба значения приводятся к float64:

>>> np.float16(1E-7).astype(np.float64)
1.1920928955078125e-07
>>> np.float64(1E-7)
1e-07

Наконец, обратите внимание, что эти правила приведения типов, специфичные для скаляров, изменяются в NumPy 2.0 (см. NEP 50: Правила продвижения для скаляров), поэтому, если вы запустите свой код с помощью NumPy 2.0, оба случая будут повышены до float16 и вернутся. True.

Другие вопросы по теме