tensorflow - TensorFlow 数据集映射中的条件与“基本”Python 行为不一致
问题描述
考虑到我有一个文件data.csv
,其中包含:
feature0,feature1,label
True,0.1,class_1
False,2.7,class_2
False,10.1,class_3
我想将其加载为数据集并将其label
转换为布尔值,使其为真,class_1
否则为假。这是我的代码:
import tensorflow as tf
data = tf.data.experimental.make_csv_dataset(
'data.csv',
32,
label_name='label',
shuffle=False,
num_epochs=1)
def view(ds, num_batches=1):
for f, l in ds.take(num_batches):
print('Features:')
print(f)
print('Labels:')
print(l)
def process_labels(features, label):
if label == 'class_1':
label = True
else:
label = False
# label = label=='class_1'
return features, label
view(data.map(process_labels))
这会引发错误:InvalidArgumentError: Input to reshape is a tensor with 3 values, but the requested shape has 1 [[{{node Reshape}}]]
. 这是为什么?更令人困惑的是,当我将 if~else 替换为已注释掉的单行代码时,label = label=='class_1'
问题就消失了。这里发生了什么事?
我正在使用 TensorFlow 2.4.1 和 Python 3.8.5。
解决方案
的第二个参数tf.data.experimental.make_csv_dataset
是批量大小,这意味着创建的数据集具有以下形状:(batch_size, feature)
. 您在该数据集上映射的任何函数都应该适用于一批数据,而不仅仅是数据集的一个元素。
label = label=='class_1'
由于广播而工作,但您以前的功能没有。
您有两种方法可以使该功能起作用:
要么编写一个处理批量数据的函数(即,您的工作解决方案)
调用
unbatch
数据集。正如文档所述,这可能会对性能产生负面影响:注意:unbatch 需要一个数据副本来将批处理张量分割成更小的、未批处理的张量。在优化性能时,尽量避免不必要地使用 unbatch。
data.unbatch().map(process_labels).batch(32)
推荐阅读
- python - 创建空日期列
- javascript - 隐藏特定 URL 中的元素
- python - 为什么要为此 tkinter 事件绑定调用多个函数?(解决了)
- html - 如何在 Laravel 代码之前加载 CSS?使用顺风
- python-3.x - 尝试在 TkSheet 中创建搜索栏
- android - ANDROID_SDK_ROOT=undefined 即使在 ubuntu 环境中设置后
- testing - 如何为 Vuetify 的 v-alert 创建单元测试?
- reactjs - 在 redux 存储更改后,React 组件没有重新渲染和更新
- python - 如何解决 Python 中 PIL 的“必须作为“序列”警告的问题?
- node.js - 是否可以重命名 pm2 正在使用的文件而不必删除它?