python - 如何减少内存的使用?
问题描述
这是我的代码示例
def normalize_3D(input):
for i in range(input.shape[0]):
s = tf.concat([tf.reshape(input[i, 9, 0], shape=[1, 1]),
tf.reshape(input[i, 9, 1], shape=[1, 1]),
tf.reshape(input[i, 9, 2], shape=[1, 1])], axis=1)
output = input[i, :, :] - s
output2 = output / tf.sqrt(tf.square(input[i, 9, 0] - input[i, 0, 0]) +
tf.square(input[i, 9, 1] - input[i, 0, 1]) +
tf.square(input[i, 9, 2] - input[i, 0, 2]))
output2 = tf.reshape(output2, [1, input.shape[1], input.shape[2]])
if i == 0:
output3 = output2
else:
output3 = tf.concat([output3, output2], axis=0)
return output3
像这个示例一样,我多次使用“for”状态来计算只有几批的数据。但是,当我编写代码时,我注意到它使用了大量内存并且出现了错误消息。我的一些预测只是显示“nan”,然后程序就卡住了。
当我计算批处理数据时,有什么方法可以减少这种内存滥用?
解决方案
您的功能可以用更简单、更有效的方式表达,如下所示:
import tensorflow as tf
def normalize_3D(input):
shift = input[:, 9]
scale = tf.norm(input[:, 9] - input[:, 0], axis=1, keepdims=True)
output = (input - tf.expand_dims(shift, 1)) / tf.expand_dims(scale, 1)
return output
推荐阅读
- scala - sun.security.provider.certpath.SunCertPathBuilderException 将 Akka HTTP 更新到 10.1.14 到 10.2.4
- java - WebView 有时使用 URL 不显示 PDF 内容
- memory - 来自地址的 QEMU KVM 转储格式字符串
- python - 当主脚本也使用 argparse 时,寻找一种方法来显示导入脚本的参数
- python - 默认的 __iter__ 函数是什么?
- python - Flask 成功返回后返回 500 内部错误
- spring-boot - Spring Boot:无法从 https URL 检索文件。失败并显示消息 javax.net.ssl.SSLException : Unsupported or unrecognized SSL message
- javascript - 如何进行 api 调用并在控制台中获取和显示数据
- fluentui-react - 如何在 fluentui-react 的 ProgressIndicator 中使用样式
- c++ - 用于更改结构或类中的 const 变量值的宏或 C++ 模板