Сейчас я на питоне, поэтому пытаюсь понять эту строку из учебник по pytorch.
x_train, y_train, x_valid, y_valid = map(
torch.tensor, (x_train, y_train, x_valid, y_valid)
)
Я понимаю, как карта работает с одним элементом
def sqr(a):
return a * a
a = [1, 2, 3, 4]
a = map(sqr, a)
print(list(a))
И здесь мне нужно использовать list(a)
, чтобы преобразовать объект карты обратно в список.
Но чего я не понимаю, так это того, как это работает с несколькими переменными?
Если я попытаюсь сделать это
def sqr(a):
return a * a
a = [1, 2, 3, 4]
b = [1, 3, 5, 7]
a, b = map(sqr, (a, b))
print(list(a))
print(list(b))
Я получаю сообщение об ошибке: TypeError: can't multiply sequence by non-int of type 'list'
Пожалуйста, проясните это для меня Спасибо
вам нужно a, b = map(sqr, x) for x in (a, b)
например
ах извините, я прочитал только конец вопроса. это работает в другом случае, потому что функция принимает списки, а не отдельные значения.
map
работает с одним так же, как и со списком/кортежем списков, он извлекает элемент заданного ввода независимо от того, что это такое.
Причина, по которой torch.tensor
работает, заключается в том, что принимает представляет собой список в качестве входных данных.
Если вы развернете следующую строку, которую вы предоставили:
x_train, y_train, x_valid, y_valid = map(
torch.tensor, (x_train, y_train, x_valid, y_valid)
)
это то же самое, что делать:
x_train, y_train, x_valid, y_valid = [torch.tensor(x_train), torch.tensor(y_train), torch.tensor(x_valid), torch.tensor(y_valid)]
С другой стороны, ваша функция sqr
не принимает списки. Он ожидает, что скалярный тип будет квадратным, чего нельзя сказать о ваших a
и b
, это списки.
Однако, если вы измените sqr
на:
def sqr(a):
return [s * s for s in a]
a = [1, 2, 3, 4]
b = [1, 3, 5, 7]
a, b = map(sqr, (a, b))
или, как предложил @Jean, a, b = map(sqr, x) for x in (a, b)
Это будет работать.
ну, вы применяете
map
к кортежу списков.