首页 > 解决方案 > TensorFlow 数据集映射中的条件与“基本”Python 行为不一致

问题描述

考虑到我有一个文件data.csv,其中包含:

feature0,feature1,label
True,0.1,class_1
False,2.7,class_2
False,10.1,class_3

我想将其加载为数据集并将其label转换为布尔值,使其为真,class_1否则为假。这是我的代码:

import tensorflow as tf

data = tf.data.experimental.make_csv_dataset(
    'data.csv',
    32,
    label_name='label',
    shuffle=False,
    num_epochs=1)

def view(ds, num_batches=1):
    
    for f, l in ds.take(num_batches):
        print('Features:')
        print(f)
        print('Labels:')
        print(l)

def process_labels(features, label):
    
    if label == 'class_1':
        label = True
    else:
        label = False
    
    # label = label=='class_1'
    
    return features, label

view(data.map(process_labels))

这会引发错误:InvalidArgumentError: Input to reshape is a tensor with 3 values, but the requested shape has 1 [[{{node Reshape}}]]. 这是为什么?更令人困惑的是,当我将 if~else 替换为已注释掉的单行代码时,label = label=='class_1'问题就消失了。这里发生了什么事?

我正在使用 TensorFlow 2.4.1 和 Python 3.8.5。

标签: tensorflowtensorflow2.0tensorflow-datasets

解决方案


的第二个参数tf.data.experimental.make_csv_dataset是批量大小,这意味着创建的数据集具有以下形状:(batch_size, feature). 您在该数据集上映射的任何函数都应该适用于一批数据,而不仅仅是数据集的一个元素。

label = label=='class_1'由于广播而工作,但您以前的功能没有。

您有两种方法可以使该功能起作用:

  • 要么编写一个处理批量数据的函数(即,您的工作解决方案)

  • 调用unbatch数据集。正如文档所述,这可能会对性能产生负面影响:

    注意:unbatch 需要一个数据副本来将批处理张量分割成更小的、未批处理的张量。在优化性能时,尽量避免不必要地使用 unbatch。

    data.unbatch().map(process_labels).batch(32)
    

推荐阅读