tensorflow2.0 - 使用 tf.estimator.DNNClassifier 调整超参数
问题描述
我已经使用 DNNClassifier 类实现了以下模型。模型参数化如下
classifier = tf.estimator.DNNClassifier(
hidden_units=[60, 30, 20],
feature_columns=feature_columns,
n_classes=len(labels),
label_vocabulary=labels,
batch_norm=True,
optimizer=lambda: tf.keras.optimizers.Adam(
learning_rate=tf.compat.v1.train.exponential_decay(
learning_rate=0.1,
global_step=tf.compat.v1.train.get_global_step(),
decay_steps=10000,
decay_rate=0.96)
)
)
现在我想做一些超参数调整(例如学习率、单元数等)。
DNNClassifier
,作为预制的估算器类,继承自Estimator
该类。但是,虽然Estimator
有一个params
参数来传递超参数,DNNClassifier
但没有。那么使用 进行超参数调整的首选方法应该是什么DNNClassifier
?
解决方案
首先,您需要一个估算器的输入函数,假设您使用 pandas 数据帧来保存数据,(data_df 和 label_df 是数据帧)您可以编写如下内容:
def make_input_fn(data_df, label_df, num_epochs=10, shuffle=True, batch_size=32):
def input_function():
ds = tf.data.Dataset.from_tensor_slices((dict(data_df), label_df))
if shuffle:
ds = ds.shuffle(1024)
ds = ds.batch(batch_size).repeat(num_epochs)
return ds
return input_function
然后使用上面的代码创建两个输入函数,一个用于训练,一个用于验证,如下所示:
train_input_fn = make_input_fn(X_train, y_train)
val_input_fn = make_input_fn(X_val, y_val, num_epochs=1, shuffle=False)
最后训练您定义的分类器并使用验证集对其进行评估。多次运行此管道以调整您的超参数。
classifier = tf.estimator.DNNClassifier(
hidden_units=[60, 30, 20],
feature_columns=feature_columns,
n_classes=len(labels),
label_vocabulary=labels,
batch_norm=True,
optimizer=lambda: tf.keras.optimizers.Adam(
learning_rate=tf.compat.v1.train.exponential_decay(
learning_rate=0.1,
global_step=tf.compat.v1.train.get_global_step(),
decay_steps=10000,
decay_rate=0.96)
)
)
# Train Classifier.
classifier.train(train_input_fn)
# Evaluate Classifier.
result = classifier.evaluate(val_input_fn)
print(result)
推荐阅读
- jquery - inIframe 函数说明
- javascript - JavaScript TypeError:无法读取未定义的属性“id”
- javascript - 将所选语言保存在本地存储中?
- python - 多单元格的值错误问题尺寸必须相等,但为 20 和 13
- node.js - Nodejs:为什么我的 Promise 在 PUT 请求中不起作用?重新爱过
- r - 在多面图-ggplot2 上添加 x 轴标签
- php - 在 system.cfg php 中查找和替换值
- ecmascript-6 - ECMA6 类不适用于带有 Webpack 和 Babel 的 IE11
- python - 如何将 pandas DataFrame 导出到 Microsoft Access?
- ios - 登录后继续处理收据