python - Tensorflow:在 tf.data.Dataset 中拆分字符串的奇怪行为
问题描述
我tf.data.Dataset
在 Tensorflow 中使用 API。我有 2 个 numpy 数组,其中data
2-d 和labels
1-d。我创建了一个Dataset
这样的:
dataset = tf.data.Dataset.from_tensor_slices((data, labels))
val_dataset = dataset.map(lambda x, y: ({'reviews': x}, y))
我有一个我想使用的预处理函数,如下所示:
def preprocess(x, y):
# split on whitespace
x['reviews'] = tf.string_split(x['reviews'])
return x, y
我尝试这样使用map
:
dataset = dataset.map(preprocess)
但我回来了:
ValueError: Shape must be rank 1 but is rank 0 for 'StringSplit' (op: 'StringSplit') with input shapes: [], [].
我搜索了一下,发现有人在预处理函数中建议了这种方法:
x['reviews'] = tf.string_split([x['reviews']])
但我不清楚我为什么要这样做。它不会像以前那样出错,但是我的数据的形状不正确。例如,这是我在我的第一个元素中看到的dataset
:
({'sequence': array([[ 6391, 3352, 10236, 244, 1362, 244, 9350, 7649, 6391,
6324, 6063, 3620, 244, 8153, 6542, 10056, 7303, 1955,
1362, 6194, 10250, 6391, 550, 244, 7577, 850, 3620,
5807, 10325, 1362, 6542, 595, 9060, 9052, 9459, 351,
4676, 9354, 7648, 3082, 7694, 8497, 10703, 1610, 9454,
10236, 244, 7965, 8018, 9392, 6391, 6063, 2878, 1318,
3169, 8198, 9354, 4131, 3620, 3082, 3352, 9052, 8018,
7527, 3419, 1907, 8835, 796, 244, 8957, 4325, 8171,
9454, 7602, 4435, 7648, 3169, 2083, 9454, 4789, 9620,
9261, 556, 3524, 8497, 9174, 8299, 5871, 9052, 2888,
9846, 1610, 1362, 4930, 2150, 1362, 8018, 3867, 341,
7694, 8497, 6063, 3620, 244, 5807, 6089, 3169, 6350,
1174, 7694, 949, 1292, 244, 9052, 9440, 3690, 1362,
1907, 9011, 4156, 6081, 145, 1174, 7694, 9986, 949,
1292, 3169, 1455, 6372, 9760, 5013, 3169, 1455, 5942,
4365, 1362, 1907, 244, 5813, 244, 7994, 3525, 3550,
7509, 6372, 9760, 7860, 9052, 2888, 7694, 8497, 1610,
1316, 326, 1174, 3039, 3524, 9703, 3620, 6612, 1455,
556, 9011, 3169, 1927, 9052, 409, 4059, 9354, 700,
5503, 3550, 9052, 2083, 1963, 595, 3169, 7715, 10236,
9442, 1174, 10087, 3169, 5312, 7474, 9052, 3525, 3169,
5826, 7885, 6944, 7130, 5821, 2878, 7184, 153, 3169,
8633, 8574, 1283, 606, 7902, 6110, 3082, 6406, 3169,
8316, 6126, 688, 10236, 9440, 3082, 10584, 2143, 5460,
5809, 1362, 2878, 10439, 3419, 1907, 4598, 4156, 10239,
1450, 5514, 5010, 9350, 244, 651]])}, 0)
所以字典值是一个二维数组,而它应该只是一维。我哪里错了?
谢谢!
解决方案
不采用标量似乎是tf.string_split
. 请在https://github.com/tensorflow/tensorflow/issues提交问题
就解决方法而言,包含在列表中的建议是一个很好的建议,但您还需要在拆分后对其进行挤压,以便您拥有一个分量向量而不是二维张量。
import tensorflow as tf
tf.enable_eager_execution()
scalar = tf.constant('ab c de')
print(scalar.shape) # () scalar
vector = scalar[None]
print(vector.shape) # (1,) vector
output = tf.sparse.to_dense(tf.string_split(vector), default_value='')
print(output) # tf.Tensor([[b'ab' b'c' b'de']], shape=(1, 3), dtype=string)
squeezed = tf.squeeze(output, axis=0)
print(squeezed.shape) # (3,) vector
print(squeezed) # tf.Tensor([b'ab' b'c' b'de'], shape=(3,), dtype=string)
推荐阅读
- php - 如何使用 Composer 在 openclient 3 项目中安装“google/apiclient”包?
- python - 错误“TypeError:FirstForm() 缺少 1 个必需的位置参数:'request'”
- javascript - 如何正确地将异步应用到此函数,以便在显示返回数据的页面呈现之前调用它?
- c++ - Langford 序列 - 利用对称性/消除对称性
- react-native - 如何修复:TypeError:超级表达式必须为空或函数
- python - Python通过嵌套字典排序
- wordpress - 构建多个 Wordpress 网站
- amazon-web-services - 使用 .ebextensions 添加到 ngnix 配置
- r - 如何在 R 中输入约束符号而不会出现错误消息
- dataframe - 将 HDF5 作为 Dask Dataframe 读取时出错,为什么?