PyTorch DataLoader превращает наборы данных в итерируемые объекты. У меня уже есть генератор, который дает образцы данных, которые я хочу использовать для обучения и тестирования. Причина, по которой я использую генератор, заключается в том, что общее количество выборок слишком велико для хранения в памяти. Я хотел бы загрузить образцы партиями для обучения.
Как лучше всего это сделать? Могу ли я сделать это без специального DataLoader? Загрузчик данных PyTorch не любит принимать генератор в качестве входных данных. Ниже приведен минимальный пример того, что я хочу сделать, что приводит к ошибке «объект типа« генератор »не имеет len ()».
import torch
from torch import nn
from torch.utils.data import DataLoader
def example_generator():
for i in range(10):
yield i
BATCH_SIZE = 3
train_dataloader = DataLoader(example_generator(),
batch_size = BATCH_SIZE,
shuffle=False)
print(f"Length of train_dataloader: {len(train_dataloader)} batches of {BATCH_SIZE}")
Я пытаюсь получить данные от итератора и воспользоваться функциональностью PyTorch DataLoader. Пример, который я привел, является минимальным примером того, чего я хотел бы достичь, но он выдает ошибку.
Обновлено: я хочу иметь возможность использовать эту функцию для сложных генераторов, в которых len заранее неизвестен.
PyTorch DataLoader
на самом деле имеет официальную поддержку итерируемого набора данных, но он просто должен быть экземпляром подкласса torch.utils.data.IterableDataset
:
Набор данных в итерируемом стиле является экземпляром подкласса IterableDataset, реализующий протокол
__iter__()
, и представляет собой итерацию по образцам данных
Таким образом, ваш код будет записан как:
from torch.utils.data import IterableDataset
class MyIterableDataset(IterableDataset):
def __init__(self, iterable):
self.iterable = iterable
def __iter__(self):
return iter(self.iterable)
...
BATCH_SIZE = 3
train_dataloader = DataLoader(MyIterableDataset(example_generator()),
batch_size = BATCH_SIZE,
shuffle=False)