首页 > 解决方案 > 如何在渴望模式下访问张量值

问题描述

我在我的数据集上使用地图功能。在映射的函数中,我想访问张量的值以在“if”中使用它。

但我现在完全看到了访问张量的方法。

我处于渴望模式并拥有 tensorflow 2.1(因为 anaconda 不支持任何较新版本)。

这是我的意思的简单示例代码:

def f1(C):
    print("every numba")
    #Access C somehow
    #if C < 2:
    #   C = C-1
    return C+2

dataset = tf.data.Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
dataset2 = dataset.map(f1)

标签: pythontensorflowtensorflow2.0

解决方案


我想像这样的方法可能对你有用。

def f1(C):
    print("print ", C)
    if C < 2:
       C = C-1
    return C

dataset = tf.data.Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map( lambda x: tf.py_function(
                                    f1,
                                    inp=[x], Tout=tf.int64))
for x in dataset:
    print(x)

推荐阅读