Лучший способ использовать итератор Pyothon в качестве набора данных в PyTorch

PyTorch DataLoader превращает наборы данных в итерируемые объекты. У меня уже есть генератор, который дает образцы данных, которые я хочу использовать для обучения и тестирования. Причина, по которой я использую генератор, заключается в том, что общее количество выборок слишком велико для хранения в памяти. Я хотел бы загрузить образцы партиями для обучения.

Как лучше всего это сделать? Могу ли я сделать это без специального DataLoader? Загрузчик данных PyTorch не любит принимать генератор в качестве входных данных. Ниже приведен минимальный пример того, что я хочу сделать, что приводит к ошибке «объект типа« генератор »не имеет len ()».

import torch
from torch import nn
from torch.utils.data import DataLoader

def example_generator():
    for i in range(10):
        yield i
    

BATCH_SIZE = 3
train_dataloader = DataLoader(example_generator(),
                        batch_size = BATCH_SIZE,
                        shuffle=False)

print(f"Length of train_dataloader: {len(train_dataloader)} batches of {BATCH_SIZE}")

Я пытаюсь получить данные от итератора и воспользоваться функциональностью PyTorch DataLoader. Пример, который я привел, является минимальным примером того, чего я хотел бы достичь, но он выдает ошибку.

Обновлено: я хочу иметь возможность использовать эту функцию для сложных генераторов, в которых len заранее неизвестен.

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
2
0
110
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

PyTorch DataLoader на самом деле имеет официальную поддержку итерируемого набора данных, но он просто должен быть экземпляром подкласса torch.utils.data.IterableDataset:

Набор данных в итерируемом стиле является экземпляром подкласса IterableDataset, реализующий протокол __iter__(), и представляет собой итерацию по образцам данных

Таким образом, ваш код будет записан как:

from torch.utils.data import IterableDataset

class MyIterableDataset(IterableDataset):
    def __init__(self, iterable):
        self.iterable = iterable

    def __iter__(self):
        return iter(self.iterable)

...

BATCH_SIZE = 3

train_dataloader = DataLoader(MyIterableDataset(example_generator()),
                              batch_size = BATCH_SIZE,
                              shuffle=False)

Другие вопросы по теме