tensorflow - 自定义对象检测 TFLite 模型错误 (pascal voc) - ValueError: The size of the train_data (0) could not be less than batch_size (2)
问题描述
我正在尝试使用 tflite 模型制造商在 jupyter notebook 中构建自定义对象检测模型,但我遇到了一些问题。
我正在使用 pascal_voc(不是 csv)获取图像,因此我将训练数据/测试数据拆分为不同的文件
os.mkdir('C:/Users/user/Desktop/GradProject5/train_test_split/train')
os.mkdir('C:/Users/user/Desktop/GradProject5/train_test_split/test')
image_paths = os.listdir('C:/Users/user/anaconda3/envs/DarkflowTest/data/dataset')
random.shuffle(image_paths)
for i, image_path in enumerate(image_paths):
if i < int(len(image_paths) * 0.8):
shutil.copy(f'C:/Users/user/anaconda3/envs/DarkflowTest/data/dataset/{image_path}', 'C:/Users/user/Desktop/GradProject5/train_test_split/train')
shutil.copy(f'C:/Users/user/anaconda3/envs/DarkflowTest/data/annotations/{image_path.replace("jpg", "xml")}', 'C:/Users/user/Desktop/GradProject5/train_test_split/train')
else:
shutil.copy(f'C:/Users/user/anaconda3/envs/DarkflowTest/data/dataset/{image_path}', 'C:/Users/user/Desktop/GradProject5/train_test_split/test')
shutil.copy(f'C:/Users/user/anaconda3/envs/DarkflowTest/data/annotations/{image_path.replace("jpg", "xml")}', 'C:/Users/user/Desktop/GradProject5/train_test_split/test')
test_image_dir='C:/Users/user/Desktop/GradProject/\train_test_split/test/'
#annotations_dir = 'C:/Users/user/anaconda3/envs/DarkflowTest/data/annotations/'
train_data=object_detector.DataLoader.from_pascal_voc(train_image_dir+'image/',train_image_dir+'xml/',label_map={1:"pill",2:"text"})
test_datal=object_detector.DataLoader.from_pascal_voc(test_image_dir+'image/',test_image_dir+'xml/',label_map={1:"pill",2:"text"})
然后用 DataLoader 加载训练数据和测试数据
model = object_detector.create(train_data, model_spec=spec, batch_size=2, train_whole_model=True)
然后我尝试创建一个模型,但我收到了这个错误。
[[1]:https://i.stack.imgur.com/ZBQtk.png][1]
----> 1 model = object_detector.create(train_data, model_spec=spec, batch_size=2, train_whole_model=True) 中的 ValueError Traceback (最近一次调用最后一次)
~\anaconda3\lib\site-packages\tensorflow_examples\lite\model_maker\core\task\object_detector.py in create(cls, train_data, model_spec, validation_data, epochs, batch_size, train_whole_model, do_train) 285 if do_train: 286 tf.compat .v1.logging.info('重新训练模型...') --> 287 object_detector.train(train_data, validation_data, epochs, batch_size) 288 else: 289 object_detector.create_model()
~\anaconda3\lib\site-packages\tensorflow_examples\lite\model_maker\core\task\object_detector.py in train(self, train_data, validation_data, epochs, batch_size) 137 # TODO(b/171449557): 将此上传到父级班级。138 if len(train_data) < batch_size: --> 139 raise ValueError('The size of the train_data (%d) could not be less ' 140 ' than batch_size (%d). 为了解决这个问题,设置' 141 ' batch_size 变小或增大'
ValueError:train_data (0) 的大小不能小于 batch_size (2)。要解决此问题,请将batch_size 设置得更小或增加train_data 的大小。
我是否因为
train_data=object_detector.DataLoader.from_pascal_voc(train_image_dir+'image/',train_image_dir+'xml/',label_map={1:"pill",2:"text"})
无法加载火车数据而出错..?我认为我做的一切都是正确的,但仍然在为这个错误而苦苦挣扎。如果你知道这个问题的解决方案,请帮助我!
解决方案
推荐阅读
- spring-boot - 使用 Consul Cluster 实现容错
- firefox - Mozilla Firefox 在 Debian 中崩溃
- mysql - 合并两个单独的 mysql 查询
- hyperledger-fabric - 我们可以使用 fabric couchdb 中的视图吗?
- java - Spring-Boot 1.5.16 JBoss 6.4.X 缺少 JTA 依赖项
- c# - 做 Html.BeginForm 时维护模型
- mongodb - MongoDb 查询数组元素
- python - 如何使用 python 将索引添加到我的 Azure 搜索服务?
- qt - 如何将录音保存到 *.wav?
- r - 将网页中的内容另存为 data.frame