首页 > 解决方案 > 循环张量并将函数应用于每个元素

问题描述

我想遍历一个包含 列表的张量Int,并对每个元素应用一个函数。在函数中,每个元素都会从 python 的字典中获取值。我已经尝试过使用 的简单方法tf.map_fn,它将在函数上add起作用,例如以下代码:

import tensorflow as tf

def trans_1(x):
    return x+10

a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_1, a)
with tf.Session() as sess:
    res = sess.run(b)
    print(str(res))
# output: [11 12 13]

但是下面的代码抛出KeyError: tf.Tensor'map_8/while/TensorArrayReadV3:0' shape=() dtype=int32异常:

import tensorflow as tf

kv_dict = {1:11, 2:12, 3:13}

def trans_2(x):
    return kv_dict[x]

a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_2, a)
with tf.Session() as sess:
    res = sess.run(b)
    print(str(res))

我的张量流版本是1.13.1. 提前谢谢。

标签: pythontensorflowtensor

解决方案


有一种简单的方法可以实现,您正在尝试什么。

问题是传递给的函数map_fn必须有张量作为其参数,张量作为返回值。但是,您的函数trans_2将普通 pythonint作为参数并返回另一个 python int。这就是你的代码不起作用的原因。

但是,TensorFlow 提供了一种简单的方法来包装普通的 python 函数,即tf.py_func,您可以在您的情况下使用它,如下所示:

import tensorflow as tf

kv_dict = {1:11, 2:12, 3:13}

def trans_2(x):
    return kv_dict[x]

def wrapper(x):
    return tf.cast(tf.py_func(trans_2, [x], tf.int64), tf.int32)

a = tf.constant([1, 2, 3])
b = tf.map_fn(wrapper, a)
with tf.Session() as sess:
    res = sess.run(b)
    print(str(res))

你可以看到我添加了一个包装函数,它需要张量参数并返回一个张量,这就是它可以在 map_fn 中使用的原因。使用强制转换是因为 python 默认使用 64 位整数,而 TensorFlow 使用 32 位整数。


推荐阅读