首页 > 解决方案 > 如何使用 TensorFlow AutoGraph 获得取决于函数参数的输出大小?

问题描述

系统信息

我想构建一个依赖于作为函数参数的布尔值的输出。让我们用这段小代码来说明我正在尝试做的事情:

import tensorflow as tf

@tf.function
def output(output1, output2):
    first_conditional_output = ()
    second_conditional_output = ()

    for i in range(2):
        if output1:
            first_conditional_output = first_conditional_output + ([i, i, i],)

        if output2:
            second_conditional_output = second_conditional_output + ([i, i],)

    outputs = (0,)
    if output1:
        outputs = outputs + (first_conditional_output,)
    if output2:
        outputs = outputs + (second_conditional_output,)

    return outputs

当我尝试使用 Python 原语运行它时,它按预期工作:

print(output(True, False)) # (<tf.Tensor: shape=(), dtype=int32, numpy=0>, ([<tf.Tensor: shape=(), dtype=int32, numpy=0>, <tf.Tensor: shape=(), dtype=int32, numpy=0>, <tf.Tensor: shape=(), dtype=int32, numpy=0>], [<tf.Tensor: shape=(), dtype=int32, numpy=1>, <tf.Tensor: shape=(), dtype=int32, numpy=1>, <tf.Tensor: shape=(), dtype=int32, numpy=1>]))

但是当我用布尔张量运行它时:

print(output(tf.constant(True), tf.constant(False)))

我得到以下异常:

    Traceback (most recent call last):
  File "t.py", line 24, in <module>
    print(output(tf.constant(True), tf.constant(False)))
  File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 580, in __call__
    result = self._call(*args, **kwds)
  File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 618, in _call
    results = self._stateful_fn(*args, **kwds)
  File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2419, in __call__
    graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
  File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2777, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2667, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 981, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 441, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 968, in wrapper
    raise e.ag_error_metadata.to_exception(e)
AssertionError: in user code:

t.py:9 output  *
    if output1:
/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/autograph/operators/control_flow.py:924 if_stmt
    basic_symbol_names, composite_symbol_names)
/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/autograph/operators/control_flow.py:962 tf_if_stmt
    error_checking_orelse)
/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py:507 new_func
    return func(*args, **kwargs)
/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:1177 cond
    return cond_v2.cond_v2(pred, true_fn, false_fn, name)
/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/ops/cond_v2.py:101 cond_v2
    name=scope)
/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/ops/cond_v2.py:220 _build_cond
    _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph])
/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/ops/cond_v2.py:633 _make_indexed_slices_indices_types_match
    assert len(set(outs_per_branch)) == 1, outs_per_branch

AssertionError: [3, 0]

我知道 AutoGraph 不允许 a 的两个分支tf.cond具有不同的输出大小。因此,任何人都知道什么可能是获得预期结果的解决方法?

第一次更新

我尝试tf.TensorArray按照建议使用,仍然无法正常工作,但出现了不同的错误:

import tensorflow as tf

@tf.function
def output(output1, output2):
    first_conditional_output = tf.TensorArray(tf.int32, size=0, dynamic_size=True, clear_after_read=False)
    second_conditional_output = tf.TensorArray(tf.int32, size=0, dynamic_size=True, clear_after_read=False)

    for i in range(2):
        if output1:
            first_conditional_output = first_conditional_output.write(i, [i, i, i])

        if output2:
            second_conditional_output = second_conditional_output.write(i, [i, i])

    outputs = tf.TensorArray(tf.int32, size=0, dynamic_size=True, clear_after_read=False)
    outputs = outputs.write(0, 0)
    if output1:
        outputs = outputs.write(1, first_conditional_output.stack())
    if output2:
        outputs = outputs.write(2, second_conditional_output.stack())

    return outputs.stack()

出现错误:

Traceback (most recent call last):
  File "t.py", line 24, in <module>
    print(output(True, False))
  File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 580, in __call__
    result = self._call(*args, **kwds)
  File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 627, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 506, in _initialize
    *args, **kwds))
  File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2446, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2777, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2667, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 981, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 441, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 968, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:

    t.py:18 output  *
        outputs = outputs.write(1, first_conditional_output.stack())
    /opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/util/tf_should_use.py:235 wrapped  **
        return _add_should_use_warning(fn(*args, **kwargs),
    /opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/ops/tensor_array_ops.py:1159 write
        return self._implementation.write(index, value, name=name)
    /opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/ops/tensor_array_ops.py:536 write
        self._check_element_shape(value.shape)
    /opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/ops/tensor_array_ops.py:505 _check_element_shape
        (shape, self.element_shape))

    ValueError: Inconsistent shapes: saw (None, 3) but expected ()

我无法强制使用element_shape参数,outputs因为first_conditional_output并且second_conditional_output具有不同的形状。

标签: pythontensorflowtensorflow2.0

解决方案


推荐阅读