首页 > 解决方案 > TensorFlow 2.4.1:内存和运行时效率 SUM(MULTIPLY(X, Y))

问题描述

我正在尝试将两个非常大的张量相乘并将结果汇​​总到一个轴上。请参阅下面的脚本。我知道有两种选择:

  1. tf.reduce_sum(x1 * x2, axis=0). 这非常快,但在我的 GPU 上的某个点出现 OOM
  2. tf.einsum("a...,a...->...", x1, x2). 这慢了约 1000 倍,但内存效率更高。我认为是因为它在执行数据读取和乘法时迭代地累积结果。

是否有一个快速但内存效率高的中间地带?我正在尝试增加n_a1000 倍(需要 8TiB 的内存)。请注意,输入会小约 48 倍,因为发生了广播乘法。一个想法 - 我认为 Grappler(TF Graph Optimizer)可以将 x1 和 x2 分段相乘并迭代地求和结果。我听说 Grappler 有一个内存优化选项 [1],但在 tf.config.optimizer.set_experimental_options()文档 [2] 中没有提到它。有没有人有任何想法?

我已经开始通过 tf.distribute.Strategy 使用多个 GPU,但是 8TiB 也太多了。此外,我正在使用的@tf.function 比脚本中的那个更密集,但是在运行该 multiply_and_sum 时它会出现 OOM。

脚本

import argparse
import numpy as np
import tensorflow as tf
from timeit import default_timer as timer
import tensorflow.keras.backend as K
import time


parser = argparse.ArgumentParser(description='Train the Keller Attention Network')
parser.add_argument("-einsum", action='store_true',
                            help="use einsum")
FLAGS = parser.parse_args()

@tf.function
def einsum(x1, x2):
    print ("Tracing  einsum before running the benchmark...")
    return tf.einsum("a...,a...->...", x1, x2)

@tf.function
def multiply_and_sum(x1, x2):
    print ("Tracing multiply and sum before running the benchmark...")
    return tf.reduce_sum(x1 * x2, axis=0)

def print_expected_bytes(factors):
    bytes_per_float32 = 4
    total_B = bytes_per_float32 * np.prod(factors)
    total_GiB = total_B / 2 ** 30

    print (f"Expected Memory Usage: {total_GiB:.2f}GiB.")

def benchmark(func, n_runs):
    # Allow us to see memory usage via `nvidia-smi`
    gpus = tf.config.list_physical_devices('GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

    # Input Shapes
    n_a = 100
    n_b = 5
    n_c = 48
    n_d = 48
    n_e = 1000
    n_f = 2

    # Input.
    x1 = tf.random.normal((n_a, n_b, n_c, 1, n_e, n_f))
    x2 = tf.random.normal((n_a, 1, 1, n_d, n_e, 1))


    # Print expected bytes to anticipate OOMs.
    print_expected_bytes([n_a, n_b, n_c, n_d, n_e, n_f])

    # Trace function first.
    func(x1, x2)
    print (f"Done tracing.")

    # Benchmark multiply_and_sum
    start = timer()
    for _ in range(n_runs):
        result = func(x1, x2)
    average_time_s = (timer() - start) / n_runs
    print (f"{func.__name__}: {average_time_s:.4f} seconds.")

def main():
    if FLAGS.einsum:
        benchmark(einsum, 3)
    else:
        benchmark(multiply_and_sum, 100)


if __name__ == "__main__":
    main()

脚本输出

(tensorflow2_latest_p37) ubuntu@ip-172-31-53-116:~/code/audiofocus$ python memory_benchmark.py -einsum
Expected Memory Usage: 8.58GiB.
Tracing  einsum before running the benchmark...
Done tracing.
einsum: 2.1754 seconds.
(tensorflow2_latest_p37) ubuntu@ip-172-31-53-116:~/code/audiofocus$ python memory_benchmark.py
Expected Memory Usage: 8.58GiB.
Tracing multiply and sum before running the benchmark...
Done tracing.
multiply_and_sum: 0.0001 seconds.

标签: pythontensorflowout-of-memory

解决方案


推荐阅读