tensorflow - 如何在 tf2.0 上同时进行多个提取?
问题描述
在 tf1 中,我可以定义所需提取的列表并使用
sess.run(myList, feed_dict)
获取由 tf1 通过图形同时计算的列表的所有元素。如何在 tf2.0 中做到这一点?
tf1 中的示例代码:
import tensorflow as tf
a = [None]*5
for i in range(5):
a[i] = tf.Variable(tf.random.normal([3,3]))
fetch_list = [None]*5
for i in range(5):
fetch_list[i] = tf.add(tf.gather(a, i), tf.ones([3,3]))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(fetch_list)
没有检查上面的代码是否运行,但我希望你明白这一点。谢谢
解决方案
因为默认情况下 tf 2.x 急切地执行,你可以这样做:
a = [None]*5
for i in range(5):
a[i] = tf.Variable(tf.random.normal([3,3]))
fetch_list = [None]*5
for i in range(5):
fetch_list[i] = tf.add(tf.gather(a, i), tf.ones([3,3]))
然后fetch_list
像以前一样填充。
根据您的真实单词示例的复杂性,您还可以考虑使用@tf.function
在推送数据之前构建执行图以帮助类似于 tf1 的优化(这是一个巨大的过度简化,但您明白了)。
您可能会考虑稍微简化/修改代码以促进这一点。可能最好只使用张量而不是张量列表。很难确切地建议如何完成此操作,因为我不知道您为示例简化了什么。
例如,如果我们认为您fetch_list
是一个(5,3,3)
张量而不是 5 个(3,3)
张量的列表,那么我相信您会意识到您的简化示例代码(或多或少)归结为如下所示:
@tf.function
def get_list(n):
return tf.random.normal((n,3,3))
fetch_list = get_list(5)
推荐阅读
- c++ - 在 for 循环内部或外部声明的 C++ 互斥锁
- python - PyCharm 警告我为 classmethod 函数创建的 classmethod 创建一个对象
- reactjs - 如何构建在 Apache 上运行的反应应用程序(具有多个环境)
- javascript - Mongoose,更新时属性返回不够
- php - 使用 PHP 从文本区域中取两三行
- python - 为什么我在这个 Python 程序中得到一个 TypeError?
- python - 使用 BeautifulSoup 抓取 Yahoo Finance 返回 None
- python - 复制并粘贴到帧缓冲区
- android - Kotlin Flow 与 LiveData
- haskell - 在 Haskell 中获取系统时间的方法是什么?