首页 > 解决方案 > 将 pytorch 代码移植到 tf2.0:相当于 tf2.0 中的 x_batch.requires_grad = True?

问题描述

我正在尝试将此repo的 pytorch 代码移植到 tf2 代码。代码整体逻辑如下:

我的问题是,为了更新数据,它已成为可训练的,因此我转换x_batchx_batch = tf.Variable(x_batch, trainable=True)但变量 intf不可迭代,因此在通过optimizer.apply_gradients(zip(gradients, x_batch)).

在pytorch中,它相对简单,可以通过

for x_batch in dataloader:    
    x_batch.requires_grad = True
    .
    .
    .
    # update the distilled data
    loss.backward()
    optimizer.step()

我在 tf2.0 中的尝试如下。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense, Conv2D, BatchNormalization, Activation, Input, MaxPooling2D, Dropout, Flatten
from tensorflow.keras.models import Model
from tensorflow import keras
from tensorflow.keras import regularizers
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.layers import BatchNormalization
import numpy as np

def vgg_block(x, filters, layers, name, weight_decay):
    for i in range(layers):
        x = Conv2D(filters, (3, 3), padding='same', kernel_initializer='he_normal',
                     kernel_regularizer=regularizers.l2(weight_decay), name=f'{name}_conv_{i}')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
    return x


def vgg8(x, weight_decay=1e-4, is_classifier=False):
    x = vgg_block(x, 16, 2, 'block_1', weight_decay)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = vgg_block(x, 32, 2, 'block_2', weight_decay)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = vgg_block(x, 64, 2, 'block_3', weight_decay)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Flatten()(x)
    x = Dense(512, kernel_initializer='he_normal', activation='relu', name = 'dense_1')(x)
    x = Dense(10)(x)
    return x
    
def l2_loss(A, B):
    """
    L-2 loss between A and B normalized by length.
    Shape of A should be (features_num, ), shape of B should be (batch_size, features_num)
    """
    # pytorch : (A - B).norm()**2 / B.size(0)
    diff = A-B
    l2loss = tf.nn.l2_loss(diff)  # sum(t ** 2) / 2
    return l2loss / B.shape[0]
    

inputs = Input((32, 32, 3))
model = Model(inputs, vgg8(inputs))

eps = 1.0e-6
bn_stats = []

for layer in model.layers:
    if isinstance(layer, BatchNormalization):
        bn_gamma, bn_beta, bn_mean, bn_var = layer.get_weights()
        #print(bn_mean.shape, bn_var.shape)  # tf.reshape(w, [-1]) 
        bn_stats.append((bn_mean, tf.math.sqrt(bn_var+eps)))

    

extractor = tf.keras.models.Model(inputs=model.inputs,
                        outputs=[layer.output for layer in model.layers if isinstance(layer, Conv2D)])


class UniformDataset(keras.utils.Sequence):
    """
    get random uniform samples with mean 0 and variance 1
    """
    def __init__(self, length, size):
        self.length = length
        self.size = size
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        sample = tf.random.normal(self.size)
        return sample

data = UniformDataset(length=100, size=(32, 32,3))

# convert to tf.data iterator
train_iter = iter(data)
train_data = []
for x in train_iter:
    train_data.append(x)
train_data = tf.stack(train_data, axis=0)
train_dataset = tf.data.Dataset.from_tensor_slices(train_data)

refined_gaussian = []
iterations = 500

for x_batch in train_dataset.batch(32):
    x_batch = tf.Variable(x_batch, trainable=True)  # make x_batch trainable
    
    outputs = extractor(x_batch)
    
    input_mean = tf.zeros([1,3]) 
    input_std = tf.ones([1,3])   
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.5)

    with tf.GradientTape(persistent=True) as tape:
        for it in range(iterations):
            mean_loss = 0
            std_loss = 0

            for cnt, (bn_stat, output) in enumerate(zip(bn_stats, outputs)):  
                output = tf.reshape(output, [output.shape[0], output.shape[-1], -1])
                tmp_mean = tf.math.reduce_mean(output, axis=2)
                tmp_std = tf.math.reduce_std(output, axis=2) + eps
                bn_mean, bn_std = bn_stat[0], bn_stat[1]
                mean_loss += l2_loss(bn_mean, tmp_mean)
                std_loss += l2_loss(bn_std, tmp_std)
            
            #print('mean_loss', mean_loss, 'std_loss', std_loss)
            x_reshape = tf.reshape(x_train, [x_train.shape[0], 3, -1])
            tmp_mean = tf.math.reduce_mean(x_reshape, axis=2)
            tmp_std = tf.math.reduce_std(x_reshape, axis=2) + eps

            mean_loss += l2_loss(input_mean, tmp_mean)
            std_loss += l2_loss(input_std, tmp_std)
            loss = mean_loss + std_loss
            gradients = tape.gradient(loss, x_batch)

            optimizer.apply_gradients(zip(gradients, x_batch))
            
        refined_gaussian.append(x_batch)
        

标签: python-3.xtensorflowdeep-learningquantization

解决方案


推荐阅读