python - 训练损失减少,但当我使用 tensorflow.data api 时模型无法学习
问题描述
我正在训练一个有 4 个类的分类模型。当我将数据生成器与numpy
数组一起使用时,模型训练得很好,损失减少了,预测也很好。
但是,当我用相同的参数为完全相同数量的 epoch 训练相同的模型时,训练损失确实减少到几乎相同的值,但即使在训练数据集上的预测准确率也低于 50%。
我探索了以下链接:
Keras 模型在更改为使用 tf.data api https://github.com/tensorflow/tensorflow/issues/22190
后无法学习任何东西
)
我对损失函数进行了调整(从categorical_crossentropy
到tf.keras.losses.sparse_categorical_crossentropy
),但问题仍然存在。我正在使用tensorflow-gpu 2.0
.
编辑:这是我用来生成数据集的类:
import tensorflow as tf
from glob import glob
import os
from tensorflow.keras.utils import Sequence
from tensorflow.keras.preprocessing.image import ImageDataGenerator
这是我的代码:
import math
import numpy as np
import cv2
"""
This generator is an abstract class as
"""
class TfrecordLoader(Sequence):
def __init__(self, dataset_path, batch_size=10):
# Load the tfrecords
video_paths = glob(os.path.join(dataset_path,"*.tfrecord"))
self._dataset_len = len(video_paths)
# Deserialize the tfrecords based on the below format
def _parse_dicom_feature_record_function(serialized_data):
# Deserialize the data
dicom_feature_description = {
'NUM_FRAMES': tf.io.FixedLenFeature([], tf.int64),
'HEIGHT': tf.io.FixedLenFeature([], tf.int64),
'WIDTH': tf.io.FixedLenFeature([], tf.int64),
'CHANNEL_DEPTH': tf.io.FixedLenFeature([], tf.int64),
# 'LVEF': tf.io.FixedLenFeature([], dtype=tf.float32),
'IMAGE_QUALITY': tf.io.FixedLenFeature([], tf.string),
'RECORD_NAME': tf.io.FixedLenFeature([], tf.string),
'VIEW_TYPE': tf.io.FixedLenFeature([], tf.string),
'FRAMES_ARRAY': tf.io.FixedLenFeature([], tf.string),
}
return tf.io.parse_single_example(serialized_data, dicom_feature_description)
self._video_dataset = tf.data.TFRecordDataset(video_paths).map(_parse_dicom_feature_record_function)
# The actual dataset used by this iterator randomly samples the frames of the passed data
self._batch_size = batch_size
self._iterator = self._create_iterator()
return
"""
Create iterator that iterates through loaded tfrecords. Note, this is intended to be overloaded for more
specific traigins cases based on this dataset
"""
def _create_iterator(self):
# Return an iterator to that data
return iter(self._video_dataset.shuffle(self._dataset_len).batch(self._batch_size))
"""
Give the length of and epoch
"""
def __len__(self):
return math.floor(self._dataset_len / self._batch_size)
"""
Returns one batch of data
"""
def __getitem__(self, idx):
return next(self._iterator)
"""
Called at the end of each epic to select a new shuffled balanced dataset
"""
def on_epoch_end(self):
print ('\n\n ------------------------- EPOCH END ------------------------ \n\n')
self._iterator = self._create_iterator()
return
"""
Convenience method for converting a tfrecord from this dataset into a traditional python dictionary
"""
def convert_tfrecorddata_to_dict(self, record, batch_index):
converted_dict = {}
converted_dict['NUM_FRAMES'] = int(record['NUM_FRAMES'][batch_index])
converted_dict['HEIGHT'] = int(record['HEIGHT'][batch_index])
converted_dict['WIDTH'] = int(record['WIDTH'][batch_index])
converted_dict['CHANNEL_DEPTH'] = int(record['CHANNEL_DEPTH'][batch_index])
# converted_dict['LVEF'] = float(record['LVEF'][batch_index])
converted_dict['IMAGE_QUALITY'] = record['IMAGE_QUALITY'][batch_index].numpy().decode()
converted_dict['RECORD_NAME'] = record['RECORD_NAME'][batch_index].numpy().decode()
converted_dict['VIEW_TYPE'] = record['VIEW_TYPE'][batch_index].numpy().decode()
# Convert the raw frames data from a string to a flat array of floats
raw_frames_data = tf.io.decode_raw(record['FRAMES_ARRAY'][batch_index], tf.uint8)
converted_dict['FRAMES_ARRAY'] = tf.reshape(raw_frames_data, [converted_dict['NUM_FRAMES'],
converted_dict['HEIGHT'],
converted_dict['WIDTH'],
converted_dict['CHANNEL_DEPTH']]).numpy()
return converted_dict
def main():
print (tf.__version__)
record_generator = TfrecordLoader('test_tfrecords/',10)
# Cycle through the dataset
for record_batch in record_generator:
# convert the first record in the batch
converted_record = record_generator.convert_tfrecorddata_to_dict(record_batch,0)
# view the first frame of the video
squeezed_img = converted_record['FRAMES_ARRAY'][0].squeeze()
cv2.imshow(converted_record['RECORD_NAME'], squeezed_img)
cv2.waitKey(200)
cv2.destroyAllWindows()
return
if __name__ == '__main__':
main()`
解决方案
推荐阅读
- cloudkit - 云套件:如何仅从您的联系人组获取公共数据库订阅通知
- python - 使用 Python 进行缓慢的 SQLite 更新
- c# - c# group by first word
- visual-studio - 用于 Javascript/AJAX 的 Visual Studio Code 镜头
- javascript - Discord Bot 多次响应事件
- php - 大小动态数组PHP,POST方式,用Jquery
- node.js - Jest 进程在所有测试执行之前退出,没有任何反馈
- java - Char变量减去char变量在java中解析为Int
- laravel - 登录 Laravel 8 应用程序时出现“未找到列” - 假设缺少“id”列?
- jmeter - 如何解决 Jmeter 和软件之间的问题