python - 从 TensorFlow 数据集中拆分数据的问题
问题描述
我正在尝试从Oxford Flowers 102 数据集中下载数据,并使用 tfds API 将其拆分为训练、验证和测试集。这是我的代码:
# Split numbers
train_split = 60
test_val_split = 20
splits = tfds.Split.ALL.subsplit([train_split,test_val_split, test_val_split])
# TODO: Create a training set, a validation set and a test set.
(training_set, validation_set, test_set), dataset_info = tfds.load('oxford_flowers102', split=splits, as_supervised=True, with_info=True)
问题是当我打印出来时dataset_info
,我的测试、训练和验证集得到以下数字
total_num_examples=8189,
splits={
'test': 6149,
'train': 1020,
'validation': 1020,
},
问题:如何将数据拆分为训练集中的 6149 和测试和验证集中的 1020?
解决方案
这似乎是数据集本身的一个错误。特别是因为数据集的总大小是 8189,而 6149 不是总数的 60% 而是 75%,所以你根本没有执行任何拆分。他们可能以错误的方式标记了分裂。此外,即使我尝试使用此处描述的不同方式(https://github.com/tensorflow/datasets/blob/master/docs/splits.md)加载数据集,我也得到了同样的错误拆分。
一个简单的解决方案是将测试集作为训练集传递给模型,反之亦然,但您不会获得所需的百分比。否则,您可以加载整个数据集(训练+测试+验证),然后自行拆分。
df_all, summary = tfds.load('oxford_flowers102', split='train+test+validation', with_info=True)
# check if the dataset loaded truly contains everything
df_all_length = [i for i,_ in enumerate(df_all)][-1] + 1
print(df_all_length)
>>out: 8189 # length is fine
train_size = int(0.6 * df_all_length)
val_test_size = int(0.2 * df_all_length)
# split whole dataset
df_train = df_all.take(train_size)
df_test = df_all.skip(train_size)
df_valid = df_test.skip(val_test_size)
df_test = df_test.take(val_test_size)
df_train_length = [i for i,_ in enumerate(df_train)][-1] + 1
df_val_length = [i for i,_ in enumerate(df_val)][-1] + 1
df_test_length = [i for i,_ in enumerate(df_test)][-1] + 1
# check sizes
print('Train: ', df_train_length)
print('Validation :', df_valid_length)
print('Test :', df_test_length)
>>out: 4913 #(true 60% of 8189)
>>out: 1638 #(true 20% of 8189)
>>out: 1638
推荐阅读
- unity3d - Unity 使预制件与周围环境发生碰撞,但不会与同一预制件的其他对象发生碰撞
- angular - 如何在 Angular 应用程序中区分具有相同 URL 但不同 ACCEPT/内容类型的两个端点?
- python - TensorFlow 检查点变量未保存
- ios - 没有给定快速通道问题的 ipa 文件
- python - 使用聚合折叠多索引数据框中的行
- java - onClick方法在android studio中不起作用
- flask - 将烧瓶邮件与 Amazon SES 一起使用
- javascript - 来自 React 类组件的 Memoize 回调作为闭包
- apache - 如何配置 Apache 以避免将多个斜杠 (/) 重定向到单个斜杠
- wpf - 如何使更改的列数填充整个宽度并在 WPF 中水平居中?