首页 > 解决方案 > 字符串到 Tensorflow 中的 one_hot 张量

问题描述

我在 tensorflow doc 中找到了以下函数来计算词汇表并将其应用于字符串张量,但它仍在使用tf.session,我无法使其工作tf.function

import tensorflow as tf
import tensorflow_transform as tft


@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.string),))
def string_to_one_hot(labels):
    codes = tft.compute_and_apply_vocabulary(labels)
    return tf.one_hot(codes, depth=tf.cast(tf.reduce_max(codes), tf.int32))


test_labels = tf.constant(['a', 'b', 'a', 'c'])
test_one_hot = string_to_one_hot(test_labels)

> tensorflow.python.framework.errors_impl.InvalidArgumentError:  You must feed a value for placeholder tensor 'compute_and_apply_vocabulary/vocabulary/Placeholder' with dtype string
     [[node compute_and_apply_vocabulary/vocabulary/Placeholder (defined at /Users/clementwalter/.pyenv/versions/keras_fsl/lib/python3.6/site-packages/tensorflow_transform/analyzer_nodes.py:102) ]] [Op:__inference_string_to_one_hot_52]

编辑

我已经能够通过直接使用哈希工具来构建这样的功能。但是我不得不使用硬编码的 bucket_size/depth 参数。有任何想法吗?

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.string),))
def string_to_one_hot(labels):
    one_hot = tf.one_hot(tf.strings.to_hash_bucket_fast(labels, 1024), depth=1024)
    return tf.boolean_mask(one_hot, tf.reduce_sum(one_hot, axis=0) > 0, axis=1)

标签: pythonfunctiontensorflowsession

解决方案


好的,我想我找到了正确的答案:

def string_to_one_hot(labels):
    colnames, codes = tf.unique(support_labels_name)
    return colnames, tf.one_hot(codes, depth=tf.size(colnames))

推荐阅读