python - tensorflow参差不齐的堆栈问题
问题描述
我正在尝试在我的模型中使用 tf.ragged.stack。当我在玩它时,我可以做以下事情:
tensor = tf.constant([[1., 2.], [3., 4.], [5., 6.]])
masks = tf.constant([[1, 1, 1], [0, 0, 0], [1, 0, 1]])
tf.ragged.stack([tf.boolean_mask(tensor, mask) for mask in masks])
它提供:
<tf.RaggedTensor [[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [], [[1.0, 2.0], [5.0, 6.0]]]>
这是完美的,我想要什么。
但是,一旦我将类似的代码放入我的模型中,它就会失败:
tensor = tf.keras.layers.Dense(2, activation = 'elu', use_bias = False)(tf.keras.Input(shape=(None, 2), dtype='float32'))
tensor = tf.reshape(tensor, [3, 2])
masks = tf.keras.Input(shape=(None, 3), dtype='int32')
masks = tf.reshape(masks, [3,3])
rag = tf.ragged.stack([tf.boolean_mask(tensor, mask) for mask in masks])
错误是:
---------------------------------------------------------------------------
_FallbackException Traceback (most recent call last)
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py in concat_v2(values, axis, name)
1171 _ctx._context_handle, tld.device_name, "ConcatV2", name,
-> 1172 tld.op_callbacks, values, axis)
1173 return _result
_FallbackException: This function does not handle the case of the path where all inputs are not already EagerTensors.
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
<ipython-input-114-d762bdd6bb0d> in <module>
3 masks = tf.keras.Input(shape=(None, 3), dtype='int32')
4 masks = tf.reshape(masks, [3,3])
----> 5 rag = tf.ragged.stack([tf.boolean_mask(tensor, mask) for mask in masks])
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/ragged/ragged_concat_ops.py in stack(values, axis, name)
116 values = [values]
117 with ops.name_scope(name, 'RaggedConcat', values):
--> 118 return _ragged_stack_concat_helper(values, axis, stack_values=True)
119
120
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/ragged/ragged_concat_ops.py in _ragged_stack_concat_helper(rt_inputs, axis, stack_values)
185 if not ragged_tensor.is_ragged(rt_inputs[i]):
186 rt_inputs[i] = ragged_tensor.RaggedTensor.from_tensor(
--> 187 rt_inputs[i], ragged_rank=1, row_splits_dtype=row_splits_dtype)
188
189 # Convert the input tensors to all have the same ragged_rank.
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/ragged/ragged_tensor.py in from_tensor(cls, tensor, lengths, padding, ragged_rank, name, row_splits_dtype)
1779 # vector that contains no default values, and reshape the input tensor
1780 # to form the values for the RaggedTensor.
-> 1781 values_shape = array_ops.concat([[-1], input_shape[2:]], axis=0)
1782 values = array_ops.reshape(tensor, values_shape)
1783 const_nrows = tensor_shape.dimension_at_index(tensor.shape, 0).value
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
178 """Call target, and fall back on dispatchers if there is a TypeError."""
179 try:
--> 180 return target(*args, **kwargs)
181 except (TypeError, ValueError):
182 # Note: convert_to_eager_tensor currently raises a ValueError, not a
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py in concat(values, axis, name)
1604 dtype=dtypes.int32).get_shape().assert_has_rank(0)
1605 return identity(values[0], name=name)
-> 1606 return gen_array_ops.concat_v2(values=values, axis=axis, name=name)
1607
1608
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py in concat_v2(values, axis, name)
1175 try:
1176 return concat_v2_eager_fallback(
-> 1177 values, axis, name=name, ctx=_ctx)
1178 except _core._SymbolicException:
1179 pass # Add nodes to the TensorFlow graph.
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py in concat_v2_eager_fallback(values, axis, name, ctx)
1207 "'concat_v2' Op, not %r." % values)
1208 _attr_N = len(values)
-> 1209 _attr_T, values = _execute.args_to_matching_eager(list(values), ctx)
1210 _attr_Tidx, (axis,) = _execute.args_to_matching_eager([axis], ctx, _dtypes.int32)
1211 _inputs_flat = list(values) + [axis]
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/execute.py in args_to_matching_eager(l, ctx, default_dtype)
261 ret.append(
262 ops.convert_to_tensor(
--> 263 t, dtype, preferred_dtype=default_dtype, ctx=ctx))
264 if dtype is None:
265 dtype = ret[-1].dtype
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, dtype_hint, ctx, accepted_result_types)
1315 raise ValueError(
1316 "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
-> 1317 (dtype.name, value.dtype.name, value))
1318 return value
1319
ValueError: Tensor conversion requested dtype int32 for Tensor with dtype int64: <tf.Tensor 'strided_slice_142103:0' shape=(0,) dtype=int64>
谁能告诉我发生了什么?
我在想的是 tf.ragged.stack 方法不适用于 tf.keras.Input 等占位符。
解决方案
你用的是什么版本的TF?以下代码在我在 Colab(使用 TF2.4)上测试时有效
但是,主要问题似乎来自数据类型。
ValueError: Tensor conversion requested dtype int32 for Tensor with dtype int64: <tf.Tensor 'strided_slice_142103:0' shape=(0,) dtype=int64>
您需要将输入转换为一种int32
格式以使您tf.ragged.stack
满意:
converted_masks = tf.cast([tf.boolean_mask(tensor, mask) for mask in masks], tf.int32)
rag = tf.ragged.stack(converted_masks)
推荐阅读
- sql-server - SQL Server Bad query performance: Bad Rows Estimate, but Statistic is accurate?
- android - E/paypal.sdk: INTERNAL_SERVER_ERROR -PayPal Integration (Android)
- c++ - 我需要帮助交换这个字符串
- flutter - 使用 BLoC 更新列表项后重建 Listview
- android-studio - android studio 中的 toast 没有正确显示消息?
- xml - XML 文件无法解析 - 同步 XMLHttpRequest 已贬值
- vba - 我可以实现一个包含特定用户窗体的公共成员变量的接口吗?
- java - CompletableFuture 在使用 allof 时跳过一些任务
- javascript - 为什么控制台给我未定义的结果
- ruby-on-rails - Ruby Rails 在 db 表中从低字段和高字段中查找 IP 地址