python-3.x - 如何在 tensorflow 2 Tensorflow 2 / Keras 中进行自定义验证步骤?
问题描述
我对验证数据有疑问。我有这个神经网络,我将数据分为 train_generator、val_generator、test_generator。
我制作了一个定制模型。
class MyModel(tf.keras.Model):
def __init__(self):
def __call__(.....)
def train_step(....)
然后我有:
train_generator = DataGenerator(....)
val_generator = DataGenerator(....)
test_generator = DataGenerator(....)
然后 :
model = MyModel()
model.compile(optimizer=keras.optimizers.Adam(clipnorm=5.),
metrics=["accuracy"])
model.fit(train_generator, validation_data = val_generator, epochs=40)
好的,程序没有给我任何错误但我的问题是:我怎么知道我的validation_data会发生什么?它的处理方式是否与 train_step 函数中的 train_data (train_generator) 相同?还是我需要指定如何处理验证数据?
如果有帮助,我还将参加 MyModel 课程
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel2, self).__init__()
self.dec2 = Decoder2()
def __call__(self, y_hat, **kwargs):
print(y_hat.shape)
z_hat = self.dec2(y_hat)
return z_hat
def train_step(self, dataset):
with tf.GradientTape() as tape:
y_hat = dataset[0]
z_true = dataset[1]
z_pred = self(y_hat, training=True)
#print("This is z_true : ", z_true.shape)
#print("This is z_pred : ", z_pred.shape)
loss = tf.reduce_mean(tf.abs(tf.cast(z_pred, tf.float64) - tf.cast(z_true, tf.float64)))
print("loss: ", loss)
global_loss.append(loss)
# Compute gradients. TRE SA FAC GRADIENT CLIPPING
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(z_true, z_pred)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
解决方案
您必须在 MyModel 类中添加一个 test_step(self, data) 函数,如您在此处看到的那样:提供您自己的评估步骤
推荐阅读
- iccube - icCube:ic3Table 和图表中的剪切标签
- php - Laravel 中文件管理器包的问题
- regex - 一元自然数的匹配范围
- python - 从 DataFrame 到嵌套的 Json 对象
- amazon-web-services - ec2 实例上的 docker-compose 返回 Permission denied: '/etc/grub.d'
- r - 将 ggpredict() 和 ggplot2() 与缩放的连续变量一起使用并尝试对它们进行缩放
- javascript - 我在反应中创建可拖动组件时出错
- asp.net - ASP.NET Core mvc 应用程序中的 FFMPEG 录制
- c++ - 使用 for_each 删除向量的元素
- python - 在另一个函数中使用变量