python - TF2.0 图像生成器无法使用 Keras ImageDataGenerator
问题描述
我正在尝试将数据集与 TF2.0 以及 keras ImageDataGenerator 一起使用,但是当我尝试调用它时,它会给我一个错误。所以这就是我正在做的。我有一个数据文件夹,其中每种类别有 4 个文件夹。我假设这将是标签,就像旧的 keras 方法一样。有 4 个文件夹有 72 个左右的图像。
这是我用来生成代码的代码
augment = True
if augment:
train_datagen = ImageDataGenerator(
rescale=1./ 255,
shear_range=0,
rotation_range=20,
zoom_range=0.15,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True,
fill_mode='nearest') # set validation split
else:
train_datagen = ImageDataGenerator(
rescale=1./ 255,
horizontal_flip=True,
fill_mode='nearest') # set validation split
images, labels = next(train_datagen.flow_from_directory(DATA_PATH))
print(images.dtype, images.shape)
print(labels.dtype, labels.shape)
input_shape = images.shape[1:]
print("InputShape:", input_shape)
img_shape = (input_shape[0], input_shape[1])
ds = tf.data.Dataset.from_generator(train_datagen.flow_from_directory,
args=[DATA_PATH], output_types=(tf.float32, tf.float32))
这会产生:
Found 324 images belonging to 4 classes.
float32 (32, 256, 256, 3)
float32 (32, 4)
InputShape: (256, 256, 3)
DS: <DatasetV1Adapter shapes: (<unknown>, <unknown>), types: (tf.float32, tf.float32)>
所以这对我来说是正确的。所以当我尝试在我的模型中使用它时
history = model.fit(ds, epochs=10, verbose=1)
它给了我这个错误:
Epoch 1/10
Traceback (most recent call last):
File "C:/Users/gus/Documents/ImageSimularity/FoodTrainer.py", line 75, in <module>
history = model.fit(ds, epochs=10, verbose=1)
File "C:\Users\gus\Anaconda3\envs\TF2\lib\site-packages\tensorflow_core\python\keras\engine\training.py", line 728, in fit
use_multiprocessing=use_multiprocessing)
File "C:\Users\gus\Anaconda3\envs\TF2\lib\site-packages\tensorflow_core\python\keras\engine\training_v2.py", line 324, in fit
total_epochs=epochs)
File "C:\Users\gus\Anaconda3\envs\TF2\lib\site-packages\tensorflow_core\python\keras\engine\training_v2.py", line 123, in run_one_epoch
batch_outs = execution_function(iterator)
File "C:\Users\gus\Anaconda3\envs\TF2\lib\site-packages\tensorflow_core\python\keras\engine\training_v2_utils.py", line 86, in execution_function
distributed_function(input_fn))
File "C:\Users\gus\Anaconda3\envs\TF2\lib\site-packages\tensorflow_core\python\eager\def_function.py", line 457, in __call__
result = self._call(*args, **kwds)
File "C:\Users\gus\Anaconda3\envs\TF2\lib\site-packages\tensorflow_core\python\eager\def_function.py", line 503, in _call
self._initialize(args, kwds, add_initializers_to=initializer_map)
File "C:\Users\gus\Anaconda3\envs\TF2\lib\site-packages\tensorflow_core\python\eager\def_function.py", line 408, in _initialize
*args, **kwds))
File "C:\Users\gus\Anaconda3\envs\TF2\lib\site-packages\tensorflow_core\python\eager\function.py", line 1848, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "C:\Users\gus\Anaconda3\envs\TF2\lib\site-packages\tensorflow_core\python\eager\function.py", line 2150, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "C:\Users\gus\Anaconda3\envs\TF2\lib\site-packages\tensorflow_core\python\eager\function.py", line 2041, in _create_graph_function
capture_by_value=self._capture_by_value),
File "C:\Users\gus\Anaconda3\envs\TF2\lib\site-packages\tensorflow_core\python\framework\func_graph.py", line 915, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "C:\Users\gus\Anaconda3\envs\TF2\lib\site-packages\tensorflow_core\python\eager\def_function.py", line 358, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "C:\Users\gus\Anaconda3\envs\TF2\lib\site-packages\tensorflow_core\python\keras\engine\training_v2_utils.py", line 66, in distributed_function
model, input_iterator, mode)
File "C:\Users\gus\Anaconda3\envs\TF2\lib\site-packages\tensorflow_core\python\keras\engine\training_v2_utils.py", line 112, in _prepare_feed_values
inputs, targets, sample_weights = _get_input_from_iterator(inputs)
File "C:\Users\gus\Anaconda3\envs\TF2\lib\site-packages\tensorflow_core\python\keras\engine\training_v2_utils.py", line 149, in _get_input_from_iterator
distribution_strategy_context.get_strategy(), x, y, sample_weights)
File "C:\Users\gus\Anaconda3\envs\TF2\lib\site-packages\tensorflow_core\python\keras\distribute\distributed_training_utils.py", line 308, in validate_distributed_dataset_inputs
x_values_list = validate_per_replica_inputs(distribution_strategy, x)
File "C:\Users\gus\Anaconda3\envs\TF2\lib\site-packages\tensorflow_core\python\keras\distribute\distributed_training_utils.py", line 356, in validate_per_replica_inputs
validate_all_tensor_shapes(x, x_values)
File "C:\Users\gus\Anaconda3\envs\TF2\lib\site-packages\tensorflow_core\python\keras\distribute\distributed_training_utils.py", line 373, in validate_all_tensor_shapes
x_shape = x_values[0].shape.as_list()
File "C:\Users\gus\Anaconda3\envs\TF2\lib\site-packages\tensorflow_core\python\framework\tensor_shape.py", line 1171, in as_list
raise ValueError("as_list() is not defined on an unknown TensorShape.")
ValueError: as_list() is not defined on an unknown TensorShape.
1/Unknown - 0s 10ms/step
1/Unknown - 0s 10ms/step
Process finished with exit code 1
似乎它开始运行但随后停止,因为没有产生任何东西。
解决方案
tf.data.Dataset
与 Keras 一起使用ImageDataGenerator
有点棘手。您可以改用 Keras 内置的fit_generator方法。
为此,您可以跳过此部分
# ds = tf.data.Dataset.from_generator(train_datagen.flow_from_directory,
# args=[DATA_PATH], output_types=(tf.float32, tf.float32))
并使用 Keras 生成器:
train_generator = train_datagen.flow_from_directory(
DATA_PATH,
target_size=(150, 150), # or other parameters you need
batch_size=32,
class_mode='binary')
最后,可以通过提到的调用训练fit_generator
:
model.fit_generator(
train_generator,
steps_per_epoch=2000,
epochs=50,
validation_data=validation_generator,
validation_steps=800)
关于这个主题的文档非常好,我建议检查一下。干杯!
推荐阅读
- ruby-on-rails - Rails如何使用where方法进行搜索或全部返回?
- asp.net-mvc - 如何在 ASP.Net Core 2.0 中使用图形 api 发布到 facebook?
- sql - 如何将多列的列不同值作为具有相应列名的数组获取?
- cakephp - 蛋糕PHP 3.6。安装 SSL 后,我得到“找不到控制器类 Webroot”
- c# - Visual FoxPro 9 C# OleDbAdapter 插入 - 功能不可用
- php - symfony 学说-> findBy(),我可以在不使用存储库方法的情况下返回一个可解析的 PHP 数组吗?
- wordpress - }); 这段代码
- php - Laravel 尝试/赶上不工作
- c# - 无法使用 NHibernate 将数据插入数据库
- protocol-buffers - 使用 proto 文件作为文档的一个来源 - 处理扩展的元数据/评论