首页 > 解决方案 > 自定义对象检测 TFLite 模型错误 (pascal voc) - ValueError: The size of the train_data (0) could not be less than batch_size (2)

问题描述

我正在尝试使用 tflite 模型制造商在 jupyter notebook 中构建自定义对象检测模型,但我遇到了一些问题。

我正在使用 pascal_voc(不是 csv)获取图像,因此我将训练数据/测试数据拆分为不同的文件


os.mkdir('C:/Users/user/Desktop/GradProject5/train_test_split/train')
os.mkdir('C:/Users/user/Desktop/GradProject5/train_test_split/test')

image_paths = os.listdir('C:/Users/user/anaconda3/envs/DarkflowTest/data/dataset')
random.shuffle(image_paths)

for i, image_path in enumerate(image_paths):
    if i < int(len(image_paths) * 0.8):
        shutil.copy(f'C:/Users/user/anaconda3/envs/DarkflowTest/data/dataset/{image_path}', 'C:/Users/user/Desktop/GradProject5/train_test_split/train')
        shutil.copy(f'C:/Users/user/anaconda3/envs/DarkflowTest/data/annotations/{image_path.replace("jpg", "xml")}', 'C:/Users/user/Desktop/GradProject5/train_test_split/train')
    else:
        shutil.copy(f'C:/Users/user/anaconda3/envs/DarkflowTest/data/dataset/{image_path}', 'C:/Users/user/Desktop/GradProject5/train_test_split/test')
        shutil.copy(f'C:/Users/user/anaconda3/envs/DarkflowTest/data/annotations/{image_path.replace("jpg", "xml")}', 'C:/Users/user/Desktop/GradProject5/train_test_split/test')
test_image_dir='C:/Users/user/Desktop/GradProject/\train_test_split/test/'
#annotations_dir = 'C:/Users/user/anaconda3/envs/DarkflowTest/data/annotations/'
train_data=object_detector.DataLoader.from_pascal_voc(train_image_dir+'image/',train_image_dir+'xml/',label_map={1:"pill",2:"text"})
test_datal=object_detector.DataLoader.from_pascal_voc(test_image_dir+'image/',test_image_dir+'xml/',label_map={1:"pill",2:"text"})

然后用 DataLoader 加载训练数据和测试数据

model = object_detector.create(train_data, model_spec=spec, batch_size=2, train_whole_model=True)

然后我尝试创建一个模型,但我收到了这个错误。

[[1]:https://i.stack.imgur.com/ZBQtk.png][1]


----> 1 model = object_detector.create(train_data, model_spec=spec, batch_size=2, train_whole_model=True) 中的 ValueError Traceback (最近一次调用最后一次)

~\anaconda3\lib\site-packages\tensorflow_examples\lite\model_maker\core\task\object_detector.py in create(cls, train_data, model_spec, validation_data, epochs, batch_size, train_whole_model, do_train) 285 if do_train: 286 tf.compat .v1.logging.info('重新训练模型...') --> 287 object_detector.train(train_data, validation_data, epochs, batch_size) 288 else: 289 object_detector.create_model()

~\anaconda3\lib\site-packages\tensorflow_examples\lite\model_maker\core\task\object_detector.py in train(self, train_data, validation_data, epochs, batch_size) 137 # TODO(b/171449557): 将此上传到父级班级。138 if len(train_data) < batch_size: --> 139 raise ValueError('The size of the train_data (%d) could not be less ' 140 ' than batch_size (%d). 为了解决这个问题,设置' 141 ' batch_size 变小或增大'

ValueError:train_data (0) 的大小不能小于 batch_size (2)。要解决此问题,请将batch_size 设置得更小或增加train_data 的大小。

我是否因为 train_data=object_detector.DataLoader.from_pascal_voc(train_image_dir+'image/',train_image_dir+'xml/',label_map={1:"pill",2:"text"}) 无法加载火车数据而出错..?我认为我做的一切都是正确的,但仍然在为这个错误而苦苦挣扎。如果你知道这个问题的解决方案,请帮助我!

标签: tensorflow

解决方案


推荐阅读