首页 > 解决方案 > dl4j lstm 不成功

问题描述

我试图在此链接的页面中间复制练习: https ://d2l.ai/chapter_recurrent-neural-networks/sequence.html

该练习使用正弦函数在 -1 到 1 之间创建 1000 个数据点,并使用循环网络来逼近该函数。

下面是我使用的代码。我将回去研究更多为什么这不起作用,因为当我很容易能够使用前馈网络来近似这个函数时,它对我来说没有多大意义。

      //get data
    ArrayList<DataSet> list = new ArrayList();
   
    DataSet dss = DataSetFetch.getDataSet(Constants.DataTypes.math, "sine", 20, 500, 0, 0);

    DataSet dsMain = dss.copy();

    if (!dss.isEmpty()){
        list.add(dss);
    }

   
    if (list.isEmpty()){

        return;
    }

    //format dataset
   list = DataSetFormatter.formatReccurnent(list, 0);

    //get network
    int history = 10;
    ArrayList<LayerDescription> ldlist = new ArrayList<>();
    LayerDescription l = new LayerDescription(1,history, Activation.RELU);
    ldlist.add(l);     
    LayerDescription ll = new LayerDescription(history, 1, Activation.IDENTITY, LossFunctions.LossFunction.MSE);
    ldlist.add(ll);

    ListenerDescription ld = new ListenerDescription(20, true, false);

    MultiLayerNetwork network = Reccurent.getLstm(ldlist, 123, WeightInit.XAVIER, new RmsProp(), ld);


    //train network
    final List<DataSet> lister = list.get(0).asList();
    DataSetIterator iter = new ListDataSetIterator<>(lister, 50);
    network.fit(iter, 50);
    network.rnnClearPreviousState();


    //test network
    ArrayList<DataSet> resList = new ArrayList<>();
    DataSet result = new DataSet();
    INDArray arr = Nd4j.zeros(lister.size()+1);     
    INDArray holder;

    if (list.size() > 1){
        //test on training data
        System.err.println("oops");

    }else{
        //test on original or scaled data
        for (int i = 0; i < lister.size(); i++) {

            holder = network.rnnTimeStep(lister.get(i).getFeatures());
            arr.putScalar(i,holder.getFloat(0));

        }
    }


    //add originaldata
    resList.add(dsMain);
    //result       
    result.setFeatures(dsMain.getFeatures());
  
    result.setLabels(arr);
    resList.add(result);

    //display
    DisplayData.plot2DScatterGraph(resList);

你能解释一下我需要 1 分 10 隐藏和 1 出 lstm 网络来逼近正弦函数的代码吗?

我没有使用任何归一化作为函数已经是 -1:1 并且我使用 Y 输入作为特征,然后使用以下 Y 输入作为标签来训练网络。

您注意到我正在构建一个可以更轻松地构建网络的类,并且我尝试对问题进行许多更改,但我厌倦了猜测。

以下是我的结果的一些示例。蓝色是数据 红色是结果

在此处输入图像描述

在此处输入图像描述

标签: javadeep-learningdl4jnd4j

解决方案


This is one of those times were you go from wondering why was this not working to how in the hell were my original results were as good as they were.

My failing was not understanding the documentation clearly and also not understanding BPTT.

With feed forward networks each iteration is stored as a row and each input as a column. An example is [dataset.size, network inputs.size]

However with recurrent input its reversed with each row being a an input and each column an iteration in time necessary to activate the state of the lstm chain of events. At minimum my input needed to be [0, networkinputs.size, dataset.size] But could also be [dataset.size, networkinputs.size, statelength.size]

在我之前的示例中,我使用这种格式的数据训练网络 [dataset.size, networkinputs.size, 1]。因此,根据我对低分辨率的理解,lstm 网络根本不应该工作,但至少以某种方式产生了一些东西。

将数据集转换为列表也可能存在一些问题,因为我也更改了为网络提供数据的方式,但我认为问题的大部分是数据结构问题。

以下是我的新结果 不完美,但考虑到这是 5 个训练阶段,非常好


推荐阅读