python - 属性错误:使用 tf.keras fit_generator() 时,“NoneType”对象没有属性“shape”
问题描述
我有超过 10000 张图像的数据集,我正在使用 tf.keras DataGenerator 批量加载数据。但是,当我使用 model.fit_generator 拟合模型时出现错误:“NoneType”对象没有属性“shape”。
这是代码片段:
import math
import random
import cv2
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.utils import Sequence
from tensorflow.keras.applications.mobilenet import preprocess_input
class DataGenerator(Sequence):
def __init__(self, dataset, batch_size=30, shuffle=True, predict=False):
self.dataset = dataset
self.batch_size=batch_size
self.shuffle=shuffle
self.predict=predict
self.on_epoch_end()
def __len__(self):
return math.ceil(len(self.dataset) /self.batch_size)
def __getitem__(self, index):
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
image_batch = [self.dataset[i][1]['dicom'] for i in indexes]
bbox_batch = [self.dataset[i][1]['boxes'] for i in indexes]
X = self.__generate_X(image_batch)
if self.predict:
return X
else:
masks = self.__generate_masks(bbox_batch)
return X, masks
def __generate_X(self, image_batch):
X = np.zeros((len(image_batch), IMAGE_WIDTH, IMAGE_HEIGHT, 1))
for k, image_path in enumerate(image_batch):
img = dicom.read_file(image_path).pixel_array
img = cv2.resize(img, dsize=(IMAGE_HEIGHT, IMAGE_WIDTH), interpolation=cv2.INTER_CUBIC)
img = np.expand_dims(img, axis=-1)
X[k] = preprocess_input(np.array(img, dtype=np.float32))
def __generate_masks(self, bbox_batch):
masks = np.zeros((len(bbox_batch), IMAGE_WIDTH, IMAGE_HEIGHT))
width_factor = IMAGE_WIDTH/imageWidth
height_factor = IMAGE_HEIGHT/imageHeight
for k, bbox_items in enumerate(bbox_batch):
if len(bbox_items) > 0:
for idx, val in enumerate(bbox_items):
x1 = round(val[0]* width_factor)
x2 = round((val[0]+val[2])* width_factor)
y1 = round(val[1]*height_factor)
y2 = round((val[1]+val[3])*height_factor)
masks[k][y1:y2, x1:x2]=1
def on_epoch_end(self):
self.indexes = np.arange(len(self.dataset))
if self.shuffle == True:
np.random.shuffle(self.indexes)
model = create_model()
model.compile()
train_gen = DataGenerator(X_train, batch_size=30, shuffle=True, predict=False)
val_gen= DataGenerator(X_val, batch_size=30, shuffle=True, predict=False)
model.fit_generator(train_gen, validation_data = val_gen, epochs=1, shuffle=True, verbose=1)
输入:X_train 和 X_val 是 numpy 数组 Tensorflow 版本:1.15.0 Keras 版本:2.2.4 这是我在使用 fit_generator 时遇到的错误
AttributeError Traceback (most recent call last)
<ipython-input-52-b30d342db2da> in <module>
----> 1 model.fit_generator(train_gen, validation_data = val_gen, epochs=1, verbose=1)
2
~\Anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\keras\engine\training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
1294 shuffle=shuffle,
1295 initial_epoch=initial_epoch,
-> 1296 steps_name='steps_per_epoch')
1297
1298 def evaluate_generator(self,
~\Anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\keras\engine\training_generator.py in model_iteration(model, data, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch, mode, batch_size, steps_name, **kwargs)
255 # `batch_size` used for validation data if validation
256 # data is NumPy/EagerTensors.
--> 257 batch_size = int(nest.flatten(batch_data)[0].shape[0])
258
259 # Callbacks batch begin.
AttributeError: 'NoneType' object has no attribute 'shape'
我非常感谢任何解决此问题的指导。
解决方案
在函数 def __generate_X(self, image_batch): 和 def __generate_masks(self, bbox_batch): 中没有返回语句
X = self.__generate_X(image_batch)
if self.predict:
return X
else:
masks = self.__generate_masks(bbox_batch)
return X, masks
这就是为什么 X 和 mask 只不过是一个None对象
推荐阅读
- apache-spark - emr 上的 pyspark 使用自动广播(即使已禁用)和用于简单 sql 查询的嵌套连接
- html - 将内联块元素添加到表格跨越列时,HTML 表格列的宽度会发生变化
- java - 更改变量时重绘不起作用
- c# - 我不知道为什么,但我的附加力不起作用我有一个刚体 2d,代码看起来正确但它仍然不起作用?
- angular - 从 Observable Array of Objects 迭代 *ngFor 时如何修复 InvalidPipeArgument 错误
- javascript - 如何使用 JavaScript 语言在 Jmeter WebDriver Sampler 中设置 InternetExplorerOptions?
- apache - mod_jk 使用 Lucee 和 Apache 生成不正确的重定向查询字符串
- system-verilog - DPI-C 中 Struct 中的动态数组
- node.js - 使用 Node.js 从 txt 文件中读取坐标
- android - com.google.android.gms.internal.firebase-perf.zza 中缺少方法