首页 > 解决方案 > 使用 tf.keras.metrics.Accuracy 时,“不允许使用 tf.Tensor 作为 Python bool”

问题描述

我有 tensorflow 1.14,我想计算一些分类指标。

我正在使用tf.keras.metrics并且我正在以下列方式使用它:

tf.keras.metrics.Accuracy(tf.argmax(tf.nn.softmax(support_pred, dim=1), axis=1),
                                                   tf.argmax(support_y, axis=1))

这给了我错误:

{TypeError}不允许使用 atf.Tensor作为 Python 。bool使用if t is not None:而不是if t:测试是否定义了张量,并使用 TensorFlow 操作(例如 tf.cond)执行以张量值为条件的子图。

我尝试改用,tf.contrib.metrics但它只有precision_at_recallrecall_at_precision而不是独立的精度和召回率。

编辑 1

我尝试了以下方法,但没有奏效:

import tensorflow as tf

a = tf.random.uniform((32, 10), 0, 1, dtype=tf.float32)
b = tf.random.uniform((32, 10), 0, 1, dtype=tf.float32)

a_softmax = tf.nn.softmax(a)
b_softmax = tf.nn.softmax(b)

a_argmax = tf.argmax(a_softmax, axis=-1)
b_argmax = tf.argmax(b_softmax, axis=-1)

acc = tf.keras.metrics.Accuracy()(a_argmax, b_argmax)

with tf.Session() as sess:
    sess.run([acc])

它给了我以下错误:

Traceback (most recent call last):
  File "C:\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1356, in _do_call
    return fn(*args)
  File "C:\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1341, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "C:\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1429, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.NotFoundError: Container localhost does not exist. (Could not find resource: localhost/total)
     [[{{node AssignAddVariableOp}}]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:/Users/96171/Desktop/dementia_cleanedup/dementia/maml_finn_copy/try_tf.py", line 15, in <module>
    sess.run(acc)
  File "C:\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 950, in run
    run_metadata_ptr)
  File "C:\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1173, in _run
    feed_dict_tensor, options, run_metadata)
  File "C:\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1350, in _do_run
    run_metadata)
  File "C:\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1370, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.NotFoundError: Container localhost does not exist. (Could not find resource: localhost/total)
     [[node AssignAddVariableOp (defined at /Users/96171/Desktop/dementia_cleanedup/dementia/maml_finn_copy/try_tf.py:12) ]]

Original stack trace for 'AssignAddVariableOp':
  File "/Users/96171/Desktop/dementia_cleanedup/dementia/maml_finn_copy/try_tf.py", line 12, in <module>
    acc = tf.keras.metrics.Accuracy()(a_argmax, b_argmax)
  File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\metrics.py", line 170, in __call__
    update_op = self.update_state(*args, **kwargs)  # pylint: disable=not-callable
  File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\utils\metrics_utils.py", line 73, in decorated
    update_op = update_state_fn(*args, **kwargs)
  File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\metrics.py", line 551, in update_state
    matches, sample_weight=sample_weight)
  File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\metrics.py", line 314, in update_state
    update_total_op = self.total.assign_add(value_sum)
  File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\ops\resource_variable_ops.py", line 1108, in assign_add
    name=name)
  File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\ops\gen_resource_variable_ops.py", line 68, in assign_add_variable_op
    "AssignAddVariableOp", resource=resource, value=value, name=name)
  File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\util\deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\ops.py", line 3616, in create_op
    op_def=op_def)
  File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\ops.py", line 2005, in __init__
    self._traceback = tf_stack.extract_stack()


Process finished with exit code 1

标签: pythontensorflowkeras

解决方案


tf.metrics.Accuracy创建一个状态通常会更新多次的对象。所以不能用y_predand来调用它y_true。尝试:

tf.keras.metrics.Accuracy()(tf.argmax(tf.nn.softmax(support_pred, dim=1), axis=1),
                                                   tf.argmax(support_y, axis=1))

如果您同时拥有输出和标签,argmax它就可以工作:softmax

import tensorflow as tf
tf.random.set_seed(0)

a = tf.random.uniform((32, 10), 0, 1, dtype=tf.float32)
b = tf.random.uniform((32, 10), 0, 1, dtype=tf.float32)

a_softmax = tf.nn.softmax(a)
b_softmax = tf.nn.softmax(b)

a_argmax = tf.argmax(a_softmax, axis=-1)
b_argmax = tf.argmax(b_softmax, axis=-1)

tf.keras.metrics.Accuracy()(a_argmax, b_argmax)
<tf.Tensor: shape=(), dtype=float32, numpy=0.1875>

推荐阅读