Цель: выполнить итерацию по dataloader
, чтобы получить доступ к torch.Tensor
объекту data['image']
для прогнозов, например:
for data in dataloader:
image, slide, filename = data['image'], data['slide_id'], data['filename']
# predict
Я подозреваю, что проблема в методе ApplicationDataset
collate()
.
Есть 2 ошибки, вызванные collate()
:
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'pathlib.PosixPath'>
KeyError: 0
TypeError
вызвано tile_filenames: List[Path]
.
def get_dataloader(slide_ids: List[str], tile_filenames: List[Path]) -> DataLoader:
dataset = ApplicationDataset(slide_ids, tile_filenames)
return DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, collate_fn=ApplicationDataset.collate)
ApplicationDataset
класс:
from pathlib import Path
from typing import List
import torch
from torch.utils.data import Dataset
from torchvision.io import read_image
class ApplicationDataset(Dataset):
def __init__(self, slide_ids: List[str], tile_filenames: List[Path]):
self.slide_ids = slide_ids
self.tile_filenames = tile_filenames
def __len__(self):
return len(self.tile_filenames)
def __getitem__(self, idx):
image = read_image(str(self.tile_filenames[idx]))
return {
'image': image,
'slide_id': self.slide_ids[idx],
'filename': self.tile_filenames[idx],
}
@staticmethod
def collate(batch):
images = [batch_item['image'] for batch_item in batch]
images = torch.stack(images, dim=0)
slide_ids = torch.tensor([batch_item['slide_id'] for batch_item in batch])
filenames = [str(batch_item['filename']) for batch_item in batch]
return images, slide_ids, filenames
Выслеживать:
(venv) me@laptop:~/BitBucket/project$ python app/container/application.py
Traceback (most recent call last):
File "/home/me/BitBucket/project/app/container/application.py", line 89, in <module>
setup_inference(file_path_params, tile_params, fast_ai_params, dataloader)
File "/home/me/BitBucket/project/app/container/application.py", line 65, in setup_inference
predictions = predict_tiles(file_path_params, tile_params, dataloader, model)
File "/home/me/BitBucket/project/app/container/model_code/predict.py", line 58, in predict_tiles
grouped_tile_images = group_tile_images(dataloader)
File "/home/me/BitBucket/project/app/container/model_code/predict.py", line 40, in group_tile_images
for data in dataloader:
File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/fastai/data/load.py", line 127, in __iter__
for b in _loaders[self.fake_l.num_workers==0](self.fake_l):
File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 435, in __next__
data = self._next_data()
File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1085, in _next_data
return self._process_data(data)
File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1111, in _process_data
data.reraise()
File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/torch/_utils.py", line 428, in reraise
raise self.exc_type(msg)
KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/fastai/data/load.py", line 164, in create_batch
try: return (fa_collate,fa_convert)[self.prebatched](b)
File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/fastai/data/load.py", line 51, in fa_collate
return (default_collate(t) if isinstance(b, _collate_types)
File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 73, in default_collate
return {key: default_collate([d[key] for d in batch]) for key in elem}
File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 73, in <dictcomp>
return {key: default_collate([d[key] for d in batch]) for key in elem}
File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 85, in default_collate
raise TypeError(default_collate_err_msg_format.format(elem_type))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'pathlib.PosixPath'>
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
data = fetcher.fetch(index)
File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 34, in fetch
data = next(self.dataset_iter)
File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/fastai/data/load.py", line 138, in create_batches
yield from map(self.do_batch, self.chunkify(res))
File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/fastai/data/load.py", line 168, in do_batch
def do_batch(self, b): return self.retain(self.create_batch(self.before_batch(b)), b)
File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/fastai/data/load.py", line 166, in create_batch
if not self.prebatched: collate_error(e,b)
File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/fastai/data/load.py", line 75, in collate_error
if i == 0: shape_a, type_a = item[idx].shape, item[idx].__class__.__name__
KeyError: 0
Дайте мне знать, что еще я должен предоставить подробности.
Проблема с __getitem__
. Он возвращает значение Path
в своем словаре, что недопустимо. Вы конвертируете Path
в строку и возвращаетесь.
def __getitem__(self, idx):
image = read_image(str(self.tile_filenames[idx]))
assert len(image.shape) == 3 and tuple(image.shape[1:]) == (256, 256)
return {
'image': image,
'slide_id': self.slide_ids[idx],
'filename': self.tile_filenames[idx].__str__(),
}
Если вам обязательно нужен Path
после получения батча из даталоадера:
for data in dataloader:
image, slide, filename = data['image'], data['slide_id'], data['filename']
# if filename is a list
filename = [Path(file) for file in filename]
# else if it is a string
# filename = Path(filename)