首页 > 解决方案 > 如何在 keras 网络中手动设置参数,比 model.set_weight() 更有效?

问题描述

我正在进行机器学习研究,探索用于训练神经网络的非反向传播选项。我已经实现了一个用于训练 keras 模型的遗传算法,但我想做一些改进。

目前我正在通过提取参数来训练模型,model.get_weights()在遗传算法中修改它们,然后再次设置权重,使用model.set_weights(),进行损失计算。这是通过将网络传递给执行训练并返回历史记录的训练函数来完成的 history = genetic_algorithm(model, *args)

在分析我的代码时,model.set_weights()占代码运行时间的 29%。有什么方法可以以更高效的方式修改模型参数?

这个问题与我提出的另一个问题有关。我将它们分开以使每个问题更简洁。

编辑:

在我的算法中,评估占用了大约 96% 的时间。几乎所有这些时间都花在了 keras 或 tensorflow 上。评估包括设置网络参数、通过网络执行前向传递以及使用损失函数(分类交叉熵)评估输出。

model.set_weights()如前所述,占代码运行时间的 29%,所以我假设剩下的大部分是通过网络的前向传播。

评估代码:

def evaluate_population(population: np.ndarray,
                        model: keras.Model,
                        loss_function,
                        inputs: np.ndarray,
                        target_outputs: np.ndarray,
                        node_indices: np.ndarray) -> np.ndarray:
    losses = np.zeros(len(population))
    params = model.get_weights()
    for i in range(len(population)):
        losses[i] = evaluate_chromosome(population[i], model, loss_function, inputs,
                                        target_outputs, params, node_indices)

    return losses
def evaluate_chromosome(chromosome: np.ndarray,
                        model: keras.Model,
                        loss_function,
                        inputs: np.ndarray,
                        target_outputs: np.ndarray,
                        params: np.ndarray,
                        node_indices: np.ndarray):
    if node_indices is None:
        set_params(model, chromosome, params)
    else:
        set_nbn_params(model, chromosome, params, node_indices)
    predicted_outputs = model(inputs)
    return loss_function(target_outputs, np.squeeze(predicted_outputs.numpy()))
def set_params(model: keras.Model, chromosome: np.ndarray, dummy_params):
    params = []
    i = 0
    for layer in dummy_params:
        ls = layer.size
        flat_layer = chromosome[i:i+ls]
        i += ls
        params.append(np.reshape(flat_layer, layer.shape))

    model.set_weights(params)

标签: tensorflowmachine-learningkerasgenetic-algorithm

解决方案


推荐阅读