Понимание PyTorch einsum

Я знаком с тем, как работает einsum в NumPy. Аналогичная функциональность также предлагается PyTorch: факел.einsum(). В чем сходство и различие с точки зрения функциональности или производительности? Информация, доступная в документации PyTorch, довольно скудна и не дает никаких сведений об этом.

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
24
0
18 049
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Поскольку описание einsum скудно в документации torch, я решил написать этот пост, чтобы задокументировать, сравнить и сопоставить, как ведет себя torch.einsum() по сравнению с numpy.einsum().

Отличия:

  • NumPy допускает как строчные, так и заглавные буквы [a-zA-Z] для «строка нижнего индекса», тогда как PyTorch допускает только строчные буквы [a-z].

  • NumPy принимает nd-массивы, простые списки Python (или кортежи), список списков (или кортеж кортежей, список кортежей, кортеж списков) или даже тензоры PyTorch как операнды (т.е. входы). Это связано с тем, что операнды должен быть только array_like, а не строго массивами NumPy nd. Напротив, PyTorch ожидает, что операнды (т.е. входные данные) будут строго тензорами PyTorch. Он выдаст TypeError, если вы передадите простые списки/кортежи Python (или их комбинации) или NumPy nd-массивы.

  • NumPy поддерживает множество аргументов ключевых слов (например, optimize) в дополнение к nd-arrays, в то время как PyTorch пока не предлагает такой гибкости.

Вот реализации некоторых примеров как в PyTorch, так и в NumPy:

# input tensors to work with

In [16]: vec
Out[16]: tensor([0, 1, 2, 3])

In [17]: aten
Out[17]: 
tensor([[11, 12, 13, 14],
        [21, 22, 23, 24],
        [31, 32, 33, 34],
        [41, 42, 43, 44]])

In [18]: bten
Out[18]: 
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3],
        [4, 4, 4, 4]])

1) Умножение матриц
ПиТорч: torch.matmul(aten, bten) ; aten.mm(bten)
НумПи: np.einsum("ij, jk -> ik", arr1, arr2)

In [19]: torch.einsum('ij, jk -> ik', aten, bten)
Out[19]: 
tensor([[130, 130, 130, 130],
        [230, 230, 230, 230],
        [330, 330, 330, 330],
        [430, 430, 430, 430]])

2) Извлечь элементы по главной диагонали
ПиТорч: torch.diag(aten)
НумПи: np.einsum("ii -> i", arr)

In [28]: torch.einsum('ii -> i', aten)
Out[28]: tensor([11, 22, 33, 44])

3) Произведение Адамара (т.е. поэлементное произведение двух тензоров)
ПиТорч: aten * bten
НумПи: np.einsum("ij, ij -> ij", arr1, arr2)

In [34]: torch.einsum('ij, ij -> ij', aten, bten)
Out[34]: 
tensor([[ 11,  12,  13,  14],
        [ 42,  44,  46,  48],
        [ 93,  96,  99, 102],
        [164, 168, 172, 176]])

4) Поэлементное возведение в квадрат
ПиТорч: aten ** 2
НумПи: np.einsum("ij, ij -> ij", arr, arr)

In [37]: torch.einsum('ij, ij -> ij', aten, aten)
Out[37]: 
tensor([[ 121,  144,  169,  196],
        [ 441,  484,  529,  576],
        [ 961, 1024, 1089, 1156],
        [1681, 1764, 1849, 1936]])

Общий: Поэлементная nth мощность может быть реализована путем повторения строки нижнего индекса и тензора n раз. Например, поэлементное вычисление 4-й степени тензора может быть выполнено с использованием:

# NumPy: np.einsum('ij, ij, ij, ij -> ij', arr, arr, arr, arr)
In [38]: torch.einsum('ij, ij, ij, ij -> ij', aten, aten, aten, aten)
Out[38]: 
tensor([[  14641,   20736,   28561,   38416],
        [ 194481,  234256,  279841,  331776],
        [ 923521, 1048576, 1185921, 1336336],
        [2825761, 3111696, 3418801, 3748096]])

5) След (т.е. сумма элементов главной диагонали)
ПиТорч: torch.trace(aten)
NumPy: np.einsum("ii -> ", arr)

In [44]: torch.einsum('ii -> ', aten)
Out[44]: tensor(110)

6) Транспонирование матрицы
ПиТорч: torch.transpose(aten, 1, 0)
NumPy: np.einsum("ij -> ji", arr)

In [58]: torch.einsum('ij -> ji', aten)
Out[58]: 
tensor([[11, 21, 31, 41],
        [12, 22, 32, 42],
        [13, 23, 33, 43],
        [14, 24, 34, 44]])

7) Внешний продукт (векторов)
ПиТорч: torch.ger(vec, vec)
NumPy: np.einsum("i, j -> ij", vec, vec)

In [73]: torch.einsum('i, j -> ij', vec, vec)
Out[73]: 
tensor([[0, 0, 0, 0],
        [0, 1, 2, 3],
        [0, 2, 4, 6],
        [0, 3, 6, 9]])

8) Внутренний продукт (векторов) ПиТорч: torch.dot(vec1, vec2)
NumPy: np.einsum("i, i -> ", vec1, vec2)

In [76]: torch.einsum('i, i -> ', vec, vec)
Out[76]: tensor(14)

9) Сумма по оси 0
ПиТорч: torch.sum(aten, 0)
NumPy: np.einsum("ij -> j", arr)

In [85]: torch.einsum('ij -> j', aten)
Out[85]: tensor([104, 108, 112, 116])

10) Сумма по оси 1
ПиТорч: torch.sum(aten, 1)
NumPy: np.einsum("ij -> i", arr)

In [86]: torch.einsum('ij -> i', aten)
Out[86]: tensor([ 50,  90, 130, 170])

11) Пакетное матричное умножение
ПиТорч: torch.bmm(batch_tensor_1, batch_tensor_2)
НумПи: np.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2)

# input batch tensors to work with
In [13]: batch_tensor_1 = torch.arange(2 * 4 * 3).reshape(2, 4, 3)
In [14]: batch_tensor_2 = torch.arange(2 * 3 * 4).reshape(2, 3, 4) 

In [15]: torch.bmm(batch_tensor_1, batch_tensor_2)  
Out[15]: 
tensor([[[  20,   23,   26,   29],
         [  56,   68,   80,   92],
         [  92,  113,  134,  155],
         [ 128,  158,  188,  218]],

        [[ 632,  671,  710,  749],
         [ 776,  824,  872,  920],
         [ 920,  977, 1034, 1091],
         [1064, 1130, 1196, 1262]]])

# sanity check with the shapes
In [16]: torch.bmm(batch_tensor_1, batch_tensor_2).shape 
Out[16]: torch.Size([2, 4, 4])

# batch matrix multiply using einsum
In [17]: torch.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2)
Out[17]: 
tensor([[[  20,   23,   26,   29],
         [  56,   68,   80,   92],
         [  92,  113,  134,  155],
         [ 128,  158,  188,  218]],

        [[ 632,  671,  710,  749],
         [ 776,  824,  872,  920],
         [ 920,  977, 1034, 1091],
         [1064, 1130, 1196, 1262]]])

# sanity check with the shapes
In [18]: torch.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2).shape

12) Сумма по оси 2
ПиТорч: torch.sum(batch_ten, 2)
NumPy: np.einsum("ijk -> ij", arr3D)

In [99]: torch.einsum("ijk -> ij", batch_ten)
Out[99]: 
tensor([[ 50,  90, 130, 170],
        [  4,   8,  12,  16]])

13) Суммируйте все элементы в тензоре nD
ПиТорч: torch.sum(batch_ten)
NumPy: np.einsum("ijk -> ", arr3D)

In [101]: torch.einsum("ijk -> ", batch_ten)
Out[101]: tensor(480)

14) Сумма по нескольким осям (т.е. маргинализация)
ПиТорч: torch.sum(arr, dim=(dim0, dim1, dim2, dim3, dim4, dim6, dim7))
NumPy: np.einsum("ijklmnop -> n", nDarr)

# 8D tensor
In [103]: nDten = torch.randn((3,5,4,6,8,2,7,9))
In [104]: nDten.shape
Out[104]: torch.Size([3, 5, 4, 6, 8, 2, 7, 9])

# marginalize out dimension 5 (i.e. "n" here)
In [111]: esum = torch.einsum("ijklmnop -> n", nDten)
In [112]: esum
Out[112]: tensor([  98.6921, -206.0575])

# marginalize out axis 5 (i.e. sum over rest of the axes)
In [113]: tsum = torch.sum(nDten, dim=(0, 1, 2, 3, 4, 6, 7))

In [115]: torch.allclose(tsum, esum)
Out[115]: True

15) Double Dot Products / Внутреннее произведение Фробениуса (аналогично: torch.sum(adamard-product) cf. 3)
ПиТорч: torch.sum(aten * bten)
НумПи: np.einsum("ij, ij -> ", arr1, arr2)

In [120]: torch.einsum("ij, ij -> ", aten, bten)
Out[120]: tensor(1300)

@FredGuth, да! Теперь я обновил наглядный пример, просто чтобы не путать с фигурами. Это точно так же, как матричное умножение, но пакетное измерение просто зависает для поездки.

kmario23 18.09.2019 21:56

Отличие от numpy docs: «Если метка появляется только один раз, она не суммируется», т. Е. «np.einsum ('i', a) создает представление a без изменений», но «torch.einsum ('i', а)" недействителен.

dashesy 27.09.2019 03:01

@dashesy Я думаю, правильно сказать, что и numpy, и torch ведут себя одинаково, когда мы не выполняем никаких операций с входным массивом/тензором, соответственно. Например: с t = torch.tensor([1, 2, 3]) в качестве входных данных результат torch.einsum('...', t) вернет входной тензор. Аналогично, в NumPy с tn = t.numpy() в качестве входных данных результат np.einsum('...', tn) также вернет тот же входной массив, что и Посмотреть. Так что я не вижу здесь никакой разницы. Я пропустил что-то еще? Не могли бы вы привести пример случая, который вы имели в виду? (то есть тот, который выдал ошибку) :)

kmario23 27.09.2019 03:41

кажется, что numpy документы неверен np.einsum('i', a) упомянутое там недействительно даже в numpy.

dashesy 27.09.2019 19:39

Другие вопросы по теме