python - 循环张量并将函数应用于每个元素
问题描述
我想遍历一个包含 列表的张量Int
,并对每个元素应用一个函数。在函数中,每个元素都会从 python 的字典中获取值。我已经尝试过使用 的简单方法tf.map_fn
,它将在函数上add
起作用,例如以下代码:
import tensorflow as tf
def trans_1(x):
return x+10
a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_1, a)
with tf.Session() as sess:
res = sess.run(b)
print(str(res))
# output: [11 12 13]
但是下面的代码抛出KeyError: tf.Tensor'map_8/while/TensorArrayReadV3:0' shape=() dtype=int32
异常:
import tensorflow as tf
kv_dict = {1:11, 2:12, 3:13}
def trans_2(x):
return kv_dict[x]
a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_2, a)
with tf.Session() as sess:
res = sess.run(b)
print(str(res))
我的张量流版本是1.13.1
. 提前谢谢。
解决方案
有一种简单的方法可以实现,您正在尝试什么。
问题是传递给的函数map_fn
必须有张量作为其参数,张量作为返回值。但是,您的函数trans_2
将普通 pythonint
作为参数并返回另一个 python int
。这就是你的代码不起作用的原因。
但是,TensorFlow 提供了一种简单的方法来包装普通的 python 函数,即tf.py_func
,您可以在您的情况下使用它,如下所示:
import tensorflow as tf
kv_dict = {1:11, 2:12, 3:13}
def trans_2(x):
return kv_dict[x]
def wrapper(x):
return tf.cast(tf.py_func(trans_2, [x], tf.int64), tf.int32)
a = tf.constant([1, 2, 3])
b = tf.map_fn(wrapper, a)
with tf.Session() as sess:
res = sess.run(b)
print(str(res))
你可以看到我添加了一个包装函数,它需要张量参数并返回一个张量,这就是它可以在 map_fn 中使用的原因。使用强制转换是因为 python 默认使用 64 位整数,而 TensorFlow 使用 32 位整数。
推荐阅读
- javascript - 如何使用 joi 中间件执行 OR 语句并处理错误?
- react-router-dom - react-router-dom:如何在 url 中使用正则表达式捕获组创建路由并在组件中获取它们
- javascript - vue3 chart-js 在悬停时显示以前的数据
- ocr - 如何强制 tika 服务器使用 curl 排除 TesseractOCRParser
- google-cloud-platform - GCP中的服务帐户和服务代理有什么区别
- .net-core - 如何在 MS 团队机器人中安排和发送消息
- kubernetes - 错误:升级失败:错误验证“”:错误验证数据:ValidationError(Ingress.spec.rules[0].http):缺少必填字段“paths”
- excel - VBA复制粘贴突然重复粘贴
- sql-server - SQL中两个日期之间的周末日期和日期
- python-3.x - 从 Pandas 中的正则表达式中提取字符串以获取大型数据集