tensorflow - 如何训练权重限制为特定值的神经网络?
问题描述
我正在尝试训练一个权重只能具有某些值的网络。但是,我这样做的方式需要很长时间,例如 MNIST 上的 3 层全连接网络每个 epoch 需要 5 小时。有没有更快的方法来做到这一点?
我正在使用 tf.keras 来构建我的网络。我添加了一个自定义 tf.constraint,它在更新权重时对可能的权重值列表进行二进制搜索。我从这里找到了适用于我的应用程序的二进制搜索代码。为了将二进制搜索功能应用于所有参数,我使用“tf.map_fn”。
这是约束类:
from tensorflow.python.keras.constraints import Constraint
import tensorflow as tf
# binary search function
def find(weights, query, shape):
vals = tf.map_fn(lambda x: weights[tf.argmin(tf.cast(x >= weights, dtype=tf.int32)[1:] - tf.cast(x >= weights, dtype=tf.int32)[:-1])], tf.reshape(query,[-1]))
return tf.reshape(vals, shape)
class WeightQuantizeClip(Constraint):
# weights parameter holds the possible weight values
def __init__(self, weights = []):
self.weights = tf.convert_to_tensor(weights)
def __call__(self, p):
p = find(self.weights, p, p.shape)
return p
def get_config(self):
return {'name': self.__class__.__name__}
当我训练具有上述约束的网络时,权重仅来自可能的权重值,但训练时间会大大增加。没有二分搜索功能,我的 GPU 得到了充分利用,但是当我使用二分搜索功能进行训练时,利用率下降到 2%。谁能帮我这个?
解决方案
从您的描述看来,剪辑操作的某些部分在 CPU 上执行,这需要非常慢的 RAM-VRAM 通信。
但是,如果您尝试进行传统的 NN 量化,实际上为此目的构建了一个完整的 TF 模块,您可能想检查一下,也许它涵盖了您的用例。
https://www.tensorflow.org/api_docs/python/tf/quantization/quantize
推荐阅读
- display - 谷歌珊瑚开发板不适用于 HDMI LCD
- javascript - echo alert() in php show everything in the string. octobercms plugin
- javascript - 嵌套对象存储为状态变量,当旧值被修改时删除新值
- c++ - 如何将opencv mat灰度图像转换为pytorch张量?
- coq - 小于自然数关系
- android - 如何在暂停的协程流中使用回调?
- python - 如何计算python中一个类中的__repr__函数调用次数?
- node.js - 无谷歌翻译的 NPM 包出现 405 访问错误
- php - 使用 jQuery、AJAX、PHP 删除数据库中的行
- ios - iOS:如何从磁盘上的照片文件中提取缩略图和元数据