python - 为什么 keras 中的自定义图像生成器会给出错误“对象不能被解释为整数”?
问题描述
我在 keras 中使用了自定义图像生成器的模板,这样我就可以使用 hdf5 文件作为输入。最初,代码给出了“形状”错误,所以我只from tensorflow.python.keras.utils.data_utils import Sequence
关注了这篇文章。现在我以这种形式使用它,你也可以在我的colab notebook中看到:
from numpy.random import uniform, randint
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
import numpy as np
from tensorflow.python.keras.utils.data_utils import Sequence
class CustomImagesGenerator(Sequence):
def __init__(self, x, zoom_range, shear_range, rescale, horizontal_flip, batch_size):
self.x = x
self.zoom_range = zoom_range
self.shear_range = shear_range
self.rescale = rescale
self.horizontal_flip = horizontal_flip
self.batch_size = batch_size
self.__img_gen = ImageDataGenerator()
self.__batch_index = 0
def __len__(self):
# steps_per_epoch, if unspecified, will use the len(generator) as a number of steps.
# hence this
return np.floor(self.x.shape[0]/self.batch_size)
# @property
# def shape(self):
# return self.x.shape
def next(self):
return self.__next__()
def __next__(self):
start = self.__batch_index*self.batch_size
stop = start + self.batch_size
self.__batch_index += 1
if stop > len(self.x):
raise StopIteration
transformed = np.array(self.x[start:stop]) # loads from hdf5
for i in range(len(transformed)):
zoom = uniform(self.zoom_range[0], self.zoom_range[1])
transformations = {
'zx': zoom,
'zy': zoom,
'shear': uniform(-self.shear_range, self.shear_range),
'flip_horizontal': self.horizontal_flip and bool(randint(0,2))
}
transformed[i] = self.__img_gen.apply_transform(transformed[i], transformations)
import pdb;pdb.set_trace()
return transformed * self.rescale
我用以下方法调用生成器:
import h5py
import tables
in_hdf5_file = tables.open_file("gdrive/My Drive/Colab Notebooks/dataset.hdf5", mode='r')
images = in_hdf5_file.root.train_img
my_gen = CustomImagesGenerator(
images,
zoom_range=[0.8, 1],
batch_size=32,
shear_range=6,
rescale=1./255,
horizontal_flip=False
)
classifier.fit_generator(my_gen, steps_per_epoch=100, epochs=1, verbose=1)
导入Sequence
解决了“形状”错误,但现在我收到错误:
线程 Thread-5 中的异常:回溯(最后一次调用):
文件“/usr/lib/python3.6/threading.py”,第 916 行,在 _bootstrap_inner self.run() 文件“/usr/lib/python3. 6/threading.py”,第 864 行,在运行 self._target(*self._args, **self._kwargs) 文件“/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/ utils/data_utils.py",第 742 行,在 _run sequence = list(range(len(self.sequence))) 类型错误:'numpy.float64' 对象不能解释为整数
我该如何解决这个问题?我怀疑这可能又是 keras 软件包中的冲突,并且不知道如何解决它。
解决方案
在您的情况下使用model.fit()
示例:
from tensorflow.keras.utils import to_categorical
import tensorflow as tf
import tables
#define your model
...
#load your data from an hdf5 file
in_hdf5_file = tables.open_file("path/to/your/dataset.hdf5", mode='r')
x = in_hdf5_file.root.train_img[:]
y = in_hdf5_file.root.train_labels[:]
yourModel.fit(x, to_categorical(y, 3), epochs=2, batch_size=5)
有关更多信息,请阅读我对您原始帖子的评论,或随时提问。
编辑:我修复了你的生成器,现在它只需要你的 hdf5 文件的路径
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import *
from tensorflow.keras.utils import to_categorical
import numpy as np
from tensorflow.python.keras.utils.data_utils import Sequence
import tensorflow as tf
import tables
#define your model
...
#training
def h5data_generator(path, batch_size=1):
batch_index = 0
while 1:
with tables.open_file(path, mdoe='r') as f:
x = f.root.train_img[batch_index:batch_index + batch_size]
y = f.root.train_labels[batch_index:batch_index + batch_size]
if batch_index >= x.shape[0]:
batch_index = 0
batch_index += 1
yield (x, to_categorical(y, 3))
del x
del y
my_gen = h5data_generator("path/to/your/dataset.hdf5")
yourModel.fit_generator(my_gen, steps_per_epoch=100, epochs=20, verbose=1)
您的生成器的问题是步骤中的错误数据输出,它没有输出(x, y)
,它不可能,它正在输出x
(在您的情况下为图像),也因为它使用Sequential
keras 试图将其解释为使用的对象它是 api(不是你的生成器的情况)。此外,它不必是 a class
,它必须是 python 生成器fit_generator()
,如 keras it self(doc string of )中的示例所示,
fit_generator.__doc__
:
Fits the model on data yielded batch-by-batch by a Python generator.
The generator is run in parallel to the model, for efficiency.
For instance, this allows you to do real-time data augmentation
on images on CPU in parallel to training your model on GPU.
The use of `keras.utils.Sequence` guarantees the ordering
and guarantees the single use of every input per epoch when
using `use_multiprocessing=True`.
Arguments:
generator: A generator or an instance of `Sequence`
(`keras.utils.Sequence`)
object in order to avoid duplicate data
when using multiprocessing.
The output of the generator must be either
- a tuple `(inputs, targets)`
- a tuple `(inputs, targets, sample_weights)`.
This tuple (a single output of the generator) makes a single batch.
Therefore, all arrays in this tuple must have the same length (equal
to the size of this batch). Different batches may have different
sizes.
For example, the last batch of the epoch is commonly smaller than
the
others, if the size of the dataset is not divisible by the batch
size.
The generator is expected to loop over its data
indefinitely. An epoch finishes when `steps_per_epoch`
batches have been seen by the model.
steps_per_epoch: Total number of steps (batches of samples)
to yield from `generator` before declaring one epoch
finished and starting the next epoch. It should typically
be equal to the number of samples of your dataset
divided by the batch size.
Optional for `Sequence`: if unspecified, will use
the `len(generator)` as a number of steps.
epochs: Integer, total number of iterations on the data.
verbose: Verbosity mode, 0, 1, or 2.
callbacks: List of callbacks to be called during training.
validation_data: This can be either
- a generator for the validation data
- a tuple (inputs, targets)
- a tuple (inputs, targets, sample_weights).
validation_steps: Only relevant if `validation_data`
is a generator. Total number of steps (batches of samples)
to yield from `generator` before stopping.
Optional for `Sequence`: if unspecified, will use
the `len(validation_data)` as a number of steps.
validation_freq: Only relevant if validation data is provided. Integer
or `collections.Container` instance (e.g. list, tuple, etc.). If an
integer, specifies how many training epochs to run before a new
validation run is performed, e.g. `validation_freq=2` runs
validation every 2 epochs. If a Container, specifies the epochs on
which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
validation at the end of the 1st, 2nd, and 10th epochs.
class_weight: Dictionary mapping class indices to a weight
for the class.
max_queue_size: Integer. Maximum size for the generator queue.
If unspecified, `max_queue_size` will default to 10.
workers: Integer. Maximum number of processes to spin up
when using process-based threading.
If unspecified, `workers` will default to 1. If 0, will
execute the generator on the main thread.
use_multiprocessing: Boolean.
If `True`, use process-based threading.
If unspecified, `use_multiprocessing` will default to `False`.
Note that because this implementation relies on multiprocessing,
you should not pass non-picklable arguments to the generator
as they can't be passed easily to children processes.
shuffle: Boolean. Whether to shuffle the order of the batches at
the beginning of each epoch. Only used with instances
of `Sequence` (`keras.utils.Sequence`).
Has no effect when `steps_per_epoch` is not `None`.
initial_epoch: Epoch at which to start training
(useful for resuming a previous training run)
Returns:
A `History` object.
Example:
```python
def generate_arrays_from_file(path):
while 1:
f = open(path)
for line in f:
# create numpy arrays of input data
# and labels, from each line in the file
x1, x2, y = process_line(line)
yield ({'input_1': x1, 'input_2': x2}, {'output': y})
f.close()
model.fit_generator(generate_arrays_from_file('/my_file.txt'),
steps_per_epoch=10000, epochs=10)
```
Raises:
ValueError: In case the generator yields data in an invalid format.
有关更多信息,请查看 keras 的 github 页面,fit_generator()
确切地说,或者再次随时提问。
编辑 2:您也可以传递batch_size
给h5data_generator()
,这将设置从数据集中提取的数据的批量大小。
推荐阅读
- node.js - 面向 Android 12 及更高版本的应用需要为 `android:exported` [Cordova] 指定显式值
- php - Laravel 我想从 {{$string}} 指令中获取字符串值并传递到刀片文件中的区域,我在 $pagename var 中传递一个值
- vb.net - 在 ReadOnly DGV 的特定单元格中聚焦并输入编辑
- spring - 将 springboot 1.5 升级到 2.0 后,出现异常“java.lang.AbstractMethodError”
- php - 如何获取更新行的 ID?
- c# - V4 Bot Framework CreateConversationAsync(ConversationReference 已过时)
- powershell - 从 Powershell 发送邮件失败
- flutter - 在应用程序启动时使用 Provider 的 SharedPreferences
- authentication - Auth0 Login 不适用于 Angular 12 中的路由策略 { useHash: true }
- statistics - 如何计算MCMC方法和观测数据获得的参数最佳值的误差