python - 如何在 tf.data.Dataset.from_generator 中保留字典键?
问题描述
我有这样的非矩形数据:
samples_train = {'data': [np.array([[1,1]]), np.array([[1,1],[2,2]]), np.array([[1,1],[2,2],[3,3]])],
'labels': [1,2,3]}
它是一个包含数组列表的字典shape=[variable, 2]
。
由于我有一个自定义训练循环,我想通过键“数据”和“标签”访问数据(我有存储的其他键),因此是 dict 格式。
我特别不想将它们填充到一个常见的序列长度(到目前为止,我确实填充了它们,并且上述from_tensor_slices
方法适用于填充的相同长度的序列)。但现在我需要它们而不是填充。
如果我尝试:
ds = tf.data.Dataset.from_tensor_slices(samples_train)
我得到这个错误,这在某种程度上是有道理的:
ValueError:无法将非矩形 Python 序列转换为张量。
所以这个问题的答案建议如下:
ds = tf.data.Dataset.from_generator(
lambda: iter(zip(samples_train['data'], samples_train['labels'])),
output_types=(tf.float32, tf.float32)
)
通过检查可以正常工作:
for batch in ds:
print(batch)
--> 输出:
(<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[1., 1.]], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
(<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[1., 1.],
[2., 2.]], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=2.0>)
(<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[1., 1.],
[2., 2.],
[3., 3.]], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=3.0>)
但是这样一来,我就松开了我的 dict 键。
但是,我希望能够像这样访问它们:
for batch in ds:
print(batch['data'])
print(batch['labels'])
如何在数据集中保留这些 dict 键?
解决方案
您可以编写一个生成器函数来生成字典,如下所示:
def my_generator(my_dict):
for data in zip(*[my_dict[key] for key in my_dict]):
yield {key:d for key,d in zip(my_dict.keys(), data)}
output_types
并在from_generator
函数中设置正确。
结果是
>>> ds = tf.data.Dataset.from_generator(
lambda: my_generator(samples_train),
output_types={"data": tf.float32, "labels": tf.float32})
>>> for batch in ds:
print(batch['data'])
print(batch['labels'])
tf.Tensor([[1. 1.]], shape=(1, 2), dtype=float32)
tf.Tensor(1.0, shape=(), dtype=float32)
tf.Tensor(
[[1. 1.]
[2. 2.]], shape=(2, 2), dtype=float32)
tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(
[[1. 1.]
[2. 2.]
[3. 3.]], shape=(3, 2), dtype=float32)
tf.Tensor(3.0, shape=(), dtype=float32)
推荐阅读
- javascript - 如何在 Angular2+ 中为 *ngFor 元素设置锚点?
- python-3.x - 为什么我的蜘蛛没有爬取所有元素?
- actions-on-google - 将已发布的代理版本恢复回对话流
- c++11 - 如何通过引用将 time_t 传递给函数
- javascript - 在 Safari 浏览器(macbook)的 javascript 中找不到变量
- c++ - 如何以链式方式在两个容器上构建迭代器
- python - 使用 fiona 和 geopandas 打开一些 gdb 文件时出错
- python - Python 使用了错误的包版本
- jquery - 在 URL 上打开不同的团队成员
- postgresql - 自日期以来的日历天数(以天为单位)