首页 > 解决方案 > DL4J:用于语义散列的自动编码器中间层的二值化

问题描述

我正在尝试MNISTAutoencoder使用 DL4J 为示例实现语义哈希。我将如何二值化中间层激活?在理想情况下,我正在寻找对我的网络设置进行一些更改,从而为开箱即用的中间层提供(几乎)二进制激活。或者,我很高兴有一些“收据”来二值化当前的 RELU 激活。就泛化能力而言,这两种方法中哪一种是有利的?

我当前的网络设置是:

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
    .seed(12345)
    .weightInit(WeightInit.XAVIER)
    .updater(new AdaGrad(0.05))
    .activation(Activation.RELU)
    .l2(0.0001)
    .list()
    .layer(new DenseLayer.Builder().nIn(784).nOut(250)
               .build())
    .layer(new DenseLayer.Builder().nIn(250).nOut(10)
               .build())
    .layer(new DenseLayer.Builder().nIn(10).nOut(250)
               .build())
    .layer(new OutputLayer.Builder().nIn(250).nOut(784)
               .activation(Activation.LEAKYRELU)
               .lossFunction(LossFunctions.LossFunction.MSE)
               .build())
    .build();

在 30 个 epoch 之后,典型的中间层激活如下所示:

[[   11.3044,   12.3678,    7.3547,    1.6518,    1.0068,         0,    5.4340,    2.1388,    2.0708,    2.5764]]
[[    9.9051,   12.5345,   11.1941,    4.7900,    1.2935,         0,    7.9786,    4.1915,    3.1802,    7.5659]]
[[    6.4629,   11.1013,   10.8903,    5.4528,    0.8009,         0,    9.4881,    3.6684,    6.4524,    7.2334]]
[[    2.3953,    0.2429,    3.7125,    4.1561,    0.8607,         0,   11.2486,    7.0178,    2.8771,    2.1996]]
[[         0,    1.6378,    0.8993,    0.3347,    0.7708,         0,    3.7053,         0,    1.6704,    2.1380]]
[[         0,    1.5158,    0.7937,         0,    0.8190,         0,    4.7548,    0.0655,    1.4635,    1.8173]]
[[    6.8344,    5.9989,   10.1286,    2.8528,    1.1178,         0,    9.1865,   10.3677,    5.3564,    4.3420]]
[[    7.0942,    7.0364,    4.8538,    0.5096,    0.0442,         0,    8.4336,    8.2783,    5.6474,    3.8944]]
[[    3.6895,   14.9696,    6.5351,    8.0446,         0,         0,   12.7816,   12.7445,    7.8495,    3.8600]]

标签: autoencoderdeeplearning4jdl4j

解决方案


这可以通过将自定义IActivation功能分配给中间层来建立。例如:

public static class ActivationBinary extends BaseActivationFunction {
    public INDArray getActivation(INDArray in, boolean training) {
        in.replaceWhere(Nd4j.ones(in.length()).muli(-1), new LessThan(0));
        in.replaceWhere(Nd4j.ones(in.length()), new GreaterThanOrEqual(0));
        return in;
    }

    public org.nd4j.common.primitives.Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
        this.assertShape(in, epsilon);
        Nd4j.getExecutioner().execAndReturn(new TanhDerivative(in, epsilon, in)); // tanh's gradient is a reasonable approximation
        return new org.nd4j.common.primitives.Pair(in, (Object)null);
    }

    public int hashCode() {
        return 1;
    }

    public boolean equals(Object obj) {
        return obj instanceof ActivationBinary;
    }

    public String toString() {
        return "Binary";
    }
}

推荐阅读