python - 使用 tf.keras.metrics.Accuracy 时,“不允许使用 tf.Tensor 作为 Python bool”
问题描述
我有 tensorflow 1.14,我想计算一些分类指标。
我正在使用tf.keras.metrics
并且我正在以下列方式使用它:
tf.keras.metrics.Accuracy(tf.argmax(tf.nn.softmax(support_pred, dim=1), axis=1),
tf.argmax(support_y, axis=1))
这给了我错误:
{TypeError}不允许使用 a
tf.Tensor
作为 Python 。bool
使用if t is not None:
而不是if t:
测试是否定义了张量,并使用 TensorFlow 操作(例如 tf.cond)执行以张量值为条件的子图。
我尝试改用,tf.contrib.metrics
但它只有precision_at_recall
和recall_at_precision
而不是独立的精度和召回率。
编辑 1
我尝试了以下方法,但没有奏效:
import tensorflow as tf
a = tf.random.uniform((32, 10), 0, 1, dtype=tf.float32)
b = tf.random.uniform((32, 10), 0, 1, dtype=tf.float32)
a_softmax = tf.nn.softmax(a)
b_softmax = tf.nn.softmax(b)
a_argmax = tf.argmax(a_softmax, axis=-1)
b_argmax = tf.argmax(b_softmax, axis=-1)
acc = tf.keras.metrics.Accuracy()(a_argmax, b_argmax)
with tf.Session() as sess:
sess.run([acc])
它给了我以下错误:
Traceback (most recent call last):
File "C:\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1356, in _do_call
return fn(*args)
File "C:\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1341, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "C:\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1429, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.NotFoundError: Container localhost does not exist. (Could not find resource: localhost/total)
[[{{node AssignAddVariableOp}}]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "C:/Users/96171/Desktop/dementia_cleanedup/dementia/maml_finn_copy/try_tf.py", line 15, in <module>
sess.run(acc)
File "C:\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 950, in run
run_metadata_ptr)
File "C:\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1173, in _run
feed_dict_tensor, options, run_metadata)
File "C:\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1350, in _do_run
run_metadata)
File "C:\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1370, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.NotFoundError: Container localhost does not exist. (Could not find resource: localhost/total)
[[node AssignAddVariableOp (defined at /Users/96171/Desktop/dementia_cleanedup/dementia/maml_finn_copy/try_tf.py:12) ]]
Original stack trace for 'AssignAddVariableOp':
File "/Users/96171/Desktop/dementia_cleanedup/dementia/maml_finn_copy/try_tf.py", line 12, in <module>
acc = tf.keras.metrics.Accuracy()(a_argmax, b_argmax)
File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\metrics.py", line 170, in __call__
update_op = self.update_state(*args, **kwargs) # pylint: disable=not-callable
File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\utils\metrics_utils.py", line 73, in decorated
update_op = update_state_fn(*args, **kwargs)
File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\metrics.py", line 551, in update_state
matches, sample_weight=sample_weight)
File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\metrics.py", line 314, in update_state
update_total_op = self.total.assign_add(value_sum)
File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\ops\resource_variable_ops.py", line 1108, in assign_add
name=name)
File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\ops\gen_resource_variable_ops.py", line 68, in assign_add_variable_op
"AssignAddVariableOp", resource=resource, value=value, name=name)
File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 788, in _apply_op_helper
op_def=op_def)
File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\util\deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\ops.py", line 3616, in create_op
op_def=op_def)
File "\Users\96171\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\ops.py", line 2005, in __init__
self._traceback = tf_stack.extract_stack()
Process finished with exit code 1
解决方案
tf.metrics.Accuracy
创建一个状态通常会更新多次的对象。所以不能用y_pred
and来调用它y_true
。尝试:
tf.keras.metrics.Accuracy()(tf.argmax(tf.nn.softmax(support_pred, dim=1), axis=1),
tf.argmax(support_y, axis=1))
如果您同时拥有输出和标签,argmax
它就可以工作:softmax
import tensorflow as tf
tf.random.set_seed(0)
a = tf.random.uniform((32, 10), 0, 1, dtype=tf.float32)
b = tf.random.uniform((32, 10), 0, 1, dtype=tf.float32)
a_softmax = tf.nn.softmax(a)
b_softmax = tf.nn.softmax(b)
a_argmax = tf.argmax(a_softmax, axis=-1)
b_argmax = tf.argmax(b_softmax, axis=-1)
tf.keras.metrics.Accuracy()(a_argmax, b_argmax)
<tf.Tensor: shape=(), dtype=float32, numpy=0.1875>