python - 用于张量流中不同数量元素的标签和预测的精度和召回 eval_metrics
问题描述
我在 Tensorflow 中将精度和召回率注册为 eval_metrics 时遇到问题。我的标签和预测没有相同数量的元素,所以我不能使用已经内置的函数。我有计算精度和召回率的功能,但我似乎无法获得precision_update_op 和recall_update_op。有什么想法可以从标签、预测和前面提到的计算精度和召回函数中获得吗?谢谢
解决方案
这是一个简单的示例,说明如何构建自己的指标。我将演示mean
,您应该也能够适应上述内容。
def mean_metrics(values):
""" For mean, there are two variables that are
required to hold the sum and the total number of variables"""
# total sum
total = tf.Variable(initial_value=0., dtype=tf.float32, name='total')
# total count
count = tf.Variable(initial_value=0., dtype=tf.float32, name='count')
# Update total op by updating total with the sum of the values
update_total_op = tf.assign_add(total, tf.cast(tf.reduce_sum(values), tf.float32))
# Update count op by updating the total size of the values
update_count_op = tf.assign_add(count, tf.cast(tf.size(tf.squeeze(values)), tf.float32))
# Mean
mean = tf.div(total, count, 'value')
# Mean update op
update_op = tf.div(update_total_op, update_count_op, 'value')
return mean, update_op
测试上面的代码:
tf.reset_default_graph()
values = tf.placeholder(tf.float32, shape=[None])
mean, mean_op = mean_metrics(values)
with tf.Session() as sess:
tf.global_variables_initializer().run()
print(sess.run([mean, mean_op], {values:[1.,2.,3.]}))
print(sess.run([mean, mean_op], {values:[4.,5.,6.]}))
#output
#[nan, 2.0]
#[2.0, 3.5]
推荐阅读
- stanford-nlp - 带节的依赖树
- events - 是否有可能监听 MakerDAO 的 LogNote 事件?
- next.js - Next JS:在路由更改之前警告用户未保存的表单
- python - 执行在不同虚拟环境/venv下构建的python脚本?
- python - 将颜色打印到控制台时出现python间距问题
- heroku - 几个heroku postgres问题(刚开始,我迷路了)
- tensorflow - 如何管理 model.provide_groundtruth 的批次
- linux - 由于 NGRAPH_VERSION,无法在 Raspberry Pi 上 CMake NGraph
- python - 属性错误:“浮动”对象没有属性
- java - 根据旋转角度更改 ImageView 位置