首页 > 解决方案 > 调用 tf.data.Dataset.map 时未更新全局变量

问题描述

我试图弄清楚何时应用地图(在地图执行后立即或要读取相应数据时),我设计了一个小测试。但是奇怪的事情发生在我身上。我期望cnt是5、5或0、5,但我只得到1。同时我打印了数据集中的每个元素,结果显示func应用于五个元素中的每一个,但是为什么cnt等于1,而不是5.

我检查了官方文档,它说:(https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map

此转换将 map_func 应用于此数据集的每个元素,并返回包含转换后元素的新数据集,其顺序与它们在输入中出现的顺序相同。map_func 可用于更改数据集元素的值和结构。

环境:tf2

cnt = 0
def func(x):
  x = x + 1
  global cnt
  cnt += 1
  return x
b = tf.data.Dataset.from_tensor_slices(list(range(5)))
b = b.map(func)
print(cnt)
it = iter(b)
for i in range(5):
  print(it.next())
print(cnt)
Output:
1
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
tf.Tensor(5, shape=(), dtype=int32)
1

可以在此处找到类似的问题:tensorflow map function not being invoked but the print from the original code does not be invoked 一次。

我究竟做错了什么?

标签: pythontensorflow

解决方案


推荐阅读