python - OperatorNotAllowedInGraphError:不允许迭代 `tf.Tensor`:AutoGraph 确实将此函数转换为自定义模型
问题描述
我正在尝试通过一些调整来实现时间序列的“注意力就是你所需要的”论文,但我收到了这个错误:
tf.Tensor
OperatorNotAllowedInGraphError:不允许迭代:AutoGraph 确实转换了此函数。
代码:
import tensorflow as tf
from tensorflow import keras
class Attention(tf.keras.layers.Layer):
def __init__(self, dk, dv, num_heads, filter_size):
super().__init__()
self.dk = dk
self.dv = dv
self.num_heads = num_heads
self.conv_q = tf.keras.layers.Conv1D(dk * num_heads, filter_size, padding='causal')
self.conv_k = tf.keras.layers.Conv1D(dk * num_heads, filter_size, padding='causal')
self.dense_v = tf.keras.layers.Dense(dv * num_heads)
self.dense1 = tf.keras.layers.Dense(dv, activation='relu')
self.dense2 = tf.keras.layers.Dense(dv)
def split_heads(self, x, batch_size, dim):
x = tf.reshape(x, (batch_size, -1, self.num_heads, dim))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, inputs):
batch_size, time_steps, _ = tf.shape(inputs)
q = self.conv_q(inputs)
k = self.conv_k(inputs)
v = self.dense_v(inputs)
q = self.split_heads(q, batch_size, self.dk)
k = self.split_heads(k, batch_size, self.dk)
v = self.split_heads(v, batch_size, self.dv)
mask = 1 - tf.linalg.band_part(tf.ones((batch_size, self.num_heads, time_steps, time_steps)), -1, 0)
dk = tf.cast(self.dk, tf.float32)
score = tf.nn.softmax(tf.matmul(q, k, transpose_b=True)/tf.math.sqrt(dk) + mask * -1e9)
outputs = tf.matmul(score, v)
outputs = tf.transpose(outputs, perm=[0, 2, 1, 3])
outputs = tf.reshape(outputs, (batch_size, time_steps, -1))
outputs = self.dense1(outputs)
outputs = self.dense2(outputs)
return outputs
class Transformer(tf.keras.models.Model):
"""
Time Series Transformer Model
"""
def __init__(self, dk, dv, num_heads, filter_size):
super().__init__()
self.attention = Attention(dk, dv, num_heads, filter_size)
self.dense_sigma = tf.keras.layers.Dense(1)
def call(self, inputs):
outputs = self.attention(inputs)
sigma = self.dense_sigma(outputs)
return sigma
Mymodel= Transformer(3,3,4,3)
Mymodel.compile(loss="mean_squared_error",
optimizer=keras.optimizers.Adam(learning_rate=1e-4),)
Mymodel.fit(X_train,Y_train,epochs=10,batch_size=32)
#X_train & Y_train are numpy array with shape ( batch_size , timesteps , no.of features )
完全错误:
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
984 except Exception as e: # pylint:disable=broad-except
985 if hasattr(e, "ag_error_metadata"):
--> 986 raise e.ag_error_metadata.to_exception(e)
987 else:
988 raise
OperatorNotAllowedInGraphError: in user code:
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:855 train_function *
return step_function(self, iterator)
<ipython-input-20-56419dd4aeb6>:61 call *
outputs = self.attention(inputs)
<ipython-input-3-a936077354d3>:24 call *
batch_size, time_steps, _ = tf.shape(inputs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:520 __iter__
self._disallow_iteration()
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:513 _disallow_iteration
self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:491 _disallow_when_autograph_enabled
" indicate you are trying to use an unsupported feature.".format(task))
OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
解决方案
似乎在图形模式下,为了解包张量,它会尝试迭代结果。无论如何,您可以改用它:
batch_size = tf.shape(inputs)[0]
time_steps = tf.shape(inputs)[1]
我的第一个建议是使用,但是我修改了我的答案,因为这里.shape
有来自 tensorflow 文档的提示:
tf.shape(x)
并且x.shape
在急切模式下应该是相同的。在 内tf.function
,直到执行时间才可能知道所有维度。因此,在为图形模式定义自定义层和模型时,更喜欢动态tf.shape(x)
而不是静态x.shape
。
推荐阅读
- grub2 - 早期配置中的 GRUB2 菜单
- python - 如何使用 Selenium 从 devtools Network 面板中检索“Initiator”字段?
- css - 具有渐变叠加 + 背景大小的可重复背景图像
- sql - 连接两个sql结果
- html - 网页突然不显示图像或下载脚本:431错误
- flutter - Flutter,如何制作简单的截图分享?
- google-apps-script - 将一堆单张文件编译成大型电子表格的最简单方法是什么?
- sql - 在 SQL 查询代码中放置未声明的变量会生成输入框?
- javascript - 如何在 ReactJS 中启用悬停
- html - Angular/ HTML5 到 iOS WKWebView 通信