python - 如何塑造 TFRecordDataset 以满足模型 API?
问题描述
我正在基于此代码构建模型以进行噪声抑制。我对 vanilla 实现的问题是它一次加载所有数据,当训练数据变得非常大时,这不是最好的主意。我的输入文件(在链接代码中表示为training.h5
)超过 30 GB。
我决定改为使用tf.data
应该允许我处理大型数据集的界面;我的问题是我不知道如何正确塑造TFRecordDataset
以使其满足模型 API 的要求。
如果您检查model.fit(x_train, [y_train, vad_train]
,它基本上需要以下内容:
- x_train,形状
[nb_sequences, window, 42]
- y_train,形状
[nb_sequences, window, 22]
- vad_train,形状
[nb_sequences, window, 1]
window
一个通常修复(在代码中:) 2000
,因此唯一的变量nb_sequences
源于您的数据集有多大。但是,对于tf.data
,我们不提供x
and y
,而只提供x
(参见模型 API 文档)。
将 tfrecord 保存到文件
为了使代码可重现,我使用以下代码创建了输入文件:
writer = tf.io.TFRecordWriter(path='example.tfrecord')
for record in data:
feature = {}
feature['X'] = tf.train.Feature(float_list=tf.train.FloatList(value=record[:42]))
feature['y'] = tf.train.Feature(float_list=tf.train.FloatList(value=record[42:64]))
feature['vad'] = tf.train.Feature(float_list=tf.train.FloatList(value=[record[64]]))
example = tf.train.Example(features=tf.train.Features(feature=feature))
serialized = example.SerializeToString()
writer.write(serialized)
writer.close()
data
是我们的训练数据,形状为[10000, 65]
。我的在这里example.tfrecord
可用。它是 3 MB,实际上是 30 GB 以上。
您可能会注意到,在链接代码中,numpy 数组具有 shape [x, 87]
,而我的是[x, 65]
. 没关系 - 其余部分不在任何地方使用。
使用 tf.data.TFRecordDataset 加载数据集
我想tf.data
通过一些预取来“按需”加载数据,无需将其全部保存在内存中。我的尝试:
import datetime
import numpy as np
import h5py
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import GRU
from tensorflow.keras import regularizers
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import backend as K
from tensorflow.keras.layers import concatenate
def load_dataset(path):
def _parse_function(example_proto):
keys_to_features = {
'X': tf.io.FixedLenFeature([42], tf.float32),
'y': tf.io.FixedLenFeature([22], tf.float32),
'vad': tf.io.FixedLenFeature([1], tf.float32)
}
features = tf.io.parse_single_example(example_proto, keys_to_features)
return (features['X'], (features['y'], features['vad']))
dataset = tf.data.TFRecordDataset(path).map(_parse_function)
return dataset
def my_crossentropy(y_true, y_pred):
return K.mean(2 * K.abs(y_true - 0.5) * K.binary_crossentropy(y_pred, y_true), axis=-1)
def mymask(y_true):
return K.minimum(y_true + 1., 1.)
def msse(y_true, y_pred):
return K.mean(mymask(y_true) * K.square(K.sqrt(y_pred) - K.sqrt(y_true)), axis=-1)
def mycost(y_true, y_pred):
return K.mean(mymask(y_true) * (10 * K.square(K.square(K.sqrt(y_pred) - K.sqrt(y_true))) + K.square(
K.sqrt(y_pred) - K.sqrt(y_true)) + 0.01 * K.binary_crossentropy(y_pred, y_true)), axis=-1)
def my_accuracy(y_true, y_pred):
return K.mean(2 * K.abs(y_true - 0.5) * K.equal(y_true, K.round(y_pred)), axis=-1)
class WeightClip(Constraint):
'''Clips the weights incident to each hidden unit to be inside a range
'''
def __init__(self, c=2.0):
self.c = c
def __call__(self, p):
return K.clip(p, -self.c, self.c)
def get_config(self):
return {'name': self.__class__.__name__,
'c': self.c}
def build_model():
reg = 0.000001
constraint = WeightClip(0.499)
main_input = Input(shape=(None, 42), name='main_input')
tmp = Dense(24, activation='tanh', name='input_dense', kernel_constraint=constraint, bias_constraint=constraint)(
main_input)
vad_gru = GRU(24, activation='tanh', recurrent_activation='sigmoid', return_sequences=True, name='vad_gru',
kernel_regularizer=regularizers.l2(reg), recurrent_regularizer=regularizers.l2(reg),
kernel_constraint=constraint, recurrent_constraint=constraint, bias_constraint=constraint)(tmp)
vad_output = Dense(1, activation='sigmoid', name='vad_output', kernel_constraint=constraint,
bias_constraint=constraint)(vad_gru)
noise_input = concatenate([tmp, vad_gru, main_input])
noise_gru = GRU(48, activation='relu', recurrent_activation='sigmoid', return_sequences=True, name='noise_gru',
kernel_regularizer=regularizers.l2(reg), recurrent_regularizer=regularizers.l2(reg),
kernel_constraint=constraint, recurrent_constraint=constraint, bias_constraint=constraint)(noise_input)
denoise_input = concatenate([vad_gru, noise_gru, main_input])
denoise_gru = GRU(96, activation='tanh', recurrent_activation='sigmoid', return_sequences=True, name='denoise_gru',
kernel_regularizer=regularizers.l2(reg), recurrent_regularizer=regularizers.l2(reg),
kernel_constraint=constraint, recurrent_constraint=constraint, bias_constraint=constraint)(
denoise_input)
denoise_output = Dense(22, activation='sigmoid', name='denoise_output', kernel_constraint=constraint,
bias_constraint=constraint)(denoise_gru)
model = Model(inputs=main_input, outputs=[denoise_output, vad_output])
model.compile(loss=[mycost, my_crossentropy],
metrics=[msse],
optimizer='adam', loss_weights=[10, 0.5])
return model
model = build_model()
dataset = load_dataset('example.tfrecord')
我的数据集现在具有以下形状:
<MapDataset shapes: ((42,), ((22,), (1,))), types: (tf.float32, (tf.float32, tf.float32))>
我认为这是 Model API 所期望的(剧透:它没有)。
model.fit(dataset.batch(10))
给出以下错误:
ValueError: Error when checking input: expected main_input to have 3 dimensions, but got array with shape (None, 42)
有道理,我window
这里没有。同时,它似乎没有得到预期的正确形状Model(inputs=main_input, outputs=[denoise_output, vad_output])
。
如何修改load_dataset
以使其符合 Model API 对tf.data
?
解决方案
鉴于您的模型有 1 个输入和 2 个输出,您tf.data.Dataset
应该有两个条目:
1)形状的输入数组(window, 42)
2)两个数组的元组,每个数组的形状(window, 22)
和(window, 1)
编辑:更新的答案 - 你已经返回两个元素元组
我刚刚注意到您的数据集有这两个条目(类似于上面描述的条目),唯一不同的是形状。
您需要执行的唯一操作是将数据批处理两次:
首先 - 恢复窗口参数。第二 - 将批次传递给模型。
window_size = 1
batch_size = 10
dataset = load_dataset('example.tfrecord')
model.fit(dataset.batch(window_size).batch(batch_size)
这应该有效。
下面是一个旧答案,我错误地假设了您的数据集形状:
旧答案,我假设您要返回三个元素元组:
假设您从 shape和的三元素元组开始(42,)
,这可以在相同的批处理操作中实现,并使用返回二元素元组的函数进行丰富:(22,)
(1,)
custom_reshape
window_size = 1
batch_size = 10
dataset = load_dataset('example.tfrecord')
dataset = dataset.batch(window_size).batch(batch_size)
# Change output format
def custom_reshape(x, y, vad):
return x, (y, vad)
dataset = dataset.map(custom_reshape)
简而言之,给定这个数据集的形状,你可以调用:
model.fit(dataset.batch(window_size).batch(10).map(custom_reshape)
它也应该可以工作。
祝你好运。再次为大惊小怪感到抱歉。
推荐阅读
- php - 如何按变量之一的数值对此 JSON 进行排序
- node.js - 为什么我的递归函数(涉及循环中的异步调用)不返回调用者函数以继续循环?
- python - 如何将数字从for循环放入列表
- php - 基于 Elasticsearch 嵌套对象的过滤和计数操作
- python - 打印更新的字符串而不显示前一个
- node.js - 如何自定义 sequelize 枚举错误信息?
- python - Heroku Local 工作正常,它连接到 Heroku 托管的 Postgresql DB。一旦尝试部署,得到错误代码=H14 状态=503
- python - 无法在 Ubuntu 20.04 上使用 pip3 安装 pytorch
- c++ - 是否可以根据提供的大小调整环形缓冲区的大小,同时在大小减小/增加的情况下保留元素?
- android - 出现错误在 android 中运行第一个颤振项目