Почему плавающая точка тензоров все еще колеблется даже после использования set_printoptions(precision=1)

Я следовал руководству, которое показывает, как правильно выполнять операции с тензорами. Они сказали, что обычно операции между тензорами выполняются вручную посредством циклов перебора тензорного массива. Затем они показали лучший способ, выполнив «скалярное произведение» с треугольной единичной матрицей, усреднить тензоры без потери пространственной информации, и показали, что оба метода дают одинаковые результаты с print(torch.allclose(xbow, xbow2)), что дало «True» в качестве возвращаемого значения для показать, что оба метода работают одинаково. Но когда я последовал их пути, мои результаты оказались «ложными», показывая вероятные различия в результатах тензорных операций.

Судя по тому, что я спросил у ближайшего эксперта вокруг меня, поскольку они используют генератор случайных чисел torch.randn() как способ создания тензоров, полученное число может давать различную точность с плавающей запятой, что снижает точность. Хоть они и уверены в этом, они не знают, почему в учебниках не возникает та же проблема, что и я. Итак, используя то, что у меня есть, я использую set_printoptions(precision=1), чтобы ограничить точку точности значения тензора. Но результаты по-прежнему «ложные». Что я сделал не так или на что здесь обратить внимание?

Коды

Способы создания тензоров из учебных пособий

torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)
print(x.shape)
print(x[0]) 

это показывает

torch.Size([4, 8, 2])
tensor([[ 0.2, -0.1],
        [-0.4, -0.9],
        [ 0.6,  0.0],
        [ 1.0,  0.1],
        [ 0.4,  1.2],
        [-1.3, -0.5],
        [ 0.2, -0.2],
        [-0.9,  1.5]])

Чтобы выполнить усреднение с помощью циклов из обучающих программ

# We want x[b, t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)

Я пытаюсь напечатать первый член тензоров

xbow[0]

с результатами

tensor([[ 0.2, -0.1],
        [-0.1, -0.5],
        [ 0.1, -0.3],
        [ 0.4, -0.2],
        [ 0.4,  0.1],
        [ 0.1, -0.0],
        [ 0.1, -0.1],
        [-0.0,  0.1]])

Затем в уроке показан другой метод: torch.tril(torch.ones(T,T))

wei = torch.tril(torch.ones(T,T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # B(T, T, T) @ (B, T, C ) ----> (B, T, C)

Затем я печатаю первый член результата

print(xbow2[0])

и результат

tensor([[ 0.2, -0.1],
        [-0.1, -0.5],
        [ 0.1, -0.3],
        [ 0.4, -0.2],
        [ 0.4,  0.1],
        [ 0.1, -0.0],
        [ 0.1, -0.1],
        [-0.0,  0.1]])

Он выглядит равным для первого члена, но когда я это делаю xbow == xbow2, это показывает, что некоторые тензоры не равны

tensor([[[ True,  True],
         [ True,  True],
         [ True, False],
         [ True,  True],
         [ True, False],
         [False,  True],
         [False, False],
         [ True,  True]],

        [[ True,  True],
         [ True,  True],
         [False,  True],
         [ True,  True],
         [False, False],
         [False, False],
         [False,  True],
         [False,  True]],

        [[ True,  True],
         [ True,  True],
         [False, False],
         [ True,  True],
         [False, False],
         [False,  True],
         [False, False],
...

Что здесь случилось?

Редактирование 2: Я хочу, чтобы здесь было точное плавающее значение в обоих методах расчета. Я знаю, что могу сравнить оба метода здесь Блог Брюса Доусона о плавающей запятой и я могу проверить, что эти значения равны torch.allclose(xbow,xbow2), но есть ли способ сохранить значение точности, уменьшив тем самым некоторую возможную информацию потери в будущем? Например, когда я хочу усреднить изображения и сжать их, есть ли вероятность того, что эти колебания с плавающей запятой станут проблемой? Как этого избежать?

Я тоже не знаю, почему в учебниках не было такой же проблемы. Но я бы вообще не ожидал, что set_printoptions(precision=1) поможет: это устанавливает точность при распечатке чисел, а не их внутреннюю точность. Таким образом, установка точности на 1 таким образом просто затрудняет обнаружение небольших различий, которые почти наверняка происходят внутри. Попробуйте установить точность 15 или 20 и посмотрите, что получится.

Steve Summit 15.07.2024 11:38

@SteveSummit устанавливает точность в set_printoptions(precision=1) или в другом методе, о котором я не знал?

RedSean 15.07.2024 11:45

Я бы посмотрел, что произойдет, если вы используете set_printoptions(precision=15) или set_printoptions(precision=20). Я не ожидаю, что это изменит ответ — я ожидаю, что небольшие различия все равно будут — но, по крайней мере, вы сможете их увидеть.

Steve Summit 15.07.2024 12:14

@SteveSummit ну, теперь у них смехотворно длинные числа с плавающей запятой, что теперь... Я вижу разницу, но единственный вывод, который я могу сделать, - это округлить ее в большую сторону

RedSean 15.07.2024 18:04
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
1
4
55
2
Перейти к ответу Данный вопрос помечен как решенный

Ответы 2

У вас могут возникнуть проблемы с точностью чисел с плавающей запятой. Прочтите этот пост, чтобы узнать, почему это иногда является проблемой. В этом помогает метод « allclose» в pytorch.

Что ж, в учебнике уже показано, как использовать метод allclose, чтобы доказать, что обе тензорные операции создают что-то равное, и в учебнике оно действительно становится чем-то равным, но не в моем. Мой друг также сказал мне, что это точность с плавающей запятой, но разве set_printoptions(precision=1) на самом деле не помогает?

RedSean 15.07.2024 05:30
Ответ принят как подходящий

Вы неправильно понимаете, что делает set_printoptions. set_printoptions переключает количество отображаемых десятичных знаков при печати тензора. Это чисто для наглядности — сам базовый тензор не изменяется.

Для сравнения xbow и xbow2 вам необходимо настроить значения допуска на torch.allclose. Ошибка между двумя тензорами находится за пределами значения по умолчанию.

torch.allclose(xbow, xbow2)
> False

torch.allclose(xbow, xbow2, atol=1e-7)
> True

Если вы хотите сравнить тензоры с более низкой точностью, вы можете использовать torch.round, чтобы ограничить количество десятичных знаков. Вы также можете использовать это, чтобы найти десятичный знак, где ошибка превышает допуск по умолчанию.

torch.allclose(
    torch.round(xbow, decimals=7), 
    torch.round(xbow2, decimals=7)
)
> False 

torch.allclose(
    torch.round(xbow, decimals=6), 
    torch.round(xbow2, decimals=6)
)
> True

Вы также можете вычислять значения с 64-битными числами с плавающей запятой для более высокой точности.

torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C, dtype=torch.float64)

xbow = torch.zeros((B,T,C), dtype=torch.float64)
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)
        
wei = torch.tril(torch.ones(T,T, dtype=torch.float64))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x

torch.allclose(xbow, xbow2)
> True

Другие вопросы по теме