tensorflow - 在构建 tf_estimator 时面临 tensorflow2.0 中的错误
问题描述
import custom_model as CM
import input_pipeline as IP
import tensorflow as tf
def custom_estimator(features, labels, mode):
logits = CM.model_net(features=features, n_classes=5)
prediction = tf.keras.layers.Activation('softmax')(logits)
preds_dict = {'class': tf.argmax(input=prediction, axis=1),
'probabilities': prediction,
'logits': logits}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode,
predictions=preds_dict)
# Compute loss
labels = tf.reshape(labels, (BATCH_SIZE, 5))
loss = tf.keras.losses.categorical_crossentropy(y_true=labels,
y_pred=prediction)
# Compute evaluation metrics
accuracy = custom_accuracy(labels, prediction)
metrics = {'accuracy': accuracy}
tf.summary.scalar('accuracy', accuracy)
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(mode, loss=loss,
eval_metric_ops=metrics)
optimizer = tf.keras.optimizers.Adam()
train_op = optimizer.minimize(loss)
return tf.estimator.EstimatorSpec(mode, loss=loss,
train_op=train_op)
# Build tf_estimator
classifier = tf.estimator.Estimator(model_fn=custom_estimator,
model_dir=model_dir)
# Train the estimator
TRAIN_FILES, TRAIN_LABELS = IP.map_file_to_label(data_dir=data_dir)
TRAIN = classifier.train(input_fn=lambda:
IP.imgs_input_fn(TRAIN_FILES, labels=TRAIN_LABELS,
perform_shuffle=True, repeat_count=EPOCHS,
batch_size=BATCH_SIZE),
steps=int(len(TRAIN_LABELS)/BATCH_SIZE))
这是我在使用 TensorFlow-2.0 时遇到的错误。此处附有相同的错误图像和代码。请帮忙。如果我输入 var_list=None 那么错误是“ValueError: Passed in object of type , not tf.Tensor”
解决方案
首先,我认为这个 Estimator 代码示例不符合 TensorFlow 2.0。在任何情况下,如果您使用的是 1.x 版本,请替换:
train_op = optimizer.minimize(loss)
有了这个:
train_op = optimizer.minimize(
loss=average_loss, global_step=tf.train.get_global_step())
如果确实,您使用的是 TensorFlow 2.0,则替换为:
train_op = optimizer.minimize(
loss=average_loss, global_step=tf.compat.v1.train.get_global_step())
推荐阅读
- aws-pinpoint - 如何使用 aws pinpoint 发送带有自定义模板的电子邮件
- javascript - 使嵌套的 mongodb 查询更快
- reactjs - 为什么 this.props.children 无法识别
- asp.net - WCF 服务迁移到 IIS 10 - 一致 503
- r - 提升分类树 gbm 字符变量
- azure - 创建 Azure Policy 和 Blueprint 所需的最低 IAM 角色是什么
- javascript - 单击循环中的 Var 图像
- makefile - How to make backtick commands fail in Makefiles?
- ruby - 如何分配变量而不引用它
- azure - 更新 web.config 以在 Azure 应用服务的 Azure DevOps 发布管道中添加特定环境的重写规则