в последнее время я разрабатываю функцию, способную работать с тензорами с размерностью:
torch.Size([51, 265, 23, 23])
где первый размер — это время, второй — шаблон, а последние 2 — размер шаблона.
Каждый отдельный шаблон может иметь максимум 3 состояния: [-1,0,1] и считается «живым». в то же время шаблон «мертв» во всех остальных случаях, когда он не имеет всех трех состояний.
моя цель — отфильтровать все мертвые шаблоны, проверив последнюю строку (последний временной шаг) тензора.
def filter_patterns(tensor_sims):
# Get the indices of the columns that need to be kept
keep_indices = torch.tensor([i for i in
range(tensor_sims.shape[1]) if
tensor_sims[-1,i].unique().numel() == 3])
# Keep only the columns that meet the condition
tensor_sims = tensor_sims[:, keep_indices]
print(f'Number of patterns: {tensor_sims.shape[1]}')
return tensor_sims
К сожалению, я не могу избавиться от цикла for.
Я пытался поиграться с функцией torch.unique() и параметром dim, попробовал уменьшить размеры тензора и сгладить его, но ничего не получилось.
def filter_patterns(tensor_sims):
# Flatten the spatial dimensions of the last timestep
x_ = tensor_sims[-1].flatten(1)
# Create masks to identify -1, 0, and 1 conditions
mask_minus_one = (x_ == -1).any(dim=1)
mask_zero = (x_ == 0).any(dim=1)
mask_one = (x_ == 1).any(dim=1)
# Combine the masks using logical_and
mask =
mask_minus_one.logical_and(mask_zero).logical_and(mask_one)
# Keep only the columns that meet the condition
tensor_sims = tensor_sims[:, mask]
print(f'Number of patterns: {tensor_sims.shape[1]}')
return tensor_sims
новая реализация работает чрезвычайно быстрее.
Я не верю, что вам сойдет с рук torch.unique
, потому что это не сработает для каждого столбца. Вместо перебора dim=1
вы можете построить три тензора маски для проверки значений -1
, 0
и 1
соответственно. Чтобы вычислить результирующую маску столбца, вы можете воспользоваться некоторой базовой логикой при объединении масок:
Учитывая, что вы проверяете только последний временной шаг, сосредоточьтесь на нем и сгладьте пространственные измерения:
x_ = x[-1].flatten(1)
Три маски для определения условий -1
, 0
и 1
можно получить с помощью: x_ == -1
, x_ == 0
и x_ == 1
соответственно. Объедините их с помощью torch.ological_or
mask = (x_ == -1).logical_or(x_ == 0).logical_or(x_ == 1)
Наконец, проверьте, что все элементы расположены True
по строкам:
keep_indices = mask.all(dim=1)
Спасибо! Я постараюсь как можно скорее