首页 > 解决方案 > 如何训练权重限制为特定值的神经网络?

问题描述

我正在尝试训练一个权重只能具有某些值的网络。但是,我这样做的方式需要很长时间,例如 MNIST 上的 3 层全连接网络每个 epoch 需要 5 小时。有没有更快的方法来做到这一点?

我正在使用 tf.keras 来构建我的网络。我添加了一个自定义 tf.constraint,它在更新权重时对可能的权重值列表进行二进制搜索。我从这里找到了适用于我的应用程序的二进制搜索代码。为了将二进制搜索功能应用于所有参数,我使用“tf.map_fn”。

这是约束类:

from tensorflow.python.keras.constraints import Constraint
import tensorflow as tf

# binary search function
def find(weights, query, shape):
    vals = tf.map_fn(lambda x: weights[tf.argmin(tf.cast(x >= weights, dtype=tf.int32)[1:] - tf.cast(x >= weights, dtype=tf.int32)[:-1])], tf.reshape(query,[-1]))
    return tf.reshape(vals, shape)

class WeightQuantizeClip(Constraint):
    # weights parameter holds the possible weight values
    def __init__(self, weights = []):

        self.weights = tf.convert_to_tensor(weights)

    def __call__(self, p):
        p = find(self.weights, p, p.shape)
        return p

    def get_config(self):
        return {'name': self.__class__.__name__}

当我训练具有上述约束的网络时,权重仅来自可能的权重值,但训练时间会大大增加。没有二分搜索功能,我的 GPU 得到了充分利用,但是当我使用二分搜索功能进行训练时,利用率下降到 2%。谁能帮我这个?

标签: tensorflow

解决方案


从您的描述看来,剪辑操作的某些部分在 CPU 上执行,这需要非常慢的 RAM-VRAM 通信。

但是,如果您尝试进行传统的 NN 量化,实际上为此目的构建了一个完整的 TF 模块,您可能想检查一下,也许它涵盖了您的用例。

https://www.tensorflow.org/api_docs/python/tf/quantization/quantize


推荐阅读