Я видел эту строку кода в реализации BiLSTM:
batch_output = batch_output[batch_mask, ...]
Я предполагаю, что это какая-то операция «маскировки», но в Google нашел мало информации о значении ...
. Пожалуйста помоги:).
Оригинальный код:
class BiLSTM(nn.Module):
def __init__(self, vocab_size, tagset, embedding_dim, hidden_dim,
num_layers, bidirectional, dropout, pretrained=None):
# irrelevant code ..........
def forward(self, batch_input, batch_input_lens, batch_mask):
batch_size, padding_length = batch_input.size()
batch_input = self.word_embeds(batch_input) # size: #batch * padding_length * embedding_dim
batch_input = rnn_utils.pack_padded_sequence(
batch_input, batch_input_lens, batch_first=True)
batch_output, self.hidden = self.lstm(batch_input, self.hidden)
self.repackage_hidden(self.hidden)
batch_output, _ = rnn_utils.pad_packed_sequence(batch_output, batch_first=True)
batch_output = batch_output.contiguous().view(batch_size * padding_length, -1)
####### HERE ##########
batch_output = batch_output[batch_mask, ...]
#########################
out = self.hidden2tag(batch_output)
return out
Этот оператор маскирует первое измерение batch_output
индексами, содержащимися в batch_mask
. На практике это означает, что вы выбираете некоторые элементы из пакета.
Вот практический пример:
>>> x = torch.rand(3,1,4,4)
tensor([[[[0.5216, 0.1122, 0.0396, 0.5824],
[0.7685, 0.5583, 0.2817, 0.9678],
[0.8878, 0.9477, 0.2554, 0.8261],
[0.2708, 0.3403, 0.7734, 0.2584]]],
[[[0.5471, 0.5031, 0.3906, 0.7554],
[0.1895, 0.3985, 0.7083, 0.7849],
[0.3128, 0.6733, 0.9223, 0.5345],
[0.2689, 0.9876, 0.1092, 0.7405]]],
[[[0.9834, 0.0276, 0.7114, 0.2872],
[0.3483, 0.2104, 0.1816, 0.5615],
[0.4323, 0.5329, 0.9198, 0.8647],
[0.9054, 0.5763, 0.7939, 0.8388]]]])
С маской и операцией маскирования:
>>> mask = [0, 2]
>>> x[mask]
tensor([[[[0.5216, 0.1122, 0.0396, 0.5824],
[0.7685, 0.5583, 0.2817, 0.9678],
[0.8878, 0.9477, 0.2554, 0.8261],
[0.2708, 0.3403, 0.7734, 0.2584]]],
[[[0.9834, 0.0276, 0.7114, 0.2872],
[0.3483, 0.2104, 0.1816, 0.5615],
[0.4323, 0.5329, 0.9198, 0.8647],
[0.9054, 0.5763, 0.7939, 0.8388]]]])
Где остается только элемент с индексом 0
и 2
.
Примечание. x[mask]
идентично x[mask, ...]
, где многоточие не требуется, поскольку все позиционированные размеры получат все свои индексы, выбранные по умолчанию.
Я предполагаю, что batch_mask
— булев тензор. В этом случае batch_output[batch_mask]
выполняет логическое индексирование, которое выбирает элементы, соответствующие True
в batch_mask
.
...
обычно обозначается как многоточие, а в случае с PyTorch (но также и с другими NumPy-подобными библиотеками) это сокращение позволяет избежать многократного повторения оператора столбца (:
). Например, при наличии tensor
v
, где v.shape
равно (2, 3, 4)
, выражение v[1, :, :]
можно переписать как v[1, ...]
.
Я провел несколько тестов, и использование batch_output[batch_mask, ...]
или batch_output[batch_mask]
, похоже, работает одинаково:
t = torch.arange(24).reshape(2, 3, 4)
# mask.shape == (2, 3)
mask = torch.tensor([[False, True, True], [True, False, False]])
print(torch.all(t[mask] == t[mask, ...])) # returns True