tensorflow - DNNClassifier - 训练后如何获取各层的参数
问题描述
def Classifier(parameters):
learning = parameters[0]
layers = parameters[1]
nodes = parameters[2]
hidden_layers = [nodes for i in range(layers)]
activation_function = tf.nn.sigmoid if parameters[3] == 0 else tf.nn.relu
age_var = tf.feature_column.numeric_column('Age')
shape_var = tf.feature_column.numeric_column('Shape')
margin_var = tf.feature_column.numeric_column('Margin')
density_var =tf.feature_column.numeric_column('Density')
features = [age_var,shape_var,margin_var,density_var]
return tf.estimator.DNNClassifier(hidden_units=hidden_layers,
n_classes=2,
feature_columns=features,
activation_fn=activation_function,
model_dir='/tmp/'+uuid.uuid4().hex,
optimizer=tf.train.AdamOptimizer(learning_rate=learning),
config=tf.contrib.learn.RunConfig(save_checkpoints_steps=250,
save_checkpoints_secs=None,
save_summary_steps=500))
训练如上定义的模型后,是否可以得到各层的参数?
如果是的话,你能给我这个命令吗
我是新手/正在学习 tensorflow
解决方案
我简单的方法是遍历函数DNNClassifier
返回的模型的变量名称列表Classifier
并调用get_variable_value
传递每个变量的名称:
for variable_name in model.get_variable_names():
print('Parameter Name: ', variable_name, ' Parameter Value: ', model.get_variable_value(variable_name))
但这仍然需要您选择每个单层的权重(默认情况下/hiddenlayer_0/...
,它的模型隐藏层/hiddenlayer_1/...
的名称是 等)。DNNClassifier
您将不得不遍历这些参数名称并仅获取名称中包含的那些参数的值hiddenlayer_<num>
,这意味着执行一些字符串模式匹配。例如:
hidden_layer_0_params = {}
for variable_name in model.get_variable_names():
if variable_name.startswith("dnn/hiddenlayer0"):
hidden_layer_0_params[variable_name] = model.get_variable_value(variable_name)
不过,您可以帮助改善这一点!一种方法是先创建和编译一个单独的tf.keras
模型,然后使用该tf.keras.estimator.model_to_estimator
函数将其转换为所需的估计器实例。
其优势的原因是tf.keras
它的模型参数命名更好一些,并且默认情况下它们没有dnn/hiddenlayer_
图层参数名称的前缀。例如,tf.keras
模型的参数名称将没有前缀,并且layer_with_weights-0/bias
默认格式化。
这是一个很好的教程,可以从 keras 模型创建估计器实例并将其用于训练和评估。
推荐阅读
- android - CollapsingToolbarLayout 内的 ViewPager
- reactjs - Nextjs 公用文件夹
- hadoop - 为 Hadoop Distcp 作业设置 YARN 应用程序名称
- php - 其他模型中的 Laravel 电子邮件验证
- r - 使用 RDCOMClient 提取 Outlook 电子邮件正文时遇到错误
- mysql - 如何在“将 Springboot 与 MySql 连接”时解决连接问题
- c# - SearchItemInfo 已过时。DNN 7.1 中已弃用
- angularjs - 在两个窗口之间共享变量。Angularjs
- excel - 如何使用 VBA Excel 合并两个以上的字符串而不会遇到运行时错误?
- c# - 我希望在加载大文件时避免表单冻结以显示