python - 如何使用 TensorFlow tf.data.Dataset flat_map 生成派生数据集?
问题描述
我有一个 Pandas DataFrame,我正在将它的一部分加载到一个 tf.data 数据集中:
dataset = tf.data.Dataset.from_tensor_slices((
df.StringColumn.values,
df.IntColumn1.values,
df.IntColumn2.values,
))
现在我想做的是使用类似于flat_map
生成派生数据集的东西,该数据集获取每一行中的数据,并在派生数据集中为原始数据集中的每一行生成一堆行。
但flat_map
似乎只是在lambda
函数中传递了占位符张量。
如果这很重要,我正在使用 TensorFlow 2.0 alpha 0。
编辑:
我想要的是能够写出这样的东西:
derived = dataset.flat_map(replicate)
def replicate(s, i1, i2):
return [[0, s, i1, i2],
[0.25, s, i1, i2],
[0.5, s, i1, i2],
[0.75, s, i1, i2]]
...然后derived
是一个具有四列和四倍行数的数据集dataset
。
但是当我尝试这个时,s
它不是一个值,它是一个字符串占位符张量。
编辑2:
好的,我的意思是该replicate
函数需要知道它正在复制的行的值:
slice_count = 16
def price(frac, total, size0, price0, size1, price1, size2, price2, size3, price3):
total_per_slice = total / slice_count
start = frac * total_per_slice
finish = start + total_per_slice
price = \
(price0 * (min(finish, size0) - max(start, 0) if 0 < finish and start < size0 else 0)) + \
(price1 * (min(finish, size1) - max(start, size0) if size0 < finish and start < size1 else 0)) + \
(price2 * (min(finish, size2) - max(start, size1) if size1 < finish and start < size2 else 0)) + \
(price3 * (min(finish, size3) - max(start, size2) if size2 < finish and start < size3 else 0))
def replicate(size0, price0, size1, price1, size2, price2, size3, price3):
total = size0 + size1 + size2 + size3
return [[
price(frac, total, size0, price0, size1, price1, size2, price2, size3, price3),
frac / slice_count] for frac in range(slice_count)]
derived = dataset.flat_map(replicate)
仅仅能够传递占位符是不够的。这是我能做的事情吗,或者如果我能以某种方式将其转换为 TensorFlow 的计算图,它是否可行,或者它只是无法按照我尝试的方式进行?
解决方案
可能有很长的路要走,但您也可以使用.concatenate()
withapply()
来实现“平面映射”
像这样的东西:
def replicate(ds):
return (ds.map(lambda s,i1,i2: (s, i1, i2, tf.constant(0.0)))
.concatenate(ds.map(lambda s,i1,i2: (s, i1, i2, tf.constant(0.25))))
.concatenate(ds.map(lambda s,i1,i2: (s, i1, i2, tf.constant(0.5))))
.concatenate(ds.map(lambda s,i1,i2: (s, i1, i2, tf.constant(0.75)))))
derived = dataset.apply(replicate)
应该给你你期望的输出
推荐阅读
- c# - 当前面有带有 HttpPost 属性的操作时,如何让 @Url.Action("actionXXX") 执行相同的命名操作
- windows - HKLM:\SOFTWARE\Microsoft\Windows NT\CurrentVersion\NetworkList\DefaultMediaCost\Default 注册表项有什么作用?
- xcode - 观看连接抛出 NSFileReadNoSuchFileError -> ENOENT
- java - 如何配置 maven 和 intellij 以包含 groovy 和 java 的参数编译器标志
- intellij-idea - IntelliJ IDEA 系列 IDE 中终端命令输入的自动化?
- reactjs - 如何添加自定义样式以响应 react-pagination-table 标题
- python - 我的 python 代码没有将 JSON 数据正确导入 CSV
- javascript - 淘汰赛表:突出显示表行
- vb.net - 如何初始化包含结构和简单元素混合的结构对象?
- angularjs - AngularJs ngNotify 点击关闭