Обучение модели PyTorch с помощью DataLoader происходит слишком медленно

Я обучаю очень маленькую НС, используя набор данных HAM10000. Для загрузки данных я использую DataLoader, который поставляется с PyTorch:

class CocoDetectionWithFilenames(CocoDetection):
    def __init__(self, root: str, ann_file: str, transform=None):
        super().__init__(root, ann_file, transform)

    def get_filename(self, idx: int) -> str:
        return self.coco.loadImgs(self.ids[idx])[0]["file_name"]


def get_loaders(root: str, ann_file: str) -> tuple[CocoDetection, DataLoader, DataLoader, DataLoader]:
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    dataset = CocoDetectionWithFilenames(
        root=root,
        ann_file=ann_file,
        transform=transform
    )
    train_size = int(0.7 * len(dataset))
    valid_size = int(0.15 * len(dataset))
    test_size = len(dataset) - train_size - valid_size
    train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size, test_size])
    num_workers = os.cpu_count()
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        prefetch_factor=1024
    )
    valid_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        prefetch_factor=1024
    )
    test_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return dataset, train_loader, valid_loader, test_loader

Дело в том, что когда мой цикл обучения выполняется, само обучение происходит очень быстро, но программа тратит 95% времени между эпохами — вероятно, загружая данные:

def extract_bboxes(targets: list[dict]) -> list[torch.Tensor]:
    bboxes = []

    for target in targets:
        xs, ys, widths, heights = target["bbox"]

        for idx, _ in enumerate(xs):
            x1, y1, width, height = xs[idx], ys[idx], widths[idx], heights[idx]
            # Convert COCO format (x, y, width, height) to (x1, y1, x2, y2)
            x2, y2 = x1 + width, y1 + height

            bboxes.append(torch.IntTensor([x1, y1, x2, y2]))

    return bboxes

num_epochs = 25
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, targets in train_loader_tqdm:
        images = images.to(device)
        bboxes = extract_bboxes(targets)
        bboxes = torch.stack(bboxes).to(device)

        optimizer.zero_grad(set_to_none=True)

        outputs = model(images)
        loss = criterion(outputs, bboxes)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_train_loss = running_loss / len(train_loader)

    train_losses.append(epoch_train_loss)
    print(f"Epoch {epoch + 1}, Loss: {epoch_train_loss}")
    model.eval()

Как видите, код обучающего цикла довольно прост, ничего странного в нем не происходит.

Вы можете попытаться извлечь ограничивающие рамки перед обучением, а затем использовать эти данные вместо того, чтобы делать это в цикле обучения.

maciek97x 25.06.2024 08:21

@ maciek97x причина медлительности не в этом. Эта функция только перебирает уже загруженные цели и считывает данные. Он не выполняет никакой тяжелой работы — по сути, эта функция почти не занимает времени. Я обновил свой вопрос, чтобы вы его увидели.

Marek M. 25.06.2024 08:27

Можно попробовать уменьшить num_workers и prefetch_factor. Возможно, он все время тратит время на выборку этих 1024 пакетов, используя все потоки.

maciek97x 25.06.2024 08:39

Кажется, ты был прав. Раньше получение данных занимало около 40 секунд, а обучение — 8 секунд, теперь после уменьшения значения num_workers данные извлекаются за 10 секунд. Пожалуйста, опубликуйте это как свой ответ, чтобы я мог его принять.

Marek M. 25.06.2024 08:42
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
0
4
57
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Попробуйте уменьшить num_workers и prefetch_factor. Он может потратить все время на получение этих 1024 пакетов, используя все потоки.

Другие вопросы по теме