Можно ли с уверенностью предположить, что 32-битные числа с плавающей запятой можно напрямую сравнивать друг с другом, если значение соответствует мантиссе?

В задаче литкода о том, является ли целое число суммой идеальных квадратов, использование чисел с плавающей запятой вместо целых привело к большему ускорению (проблема с идеальными квадратами). Можно ли с уверенностью предположить, что если целочисленные значения гарантированно будут меньше 10 000 и больше или равны 0, мы можем вместо этого использовать числа с плавающей запятой?

Пример сравнения:

if (n == i*i + j*j * 2)
    result3++;

if (n == i*i + k*k)
    result2++;

оба int и float прошли все тесты (n,i,j,k, все с плавающей запятой или все int), но я все еще не уверен, есть ли какая-либо разница между процессорами (не уверен, что leetcode всегда использует одно и то же) или компилятором или чем-то еще (как время?).

Ссылка на задачу и код: https://leetcode.com/problems/perfect-squares/description/

Код:

#include<iostream>
#include<math.h>
class Solution {
public:

static constexpr int simd =8;
using FAST_TYPE = short;
using MASK_TYPE = short;
    const int numSquares(const int n) const noexcept {
        if (n==2 || n==8)
            return 2;
        if (n==3 || n==6 || n==11)
            return 3;
        
        if ((int)std::sqrt(n)*(int)std::sqrt(n) == n)
            return 1;

        FAST_TYPE found2 = 0;
        FAST_TYPE found3 = 0;
        FAST_TYPE found32 = 0;
        FAST_TYPE found33 = 0;
        FAST_TYPE found34 = 0;


        alignas(64)
        FAST_TYPE zeroSimd[simd];
        alignas(64)
        FAST_TYPE oneSimd[simd];
        alignas(64)
        FAST_TYPE found3Simd[simd];
        alignas(64)
        FAST_TYPE found3Simd2[simd];
        alignas(64)
        FAST_TYPE found3Simd3[simd];
        alignas(64)
        FAST_TYPE found3Simd4[simd];                
        alignas(64)
        FAST_TYPE mSimd[simd];
        alignas(64)
        FAST_TYPE kSimd[simd];
        alignas(64)
        FAST_TYPE k0Simd[simd];
        alignas(64)
        FAST_TYPE nSimd[simd];
        alignas(64)
        FAST_TYPE twoSimd[simd];
        alignas(64)
        FAST_TYPE threeSimd[simd];   
        alignas(64)
        FAST_TYPE iSimd[simd];     
        alignas(64)
        FAST_TYPE jSimd[simd];
        alignas(64)
        FAST_TYPE ijSimd[simd];           
        alignas(64)
        FAST_TYPE j2Simd[simd];        
        alignas(64)
        FAST_TYPE i2Simd[simd];      
        alignas(64)
        MASK_TYPE mask1Simd[simd];                    
        alignas(64)
        MASK_TYPE mask2Simd[simd];                    
        alignas(64)
        MASK_TYPE mask3Simd[simd];                    
        alignas(64)
        MASK_TYPE mask4Simd[simd];         
        alignas(64)
        FAST_TYPE sum1Simd[simd];                                              
        alignas(64)
        FAST_TYPE sum2Simd[simd];                                                      
        alignas(64)
        FAST_TYPE sum3Simd[simd];       
        alignas(64)
        FAST_TYPE mulSimd[simd];                                                         
        for(int i=0;i<simd;i++)
        {
            zeroSimd[i]=0;
            oneSimd[i]=1;
            found3Simd[i]=0;
            found3Simd2[i]=0;
            found3Simd3[i]=0;
            found3Simd4[i]=0;
            mSimd[i]=i;
            nSimd[i]=n;
            twoSimd[i]=2;
            threeSimd[i]=2;
            
        }
        for(int i=1+std::sqrt(n);i>=1;i--)
        {
            const FAST_TYPE i2 = i*i;
            const FAST_TYPE i22 = 2*i*i;            
            const FAST_TYPE i23 = 3*i*i;            
            #pragma GCC ivdep
            for(int m=0;m<simd;m++)
                iSimd[m]=i2;
            #pragma GCC ivdep
            for(int m=0;m<simd;m++)
                i2Simd[m]=i22;                
            found2 += (i22 == n);            
            found3+=(i23 == n);   
            for(int j=i-1;j>=1;j--)
            {
                const FAST_TYPE j2 = j*j;
                const FAST_TYPE j22 = 2*j*j;
                const FAST_TYPE j23 = 3*j*j;
                #pragma GCC ivdep
                for(int m=0;m<simd;m++)
                    jSimd[m]=j2;
                #pragma GCC ivdep
                for(int m=0;m<simd;m++)
                    j2Simd[m]=j22;       
                #pragma GCC ivdep
                for(int m=0;m<simd;m++)
                    ijSimd[m]=i2+j2;                                      
                found2+=(i2 + j2 == n);
                found3+=(i2 + j22 == n)+(i22 + j2 == n)+(j23 == n);        
                const int k32 = j-1 - ((j-1)%simd);  
                #pragma GCC unroll 2
                for(int k0=1;k0<=k32;k0+=simd) 
                {
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)
                        k0Simd[m]=k0;
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                 
                        kSimd[m] = k0Simd[m]+mSimd[m];
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                 
                        kSimd[m] = kSimd[m]*kSimd[m];


                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++) 
                        sum1Simd[m]=ijSimd[m] + kSimd[m];
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                             
                        mask1Simd[m]=sum1Simd[m] == nSimd[m];
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                  
                        found3Simd[m]=mask1Simd[m]?oneSimd[m]:found3Simd[m];


                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++) 
                        sum2Simd[m]=i2Simd[m] + kSimd[m];
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                             
                        mask2Simd[m]=(sum2Simd[m]==nSimd[m]);
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                  
                        found3Simd2[m]=mask2Simd[m]?oneSimd[m]:found3Simd2[m];

                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++) 
                        sum3Simd[m]=j2Simd[m] + kSimd[m];
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                             
                        mask3Simd[m]=(sum3Simd[m]==nSimd[m]);                        
                     #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                      
                        found3Simd3[m]=mask3Simd[m]?oneSimd[m]:found3Simd3[m];     


                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++) 
                        mulSimd[m]=threeSimd[m]*kSimd[m];
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)       
                        mask4Simd[m]=(mulSimd[m]==nSimd[m]);                    
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                                                                                                                                                                                                    
                        found3Simd4[m]=mask4Simd[m]?oneSimd[m]:found3Simd4[m];
                    
                }

                
                for(int k=k32;k<=j-1;k++)   
                {                   
                    const FAST_TYPE k2 = k*k;
                    found3+=(i2 + j2 + k2 ==n);
                    found32+=(i22 + k2 ==n);
                    found33+=(j22 + k2 ==n);
                    found34+=(3*k2 ==n);
                }
            }      
        }

        for(int i=0;i<simd;i++)
        {
            found3+=found3Simd[i];
            found32+=found3Simd2[i];
            found33+=found3Simd3[i];
            found34+=found3Simd4[i];
        }
        found3 += found32 + found33 + found34;
        if (found2)
            return 2;

        if (found3)
            return 3;

        return 4;
    }

};

int main()
{
    Solution s;
    for(int i=10;i<20;i++)
    {
        std::cout<<i<<" is equal to sum of "<<s.numSquares(i)<< " perfect squares"<<std::endl; 
    }
}

выход:

10 is equal to sum of 2 perfect squares
11 is equal to sum of 3 perfect squares
12 is equal to sum of 3 perfect squares
13 is equal to sum of 2 perfect squares
14 is equal to sum of 3 perfect squares
15 is equal to sum of 4 perfect squares
16 is equal to sum of 1 perfect squares
17 is equal to sum of 2 perfect squares
18 is equal to sum of 2 perfect squares
19 is equal to sum of 3 perfect squares

вы всегда можете сравнить числа с плавающей запятой напрямую через ==. Проблемы возникают только в том случае, если вы предполагаете, что арифметика с плавающей запятой точна. Ваша ссылка на «проблему и код» не содержит проблем и кода, а код, который вы здесь разместили, является неполным. В качестве примера вопроса это не очень полезно.

463035818_is_not_an_ai 14.08.2024 13:36

также все, что связано с SIMD, похоже, не связано

463035818_is_not_an_ai 14.08.2024 13:38

Я думаю, что это тесно связано: stackoverflow.com/questions/17333/… . TL;DR, вы не можете явно сравнивать числа с плавающей запятой в контексте математических задач, а диапазон [0, 1e5] не очень полезен, даже простой 10*3 == 90/3 не обязательно может быть верным (здесь, в SO, есть лучшие примеры, попробую найти).

pptaszni 14.08.2024 13:39

Обновлен код с рабочим примером и прямой ссылкой на проблему. Я не знаю, был ли SIMD ответственен за работу поплавка или нет.

huseyin tugrul buyukisik 14.08.2024 13:41

Возможные причины, по которым float работает быстрее в вашем случае: 1. Функция sqrt(int) отсутствует, поэтому компилятор преобразуется в double для вызова функции, а затем обратно в short. Если всё float, так и остаётся. 2. short — не очень быстрый тип данных, если только векторизация не работает. Компилятор, скорее всего, расширит знак до int. 3. Большинство процессоров имеют только один блок целочисленного умножения (если он не векторизован), но два умножителя с плавающей запятой.

Homer512 14.08.2024 13:52

Было бы полезно попробовать это с двойной точностью, поскольку сегодня на многих процессорах код double64 фактически немного быстрее, чем float32. 10^8 также больше, чем мантисса float32 из 23 бит = 2^23 = 8388608 ~ 8.4.10^6, которую действительно может содержать. Вы можете безопасно сравнивать целые значения, вписывающиеся в мантиссу, как числа с плавающей запятой, при условии, что вы полностью избегаете деления и никогда не переполняете доступную длину мантиссы. +,- и * безопасны при условии, что вы не переполните доступную длину мантиссы. Я подозреваю, что вы также обнаружите, что он будет работать немного быстрее, если вы сделаете FAST_TYPE = int.

Martin Brown 14.08.2024 14:03

Примечание: возможно, впервые я вижу вопрос о проблеме Leetcode от старшего программиста C++, а не от новичка, ошибочно полагающего, что он сможет выучить язык таким образом.

prapin 14.08.2024 14:54

@pptaszni Невозможно, чтобы 10*3 == 90/3 могло вычислить значение false, с плавающей запятой или нет. (90 * (1./3), это другой вопрос.)

Steve Summit 14.08.2024 18:20

Вопреки распространенному мнению, арифметика с плавающей запятой точна в отношении значений, которые она может представлять. Целые числа от 0 до 2²⁴ = 16777216 точно представляются в формате с плавающей запятой одинарной точности.

Steve Summit 14.08.2024 18:26

@MartinBrown Существует неявный 24-й бит мантиссы (достигаемый посредством нормализации). Таким образом, число с плавающей запятой может фактически представлять целые числа до 2**24. Вот быстрое доказательство: godbolt.org/z/Y5jM1bjzn

Homer512 15.08.2024 00:31
Стоит ли изучать PHP в 2023-2024 годах?
Стоит ли изучать PHP в 2023-2024 годах?
Привет всем, сегодня я хочу высказать свои соображения по поводу вопроса, который я уже много раз получал в своем сообществе: "Стоит ли изучать PHP в...
Поведение ключевого слова "this" в стрелочной функции в сравнении с нормальной функцией
Поведение ключевого слова "this" в стрелочной функции в сравнении с нормальной функцией
В JavaScript одним из самых запутанных понятий является поведение ключевого слова "this" в стрелочной и обычной функциях.
Приемы CSS-макетирования - floats и Flexbox
Приемы CSS-макетирования - floats и Flexbox
Здравствуйте, друзья-студенты! Готовы совершенствовать свои навыки веб-дизайна? Сегодня в нашем путешествии мы рассмотрим приемы CSS-верстки - в...
Тестирование функциональных ngrx-эффектов в Angular 16 с помощью Jest
В системе управления состояниями ngrx, совместимой с Angular 16, появились функциональные эффекты. Это здорово и делает код определенно легче для...
Концепция локализации и ее применение в приложениях React ⚡️
Концепция локализации и ее применение в приложениях React ⚡️
Локализация - это процесс адаптации приложения к различным языкам и культурным требованиям. Это позволяет пользователям получить опыт, соответствующий...
Пользовательский скаляр GraphQL
Пользовательский скаляр GraphQL
Листовые узлы системы типов GraphQL называются скалярами. Достигнув скалярного типа, невозможно спуститься дальше по иерархии типов. Скалярный тип...
0
10
94
2
Перейти к ответу Данный вопрос помечен как решенный

Ответы 2

Можно ли с уверенностью предположить, что 32-битные числа с плавающей запятой можно напрямую сравнивать друг с другом, если значение соответствует мантиссе?

Один из способов определить, небезопасно ли это, — сравнить результаты с прямым неэффективным эталонным кодом и поискать различия.

Если сравнивать всех, то хотя бы на одной машине и компиляторе можно сравнивать безопасно.

В коде ОП, скорее всего, возникнут проблемы, если int 16-битный.

Ниже приведен тестовый код для сравнения.

/*
 * Return true if `n` is the sum of 2 perfect squares
 */
bool IsSumOf2Squares(int n, int *a_ptr, int *b_ptr) {
  if (n < 0) {
    return false;
  }
  int b_target;
  for (int a = 0; (b_target = n - a * a) >= 0; a++) {
    int diff;
    for (int b = a; (diff = b_target - b * b) >= 0; b++) {
      if (diff == 0) {
        *a_ptr = a;
        *b_ptr = b;
        return true;
      }
    }
  }
  return false;
}

int main(void) {
  clock_t c0, c1;
  c0 = clock();
  int count = 0;
  int n = 10000;
  for (int i = -42; i <= n; i++) {
    int a, b;
    if (IsSumOf2Squares(i, &a, &b)) {
      if (count % 300 == 0) {
        printf("%10d: %5d %5d\n", i, a, b);
        fflush(stdout);
      }
      count++;
    }
  }
  c1 = clock();
  printf("Count: %d, Time:%gs\n", count, (double) (c1 - c0) / CLOCKS_PER_SEC);
  return 0;
}

Выход

         0:     0     0
       901:     1    30
      1933:    13    42
      3001:    20    51
      4105:     3    64
      5213:    37    62
      6354:    27    75
      7489:    33    80
      8656:    40    84
      9808:    68    72
Count: 2750, Time:0.015s
Ответ принят как подходящий

Можно ли с уверенностью предположить, что 32-битные числа с плавающей запятой можно напрямую сравнивать друг с другом, если значение соответствует мантиссе?

Это неправильный вопрос; Числа с плавающей запятой всегда можно сравнивать друг с другом, и сравнение покажет, что они равны тогда и только тогда, когда они равны.

Правильный вопрос заключается в том, позволит ли использование арифметики с плавающей запятой получить желаемые результаты.

Согласно IEEE-754 и другим спецификациям чисел с плавающей запятой, каждое число с плавающей запятой представляет одно действительное число. Это представление является точным, без ошибок. Именно операции с плавающей запятой, а не числа, приближают действительную арифметику. При выполнении операции с плавающей запятой результатом с плавающей запятой является результат арифметических операций с действительными числами, округленный до ближайшего значения, представимого в формате с плавающей запятой (с использованием различных правил округления). Если результат вещественной арифметики представим, то это и есть результат; округления не будет.

Стандарт C++ гарантирует, что float может представлять числа с достаточным разрешением, чтобы различать шестизначные десятичные числа во всем диапазоне. Таким образом, он может представлять все целые числа по крайней мере до 1 000 000, что превышает запрошенный вами диапазон 10 000.

Операции, которые вы выполняете с целочисленными значениями, такими как 3*j*j, будут точны в этом диапазоне.

Эта операция не всегда будет точной: (int)std::sqrt(n)*(int)std::sqrt(n).

Когда n не является квадратом, sqrt, конечно, не может дать точный результат. Однако для вашей цели это подходит, так как возвращаемое значение будет усечено, а вычисленное произведение не будет равно n, поэтому сравнение (int)std::sqrt(n)*(int)std::sqrt(n) == n будет иметь значение false, как и хотелось.

Проблема заключается в том, что некоторые реализации sqrt могут не возвращать целое число, даже если n является целочисленным квадратом.

Если у вас хорошая реализация sqrt, она вернет точный результат, а сравнение продукта с n будет иметь значение true, как и пожелано. Если у вас плохая реализация sqrt и она возвращает значение немного меньше правильного квадратного корня из n, тогда результат сравнения будет ложным, что вызовет проблему в вашей программе.

Поскольку вас интересует только диапазон до 10 000, вы можете легко проверить sqrt на всех квадратах этого диапазона, чтобы увидеть, работает ли он так, как хотелось бы.

Другие вопросы по теме