python - 如何使用 TensorFlow AutoGraph 获得取决于函数参数的输出大小?
问题描述
系统信息
- 操作系统:Linux Ubuntu 20.04
- Python版本:3.7.5
- TensorFlow 版本:2.2.0
我想构建一个依赖于作为函数参数的布尔值的输出。让我们用这段小代码来说明我正在尝试做的事情:
import tensorflow as tf
@tf.function
def output(output1, output2):
first_conditional_output = ()
second_conditional_output = ()
for i in range(2):
if output1:
first_conditional_output = first_conditional_output + ([i, i, i],)
if output2:
second_conditional_output = second_conditional_output + ([i, i],)
outputs = (0,)
if output1:
outputs = outputs + (first_conditional_output,)
if output2:
outputs = outputs + (second_conditional_output,)
return outputs
当我尝试使用 Python 原语运行它时,它按预期工作:
print(output(True, False)) # (<tf.Tensor: shape=(), dtype=int32, numpy=0>, ([<tf.Tensor: shape=(), dtype=int32, numpy=0>, <tf.Tensor: shape=(), dtype=int32, numpy=0>, <tf.Tensor: shape=(), dtype=int32, numpy=0>], [<tf.Tensor: shape=(), dtype=int32, numpy=1>, <tf.Tensor: shape=(), dtype=int32, numpy=1>, <tf.Tensor: shape=(), dtype=int32, numpy=1>]))
但是当我用布尔张量运行它时:
print(output(tf.constant(True), tf.constant(False)))
我得到以下异常:
Traceback (most recent call last):
File "t.py", line 24, in <module>
print(output(tf.constant(True), tf.constant(False)))
File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 580, in __call__
result = self._call(*args, **kwds)
File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 618, in _call
results = self._stateful_fn(*args, **kwds)
File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2419, in __call__
graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2777, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2667, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 981, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 441, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 968, in wrapper
raise e.ag_error_metadata.to_exception(e)
AssertionError: in user code:
t.py:9 output *
if output1:
/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/autograph/operators/control_flow.py:924 if_stmt
basic_symbol_names, composite_symbol_names)
/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/autograph/operators/control_flow.py:962 tf_if_stmt
error_checking_orelse)
/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py:507 new_func
return func(*args, **kwargs)
/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:1177 cond
return cond_v2.cond_v2(pred, true_fn, false_fn, name)
/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/ops/cond_v2.py:101 cond_v2
name=scope)
/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/ops/cond_v2.py:220 _build_cond
_make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph])
/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/ops/cond_v2.py:633 _make_indexed_slices_indices_types_match
assert len(set(outs_per_branch)) == 1, outs_per_branch
AssertionError: [3, 0]
我知道 AutoGraph 不允许 a 的两个分支tf.cond
具有不同的输出大小。因此,任何人都知道什么可能是获得预期结果的解决方法?
第一次更新
我尝试tf.TensorArray
按照建议使用,仍然无法正常工作,但出现了不同的错误:
import tensorflow as tf
@tf.function
def output(output1, output2):
first_conditional_output = tf.TensorArray(tf.int32, size=0, dynamic_size=True, clear_after_read=False)
second_conditional_output = tf.TensorArray(tf.int32, size=0, dynamic_size=True, clear_after_read=False)
for i in range(2):
if output1:
first_conditional_output = first_conditional_output.write(i, [i, i, i])
if output2:
second_conditional_output = second_conditional_output.write(i, [i, i])
outputs = tf.TensorArray(tf.int32, size=0, dynamic_size=True, clear_after_read=False)
outputs = outputs.write(0, 0)
if output1:
outputs = outputs.write(1, first_conditional_output.stack())
if output2:
outputs = outputs.write(2, second_conditional_output.stack())
return outputs.stack()
出现错误:
Traceback (most recent call last):
File "t.py", line 24, in <module>
print(output(True, False))
File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 580, in __call__
result = self._call(*args, **kwds)
File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 627, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 506, in _initialize
*args, **kwds))
File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2446, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2777, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2667, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 981, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 441, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 968, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
t.py:18 output *
outputs = outputs.write(1, first_conditional_output.stack())
/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/util/tf_should_use.py:235 wrapped **
return _add_should_use_warning(fn(*args, **kwargs),
/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/ops/tensor_array_ops.py:1159 write
return self._implementation.write(index, value, name=name)
/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/ops/tensor_array_ops.py:536 write
self._check_element_shape(value.shape)
/opt/anaconda3/envs/transformers/lib/python3.7/site-packages/tensorflow/python/ops/tensor_array_ops.py:505 _check_element_shape
(shape, self.element_shape))
ValueError: Inconsistent shapes: saw (None, 3) but expected ()
我无法强制使用element_shape
参数,outputs
因为first_conditional_output
并且second_conditional_output
具有不同的形状。
解决方案
推荐阅读
- php - 在 PHP 上保存日期差异的结果
- android - 前台服务和 JobScheduler 的区别
- java - 在场景之间实现类似横向滚动的过渡
- docker - Docker stack deploy cannot deploy service to different node in swarm cluster
- python - 如何从熊猫数据框中以列形式存在的json格式数据中检索某些键和值
- ruby-on-rails - 为什么 Active Storage 在页面上嵌入文本消息?
- end-user - 在没有 Visual Studio 的最终用户计算机上运行程序
- ios - 无法在表格视图中快速显示检索到的 API 数据
- pytest - Codecov:错误处理覆盖率报告
- pgadmin-4 - PGadmin V4.17 - Windows 10 - 无法为表“错误消息”创建选择、更新脚本错误获取脚本的 SQl:'attname'