首页 > 解决方案 > Tensorflow 2.3,Tensorflow 数据集,TypeError:() 接受 1 个位置参数,但给出了 4 个

问题描述

我使用 tf.data.TextLineDataset 读取 4 个大文件,并使用 tf.data.Dataset.zip 压缩这 4 个文件并创建“数据集”。但是,我不能将“数据集”传递给 dataset.map 以使用 tf.compat.v1.string_split 并使用 \t 分隔符拆分,最后使用批处理、预取并最终输入我的模型。

这是我的代码:

d1 = tf.data.TextLineDataset("File1.raw")
d2 = tf.data.TextLineDataset("File2.raw")
d3 = tf.data.TextLineDataset("File3.raw")
d4 = tf.data.TextLineDataset("File4.raw")
dataset = tf.data.Dataset.zip((d1,d2,d3,d4))
dataset = dataset.map(lambda string: tf.compat.v1.string_split([string],sep='\t').values)

这是错误信息:

packages/tensorflow/python/autograph/impl/api.py", line 339, in _call_unconverted
return f(*args, **kwargs)
TypeError: <lambda>() takes 1 positional argument but 4 were given

我该怎么办?

标签: pythontensorflowdeep-learningtensorflow2.0tensorflow-datasets

解决方案


tf.data.Dataset.zip函数同时迭代任意数量的数据集对象。换句话说,如果您压缩四个数据集,您将在每次迭代中获得四个项目(每个数据集中一个)。这也解释了收到的错误 OP

TypeError: <lambda>() takes 1 positional argument but 4 were given

被映射的函数需要能够处理四个参数,因为它被应用于四个数据集的压缩包。下面的代码包含一个函数,它接受四个参数(数据集)并将它们拆分为\t. 您可以将其映射到压缩数据集。我tf.data.TextLineDataset用样本数据集替换了这些对象。

import tensorflow as tf

d1 = tf.data.Dataset.from_tensors(["foo\t1"])
d2 = tf.data.Dataset.from_tensors(["foo\t2"])
d3 = tf.data.Dataset.from_tensors(["foo\t3"])
d4 = tf.data.Dataset.from_tensors(["foo\t4"])

def split_by_tab(text1, text2, text3, text4):
    sep = "\t"
    return (
        tf.strings.split(text1, sep=sep),
        tf.strings.split(text2, sep=sep),
        tf.strings.split(text3, sep=sep),
        tf.strings.split(text4, sep=sep),
    )

dataset = tf.data.Dataset.zip((d1,d2,d3,d4))
dataset = dataset.map(split_by_tab)

作为替代方案,我可以合并这些文件并创建一个非常大的文件,然后从中随机播放、批处理和预取行。对?还有其他解决方案吗?

这些文件可以合并,但如果它们很大,则可能不值得这样做。我没有意识到这些功能被拆分到多个文件中。在这种情况下,压缩是一个合理的做法。

还有一个tensorflow_text可能与此问题相关的库。可能值得一试。


推荐阅读