Функция, которая возвращает пакет данных при каждом вызове

Я пытаюсь создать функцию, которая возвращает пакет данных (список) каждый раз, когда я ее вызываю.

Он должен иметь возможность повторяться для любого количества шагов обучения и перезапускаться с самого начала после повторения всего набора данных (после каждой эпохи).

def generate_batch(X, batch_size):
    for i in range(0, len(X), batch_size):
        batch = X[i:i+batch_size]
        yield batch

X = [
[1, 2],
[4, 0],
[5, 1], 
[9, 99],
[9, 1],
[1, 1]]

for step in range(num_training_steps):
    x_batch = generate_batch(X, batch_size=2)
    print(list(x_batch))

когда я печатаю вывод функции, я вижу, что она получает все данные (X), а не пакет:

[[[1, 2], [4, 0]], [[5, 1], [9, 99]], [[9, 1], [1, 1]]]
[[[1, 2], [4, 0]], [[5, 1], [9, 99]], [[9, 1], [1, 1]]]
[[[1, 2], [4, 0]], [[5, 1], [9, 99]], [[9, 1], [1, 1]]]

В чем проблема? это правильный способ использования yield?

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
0
0
34
2
Перейти к ответу Данный вопрос помечен как решенный

Ответы 2

Ответ принят как подходящий

Прежде всего, если вы хотите перезапуститься с самого начала после того, как данные закончились, вам нужно будет обернуть тело функции генератора в бесконечный цикл, например:

def generate_batch(X, batch_size):
    while 1:
        for i in range(0, len(X), batch_size):
            batch = X[i:i+batch_size]
            yield batch

Затем, когда вы делаете:

x_batch = generate_batch(X, batch_size=2)

Теперь x_batch — это генератор. Вам нужно будет перебрать его или вызвать next(), чтобы получить данные по одному пакету за раз. Если вы просто сделаете list(x_batch), он будет повторять и собирать все пакеты для вас в список. Это не то, что вы хотите.

Что вы хотите:

gen = generate_batch(X, batch_size=2)

for step in range(num_training_steps):
    x_batch = next(gen)
    print(x_batch)

Или, альтернативно, если вам нужна вызываемая функция:

gen = generate_batch(X, batch_size=2)
gen = gen.__next__

for step in range(num_training_steps):
    x_batch = gen()
    print(x_batch)

Кроме того, вы, вероятно, захотите дать функции другое имя, например, например. create_batch_generator().

Спасибо @Марко! У меня тут вопрос: зачем нужен цикл while?

Minions 04.05.2022 17:48

@Minions, потому что без него вы бы просто перебирали набор данных однажды, а затем генератор останавливался, но, как вы говорите, вы хотите, чтобы он снова и снова запускался бесконечно.

Marco Bonelli 04.05.2022 18:20

Ну, вы можете использовать для этого itertools.cycle. Это будет продолжать повторять список, как это делает tf.data.RepeatDataset.

В вашем исходном коде есть небольшая настройка

from itertools import cycle

def generate_batch(X, batch_size):
    dataset = cycle(X)
    
    while True:
        batch = list(zip(range(batch_size), dataset))
        yield list(map(lambda x: x[1], batch))

Вот и все. Теперь вы можете подключить его к своему коду.

X = [
[1, 2],
[4, 0],
[5, 1], 
[9, 99],
[9, 1],
[1, 1]]


for step in range(20):
    for batch in generate_batch(X, 2):
        print(batch)

Это будет выглядеть следующим образом

[[1, 2], [4, 0]]
[[9, 99], [9, 1]]
[[1, 2], [4, 0]]
[[9, 99], [9, 1]]
[[1, 2], [4, 0]]
[[9, 99], [9, 1]]
... and so on

Другие вопросы по теме