У меня есть набор данных типа карты, который используется для задач сегментации экземпляров. Набор данных очень несбалансирован в том смысле, что на одних изображениях всего 10 объектов, а на других — до 1200.
Как я могу ограничить количество объектов в пакете?
Минимальный воспроизводимый пример:
import math
import torch
import random
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
W = 700
H = 1000
def collate_fn(batch) -> tuple:
return tuple(zip(*batch))
class SyntheticDataset(Dataset):
def __init__(self, image_ids):
self.image_ids = torch.tensor(image_ids, dtype=torch.int64)
self.num_classes = 9
def __len__(self):
return len(self.image_ids)
def __getitem__(self, idx: int):
"""
returns single sample
"""
# print("idx: ", idx)
# deliberately left dangling
# id = self.image_ids[idx].item()
# image_id = self.image_ids[idx]
image_id = torch.as_tensor(idx)
image = torch.randint(0, 255, (H, W))
num_objects = random.randint(10, 1200)
image = torch.randint(0, 255, (3, H, W))
masks = torch.randint(0, 255, (num_objects, H, W))
target = {}
target["image_id"] = image_id
areas = torch.randint(100, 20000, (1, num_objects), dtype=torch.int64)
boxes = torch.randint(100, H * W, (num_objects, 4), dtype=torch.int64)
labels = torch.randint(1, self.num_classes, (1, num_objects), dtype=torch.int64)
iscrowd = torch.zeros(len(labels), dtype=torch.int64)
target["boxes"] = boxes
target["labels"] = labels
target["area"] = areas
target["iscrowd"] = iscrowd
target["masks"] = masks
return image, target, image_id
class BalancedObjectsSampler(BatchSampler):
"""Samples either batch_size images or batches num_objs_per_batch objects.
Args:
data_source (list): contains tuples of (img_id).
batch_size (int): batch size.
num_objs_per_batch (int): number of objects in a batch.
Return
yields the batch_ids/image_ids/image_indices
"""
def __init__(self, data_source, batch_size, num_objs_per_batch, drop_last=False):
self.data_source = data_source
self.sampler = data_source
self.batch_size = batch_size
self.drop_last = drop_last
self.num_objs_per_batch = num_objs_per_batch
self.batch_count = math.ceil(len(self.data_source) / self.batch_size)
def __iter__(self):
obj_count = 0
batch = []
batches = []
counter = 0
for i, (k, s) in enumerate(self.data_source.iteritems()):
if (
obj_count <= obj_count + s
and len(batch) <= self.batch_size - 1
and obj_count + s <= self.num_objs_per_batch
and i < len(self.data_source) - 1
):
# because of https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler
batch.append(i)
obj_count += s
else:
batches.append(batch)
yield batch
obj_count = 0
batch = []
counter += 1
obj_sums = {}
batch_size = 10
workers = 4
fake_image_ids = np.random.randint(1600000, 1700000, 100)
# assigning any in-range number objects count to each image
for i, k in enumerate(fake_image_ids):
obj_sums[k] = random.randint(10, 1200)
obj_counts = pd.Series(obj_sums)
train_dataset = SyntheticDataset(image_ids=fake_image_ids)
balanced_sampler = BalancedObjectsSampler(
data_source=obj_counts,
batch_size=batch_size,
num_objs_per_batch=1500,
drop_last=False,
)
data_loader_sampler = torch.utils.data.DataLoader(
train_dataset,
num_workers=workers,
collate_fn=collate_fn,
sampler=balanced_sampler,
)
data_loader_iter = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=workers,
collate_fn=collate_fn,
)
Итерация по balance_sampler
for i, bal_batch in enumerate(balanced_sampler):
print(f"batch_{i}: ", bal_batch)
урожаи
batch_0: [0]
batch_1: [2, 3]
batch_2: [5]
batch_3: [7]
batch_4: [9, 10]
batch_5: [12, 13, 14, 15]
batch_6: [17, 18]
batch_7: [20, 21, 22]
batch_8: [24, 25]
batch_9: [27]
batch_10: [29]
batch_11: [31]
batch_12: [33]
batch_13: [35, 36, 37]
batch_14: [39, 40]
batch_15: [42, 43]
batch_16: [45, 46]
batch_17: [48, 49, 50]
batch_18: [52, 53, 54]
batch_19: [56]
batch_20: [58, 59]
batch_21: [61, 62]
batch_22: [64]
batch_23: [66]
batch_24: [68]
batch_25: [70, 71]
batch_26: [73]
batch_27: [75, 76, 77]
batch_28: [79, 80]
batch_29: [82, 83, 84, 85, 86, 87]
batch_30: [89]
batch_31: [91]
batch_32: [93, 94]
batch_33: [96]
batch_34: [98]
Приведенные выше отображаемые значения являются индексами изображений, но также могут быть индексом пакета или даже идентификаторами изображений.
Запустив
for i, batch in enumerate(data_loader_sampler):
print("__sample__: ", i, len(batch[0]))
Видно, что партия содержит один образец вместо ожидаемого количества.
__sample__: 0 1
__sample__: 1 1
__sample__: 2 1
__sample__: 3 1
__sample__: 4 1
__sample__: 5 1
__sample__: 6 1
__sample__: 7 1
__sample__: 8 1
__sample__: 9 1
__sample__: 10 1
__sample__: 11 1
__sample__: 12 1
__sample__: 13 1
__sample__: 14 1
__sample__: 15 1
__sample__: 16 1
__sample__: 17 1
__sample__: 18 1
__sample__: 19 1
__sample__: 20 1
__sample__: 21 1
__sample__: 22 1
__sample__: 23 1
__sample__: 24 1
__sample__: 25 1
__sample__: 26 1
__sample__: 27 1
__sample__: 28 1
__sample__: 29 1
__sample__: 30 1
__sample__: 31 1
__sample__: 32 1
__sample__: 33 1
__sample__: 34 1
Что я действительно пытаюсь предотвратить, так это следующее поведение, которое возникает из-за
for i, batch in enumerate(data_loader_iter):
print("__iter__: ", i, sum([k["masks"].shape[0] for k in batch[1]]))
который
__iter__: 0 2510
__iter__: 1 2060
__iter__: 2 2203
__iter__: 3 2815
ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
Traceback (most recent call last):
File "/usr/lib/python3.8/multiprocessing/queues.py", line 239, in _feed
obj = _ForkingPickler.dumps(obj)
File "/usr/lib/python3.8/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
File "/blip/venv/lib/python3.8/site-packages/torch/multiprocessing/reductions.py", line 328, in reduce_storage
fd, size = storage._share_fd_()
RuntimeError: falseINTERNAL ASSERT FAILED at "../aten/src/ATen/MapAllocator.cpp":300, please report a bug to PyTorch. unable to write to file </torch_431207_56>
Traceback (most recent call last):
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 990, in _try_get_data
data = self._data_queue.get(timeout=timeout)
File "/usr/lib/python3.8/multiprocessing/queues.py", line 107, in get
if not self._poll(timeout):
File "/usr/lib/python3.8/multiprocessing/connection.py", line 257, in poll
return self._poll(timeout)
File "/usr/lib/python3.8/multiprocessing/connection.py", line 424, in _poll
r = wait([self], timeout)
File "/usr/lib/python3.8/multiprocessing/connection.py", line 931, in wait
ready = selector.select(timeout)
File "/usr/lib/python3.8/selectors.py", line 415, in select
fd_event_list = self._selector.poll(timeout)
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
_error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 431257) is killed by signal: Bus error. It is possible that dataloader's workers are out of shared memory. Please try to raise your shared memory limit.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "so.py", line 170, in <module>
for i, batch in enumerate(data_loader_iter):
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
data = self._next_data()
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1186, in _next_data
idx, data = self._get_data()
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1152, in _get_data
success, data = self._try_get_data()
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1003, in _try_get_data
raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
RuntimeError: DataLoader worker (pid(s) 431257) exited unexpectedly
что всегда происходит, когда количество объектов в пакете превышает ~ 2500.
Немедленным обходным путем было бы установить низкий уровень batch_size
, мне просто нужно более оптимальное решение.
Да, только для отладки.
Пожалуйста, предоставьте трассировку стека с workers=0
Итерация по data_loader_sampler
точно такая же.
Если то, что вы пытаетесь решить, действительно:
ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
Вы можете попробовать изменить размер выделенной общей памяти с помощью
# mount -o remount,size=<whatever_is_enough>G /dev/shm
Однако, поскольку это не всегда возможно, одним из решений вашей проблемы будет
class SyntheticDataset(Dataset):
def __init__(self, image_ids):
self.image_ids = torch.tensor(image_ids, dtype=torch.int64)
self.num_classes = 9
def __len__(self):
return len(self.image_ids)
def __getitem__(self, indices):
worker_info = torch.utils.data.get_worker_info()
batch = []
for i in indices:
sample = self.get_sample(i)
batch.append(sample)
gc.collect()
return batch
def get_sample(self, idx: int):
image_id = torch.as_tensor(idx)
image = torch.randint(0, 255, (H, W))
num_objects = idx
image = torch.randint(0, 255, (3, H, W))
masks = torch.randint(0, 255, (num_objects, H, W))
target = {}
target["image_id"] = image_id
areas = torch.randint(100, 20000, (1, num_objects), dtype=torch.int64)
boxes = torch.randint(100, H * W, (num_objects, 4), dtype=torch.int64)
labels = torch.randint(1, self.num_classes, (1, num_objects), dtype=torch.int64)
iscrowd = torch.zeros(len(labels), dtype=torch.int64)
target["boxes"] = boxes
target["labels"] = labels
target["area"] = areas
target["iscrowd"] = iscrowd
target["masks"] = masks
return image, target, image_id
и
class BalancedObjectsSampler(BatchSampler):
"""Samples either batch_size images or batches num_objs_per_batch objects.
Args:
data_source (list): contains tuples of (img_id).
batch_size (int): batch size.
num_objs_per_batch (int): number of objects in a batch.
Return
yields the batch_ids/image_ids/image_indices
"""
def __init__(self, data_source, batch_size, num_objs_per_batch, drop_last=False):
self.data_source = data_source
self.sampler = data_source
self.batch_size = batch_size
self.drop_last = drop_last
self.num_objs_per_batch = num_objs_per_batch
self.batch_count = math.ceil(len(self.data_source) / self.batch_size)
obj_count = 0
batch = []
batches = []
batches_sums = []
for i, (k, s) in enumerate(self.data_source.iteritems()):
if (
len(batch) < self.batch_size
and obj_count + s < self.num_objs_per_batch
and i < len(self.data_source) - 1
):
batch.append(s)
obj_count += s
else:
batches.append(len(batch))
batches_sums.append(obj_count)
obj_count = 0
batch = []
self.batches = batches
self.batch_count = len(batches)
def __iter__(self):
batch = []
img_counts_id = 0
for idx, (k, s) in enumerate(self.data_source.iteritems()):
if len(batch) < self.batches[img_counts_id] and idx < len(self.data_source):
batch.append(s)
elif len(batch) == self.batches[img_counts_id]:
gc.collect()
yield batch
batch = []
if img_counts_id < self.batch_count - 1:
img_counts_id += 1
else:
break
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self) -> int:
if self.drop_last:
return len(self.data_source) // self.batch_size
else:
return (len(self.data_source) + self.batch_size - 1) // self.batch_size
Поскольку __getitem__
SyntheticDataset получал список индексов, самым простым решением было бы просто перебирать индексы и получать список образцов. Возможно, вам просто придется по-разному сопоставлять выходные данные, чтобы передать их вашей модели.
Для BalancedObjectsSampler я рассчитал размер каждого пакета в __init__
и использовал его в __iter__
для сборки пакетов.
ПРИМЕЧАНИЕ. Это все равно не удастся, если ваш num_workers > 0
для вас пытается упаковать не более 1500 объектов в пакет — и обычно один рабочий загружает один пакет за раз. Следовательно, вы должны переоценить свой num_objs_per_batch
при рассмотрении возможности использования многопроцессорной обработки.
Так что это правда, большое спасибо за то, что нашли время. Вы правильно подошли к вопросу.
Можете ли вы установить
workers=0
, чтобы получить лучшую трассировку?