python - 如何在张量流中导出 tf.estimator.LinearRegressor 中 bucketized_column 的权重?
问题描述
我正在研究 Google Crash ML 的原因。我在“功能交叉”一章中遇到了麻烦。
https://developers.google.com/machine-learning/crash-course/feature-crosses/programming-exercise
我试图从中获得交叉特征的权重linear_regressor
。
# here I change _ to linear_model
linear_model = train_model(
learning_rate=1.0,
steps=500,
batch_size=100,
feature_columns=construct_feature_columns(),
training_examples=training_examples,
training_targets=training_targets,
validation_examples=validation_examples,
validation_targets=validation_targets)
Weight_bucketized_longitude= linear_model.get_variable_value('linear/linear_model/bucketized_longitude/weights')
print(Weight_bucketized_longitude)
但是,我收到如下错误消息:
错误信息:
NotFoundError:在检查点中找不到关键线性/线性模型/bucketized_longitude/权重
看起来路径是错误的。该路径适用于numeric_column
,但不适用于bucketized_column
。
你能帮忙指出正确的路径吗?谢谢。
#
我尝试了 Geeocode 的方法。但是,我仍然收到错误消息。
Weight_bucketized_longitude= linear_model.get_variable_value('linear/linear_model/bucketized_longitude/weights')
() 中的 AttributeErrorTraceback (最近一次调用最后一次) ----> 1 Weight_bucketized_longitude= >linear_model.get_variable_value(["linear", "linear_model", >"bucketized_longitude", "weights"])
/usr/local/lib/python2.7/dist->packages/tensorflow/python/estimator/estimator.pyc in >get_variable_value(self, name) 252 _check_checkpoint_available(self.model_dir) 253 with context.graph_mode(): -- > 254 return training.load_variable(self.model_dir, name) 255 256 def get_variable_names(self):
/usr/local/lib/python2.7/dist->packages/tensorflow/python/training/checkpoint_utils.pyc in >load_variable(ckpt_dir_or_file, name) 77 """ 78 # TODO(b/29227106):在正确放置并删除 >this。---> 79 if name.endswith(":0"): 80 name = name[:-2] 81 reader = load_checkpoint(ckpt_dir_or_file)
AttributeError: 'list' 对象没有属性 'endswith'
解决方案
问题是linear_model.get_variable_value()
必须传递带有变量名称的字符串列表。从文档中:
获取变量值
get_variable_value(name)
返回名称给定的变量的值。
Args: name:字符串或字符串列表,张量的名称。返回: Numpy 数组 - 张量的值。
引发:ValueError:如果 Estimator 尚未生成检查点。
因此,您的代码应更改如下:
Weight_bucketized_longitude= linear_model.get_variable_value(["linear", "linear_model", "bucketized_longitude", "weights"])
推荐阅读
- pyqt - QTableView:防止用户导航离开特定行
- python - Cassandra 的 execute_concurrent 无法正常工作
- docker - 如何让docker重置远程服务器上的图像?
- sql - SQL 格式数字有括号
- less - 如何在 Vuetify 样式中使用手写笔块级导入
- firebase - 我可以将一个孩子设置为 .write: true 而其他孩子受到限制吗?
- sql-server - 除非我重新启动 sql server,否则 SSIS 包需要永远执行
- java - 此服务器不支持 Project facet Dynamic Web Module 4.0
- c++ - 在 Ubuntu 的终端上运行 C++ 文件
- spring - 如何避免在 requestMapping 方法中使用实体造成的漏洞?