python - 如何在子类化 keras 层/模型时正确使用`@tf.function`?
问题描述
我有一个自定义tf.keras.layers.Layer
,它只使用 TF 运算符进行某种位解包(将整数转换为布尔值(0 或 1 浮点数))。
class CharUnpack(keras.layers.Layer):
def __init__(self, name="CharUnpack", *args, **kwargs):
super(CharUnpack, self).__init__(trainable=False, name=name, *args, **kwargs)
# Range [7, 6, ..., 0] to bit-shift integers
self._shifting_range = tf.reshape(
tf.dtypes.cast(
tf.range(7, -1, -1, name='shifter_range'),
tf.uint8,
name='shifter_cast'),
(1, 1, 8),
name='shifter_reshape')
# Constant value 0b00000001 to use as bitwise and operator
self._selection_bit = tf.constant(0x01, dtype=tf.uint8, name='and_selection_bit')
def call(self, inputs):
return tf.dtypes.cast(
tf.reshape(
tf.bitwise.bitwise_and(
tf.bitwise.right_shift(
tf.expand_dims(inputs, 2),
self._shifting_range,
),
self._selection_bit,
),
[x if x else -1 for x in self.compute_output_shape(inputs.shape)]
),
tf.float32
)
def compute_output_shape(self, input_shape):
try:
if len(input_shape) > 1:
output_shape = tf.TensorShape(tuple(list(input_shape[:-1]) + [input_shape[-1] * 8]))
else:
output_shape = tf.TensorShape((input_shape[0] * 8,))
except TypeError:
output_shape = input_shape
return output_shape
def compute_output_signature(self, input_signature):
return tf.TensorSpec(self.compute_output_shape(input_signature.shape), tf.float32)
如本TF 指南所示,我尝试对这一层进行基准测试以提高时间性能。
inputs = tf.zeros([64, 400], dtype=tf.uint8)
eager = CharUnpack()
@tf.function
def fun(x):
eager(x)
# Warm-up
eager(inputs)
fun(inputs)
print("Function:", timeit.timeit(lambda: fun(inputs), number=100))
print("Eager:", timeit.timeit(lambda: eager(inputs), number=100))
Function: 0.01062483999885444
Eager: 0.12658399900101358
如你所见,我可以得到 10 倍的加速!!!所以,我将@tf.function
装饰器添加到我的CharUnpack.call
方法中:
+ @tf.function
def call(self, inputs):
return tf.dtypes.cast(
现在我希望 theeager
和 thefun
调用都花费相似的时间,但我没有得到任何改善。
Function: 0.009667591999459546
Eager: 0.10346330100037449
此外,在这个SO 答案的第 2.1 节中,模型默认是图形编译的(这应该是逻辑),但情况似乎并非如此......
如何正确使用@tf.function
装饰器使我的层始终图形编译?
解决方案
推荐阅读
- javascript - 在 Bing 地图中显示来自数据源 URL 的所有业务位置
- python - 如何检查 TPU 设备类型是 v2 还是 v3?
- javascript - 编译 ejs 时 /workspace/Frontend/todoApp/views/board.ejs 中出现意外的标识符?
- python - 如何使用 Django 配置 Celery,以便工作人员能够正确接收任务?
- nestjs - 在nestjs中是否可以为同一个ROUTE指定多个处理程序?
- unity3d - 如何将 Mixamo 动画应用到从 Unity 商店下载的头像
- spring - 如何使用 Spring webflux 实现/迁移 OncePerRequestFilter
- flash - 将站点移至 GCP,现在我的 flash 对象无法正确呈现
- vb.net - 从一个值中获取所有可能的组合
- terraform - 我应该对 .terraform 文件夹进行版本控制吗?