python - 如何从 tensorflow 数据集中选择特定列?
问题描述
我正在用 tf.data.Dataset 预处理的 CSV 文件中的数据训练我的 Tensorflow 模型。但是,我希望模型分成三个分支,对应于一组不同的 csv 列,并且 model.fit 需要为每个输出提供单独的数据集。CSV 文件的所有列都需要经过相同的预处理,因此最有效的准备方法是加载整个文件,对其进行处理,然后将数据集拆分为三个部分。但是,我正在努力寻找一种方法。
我希望 dataset.map 允许我使用以下操作选择一些列:
dset = dset.map(lambda x: x[[1, 2, 3, 7]])
但似乎 tensorflow 将其解释为x[1][2][3][7]
。
我发现创建单独数据集的唯一可行方法是从一开始就这样做:
y = []
for cls, keys in output_classes.items():
tmp = tf.data.experimental.CsvDataset(data_path, [tf.int32 for i in keys], select_cols=keys)
[...]
y.append(tmp)
y = tf.data.Dataset.zip(tuple(y))
不幸的是,它会产生很多不必要的开销并极大地减慢训练速度。
有没有办法通过特征子集拆分 tf.data.Dataset 对象?
解决方案
尝试tf.gather
:
tf.gather(tf.constant([1,2,3,4]), [1,2,3])
# ouputs : array([2, 3, 4])
如果您有高维数据,请使用tf.gather_nd
.
推荐阅读
- gem5 - gem5 FS 模式以超级用户身份运行失败,并显示“IOError:找不到系统文件的路径”。
- php - 从当前工作目录获取路径而不是从根目录
- javascript - 来自子组件的 $ref 在父组件(vuejs)中始终未定义
- python - 匹配两个列表之间的元素,然后使用第三个列表上的匹配位置
- string - 从分隔块解析数据
- git - 如何在 Mac OS 中安装 GitExtensions?
- php - 如何过滤与服务相关的活动
- javascript - Javascript正则表达式匹配相等数量的连续1和0
- protractor - 我们如何使用标签在黄瓜场景中设置优先级
- java - 在 selenium webdriver 中编写代码以获取此错误