python - 应用没有 tf.Estimator 的特征列 (Tensorflow 2.0.0-rc0)
问题描述
在 Tensorflow tf.Estimator 和 tf.feature_column 文档中,有很好的文档说明如何将特征列与 Estimator 一起使用,以便对正在使用的数据集中的分类特征进行一次性编码。
但是,我想将我的特征列直接“应用”到我从 .csv 文件创建的 tf.dataset(有两列:UserID、MovieID),甚至不需要定义模型或估计器。(原因:我想检查我的数据管道中到底发生了什么,即我希望能够通过我的管道运行一批样本,然后在输出中查看特征是如何编码的。)
这是我到目前为止所尝试的:
column_names = ['UserID', 'MovieID']
user_col = tf.feature_column.categorical_column_with_hash_bucket(key='UserID', hash_bucket_size=1000)
movie_col = tf.feature_column.categorical_column_with_hash_bucket(key='MovieID', hash_bucket_size=1000)
feature_columns = [tf.feature_column.indicator_column(user_col), tf.feature_column.indicator_column(movie_col)]
feature_layer = tf.keras.layers.DenseFeatures(feature_columns=feature_columns)
def process_csv(line):
fields = tf.io.decode_csv(line, record_defaults=[tf.constant([], dtype=tf.int32)]*2, field_delim=";")
features = dict(zip(column_names, fields))
return features
ds = tf.data.TextLineDataset(csv_filepath)
ds = ds.map(process_csv, num_parallel_calls=4)
ds = ds.batch(10)
ds.map(lambda x: feature_layer(x))
但是 map 调用的最后一行会引发以下错误:
ValueError: Column dtype 和 SparseTensors dtype 必须兼容。键:MovieID,列 dtype:,张量 dtype:
我不确定这个错误是什么意思...我还尝试使用我定义的 feature_layer 定义一个 tf.keras 模型,然后在我的数据集上运行 .predict() - 而不是使用 ds.map(lambda x:特征层(x)):
model = tf.keras.Sequential([feature_layer])
model.compile()
model.predict(ds)
但是,这会导致与上述完全相同的错误。有人知道出了什么问题吗?是否有更简单的方法来实现这一目标?
解决方案
刚刚发现问题:tf.feature_column.categorical_column_with_hash_bucket() 采用可选参数 dtype,默认设置为 tf.dtypes.string。但是,我的列的数据类型是数字(tf.dtypes.int32)。这解决了这个问题:
tf.feature_column.categorical_column_with_hash_bucket(key='UserID', hash_bucket_size=1000, dtype=tf.dtypes.int32)
推荐阅读
- reactjs - How to yield inside an external callback function in react redux saga?
- css - React Native headerTitleStyle 不会居中
- c++ - 如何为 iOS 的 cmake 项目查找(配置)Qt?
- asp.net-mvc - MVC IIS 内存泄漏
- c# - 在 Post 操作发生后使用模型绑定更新 TextBox
- java - 如何使用在一个类中使用子类和在另一个类中使用超类的参数覆盖方法?
- javascript - 在 Google 表格中循环函数
- c++ - 当多个流处于活动状态时,gRPC grpc_completion_queue_next() 似乎是不公平的
- docker - 我的错误:无法在 http+docker://localhost 连接到 Docker 守护程序 - 它正在运行吗?
- sql - 我如何在sql server中的where条件中使用case语句