首页 > 解决方案 > 如何在张量流中导出 tf.estimator.LinearRegressor 中 bucketized_column 的权重?

问题描述

我正在研究 Google Crash ML 的原因。我在“功能交叉”一章中遇到了麻烦。

https://developers.google.com/machine-learning/crash-course/feature-crosses/programming-exercise

我试图从中获得交叉特征的权重linear_regressor

# here I change _ to linear_model 
linear_model = train_model(
               learning_rate=1.0,
               steps=500,
               batch_size=100,
               feature_columns=construct_feature_columns(),
               training_examples=training_examples,
               training_targets=training_targets,
               validation_examples=validation_examples,
               validation_targets=validation_targets)

Weight_bucketized_longitude= linear_model.get_variable_value('linear/linear_model/bucketized_longitude/weights')   
print(Weight_bucketized_longitude)

但是,我收到如下错误消息:

错误信息:

NotFoundError:在检查点中找不到关键线性/线性模型/bucketized_longitude/权重

看起来路径是错误的。该路径适用于numeric_column,但不适用于bucketized_column

你能帮忙指出正确的路径吗?谢谢。

#

我尝试了 Geeocode 的方法。但是,我仍然收到错误消息。

Weight_bucketized_longitude= linear_model.get_variable_value('linear/linear_model/bucketized_longitude/weights')   

() 中的 AttributeErrorTraceback (最近一次调用最后一次) ----> 1 Weight_bucketized_longitude= >linear_model.get_variable_value(["linear", "linear_model", >"bucketized_longitude", "weights"])

/usr/local/lib/python2.7/dist->packages/tensorflow/python/estimator/estimator.pyc in >get_variable_value(self, name) 252 _check_checkpoint_available(self.model_dir) 253 with context.graph_mode(): -- > 254 return training.load_variable(self.model_dir, name) 255 256 def get_variable_names(self):

/usr/local/lib/python2.7/dist->packages/tensorflow/python/training/checkpoint_utils.pyc in >load_variable(ckpt_dir_or_file, name) 77 """ 78 # TODO(b/29227106):在正确放置并删除 >this。---> 79 if name.endswith(":0"): 80 name = name[:-2] 81 reader = load_checkpoint(ckpt_dir_or_file)

AttributeError: 'list' 对象没有属性 'endswith'

标签: pythontensorflow

解决方案


问题是linear_model.get_variable_value()必须传递带有变量名称的字符串列表。从文档中:

获取变量值

get_variable_value(name)

返回名称给定的变量的值。

Args: name:字符串或字符串列表,张量的名称。返回: Numpy 数组 - 张量的值。

引发:ValueError:如果 Estimator 尚未生成检查点。

因此,您的代码应更改如下:

Weight_bucketized_longitude= linear_model.get_variable_value(["linear", "linear_model", "bucketized_longitude", "weights"])

推荐阅读