首页 > 解决方案 > 如何减少内存的使用?

问题描述

这是我的代码示例

def normalize_3D(input):
    for i in range(input.shape[0]):
        s = tf.concat([tf.reshape(input[i, 9, 0], shape=[1, 1]),
                       tf.reshape(input[i, 9, 1], shape=[1, 1]),
                       tf.reshape(input[i, 9, 2], shape=[1, 1])], axis=1)

        output = input[i, :, :] - s
        output2 = output / tf.sqrt(tf.square(input[i, 9, 0] - input[i, 0, 0]) +
                                   tf.square(input[i, 9, 1] - input[i, 0, 1]) +
                                   tf.square(input[i, 9, 2] - input[i, 0, 2]))
        output2 = tf.reshape(output2, [1, input.shape[1], input.shape[2]])
        if i == 0:
            output3 = output2
        else:
            output3 = tf.concat([output3, output2], axis=0)

    return output3

像这个示例一样,我多次使用“for”状态来计算只有几批的数据。但是,当我编写代码时,我注意到它使用了大量内存并且出现了错误消息。我的一些预测只是显示“nan”,然后程序就卡住了。

当我计算批处理数据时,有什么方法可以减少这种内存滥用?

标签: pythontensorflow

解决方案


您的功能可以用更简单、更有效的方式表达,如下所示:

import tensorflow as tf

def normalize_3D(input):
    shift = input[:, 9]
    scale = tf.norm(input[:, 9] - input[:, 0], axis=1, keepdims=True)
    output = (input - tf.expand_dims(shift, 1)) / tf.expand_dims(scale, 1)
    return output

推荐阅读