Что делает Tensor[batch_mask, ...]?

Я видел эту строку кода в реализации BiLSTM:

batch_output = batch_output[batch_mask, ...]

Я предполагаю, что это какая-то операция «маскировки», но в Google нашел мало информации о значении .... Пожалуйста помоги:).

Оригинальный код:

class BiLSTM(nn.Module):
    def __init__(self, vocab_size, tagset, embedding_dim, hidden_dim,
                 num_layers, bidirectional, dropout, pretrained=None):
         # irrelevant code ..........

    def forward(self, batch_input, batch_input_lens, batch_mask):
        batch_size, padding_length = batch_input.size()
        batch_input = self.word_embeds(batch_input)  # size: #batch * padding_length * embedding_dim
        batch_input = rnn_utils.pack_padded_sequence(
            batch_input, batch_input_lens, batch_first=True)
        batch_output, self.hidden = self.lstm(batch_input, self.hidden)
        self.repackage_hidden(self.hidden)
        batch_output, _ = rnn_utils.pad_packed_sequence(batch_output, batch_first=True)
        batch_output = batch_output.contiguous().view(batch_size * padding_length, -1)
        
        #######  HERE  ##########
        batch_output = batch_output[batch_mask, ...]
        #########################

        out = self.hidden2tag(batch_output)
        return out

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
0
0
17
2
Перейти к ответу Данный вопрос помечен как решенный

Ответы 2

Этот оператор маскирует первое измерение batch_output индексами, содержащимися в batch_mask. На практике это означает, что вы выбираете некоторые элементы из пакета.

Вот практический пример:

>>> x = torch.rand(3,1,4,4)
tensor([[[[0.5216, 0.1122, 0.0396, 0.5824],
          [0.7685, 0.5583, 0.2817, 0.9678],
          [0.8878, 0.9477, 0.2554, 0.8261],
          [0.2708, 0.3403, 0.7734, 0.2584]]],


        [[[0.5471, 0.5031, 0.3906, 0.7554],
          [0.1895, 0.3985, 0.7083, 0.7849],
          [0.3128, 0.6733, 0.9223, 0.5345],
          [0.2689, 0.9876, 0.1092, 0.7405]]],


        [[[0.9834, 0.0276, 0.7114, 0.2872],
          [0.3483, 0.2104, 0.1816, 0.5615],
          [0.4323, 0.5329, 0.9198, 0.8647],
          [0.9054, 0.5763, 0.7939, 0.8388]]]])

С маской и операцией маскирования:

>>> mask = [0, 2]
>>> x[mask]
tensor([[[[0.5216, 0.1122, 0.0396, 0.5824],
          [0.7685, 0.5583, 0.2817, 0.9678],
          [0.8878, 0.9477, 0.2554, 0.8261],
          [0.2708, 0.3403, 0.7734, 0.2584]]],


        [[[0.9834, 0.0276, 0.7114, 0.2872],
          [0.3483, 0.2104, 0.1816, 0.5615],
          [0.4323, 0.5329, 0.9198, 0.8647],
          [0.9054, 0.5763, 0.7939, 0.8388]]]])

Где остается только элемент с индексом 0 и 2.

Примечание. x[mask] идентично x[mask, ...], где многоточие не требуется, поскольку все позиционированные размеры получат все свои индексы, выбранные по умолчанию.

Ответ принят как подходящий

Я предполагаю, что batch_mask — булев тензор. В этом случае batch_output[batch_mask] выполняет логическое индексирование, которое выбирает элементы, соответствующие True в batch_mask.

... обычно обозначается как многоточие, а в случае с PyTorch (но также и с другими NumPy-подобными библиотеками) это сокращение позволяет избежать многократного повторения оператора столбца (:). Например, при наличии tensorv, где v.shape равно (2, 3, 4), выражение v[1, :, :] можно переписать как v[1, ...].

Я провел несколько тестов, и использование batch_output[batch_mask, ...] или batch_output[batch_mask], похоже, работает одинаково:

t = torch.arange(24).reshape(2, 3, 4)

# mask.shape == (2, 3)
mask = torch.tensor([[False, True, True], [True, False, False]])

print(torch.all(t[mask] == t[mask, ...]))  # returns True

Другие вопросы по теме