python - 如何从 TensorFlow 中的函数返回张量的值?
问题描述
我正在 Keras 开展一个深度学习项目,并且已经使用 TensorFlow 后端实现了一个敏感度函数,因为如果我想使用它来评估模型,就需要这样做。但是,我无法从张量中提取值。我想返回它,以便我可以在其他函数中使用这些值。理想情况下,返回值应该是int
. 每当我评估函数时,我只得到张量对象本身,而不是它的真实值。
我曾尝试创建会话并进行评估,但无济于事。我能够以这种方式很好地打印该值,但我无法将该值分配给另一个变量。
def calculate_tp(y, y_pred):
TP = 0
FP = 0
TN = 0
FN = 0
for i in range(5):
true = K.equal(y, i)
preds = K.equal(y_pred, i)
TP += K.sum(K.cast(tf.boolean_mask(preds, tf.math.equal(true, True)), 'int32'))
FP += K.sum(K.cast(tf.boolean_mask(true, tf.math.equal(~preds, True)), 'int32'))
TN += K.sum(K.cast(tf.boolean_mask(~preds, tf.math.equal(true, True)), 'int32'))
FN += K.sum(K.cast(tf.boolean_mask(true, tf.math.equal(preds, False)), 'int32'))
"""with tf.Session() as sess:
TP = TP.eval()
FP = FP.eval()
FN = FN.eval()
FP = FP.eval()
print(TP, FP, TN, FN)
#sess.run(FP)"""
return TP / (TP + FN)
解决方案
好的,可能是因为在您的尝试中 TP 始终为 0 吗?
如果我尝试:
y = np.array([0, 0, 0, 0, 0, 1, 1, 1 ,1 ,1])
y_pred = np.array([0.01, 0.005, 0.5, 0.09, 0.56, 0.999, 0.89, 0.987 ,0.899 ,1])
def calculate_tp(y, y_pred):
TP = 0
FP = 0
TN = 0
FN = 0
for i in range(5):
true = K.equal(y, i)
preds = K.equal(y_pred, i)
TP += K.sum(K.cast(tf.boolean_mask(preds, tf.math.equal(true, True)), 'int32'))
FP += K.sum(K.cast(tf.boolean_mask(true, tf.math.equal(~preds, True)), 'int32'))
TN += K.sum(K.cast(tf.boolean_mask(~preds, tf.math.equal(true, True)), 'int32'))
FN += K.sum(K.cast(tf.boolean_mask(true, tf.math.equal(preds, False)), 'int32'))
TP = TP.eval(session=tf.Session())
FP = FP.eval(session=tf.Session())
TN = TN.eval(session=tf.Session())
FN = FN.eval(session=tf.Session())
print(TP, FP, TN, FN)
results = TP / (TP + FN)
return results
res = calculate_tp(y, y_pred)
print(res)
#Outputs :
#0 5 5 5
#1 9 9 9
#1 9 9 9
#1 9 9 9
#1 9 9 9
#0.1
它给了我一个浮点数,就像你想要的那样。
有帮助吗?
推荐阅读
- asp.net - 我想获得选定的单选按钮值
- c# - 即使不使用 RedirectStandardError/RedirectStandardOutput,Process.WaitForExit 也会挂起
- ansible - Ansible 和 Jinja2 变量组合
- c++ - 未找到函数 setenv()?
- java - JPA 数据库连接问题
- spring-boot - 无法连接到 datastax cassandra
- php - PHP Mysql查询优化报告
- spring-mvc - 在角度为 5 的下拉菜单中显示来自远程服务器的文件列表
- bash - Shell 脚本从文件中获取日期部分,然后将其与过去 2 天进行比较
- python - 用 django 创建类似按钮的最佳方法是什么?