首页 > 解决方案 > 使用 Estimator API 在大字符串上训练 RNN

问题描述

我们想让我们的估算器预测一个单词或下一个工作,就像智能​​手机中的键盘一样。我们想在一些文本文件上训练它。

所以我们继续查看 tensorflow API 发现

estimator = RNNEstimator(
    head=tf.contrib.estimator.regression_head(),
    sequence_feature_columns=[token_emb],
    rnn_cell_fn=rnn_cell_fn)

这似乎是为 RNN 创建估计器的便捷方式。现在,我们面临特征列的问题。我们正在这样设置它们

token_sequence = sequence_categorical_column_with_hash_bucket(
    key="text", hash_bucket_size=num_of_categories, dtype=tf.string)
token_emb = embedding_column(categorical_column=token_sequence, 
    dimension=8)

where'text'在我们的输入函数中定义

train_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"text": features},
    y=labels,
    batch_size=batch_size,
    num_epochs=None,
    shuffle=True)

其中features只是从我们的原始文本中采样的一长串 40 个字符的序列。

问题

  1. 无论如何都可以在字符串输入上使用特征列吗?文档并没有真正放弃太多。
  2. 怎么处理标签?目前,我们收到一个错误,因为它们从不转换为整数
  3. 即使将任意整数作为标签引入,我们在调用estimator.train(input_fn=train_input_fn, steps=100)which时仍然会出错

    '给定类型:{}'.format(type(features))) ValueError: features 应该是Tensors 的字典。给定类型:

所以我们在这里肯定做错了什么。任何帮助表示赞赏:)

标签: pythontensorflowrnntensorflow-estimator

解决方案


在StateSavingRnnEstimator 单元测试中有一个简短的示例,将字符串单词特征作为 SparseTensor 传递给每步一个标签的分类Estimator(即语言建模)。这看起来大概是您想要做的事情,但需要注意的是,有问题的 Estimator 已被弃用;从中获取想法并定义自己的想法可能是有意义的。model_fn


推荐阅读