首页 > 解决方案 > 如何使用带有 DeepLearning4j 的二维数组的 LSTM

问题描述

我正在尝试学习如何将 LSTM 与 deeplearning4j 库一起使用。

我创建了一个虚拟场景,我想根据我收集的数据获得输出(3 个类)。

如果有人好奇,我从这里(http://www.osservatoriodioropa.it/meteoropa/NOAAMO.TXT )得到数据:)

回到场景。我创建了 2 个矩阵,一个带有特征,另一个带有我想要输出的类,就像测试一样。

当我尝试分类器时,我得到了

Exception in thread "main" java.lang.IllegalStateException: 3D input expected to RNN layer expected, got 2

我认为是因为 RnnOutputLayer 需要一个 3d 矩阵,但我无法理解如何填充它。如何将 2d 矩阵转换为将前一个事件与新事件相关联的 3d 矩阵?数据是一个时间序列,我也想根据前几天来关联新一天的分类。(我知道数据可能不适合这种情况,并且有更好的方法来做到这一点,但这只是学习如何使用 LSTM,而不是如何分类这个特定的数据集)

这是到目前为止的代码

public class Test {

    public static void main(String args[]) {
        int events = 5;
        int features = 6;
        int classes = 3;

        double[][] featureMatrix = new double[events][features];
        double[][] labelMatrix = new double[events][classes];

        for (int i = 0; i < events; i++) {
            for (int f = 0; f < features; f++) {
                featureMatrix[i][f] = getFeature(i, f);
            }
            for (int c = 0; c < classes; c++) {
                labelMatrix[i][c] = getResult(i, c);
            }
        }

        INDArray trainingIn = Nd4j.create(featureMatrix);
        INDArray trainingOut = Nd4j.create(labelMatrix);

        DataSet myData = new DataSet(trainingIn, trainingOut);
        MultiLayerNetwork multiLayerNetwork = createModel(features,classes);
        multiLayerNetwork.init();
        multiLayerNetwork.fit(myData);

    }

    private static double getFeature(int i, int f) {
        //dummy
        return 1.;
    }

    private static double getResult(int i, int c) {
        //dummy

        return 1.;
    }

    public static MultiLayerNetwork createModel(int inputNum, int outputNum) {
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .trainingWorkspaceMode(ENABLED).inferenceWorkspaceMode(ENABLED)
                .seed(123456)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .updater(new RmsProp.Builder().learningRate(0.05).rmsDecay(0.002).build())
                .l2(0.0005)
                .weightInit(WeightInit.XAVIER)
                .activation(Activation.TANH)
                .list()

                .layer(new LSTM.Builder().name("1").nIn(inputNum).nOut(inputNum).build())
                .layer(new LSTM.Builder().name("2").nIn(inputNum).nOut(inputNum).build())

                .layer(new RnnOutputLayer.Builder().name("output").nIn(inputNum).nOut(outputNum)
                        .activation(Activation.IDENTITY).lossFunction(LossFunctions.LossFunction.MSE).build())
                .build();
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();
        return net;
    }

}

标签: deep-learninglstmdeeplearning4jnd4j

解决方案


推荐阅读