首页 > 解决方案 > 如何在 tf2.0 上同时进行多个提取?

问题描述

在 tf1 中,我可以定义所需提取的列表并使用

sess.run(myList, feed_dict)

获取由 tf1 通过图形同时计算的列表的所有元素。如何在 tf2.0 中做到这一点?

tf1 中的示例代码:

import tensorflow as tf
a = [None]*5
for i in range(5):
    a[i] = tf.Variable(tf.random.normal([3,3]))
fetch_list = [None]*5
for i in range(5):
    fetch_list[i] = tf.add(tf.gather(a, i), tf.ones([3,3]))

sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(fetch_list)

没有检查上面的代码是否运行,但我希望你明白这一点。谢谢

标签: tensorflowtensorflow2.0

解决方案


因为默认情况下 tf 2.x 急切地执行,你可以这样做:

a = [None]*5
for i in range(5):
    a[i] = tf.Variable(tf.random.normal([3,3]))
fetch_list = [None]*5
for i in range(5):
    fetch_list[i] = tf.add(tf.gather(a, i), tf.ones([3,3]))

然后fetch_list像以前一样填充。

根据您的真实单词示例的复杂性,您还可以考虑使用@tf.function在推送数据之前构建执行图以帮助类似于 tf1 的优化(这是一个巨大的过度简化,但您明白了)。

您可能会考虑稍微简化/修改代码以促进这一点。可能最好只使用张量而不是张量列表。很难确切地建议如何完成此操作,因为我不知道您为示例简化了什么。

例如,如果我们认为您fetch_list是一个(5,3,3)张量而不是 5 个(3,3)张量的列表,那么我相信您会意识到您的简化示例代码(或多或少)归结为如下所示:

@tf.function
def get_list(n):
  return tf.random.normal((n,3,3))

fetch_list = get_list(5)

推荐阅读