Когда я загружаю набор данных, мне интересно, есть ли какой-нибудь быстрый способ найти количество выборок или пакетов в этом наборе данных. Я знаю, что если я загружаю набор данных с помощью with_info=True
, я могу видеть, например, total_num_examples=6000,
, но эта информация недоступна, если я разделяю набор данных.
В настоящее время я подсчитываю количество образцов следующим образом, но мне интересно, есть ли лучшее решение:
train_subsplit_1, train_subsplit_2, train_subsplit_3 = tfds.Split.TRAIN.subsplit(3)
cifar10_trainsub3 = tfds.load("cifar10", split=train_subsplit_3)
cifar10_trainsub3 = cifar10_trainsub3.batch(1000)
n = 0
for i, batch in enumerate(cifar10_trainsub3.take(-1)):
print(i, n, batch['image'].shape)
n += len(batch['image'])
print(i, n)
Если возможно узнать длину, вы можете использовать:
tf.data.experimental.cardinality(dataset)
но проблема в том, что набор данных TF изначально загружается лениво. Таким образом, мы можем не знать размер набора данных заранее. Действительно, вполне возможно иметь набор данных, представляющий бесконечный набор данных!
Если это достаточно маленький набор данных, вы также можете просто перебрать его, чтобы получить длину. Я использовал следующую уродливую маленькую конструкцию раньше, но это зависит от того, достаточно ли мал набор данных, чтобы мы могли загрузить его в память, и это действительно не улучшение по сравнению с вашим циклом for
выше!
dataset_length = [i for i,_ in enumerate(dataset)][-1] + 1