python - 将数据从自定义数据生成器传递到 model.fit()
问题描述
我正在做端到端映射。因为我必须传递两个图像(输入和输出),所以我创建了一个自定义生成器。我的生成器得到两个不同分辨率的相同图像。现在我只能得到 5 张图像传递给训练,但我想传递整个生成器,以便我的所有数据都得到训练。由于我是使用生成器和产量的新手,因此我不是传递整个生成器的正确方法。
import os
import numpy as np
import cv2
class image_gen():
def __init__(self, idir,odir,batch_size, shuffle = True):
self.batch_index=0
self.idir=idir
self.odir=odir# directory containing input images
self.batch_size=batch_size #batch size is number of samples in a batch
self.shuffle=shuffle # set to True to shuffle images, False for no shuffle
self.label_list=[] # initialize list to hold sequential list of total labels generated
self.image_list=[] # initialize list to hold sequential list of total images filenames generated
self.i_list=os.listdir(self.idir)
self.o_list=os.listdir(self.odir)# list of images in directory
def get_images(self): # gets a batch of input images, resizes input image to make target images
while True:
input_image_batch=[]
output_image_batch=[]# initialize list to hold a batch of target images
sample_count=len(self.i_list) # determine total number of images available
for i in range(self.batch_index * self.batch_size, (self.batch_index + 1) * self.batch_size ): #iterate for a batch
j=i % sample_count # cycle j value over range of available images
k=j % self.batch_size # cycle k value over batch size
if self.shuffle: # if shuffle select a random integer between 0 and sample_count-1 to pick as the image=label pair
m=np.random.randint(low=0, high=sample_count-1, size=None, dtype=int)
else:
m=j # no shuffle
#input
path_to_in_img=os.path.join(self.idir,self.i_list[m])
path_to_out_img=os.path.join(self.odir,self.o_list[m])
# define the path to the m th image
input_image=cv2.imread(path_to_in_img)
input_image=cv2.resize( input_image,(3200,3200))#create the target image from the input image
output_image=cv2.imread(path_to_out_img)
output_image=cv2.resize(output_image,(3200,3200))
input_image_batch.append(input_image)
output_image_batch.append(output_image)
input_image_array=np.array(input_image_batch)
input_image_array = input_image_array / 255.0
output_image_array=np.array(output_image_batch)
output_image_array = output_image_array /255.0
self.batch_index= self.batch_index + 1
yield (input_image_array, output_image_array )
if self.batch_index * self.batch_size > sample_count:
break
这就是我获取图像的方式
batch_size=5
idir=r'D:\\train'
odir=r'D:\\Train\\train'#
shuffle=True
gen=image_gen(idir,odir,batch_size,shuffle=True) # instantiate an instance of the class
input_images,output_images = next(gen.get_images())
这就是我训练的方式。这样我只训练 5 张图像而不是整个数据集
model.fit(input_images,output_images,validation_data = (valin_images,valout_images),batch_size= 5,epochs = 100)
当我尝试传递整个数据集时
model.fit(gen(),validation_data = (valin_images,valout_images),batch_size= 5,epochs = 1)
我收到一个错误“image_gen”对象不可调用。我应该如何将生成器传递给 model.fit()
解决方案
出现此问题的原因是因为当您尝试访问 a 时会引发此错误,就image_gen
好像它是一个函数一样,但实际上它是一个类的对象。
在您提供的第一个片段中,您实际上访问了该类的方法,该方法确实是一个生成器,它产生了一些numpy
可以作为模型输入的数组。然而,由于第一段中描述的错误,第二个片段失败了。
您的问题的两种可能的解决方案如下:
- 使用
Keras Sequence()
发电机。 - 将函数用作生成器 (
def my_generator(...)
)。
我个人推荐第一个解决方案,因为Sequence()
生成器确保您在一个时期内每个样本只训练一次,这在简单函数生成器的情况下不满足。
- 解决方案
Keras Sequence()
:
您需要覆盖Sequence
该类,然后覆盖其方法。TensorFlow 官方文档中的一个完整示例是:
from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.
class CIFAR10Sequence(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return math.ceil(len(self.x) / self.batch_size)
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) *
self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) *
self.batch_size]
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
您可以使用上面的代码作为解决方案的起点。顺便说一句,您的网络很可能不会以如此大的图像尺寸进行训练,您也可以尝试降低它们。
简单生成器的解决方案可能是:
def my_generator(path_to_dataset, other_argument): ... ... yield image_1, image_2 train_generator = my_generator(path_to_train,argument_1) val_generator = my_generator(path_to_val,argument_2) model.fit(train_generator, steps_per_epoch=len(training_samples) // BATCH_SIZE, epochs=10, validation_data=val_generator, validation_steps=len(validation_samples) // BATCH_SIZE)
推荐阅读
- angularjs - 如何检测我处于测试模式 angularjs
- android - java.lang.IllegalArgumentException 仅在 Oreo 上崩溃
- c# - ASP.NET Core 2.1 本地化选项 SupportedCultures 始终只包含英文
- javascript - Express, 如何加载 JS 文件
- updates - Azure B2C 以编程方式更新自定义用户属性
- json - Elastica 和 Json 字段
- python-2.7 - 将舍入浮点数转换为字符串 pandas 的问题
- python - 使属性只读的更短方法
- mysql - 获取一个巨大的 gzip 文件的最后一行
- bash - 在 bash 中读取和输出矩阵文件