Правильное использование Custom Sampler в Pytorch

У меня есть набор данных типа карты, который используется для задач сегментации экземпляров. Набор данных очень несбалансирован в том смысле, что на одних изображениях всего 10 объектов, а на других — до 1200.

Как я могу ограничить количество объектов в пакете?

Минимальный воспроизводимый пример:

import math
import torch
import random
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler


np.random.seed(0)
random.seed(0)
torch.manual_seed(0)


W = 700
H = 1000

def collate_fn(batch) -> tuple:
    return tuple(zip(*batch))

class SyntheticDataset(Dataset):
    def __init__(self, image_ids):
        self.image_ids = torch.tensor(image_ids, dtype=torch.int64)
        self.num_classes = 9

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx: int):
        """
            returns single sample
        """
        # print("idx: ", idx)

        # deliberately left dangling
        # id = self.image_ids[idx].item()
        # image_id = self.image_ids[idx]
        image_id = torch.as_tensor(idx)
        image = torch.randint(0, 255, (H, W))

        num_objects = random.randint(10, 1200)
        image = torch.randint(0, 255, (3, H, W))
        masks = torch.randint(0, 255, (num_objects, H, W))

        target = {}
        target["image_id"] = image_id

        areas = torch.randint(100, 20000, (1, num_objects), dtype=torch.int64)
        boxes = torch.randint(100, H * W, (num_objects, 4), dtype=torch.int64)
        labels = torch.randint(1, self.num_classes, (1, num_objects), dtype=torch.int64)
        iscrowd = torch.zeros(len(labels), dtype=torch.int64)

        target["boxes"] = boxes
        target["labels"] = labels
        target["area"] = areas
        target["iscrowd"] = iscrowd
        target["masks"] = masks

        return image, target, image_id


class BalancedObjectsSampler(BatchSampler):
    """Samples either batch_size images or batches num_objs_per_batch objects.

    Args:
        data_source (list): contains tuples of (img_id).
        batch_size (int): batch size.
        num_objs_per_batch (int): number of objects in a batch.
    Return
        yields the batch_ids/image_ids/image_indices

    """

    def __init__(self, data_source, batch_size, num_objs_per_batch, drop_last=False):
        self.data_source = data_source
        self.sampler = data_source
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.num_objs_per_batch = num_objs_per_batch
        self.batch_count = math.ceil(len(self.data_source) / self.batch_size)

    def __iter__(self):

        obj_count = 0
        batch = []
        batches = []
        counter = 0
        for i, (k, s) in enumerate(self.data_source.iteritems()):
            if (
                obj_count <= obj_count + s
                and len(batch) <= self.batch_size - 1
                and obj_count + s <= self.num_objs_per_batch
                and i < len(self.data_source) - 1
            ):
                # because of https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler
                batch.append(i)
                obj_count += s
            else:
                batches.append(batch)
                yield batch
                obj_count = 0
                batch = []
            counter += 1


obj_sums = {}
batch_size = 10
workers = 4
fake_image_ids = np.random.randint(1600000, 1700000, 100)

# assigning any in-range number objects count to each image
for i, k in enumerate(fake_image_ids):
    obj_sums[k] = random.randint(10, 1200)

obj_counts = pd.Series(obj_sums)

train_dataset = SyntheticDataset(image_ids=fake_image_ids)

balanced_sampler = BalancedObjectsSampler(
    data_source=obj_counts,
    batch_size=batch_size,
    num_objs_per_batch=1500,
    drop_last=False,
)

data_loader_sampler = torch.utils.data.DataLoader(
    train_dataset,
    num_workers=workers,
    collate_fn=collate_fn,
    sampler=balanced_sampler,
)

data_loader_iter = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=workers,
    collate_fn=collate_fn,
)

Итерация по balance_sampler

for i, bal_batch in enumerate(balanced_sampler):
    print(f"batch_{i}: ", bal_batch)

урожаи

batch_0:  [0]
batch_1:  [2, 3]
batch_2:  [5]
batch_3:  [7]
batch_4:  [9, 10]
batch_5:  [12, 13, 14, 15]
batch_6:  [17, 18]
batch_7:  [20, 21, 22]
batch_8:  [24, 25]
batch_9:  [27]
batch_10:  [29]
batch_11:  [31]
batch_12:  [33]
batch_13:  [35, 36, 37]
batch_14:  [39, 40]
batch_15:  [42, 43]
batch_16:  [45, 46]
batch_17:  [48, 49, 50]
batch_18:  [52, 53, 54]
batch_19:  [56]
batch_20:  [58, 59]
batch_21:  [61, 62]
batch_22:  [64]
batch_23:  [66]
batch_24:  [68]
batch_25:  [70, 71]
batch_26:  [73]
batch_27:  [75, 76, 77]
batch_28:  [79, 80]
batch_29:  [82, 83, 84, 85, 86, 87]
batch_30:  [89]
batch_31:  [91]
batch_32:  [93, 94]
batch_33:  [96]
batch_34:  [98]

Приведенные выше отображаемые значения являются индексами изображений, но также могут быть индексом пакета или даже идентификаторами изображений.

Запустив

for i, batch in enumerate(data_loader_sampler):
    print("__sample__: ", i, len(batch[0]))

Видно, что партия содержит один образец вместо ожидаемого количества.

__sample__:  0 1
__sample__:  1 1
__sample__:  2 1
__sample__:  3 1
__sample__:  4 1
__sample__:  5 1
__sample__:  6 1
__sample__:  7 1
__sample__:  8 1
__sample__:  9 1
__sample__:  10 1
__sample__:  11 1
__sample__:  12 1
__sample__:  13 1
__sample__:  14 1
__sample__:  15 1
__sample__:  16 1
__sample__:  17 1
__sample__:  18 1
__sample__:  19 1
__sample__:  20 1
__sample__:  21 1
__sample__:  22 1
__sample__:  23 1
__sample__:  24 1
__sample__:  25 1
__sample__:  26 1
__sample__:  27 1
__sample__:  28 1
__sample__:  29 1
__sample__:  30 1
__sample__:  31 1
__sample__:  32 1
__sample__:  33 1
__sample__:  34 1

Что я действительно пытаюсь предотвратить, так это следующее поведение, которое возникает из-за

for i, batch in enumerate(data_loader_iter):
    print("__iter__: ", i, sum([k["masks"].shape[0] for k in batch[1]]))

который

__iter__:  0 2510
__iter__:  1 2060
__iter__:  2 2203
__iter__:  3 2815
ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/queues.py", line 239, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/usr/lib/python3.8/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
  File "/blip/venv/lib/python3.8/site-packages/torch/multiprocessing/reductions.py", line 328, in reduce_storage
    fd, size = storage._share_fd_()
RuntimeError: falseINTERNAL ASSERT FAILED at "../aten/src/ATen/MapAllocator.cpp":300, please report a bug to PyTorch. unable to write to file </torch_431207_56>
Traceback (most recent call last):
  File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 990, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "/usr/lib/python3.8/multiprocessing/queues.py", line 107, in get
    if not self._poll(timeout):
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 257, in poll
    return self._poll(timeout)
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 424, in _poll
    r = wait([self], timeout)
  File "/usr/lib/python3.8/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/usr/lib/python3.8/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
  File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 431257) is killed by signal: Bus error. It is possible that dataloader's workers are out of shared memory. Please try to raise your shared memory limit.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "so.py", line 170, in <module>
    for i, batch in enumerate(data_loader_iter):
  File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
    data = self._next_data()
  File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1186, in _next_data
    idx, data = self._get_data()
  File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1152, in _get_data
    success, data = self._try_get_data()
  File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1003, in _try_get_data
    raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
RuntimeError: DataLoader worker (pid(s) 431257) exited unexpectedly

что всегда происходит, когда количество объектов в пакете превышает ~ 2500.

Немедленным обходным путем было бы установить низкий уровень batch_size, мне просто нужно более оптимальное решение.

Можете ли вы установить workers=0, чтобы получить лучшую трассировку?

jodag 16.03.2022 17:31

Да, только для отладки.

Índio 16.03.2022 18:10

Пожалуйста, предоставьте трассировку стека с workers=0

jodag 16.03.2022 18:42

Итерация по data_loader_sampler точно такая же.

Índio 16.03.2022 19:02
Стоит ли изучать PHP в 2023-2024 годах?
Стоит ли изучать PHP в 2023-2024 годах?
Привет всем, сегодня я хочу высказать свои соображения по поводу вопроса, который я уже много раз получал в своем сообществе: "Стоит ли изучать PHP в...
Поведение ключевого слова "this" в стрелочной функции в сравнении с нормальной функцией
Поведение ключевого слова "this" в стрелочной функции в сравнении с нормальной функцией
В JavaScript одним из самых запутанных понятий является поведение ключевого слова "this" в стрелочной и обычной функциях.
Приемы CSS-макетирования - floats и Flexbox
Приемы CSS-макетирования - floats и Flexbox
Здравствуйте, друзья-студенты! Готовы совершенствовать свои навыки веб-дизайна? Сегодня в нашем путешествии мы рассмотрим приемы CSS-верстки - в...
Тестирование функциональных ngrx-эффектов в Angular 16 с помощью Jest
В системе управления состояниями ngrx, совместимой с Angular 16, появились функциональные эффекты. Это здорово и делает код определенно легче для...
Концепция локализации и ее применение в приложениях React ⚡️
Концепция локализации и ее применение в приложениях React ⚡️
Локализация - это процесс адаптации приложения к различным языкам и культурным требованиям. Это позволяет пользователям получить опыт, соответствующий...
Пользовательский скаляр GraphQL
Пользовательский скаляр GraphQL
Листовые узлы системы типов GraphQL называются скалярами. Достигнув скалярного типа, невозможно спуститься дальше по иерархии типов. Скалярный тип...
3
4
131
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Если то, что вы пытаетесь решить, действительно:

ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).

Вы можете попробовать изменить размер выделенной общей памяти с помощью

# mount -o remount,size=<whatever_is_enough>G /dev/shm

Однако, поскольку это не всегда возможно, одним из решений вашей проблемы будет

class SyntheticDataset(Dataset):

    def __init__(self, image_ids):
        self.image_ids = torch.tensor(image_ids, dtype=torch.int64)
        self.num_classes = 9

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, indices):
        worker_info = torch.utils.data.get_worker_info()

        batch = []
        for i in indices:
            sample = self.get_sample(i)
            batch.append(sample)
        gc.collect()
        return batch

    def get_sample(self, idx: int):

        image_id = torch.as_tensor(idx)
        image = torch.randint(0, 255, (H, W))

        num_objects = idx
        image = torch.randint(0, 255, (3, H, W))
        masks = torch.randint(0, 255, (num_objects, H, W))

        target = {}
        target["image_id"] = image_id

        areas = torch.randint(100, 20000, (1, num_objects), dtype=torch.int64)
        boxes = torch.randint(100, H * W, (num_objects, 4), dtype=torch.int64)
        labels = torch.randint(1, self.num_classes, (1, num_objects), dtype=torch.int64)
        iscrowd = torch.zeros(len(labels), dtype=torch.int64)

        target["boxes"] = boxes
        target["labels"] = labels
        target["area"] = areas
        target["iscrowd"] = iscrowd
        target["masks"] = masks

        return image, target, image_id

и

class BalancedObjectsSampler(BatchSampler):
    """Samples either batch_size images or batches num_objs_per_batch objects.

    Args:
        data_source (list): contains tuples of (img_id).
        batch_size (int): batch size.
        num_objs_per_batch (int): number of objects in a batch.
    Return
        yields the batch_ids/image_ids/image_indices

    """

    def __init__(self, data_source, batch_size, num_objs_per_batch, drop_last=False):
        self.data_source = data_source
        self.sampler = data_source
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.num_objs_per_batch = num_objs_per_batch
        self.batch_count = math.ceil(len(self.data_source) / self.batch_size)

        obj_count = 0
        batch = []
        batches = []
        batches_sums = []
        for i, (k, s) in enumerate(self.data_source.iteritems()):

            if (
                len(batch) < self.batch_size
                and obj_count + s < self.num_objs_per_batch
                and i < len(self.data_source) - 1
            ):
                batch.append(s)
                obj_count += s
            else:
                batches.append(len(batch))
                batches_sums.append(obj_count)
                obj_count = 0
                batch = []

        self.batches = batches
        self.batch_count = len(batches)

    def __iter__(self):
        batch = []
        img_counts_id = 0
        for idx, (k, s) in enumerate(self.data_source.iteritems()):
            if len(batch) < self.batches[img_counts_id] and idx < len(self.data_source):
                batch.append(s)
            elif len(batch) == self.batches[img_counts_id]:
                gc.collect()
                yield batch
                batch = []
                if img_counts_id < self.batch_count - 1:
                    img_counts_id += 1
                else:
                    break

        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self) -> int:
        if self.drop_last:
            return len(self.data_source) // self.batch_size
        else:
            return (len(self.data_source) + self.batch_size - 1) // self.batch_size

Поскольку __getitem__ SyntheticDataset получал список индексов, самым простым решением было бы просто перебирать индексы и получать список образцов. Возможно, вам просто придется по-разному сопоставлять выходные данные, чтобы передать их вашей модели.

Для BalancedObjectsSampler я рассчитал размер каждого пакета в __init__ и использовал его в __iter__ для сборки пакетов.

ПРИМЕЧАНИЕ. Это все равно не удастся, если ваш num_workers > 0 для вас пытается упаковать не более 1500 объектов в пакет — и обычно один рабочий загружает один пакет за раз. Следовательно, вы должны переоценить свой num_objs_per_batch при рассмотрении возможности использования многопроцессорной обработки.

Так что это правда, большое спасибо за то, что нашли время. Вы правильно подошли к вопросу.

Índio 17.03.2022 20:28

Другие вопросы по теме