首页 > 解决方案 > 在 GCP TPU 上嵌入 TPU 性能

问题描述

我正在使用单个 TPU 内核在 GCP TPU v3 上测试 TPUEmbedding 的性能。我发现我只能获得大约 1-2 GB/s 的内存带宽。这与规范(900GB/s)相比非常低。想知道代码有什么问题。这是使用 tensroflow '2.3.0-dev20200620'

要运行代码,您需要设置环境变量 TPU_TP

import time
import tensorflow as tf
import itertools
import numpy as np
import os
import sys

from tensorflow.python.ops import init_ops_v2
from tensorflow.python.tpu import tpu_embedding_v2
from tensorflow.python.tpu import tpu_embedding_v2_utils
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.eager import def_function
from tensorflow.python.eager import remote
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.util import nest

batch = 16384
nnz = 30
em = 128
features = 1000000
feature_watched_values = np.random.randint(0, features, (batch * nnz * 1, ))
batch_size = batch * nnz 
resolver = None

table_test = tpu_embedding_v2_utils.TableConfig(
        vocabulary_size=features,
        dim=em,
        initializer=None,
        combiner='sum',
        name='test')
feature_config = (
        tpu_embedding_v2_utils.FeatureConfig(
            table=table_test, name='watched'))

def get_strategy():
   resolver = tpu_cluster_resolver.TPUClusterResolver(tpu="grpc://"+os.environ["TPU_IP"])
   remote.connect_to_cluster(resolver)
   topology = tpu_strategy_util.initialize_tpu_system(resolver)
   device_assignment = tf.python.tpu.device_assignment.DeviceAssignment.build(topology,computation_shape=[1, 1, 1, 1],num_replicas=1)

   return tpu_strategy.TPUStrategy(resolver, device_assignment=device_assignment)

def create_strategy_and_mid_level():
   strategy = get_strategy()
   with strategy.scope():
       optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
       embedding = tpu_embedding_v2.TPUEmbedding(
           feature_config=feature_config,
           batch_size=batch_size,
           optimizer=optimizer)

   return strategy, embedding, optimizer

strategy, embedding, optimizer = create_strategy_and_mid_level()
training = False

def create_dense_input_fn(strategy, include_weights=False, weight=0.5):
    def input_fn(ctx):
      del ctx
      features = (feature_watched_values)
      return dataset_ops.DatasetV2.from_tensor_slices(features).repeat().batch(batch_size)
    return input_fn

def get_replica_numpy(structured, strategy, replica_id):

    def select_replica(x):
      x = strategy.experimental_local_results(x)
      if len(x) == 1:
        return x 
 
      return x[replica_id] 

    return nest.map_structure(select_replica, structured)

input_fn = create_dense_input_fn(strategy)
dist = strategy.experimental_distribute_datasets_from_function(
        input_fn,
        options=distribute_lib.InputOptions(
            experimental_prefetch_to_device=False))
dist_iter = iter(dist)

# @def_function.function
@tf.function
def test_fn():
      def step():
        print("In STEPs")
        activation = embedding.dequeue()
        shard0 = get_replica_numpy(activation, strategy, 0)
        res = tf.math.reduce_sum(tf.reshape(shard0[0], [batch, nnz, em]), axis=1)
        print("RES device : ", res.device)
        return res

      embedding.enqueue(next(dist_iter), training=False)
      return strategy.run(step)

def test_dense_lookup():
    steps = 4
    warmups = 1
    start = time.time()
    for i in range(0, steps+warmups):
        res = test_fn()
    end0 = time.time()
    res.numpy()  
    end = time.time()

    total_bytes = batch * nnz * em * tf.float32.size
    print("Test batch = ", batch, " nnz = ", nnz, ", em = ", em)
    print(" RES shape: ", res.shape)
    print("Whole loop time is : ", end0 - start, end - start)
    print("TPU: total bytes {0}, mem bw {1:.3f} GB/s".format(total_bytes, total_bytes*1.0*steps/(end - start)/1.0e9))
    
test_dense_lookup()

print("done")

标签: tensorflowmemoryembeddingbandwidthtpu

解决方案


推荐阅读