首页 > 解决方案 > 从 Tensorflow CSV 服务函数中拆分字符串

问题描述

我有一些涉及字符串拆分的功能逻辑的处理功能。

来自数据集的工作张量:<tf.Tensor 'arg0:0' shape=(1,) dtype=string>

来自服务函数的无效张量:<tf.Tensor 'DecodeCSV:1' shape=(?, 1) dtype=string>

服务功能:

def csv_serving_input_fn():
    csv_row = tf.placeholder(shape=[None], dtype=tf.string)
    features = parse_csv(csv_row, is_serving=True)
    ...
    return tf.estimator.export.ServingInputReceiver(
        features=process_features(features),
        receiver_tensors={'csv_row': csv_row}
    )

解析函数

def parse_csv(csv_row, is_serving=False):
    columns = tf.decode_csv(tf.expand_dims(csv_row, -1), record_defaults=HEADER_DEFAULTS)
    return dict(zip(HEADER, columns))

形状失败(?,)

def process_features(features):
    x = tf.string_split(features['text'])
    ....


错误:

ValueError: Shape must be rank 1 but is rank 2 for 'StringSplit' (op: 'StringSplit') with input shapes: [?,1], [].

什么是正确的服务功能?

类似问题: 在 tensorflow 中拆分字符串

标签: pythontensorflowgoogle-cloud-ml

解决方案


文档对此不是很清楚,但string_split用于 1 级向量。

请改用strings.split


推荐阅读