В задаче литкода о том, является ли целое число суммой идеальных квадратов, использование чисел с плавающей запятой вместо целых привело к большему ускорению (проблема с идеальными квадратами). Можно ли с уверенностью предположить, что если целочисленные значения гарантированно будут меньше 10 000 и больше или равны 0, мы можем вместо этого использовать числа с плавающей запятой?
Пример сравнения:
if (n == i*i + j*j * 2)
result3++;
if (n == i*i + k*k)
result2++;
оба int
и float
прошли все тесты (n,i,j,k, все с плавающей запятой или все int), но я все еще не уверен, есть ли какая-либо разница между процессорами (не уверен, что leetcode всегда использует одно и то же) или компилятором или чем-то еще (как время?).
Ссылка на задачу и код: https://leetcode.com/problems/perfect-squares/description/
Код:
#include<iostream>
#include<math.h>
class Solution {
public:
static constexpr int simd =8;
using FAST_TYPE = short;
using MASK_TYPE = short;
const int numSquares(const int n) const noexcept {
if (n==2 || n==8)
return 2;
if (n==3 || n==6 || n==11)
return 3;
if ((int)std::sqrt(n)*(int)std::sqrt(n) == n)
return 1;
FAST_TYPE found2 = 0;
FAST_TYPE found3 = 0;
FAST_TYPE found32 = 0;
FAST_TYPE found33 = 0;
FAST_TYPE found34 = 0;
alignas(64)
FAST_TYPE zeroSimd[simd];
alignas(64)
FAST_TYPE oneSimd[simd];
alignas(64)
FAST_TYPE found3Simd[simd];
alignas(64)
FAST_TYPE found3Simd2[simd];
alignas(64)
FAST_TYPE found3Simd3[simd];
alignas(64)
FAST_TYPE found3Simd4[simd];
alignas(64)
FAST_TYPE mSimd[simd];
alignas(64)
FAST_TYPE kSimd[simd];
alignas(64)
FAST_TYPE k0Simd[simd];
alignas(64)
FAST_TYPE nSimd[simd];
alignas(64)
FAST_TYPE twoSimd[simd];
alignas(64)
FAST_TYPE threeSimd[simd];
alignas(64)
FAST_TYPE iSimd[simd];
alignas(64)
FAST_TYPE jSimd[simd];
alignas(64)
FAST_TYPE ijSimd[simd];
alignas(64)
FAST_TYPE j2Simd[simd];
alignas(64)
FAST_TYPE i2Simd[simd];
alignas(64)
MASK_TYPE mask1Simd[simd];
alignas(64)
MASK_TYPE mask2Simd[simd];
alignas(64)
MASK_TYPE mask3Simd[simd];
alignas(64)
MASK_TYPE mask4Simd[simd];
alignas(64)
FAST_TYPE sum1Simd[simd];
alignas(64)
FAST_TYPE sum2Simd[simd];
alignas(64)
FAST_TYPE sum3Simd[simd];
alignas(64)
FAST_TYPE mulSimd[simd];
for(int i=0;i<simd;i++)
{
zeroSimd[i]=0;
oneSimd[i]=1;
found3Simd[i]=0;
found3Simd2[i]=0;
found3Simd3[i]=0;
found3Simd4[i]=0;
mSimd[i]=i;
nSimd[i]=n;
twoSimd[i]=2;
threeSimd[i]=2;
}
for(int i=1+std::sqrt(n);i>=1;i--)
{
const FAST_TYPE i2 = i*i;
const FAST_TYPE i22 = 2*i*i;
const FAST_TYPE i23 = 3*i*i;
#pragma GCC ivdep
for(int m=0;m<simd;m++)
iSimd[m]=i2;
#pragma GCC ivdep
for(int m=0;m<simd;m++)
i2Simd[m]=i22;
found2 += (i22 == n);
found3+=(i23 == n);
for(int j=i-1;j>=1;j--)
{
const FAST_TYPE j2 = j*j;
const FAST_TYPE j22 = 2*j*j;
const FAST_TYPE j23 = 3*j*j;
#pragma GCC ivdep
for(int m=0;m<simd;m++)
jSimd[m]=j2;
#pragma GCC ivdep
for(int m=0;m<simd;m++)
j2Simd[m]=j22;
#pragma GCC ivdep
for(int m=0;m<simd;m++)
ijSimd[m]=i2+j2;
found2+=(i2 + j2 == n);
found3+=(i2 + j22 == n)+(i22 + j2 == n)+(j23 == n);
const int k32 = j-1 - ((j-1)%simd);
#pragma GCC unroll 2
for(int k0=1;k0<=k32;k0+=simd)
{
#pragma GCC ivdep
for(int m=0;m<simd;m++)
k0Simd[m]=k0;
#pragma GCC ivdep
for(int m=0;m<simd;m++)
kSimd[m] = k0Simd[m]+mSimd[m];
#pragma GCC ivdep
for(int m=0;m<simd;m++)
kSimd[m] = kSimd[m]*kSimd[m];
#pragma GCC ivdep
for(int m=0;m<simd;m++)
sum1Simd[m]=ijSimd[m] + kSimd[m];
#pragma GCC ivdep
for(int m=0;m<simd;m++)
mask1Simd[m]=sum1Simd[m] == nSimd[m];
#pragma GCC ivdep
for(int m=0;m<simd;m++)
found3Simd[m]=mask1Simd[m]?oneSimd[m]:found3Simd[m];
#pragma GCC ivdep
for(int m=0;m<simd;m++)
sum2Simd[m]=i2Simd[m] + kSimd[m];
#pragma GCC ivdep
for(int m=0;m<simd;m++)
mask2Simd[m]=(sum2Simd[m]==nSimd[m]);
#pragma GCC ivdep
for(int m=0;m<simd;m++)
found3Simd2[m]=mask2Simd[m]?oneSimd[m]:found3Simd2[m];
#pragma GCC ivdep
for(int m=0;m<simd;m++)
sum3Simd[m]=j2Simd[m] + kSimd[m];
#pragma GCC ivdep
for(int m=0;m<simd;m++)
mask3Simd[m]=(sum3Simd[m]==nSimd[m]);
#pragma GCC ivdep
for(int m=0;m<simd;m++)
found3Simd3[m]=mask3Simd[m]?oneSimd[m]:found3Simd3[m];
#pragma GCC ivdep
for(int m=0;m<simd;m++)
mulSimd[m]=threeSimd[m]*kSimd[m];
#pragma GCC ivdep
for(int m=0;m<simd;m++)
mask4Simd[m]=(mulSimd[m]==nSimd[m]);
#pragma GCC ivdep
for(int m=0;m<simd;m++)
found3Simd4[m]=mask4Simd[m]?oneSimd[m]:found3Simd4[m];
}
for(int k=k32;k<=j-1;k++)
{
const FAST_TYPE k2 = k*k;
found3+=(i2 + j2 + k2 ==n);
found32+=(i22 + k2 ==n);
found33+=(j22 + k2 ==n);
found34+=(3*k2 ==n);
}
}
}
for(int i=0;i<simd;i++)
{
found3+=found3Simd[i];
found32+=found3Simd2[i];
found33+=found3Simd3[i];
found34+=found3Simd4[i];
}
found3 += found32 + found33 + found34;
if (found2)
return 2;
if (found3)
return 3;
return 4;
}
};
int main()
{
Solution s;
for(int i=10;i<20;i++)
{
std::cout<<i<<" is equal to sum of "<<s.numSquares(i)<< " perfect squares"<<std::endl;
}
}
выход:
10 is equal to sum of 2 perfect squares
11 is equal to sum of 3 perfect squares
12 is equal to sum of 3 perfect squares
13 is equal to sum of 2 perfect squares
14 is equal to sum of 3 perfect squares
15 is equal to sum of 4 perfect squares
16 is equal to sum of 1 perfect squares
17 is equal to sum of 2 perfect squares
18 is equal to sum of 2 perfect squares
19 is equal to sum of 3 perfect squares
также все, что связано с SIMD, похоже, не связано
Я думаю, что это тесно связано: stackoverflow.com/questions/17333/… . TL;DR, вы не можете явно сравнивать числа с плавающей запятой в контексте математических задач, а диапазон [0, 1e5] не очень полезен, даже простой 10*3 == 90/3
не обязательно может быть верным (здесь, в SO, есть лучшие примеры, попробую найти).
Обновлен код с рабочим примером и прямой ссылкой на проблему. Я не знаю, был ли SIMD ответственен за работу поплавка или нет.
Возможные причины, по которым float
работает быстрее в вашем случае: 1. Функция sqrt(int)
отсутствует, поэтому компилятор преобразуется в double
для вызова функции, а затем обратно в short
. Если всё float
, так и остаётся. 2. short
— не очень быстрый тип данных, если только векторизация не работает. Компилятор, скорее всего, расширит знак до int
. 3. Большинство процессоров имеют только один блок целочисленного умножения (если он не векторизован), но два умножителя с плавающей запятой.
Было бы полезно попробовать это с двойной точностью, поскольку сегодня на многих процессорах код double64 фактически немного быстрее, чем float32. 10^8 также больше, чем мантисса float32 из 23 бит = 2^23 = 8388608 ~ 8.4.10^6, которую действительно может содержать. Вы можете безопасно сравнивать целые значения, вписывающиеся в мантиссу, как числа с плавающей запятой, при условии, что вы полностью избегаете деления и никогда не переполняете доступную длину мантиссы. +,- и * безопасны при условии, что вы не переполните доступную длину мантиссы. Я подозреваю, что вы также обнаружите, что он будет работать немного быстрее, если вы сделаете FAST_TYPE = int.
Примечание: возможно, впервые я вижу вопрос о проблеме Leetcode от старшего программиста C++, а не от новичка, ошибочно полагающего, что он сможет выучить язык таким образом.
@pptaszni Невозможно, чтобы 10*3 == 90/3
могло вычислить значение false
, с плавающей запятой или нет. (90 * (1./3)
, это другой вопрос.)
Вопреки распространенному мнению, арифметика с плавающей запятой точна в отношении значений, которые она может представлять. Целые числа от 0 до 2²⁴ = 16777216 точно представляются в формате с плавающей запятой одинарной точности.
@MartinBrown Существует неявный 24-й бит мантиссы (достигаемый посредством нормализации). Таким образом, число с плавающей запятой может фактически представлять целые числа до 2**24. Вот быстрое доказательство: godbolt.org/z/Y5jM1bjzn
Можно ли с уверенностью предположить, что 32-битные числа с плавающей запятой можно напрямую сравнивать друг с другом, если значение соответствует мантиссе?
Один из способов определить, небезопасно ли это, — сравнить результаты с прямым неэффективным эталонным кодом и поискать различия.
Если сравнивать всех, то хотя бы на одной машине и компиляторе можно сравнивать безопасно.
В коде ОП, скорее всего, возникнут проблемы, если int
16-битный.
Ниже приведен тестовый код для сравнения.
/*
* Return true if `n` is the sum of 2 perfect squares
*/
bool IsSumOf2Squares(int n, int *a_ptr, int *b_ptr) {
if (n < 0) {
return false;
}
int b_target;
for (int a = 0; (b_target = n - a * a) >= 0; a++) {
int diff;
for (int b = a; (diff = b_target - b * b) >= 0; b++) {
if (diff == 0) {
*a_ptr = a;
*b_ptr = b;
return true;
}
}
}
return false;
}
int main(void) {
clock_t c0, c1;
c0 = clock();
int count = 0;
int n = 10000;
for (int i = -42; i <= n; i++) {
int a, b;
if (IsSumOf2Squares(i, &a, &b)) {
if (count % 300 == 0) {
printf("%10d: %5d %5d\n", i, a, b);
fflush(stdout);
}
count++;
}
}
c1 = clock();
printf("Count: %d, Time:%gs\n", count, (double) (c1 - c0) / CLOCKS_PER_SEC);
return 0;
}
Выход
0: 0 0
901: 1 30
1933: 13 42
3001: 20 51
4105: 3 64
5213: 37 62
6354: 27 75
7489: 33 80
8656: 40 84
9808: 68 72
Count: 2750, Time:0.015s
Можно ли с уверенностью предположить, что 32-битные числа с плавающей запятой можно напрямую сравнивать друг с другом, если значение соответствует мантиссе?
Это неправильный вопрос; Числа с плавающей запятой всегда можно сравнивать друг с другом, и сравнение покажет, что они равны тогда и только тогда, когда они равны.
Правильный вопрос заключается в том, позволит ли использование арифметики с плавающей запятой получить желаемые результаты.
Согласно IEEE-754 и другим спецификациям чисел с плавающей запятой, каждое число с плавающей запятой представляет одно действительное число. Это представление является точным, без ошибок. Именно операции с плавающей запятой, а не числа, приближают действительную арифметику. При выполнении операции с плавающей запятой результатом с плавающей запятой является результат арифметических операций с действительными числами, округленный до ближайшего значения, представимого в формате с плавающей запятой (с использованием различных правил округления). Если результат вещественной арифметики представим, то это и есть результат; округления не будет.
Стандарт C++ гарантирует, что float
может представлять числа с достаточным разрешением, чтобы различать шестизначные десятичные числа во всем диапазоне. Таким образом, он может представлять все целые числа по крайней мере до 1 000 000, что превышает запрошенный вами диапазон 10 000.
Операции, которые вы выполняете с целочисленными значениями, такими как 3*j*j
, будут точны в этом диапазоне.
Эта операция не всегда будет точной: (int)std::sqrt(n)*(int)std::sqrt(n)
.
Когда n
не является квадратом, sqrt
, конечно, не может дать точный результат. Однако для вашей цели это подходит, так как возвращаемое значение будет усечено, а вычисленное произведение не будет равно n
, поэтому сравнение (int)std::sqrt(n)*(int)std::sqrt(n) == n
будет иметь значение false, как и хотелось.
Проблема заключается в том, что некоторые реализации sqrt
могут не возвращать целое число, даже если n
является целочисленным квадратом.
Если у вас хорошая реализация sqrt
, она вернет точный результат, а сравнение продукта с n
будет иметь значение true, как и пожелано. Если у вас плохая реализация sqrt
и она возвращает значение немного меньше правильного квадратного корня из n
, тогда результат сравнения будет ложным, что вызовет проблему в вашей программе.
Поскольку вас интересует только диапазон до 10 000, вы можете легко проверить sqrt
на всех квадратах этого диапазона, чтобы увидеть, работает ли он так, как хотелось бы.
вы всегда можете сравнить числа с плавающей запятой напрямую через
==
. Проблемы возникают только в том случае, если вы предполагаете, что арифметика с плавающей запятой точна. Ваша ссылка на «проблему и код» не содержит проблем и кода, а код, который вы здесь разместили, является неполным. В качестве примера вопроса это не очень полезно.