首页 > 解决方案 > Tensorflow:在 tf.data.Dataset 中拆分字符串的奇怪行为

问题描述

tf.data.Dataset在 Tensorflow 中使用 API。我有 2 个 numpy 数组,其中data2-d 和labels1-d。我创建了一个Dataset这样的:

dataset = tf.data.Dataset.from_tensor_slices((data, labels))
val_dataset = dataset.map(lambda x, y: ({'reviews': x}, y))

我有一个我想使用的预处理函数,如下所示:

def preprocess(x, y):
    # split on whitespace
    x['reviews'] = tf.string_split(x['reviews'])
    return x, y

我尝试这样使用map

dataset = dataset.map(preprocess)

但我回来了:

ValueError: Shape must be rank 1 but is rank 0 for 'StringSplit' (op: 'StringSplit') with input shapes: [], [].

我搜索了一下,发现有人在预处理函数中建议了这种方法:

x['reviews'] = tf.string_split([x['reviews']])

但我不清楚我为什么要这样做。它不会像以前那样出错,但是我的数据的形状不正确。例如,这是我在我的第一个元素中看到的dataset

({'sequence': array([[ 6391,  3352, 10236,   244,  1362,   244,  9350,  7649,  6391,
         6324,  6063,  3620,   244,  8153,  6542, 10056,  7303,  1955,
         1362,  6194, 10250,  6391,   550,   244,  7577,   850,  3620,
         5807, 10325,  1362,  6542,   595,  9060,  9052,  9459,   351,
         4676,  9354,  7648,  3082,  7694,  8497, 10703,  1610,  9454,
        10236,   244,  7965,  8018,  9392,  6391,  6063,  2878,  1318,
         3169,  8198,  9354,  4131,  3620,  3082,  3352,  9052,  8018,
         7527,  3419,  1907,  8835,   796,   244,  8957,  4325,  8171,
         9454,  7602,  4435,  7648,  3169,  2083,  9454,  4789,  9620,
         9261,   556,  3524,  8497,  9174,  8299,  5871,  9052,  2888,
         9846,  1610,  1362,  4930,  2150,  1362,  8018,  3867,   341,
         7694,  8497,  6063,  3620,   244,  5807,  6089,  3169,  6350,
         1174,  7694,   949,  1292,   244,  9052,  9440,  3690,  1362,
         1907,  9011,  4156,  6081,   145,  1174,  7694,  9986,   949,
         1292,  3169,  1455,  6372,  9760,  5013,  3169,  1455,  5942,
         4365,  1362,  1907,   244,  5813,   244,  7994,  3525,  3550,
         7509,  6372,  9760,  7860,  9052,  2888,  7694,  8497,  1610,
         1316,   326,  1174,  3039,  3524,  9703,  3620,  6612,  1455,
          556,  9011,  3169,  1927,  9052,   409,  4059,  9354,   700,
         5503,  3550,  9052,  2083,  1963,   595,  3169,  7715, 10236,
         9442,  1174, 10087,  3169,  5312,  7474,  9052,  3525,  3169,
         5826,  7885,  6944,  7130,  5821,  2878,  7184,   153,  3169,
         8633,  8574,  1283,   606,  7902,  6110,  3082,  6406,  3169,
         8316,  6126,   688, 10236,  9440,  3082, 10584,  2143,  5460,
         5809,  1362,  2878, 10439,  3419,  1907,  4598,  4156, 10239,
         1450,  5514,  5010,  9350,   244,   651]])}, 0)

所以字典值是一个二维数组,而它应该只是一维。我哪里错了?

谢谢!

标签: pythontensorflowdataset

解决方案


不采用标量似乎是tf.string_split. 请在https://github.com/tensorflow/tensorflow/issues提交问题

就解决方法而言,包含在列表中的建议是一个很好的建议,但您还需要在拆分后对其进行挤压,以便您拥有一个分量向量而不是二维张量。

import tensorflow as tf
tf.enable_eager_execution()
scalar = tf.constant('ab c de')
print(scalar.shape)  # () scalar
vector = scalar[None]
print(vector.shape)  # (1,) vector
output = tf.sparse.to_dense(tf.string_split(vector), default_value='')
print(output)  # tf.Tensor([[b'ab' b'c' b'de']], shape=(1, 3), dtype=string)
squeezed = tf.squeeze(output, axis=0)
print(squeezed.shape)  # (3,) vector
print(squeezed)  # tf.Tensor([b'ab' b'c' b'de'], shape=(3,), dtype=string)

推荐阅读