generator - 通过 ImageDataGenerator 将数据导入 tensorflow 自动编码器
问题描述
当我尝试通过将图像导入为 numpy 数组来训练自动编码器时,训练进行得很快,第一个时期本身的训练损失 < 0,结果也不错。
但是当我通过 ImageDataGenerator 导入相同的数据时,起始损失在 32000 左右,随着训练的进行,它会非常缓慢地下降,并且在 50 个 epoch 之后,它会在 31000 左右饱和。我使用 mse 作为 Adam Optimiser 的损失函数。我尝试了不同的损失函数,但问题仍然存在,例如非常高的值在开始时会很快饱和到非常高的值。欢迎任何建议。谢谢。
以下是我的代码。
from convautoencoder import ConvAutoencoder
from tensorflow.keras.optimizers import Adam
import numpy as np
import cv2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.config import experimental
from tensorflow.python.client import device_lib
devices = experimental.list_physical_devices('GPU')
experimental.set_memory_growth(devices[0], True)
EPOCHS = 5000
BS = 4
trainAug = ImageDataGenerator()
valAug = ImageDataGenerator()
# initialize the training generator
trainGen = trainAug.flow_from_directory(
config.TRAIN_PATH,
class_mode="input",
classes=None,
target_size=(64, 64),
color_mode="grayscale",
shuffle=True,
batch_size=BS)
# initialize the validation generator
valGen = valAug.flow_from_directory(
config.TRAIN_PATH,
class_mode="input",
classes=None,
target_size=(64, 64),
color_mode="grayscale",
shuffle=False,
batch_size=BS)
# initialize the testing generator
testGen = valAug.flow_from_directory(
config.TRAIN_PATH,
class_mode="input",
classes=None,
target_size=(64, 64),
color_mode="grayscale",
shuffle=False,
batch_size=BS)
mc = ModelCheckpoint('best_model_1.h5', monitor='val_loss', mode='min', save_best_only=True)
print("[INFO] building autoencoder...")
(encoder, decoder, autoencoder) = ConvAutoencoder.build(64, 64, 1)
opt = Adam(learning_rate= 0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-04, amsgrad=False)
autoencoder.compile(loss="hinge", optimizer=opt)
H = autoencoder.fit( trainGen, validation_data=valGen, epochs=EPOCHS, batch_size=BS ,callbacks=[ mc])
解决方案
好的。这是一个愚蠢的错误。
添加重新缩放因子 rescale=1。/255 到 imageDataGenerator 解决了这个问题。
推荐阅读
- javascript - php 使用 Pusher 或 socket.io 显示在线用户
- python - 获取列表中项目的全长窗口
- argo-workflows - Argo UI 和 RBAC - 在命名空间安装中按组显示工作流
- python - 如何转换矩阵内的区域?
- c# - 创建指数平均值
- scala - 如何在 Scala 应用程序中修复“接口中的静态方法需要 -target:jvm-1.8”?
- typescript - 动态导入打字稿快递
- c# - 在 WPF C# 中的 DatagridComboboxColumn 正文中使用 Datagrid
- c# - LookupAccountSid() 在 Server 2016 上引发 System.AccessViolationException
- javascript - 是否可以使用 Web Auido API 在不扭曲声音的情况下向 MediaElementAudioSourceNode 添加效果?