Я пытаюсь создать функцию, которая возвращает пакет данных (список) каждый раз, когда я ее вызываю.
Он должен иметь возможность повторяться для любого количества шагов обучения и перезапускаться с самого начала после повторения всего набора данных (после каждой эпохи).
def generate_batch(X, batch_size):
for i in range(0, len(X), batch_size):
batch = X[i:i+batch_size]
yield batch
X = [
[1, 2],
[4, 0],
[5, 1],
[9, 99],
[9, 1],
[1, 1]]
for step in range(num_training_steps):
x_batch = generate_batch(X, batch_size=2)
print(list(x_batch))
когда я печатаю вывод функции, я вижу, что она получает все данные (X), а не пакет:
[[[1, 2], [4, 0]], [[5, 1], [9, 99]], [[9, 1], [1, 1]]]
[[[1, 2], [4, 0]], [[5, 1], [9, 99]], [[9, 1], [1, 1]]]
[[[1, 2], [4, 0]], [[5, 1], [9, 99]], [[9, 1], [1, 1]]]
В чем проблема? это правильный способ использования yield
?
Прежде всего, если вы хотите перезапуститься с самого начала после того, как данные закончились, вам нужно будет обернуть тело функции генератора в бесконечный цикл, например:
def generate_batch(X, batch_size):
while 1:
for i in range(0, len(X), batch_size):
batch = X[i:i+batch_size]
yield batch
Затем, когда вы делаете:
x_batch = generate_batch(X, batch_size=2)
Теперь x_batch
— это генератор. Вам нужно будет перебрать его или вызвать next()
, чтобы получить данные по одному пакету за раз. Если вы просто сделаете list(x_batch)
, он будет повторять и собирать все пакеты для вас в список. Это не то, что вы хотите.
Что вы хотите:
gen = generate_batch(X, batch_size=2)
for step in range(num_training_steps):
x_batch = next(gen)
print(x_batch)
Или, альтернативно, если вам нужна вызываемая функция:
gen = generate_batch(X, batch_size=2)
gen = gen.__next__
for step in range(num_training_steps):
x_batch = gen()
print(x_batch)
Кроме того, вы, вероятно, захотите дать функции другое имя, например, например. create_batch_generator()
.
@Minions, потому что без него вы бы просто перебирали набор данных однажды, а затем генератор останавливался, но, как вы говорите, вы хотите, чтобы он снова и снова запускался бесконечно.
Ну, вы можете использовать для этого itertools.cycle
. Это будет продолжать повторять список, как это делает tf.data.RepeatDataset.
В вашем исходном коде есть небольшая настройка
from itertools import cycle
def generate_batch(X, batch_size):
dataset = cycle(X)
while True:
batch = list(zip(range(batch_size), dataset))
yield list(map(lambda x: x[1], batch))
Вот и все. Теперь вы можете подключить его к своему коду.
X = [
[1, 2],
[4, 0],
[5, 1],
[9, 99],
[9, 1],
[1, 1]]
for step in range(20):
for batch in generate_batch(X, 2):
print(batch)
Это будет выглядеть следующим образом
[[1, 2], [4, 0]]
[[9, 99], [9, 1]]
[[1, 2], [4, 0]]
[[9, 99], [9, 1]]
[[1, 2], [4, 0]]
[[9, 99], [9, 1]]
... and so on
Спасибо @Марко! У меня тут вопрос: зачем нужен цикл while?