python - 具有 3D-CNN s3dg 模型权重的 TensorFlow 训练循环
问题描述
我坚持为 s3dg 3D-CNN 模型创建一个训练循环。tensorflow 代码模型示例在这里,我正在关注:https ://github.com/tensorflow/models/blob/master/research/slim/nets/s3dg.py
我最初习惯使用 Keras model.fit 和 model.train,所以这种风格的 tensorflow 对我来说是新的。
我正在按照这个示例创建一个训练循环。
我的训练循环如下所示:
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = s3dg(x)
print("WHATS Y",y)
print("WHATS LOGITS",logits)
print("WHATS LOGITS SHAPES",(logits[1]['Predictions']))
loss_value = loss_fn(y, logits[1]['Predictions'])
grads = tape.gradient(loss_value, s3dg.trainable_variables)
optimizer.apply_gradients(zip(grads, s3dg.trainable_variables))
train_acc_metric.update_state(y, logits[1]['Predictions'])
return loss_value
@tf.function
def test_step(x, y):
val_logits = s3dg(x)
val_acc_metric.update_state(y, val_logits)
for epoch in tqdm.trange(EPOCHS, desc='Epoch Loop'):
print("\nStart of epoch %d" % (epoch,))
start_time = time.time()
# Iterate over the batches of the dataset.
for step, (x_batch_train, y_batch_train) in enumerate(train):
loss_value = train_step(x_batch_train, y_batch_train)
# Log every 50 batches.
if step % 50 == 0:
print(
"Training loss (for one batch) at step %d: %.4f"
% (step, float(loss_value))
)
print("Seen so far: %d samples" % ((step + 1) * BS))
# Display metrics at the end of each epoch.
train_acc = train_acc_metric.result()
print("Training acc over epoch: %.4f" % (float(train_acc),))
# Reset training metrics at the end of each epoch
train_acc_metric.reset_states()
# Run a validation loop at the end of each epoch.
for x_batch_val, y_batch_val in valid:
test_step(x_batch_val, y_batch_val)
val_acc = val_acc_metric.result()
val_acc_metric.reset_states()
print("Validation acc: %.4f" % (float(val_acc),))
print("Time taken: %.2fs" % (time.time() - start_time))
我被卡住了,因为来自 tensorflow 模型的 s3dg 代码示例没有显示如何访问可训练变量/可训练权重。所以 s3dg.trainable_variables 或 s3dg.trainable_weights 是未定义的,我不确定如何在训练循环中更新我的这些值。
grads = tape.gradient(loss_value, s3dg.trainable_variables)
optimizer.apply_gradients(zip(grads, s3dg.trainable_variables))
这是的输出
print("WHATS LOGITS",logits)
https://www.dropbox.com/s/w5nu6267j0b4wzn/Screen%20Shot%202021-09-26%20at%205.48.52%20PM.png?dl=0
当我遵循“NDHWC”的结构时,我输入到 s3dg 的形状似乎是正确的
Dropbox 上显示的图像中的形状显示 (50,1000) - 那是因为屏幕截图现在没有更新,但实际上形状是 (50,2) (第二个参数是参数https 中描述的类标签的数量: //github.com/tensorflow/models/blob/master/research/slim/nets/s3dg.py。
原来设置为1000,我改成2,因为我只对两个标签进行分类。
所以返回的 logits 是一个带有字典的张量,我访问 ['Predictions'] 以返回 loss_fn。
但是我无法正确更新权重。
解决方案
推荐阅读
- java - 通过Java设置springboot logback根记录器级别不起作用
- user-interface - 重新创建 Pinterest UI Flutter
- c# - 关闭代码文件时自动折叠区域
- angular - 带有可重用子组件示例的 Angular 8 反应式表单
- notepad++ - 大型 .A2L 文件,引用的翻译问题
- mysql - MySQL:如何在客户端和服务器端都启用本地加载数据
- angular - 对返回 ValidatorFn[] 类型的函数进行单元测试
- php - 当我在其他文件中调用它们时,类静态变量会重置
- java - Java Hibernate 的主键有问题吗?
- java - JDBC 拒绝使用指定的密码