首页 > 解决方案 > 如何在张量流评估期间打印当前评估步骤?

问题描述

我正在尝试获取 tensorflow 评估的当前批次索引,但它不起作用。我的代码是关于评估图像分类器的。

def _create_local(name, shape, collections=None, validate_shape=True,
                  dtype=tf.float32):
    """Creates a new local variable.
    Args:
      name: The name of the new or existing variable.
      shape: Shape of the new or existing variable.
      collections: A list of collection names to which the Variable will be added.
      validate_shape: Whether to validate the shape of the variable.
      dtype: Data type of the variables.
    Returns:
      The created variable.
    """
    # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES
    collections = list(collections or [])
    collections += [tf.GraphKeys.LOCAL_VARIABLES]
    return variables.Variable(
        initial_value=tf.zeros(shape, dtype=dtype),
        name=name,
        trainable=False,
        collections=collections,
        validate_shape=validate_shape)

我的评估功能:

def _get_evaluation(channel):
    with tf.name_scope("eval"):
        result = _create_local('result_list', shape=[10, 4],dtype=tf.float32)
        #my codes
        # current evaluation step
        currrent_step = tf.range(result[-1,-1], result[-1,-1] + FLAGS.batch_size) 

    return currrent_step

标签: pythontensorflow

解决方案


只需将c_value传递给_create_local

def _get_evaluation(channel):
    with tf.name_scope("eval"):
        result = _create_local('result_list', shape=[10, 4],dtype=tf.float32)

        # current evaluation step 
        c_value= _create_local('c_value', shape=(1,), dtype=tf.int32)

LOCAL_VARIABLES 评估


推荐阅读