首页 > 解决方案 > tensorflow 在使用 tf.map_fn 时消耗太多内存

问题描述

我有以下代码:

    @tf.function
    def c(self, point_instances, instances, alpha):
        def inner_comp(j):
            print(j.shape)
            print(tf.transpose(instances).shape)
            print(alpha.shape)
            print(tf.linalg.matvec(tf.transpose(instances), alpha).shape)
            tmp = tf.tensordot(j, tf.linalg.matvec(tf.transpose(instances), alpha), 1)
            print(tmp.shape)
            return tmp

        return tf.reduce_max(tf.abs(tf.map_fn(inner_comp, point_instances)))

我注意到当tf.map_fn完成第二次迭代时,由于内存不足,它被杀死了。所以控制台输出如下:

(245245,)
(245245, 460)
(460,)
(245245,)
()

(245245,)
(245245, 1040)
(1040,)
(245245,)
()


Killed

虽然我的数据是大规模的,但输出的inner_comp始终是标量。我试图分配 200 GB 的内存,但它仍然被杀死。原因是什么?

标签: python-3.xtensorflow

解决方案


推荐阅读