首页 > 解决方案 > TF map_fn 太慢了

问题描述

我想在 tensorflow 中创建一个自定义层,它应该将函数f应用于传入的张量。因此,如果批次由张量组成,T = [T1, T2, ..., Tn]它应该返回张量[f(T1), f(T2), ..., f(Tn)]

这样做的预期方法似乎是使用该tf.map_fn功能。但是,我注意到这个函数非常慢。下面是一个 MWE,它在我的笔记本电脑上产生以下性能:

有什么方法可以加快批量大小的迭代?

import tensorflow as tf
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

class Identity(tf.keras.layers.Layer):
    def __init__(self,  **kwargs):
        super(Identity, self).__init__(**kwargs)

    def call(self, inputs):
        output = tf.map_fn(lambda x: x, inputs)
#        output = inputs
        return output  

    def compute_output_shape(self, input_shape):
        return input_shape

model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(512, activation=tf.nn.relu),
        Identity(),
        tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, batch_size=100)

标签: pythontensorflow

解决方案


推荐阅读