python - 如何将 numpy 数组中的分类数据加载到 Indicator 或 Embedding 列中?
问题描述
使用 Tensorflow 1.8.0,每当我们尝试构建分类列时都会遇到问题。这是一个演示问题的完整示例。它按原样运行(仅使用数字列)。取消注释指标列定义和数据会生成一个堆栈跟踪,以tensorflow.python.framework.errors_impl.InternalError: Unable to get element as bytes.
import tensorflow as tf
import numpy as np
def feature_numeric(key):
return tf.feature_column.numeric_column(key=key, default_value=0)
def feature_indicator(key, vocabulary):
return tf.feature_column.indicator_column(
tf.feature_column.categorical_column_with_vocabulary_list(
key=key, vocabulary_list=vocabulary ))
labels = ['Label1','Label2','Label3']
model = tf.estimator.DNNClassifier(
feature_columns=[
feature_numeric("number"),
# feature_indicator("indicator", ["A","B","C"]),
],
hidden_units=[64, 16, 8],
model_dir='./models',
n_classes=len(labels),
label_vocabulary=labels)
def train(inputs, training):
model.train(
input_fn=tf.estimator.inputs.numpy_input_fn(
x=inputs,
y=training,
shuffle=True
), steps=1)
inputs = {
"number": np.array([1,2,3,4,5]),
# "indicator": np.array([
# ["A"],
# ["B"],
# ["C"],
# ["A", "A"],
# ["A", "B", "C"],
# ]),
}
training = np.array(['Label1','Label2','Label3','Label2','Label1'])
train(inputs, training)
尝试使用嵌入票价并没有更好的效果。仅使用数字输入,我们可以成功扩展到数千个输入节点,实际上我们已经在预处理器中临时扩展了我们的分类特征来模拟指标。
文档categorical_column_*()
中indicator_column()
充斥着我们非常确定我们没有使用的特性的引用(原型输入,不管bytes_list
是什么),但也许我们错了?
解决方案
这里的问题与“指标”输入数组的参差不齐的形状有关(一些元素的长度为 1,一个为长度 2,一个为长度 3)。如果你用一些非词汇字符串填充你的输入列表(例如,我使用“Z”,因为你的词汇是“A”、“B”、“C”),你会得到预期的结果:
inputs = {
"number": np.array([1,2,3,4,5]),
"indicator": np.array([
["A", "Z", "Z"],
["B", "Z", "Z"],
["C", "Z", "Z"],
["A", "A", "Z"],
["A", "B", "C"]
])
}
您可以通过打印结果张量来验证这是否有效:
dense = tf.feature_column.input_layer(
inputs,
[
feature_numeric("number"),
feature_indicator("indicator", ["A","B","C"]),
])
with tf.train.MonitoredTrainingSession() as sess:
print(dense)
print(sess.run(dense))
推荐阅读
- graphics - 是否存在广义顶点和片段着色器的概念?
- spring - WebApplicationContext的初始化过程
- c# - 将类理解为自定义数据类型 C#
- python-3.x - 生成依赖数据
- python - 无需模块即可生成更好的名称
- node.js - JSON 变量返回未定义
- assembly - 虚拟内存——如何找到 n、m 和 p 的值(参考 Bryant 和 O'Hallaron 表示法)
- python - 对“全局”和“非本地”代码(PYTHON/TKINTER)感到困惑(UnboundLocalError :)
- node.js - 使用 NodeJS Express 服务器发送直通流以响应时出错
- java - Java Spring boot微服务可以导出为安装文件吗?