java - dl4j lstm 不成功
问题描述
我试图在此链接的页面中间复制练习: https ://d2l.ai/chapter_recurrent-neural-networks/sequence.html
该练习使用正弦函数在 -1 到 1 之间创建 1000 个数据点,并使用循环网络来逼近该函数。
下面是我使用的代码。我将回去研究更多为什么这不起作用,因为当我很容易能够使用前馈网络来近似这个函数时,它对我来说没有多大意义。
//get data
ArrayList<DataSet> list = new ArrayList();
DataSet dss = DataSetFetch.getDataSet(Constants.DataTypes.math, "sine", 20, 500, 0, 0);
DataSet dsMain = dss.copy();
if (!dss.isEmpty()){
list.add(dss);
}
if (list.isEmpty()){
return;
}
//format dataset
list = DataSetFormatter.formatReccurnent(list, 0);
//get network
int history = 10;
ArrayList<LayerDescription> ldlist = new ArrayList<>();
LayerDescription l = new LayerDescription(1,history, Activation.RELU);
ldlist.add(l);
LayerDescription ll = new LayerDescription(history, 1, Activation.IDENTITY, LossFunctions.LossFunction.MSE);
ldlist.add(ll);
ListenerDescription ld = new ListenerDescription(20, true, false);
MultiLayerNetwork network = Reccurent.getLstm(ldlist, 123, WeightInit.XAVIER, new RmsProp(), ld);
//train network
final List<DataSet> lister = list.get(0).asList();
DataSetIterator iter = new ListDataSetIterator<>(lister, 50);
network.fit(iter, 50);
network.rnnClearPreviousState();
//test network
ArrayList<DataSet> resList = new ArrayList<>();
DataSet result = new DataSet();
INDArray arr = Nd4j.zeros(lister.size()+1);
INDArray holder;
if (list.size() > 1){
//test on training data
System.err.println("oops");
}else{
//test on original or scaled data
for (int i = 0; i < lister.size(); i++) {
holder = network.rnnTimeStep(lister.get(i).getFeatures());
arr.putScalar(i,holder.getFloat(0));
}
}
//add originaldata
resList.add(dsMain);
//result
result.setFeatures(dsMain.getFeatures());
result.setLabels(arr);
resList.add(result);
//display
DisplayData.plot2DScatterGraph(resList);
你能解释一下我需要 1 分 10 隐藏和 1 出 lstm 网络来逼近正弦函数的代码吗?
我没有使用任何归一化作为函数已经是 -1:1 并且我使用 Y 输入作为特征,然后使用以下 Y 输入作为标签来训练网络。
您注意到我正在构建一个可以更轻松地构建网络的类,并且我尝试对问题进行许多更改,但我厌倦了猜测。
以下是我的结果的一些示例。蓝色是数据 红色是结果
解决方案
This is one of those times were you go from wondering why was this not working to how in the hell were my original results were as good as they were.
My failing was not understanding the documentation clearly and also not understanding BPTT.
With feed forward networks each iteration is stored as a row and each input as a column. An example is [dataset.size, network inputs.size]
However with recurrent input its reversed with each row being a an input and each column an iteration in time necessary to activate the state of the lstm chain of events. At minimum my input needed to be [0, networkinputs.size, dataset.size] But could also be [dataset.size, networkinputs.size, statelength.size]
在我之前的示例中,我使用这种格式的数据训练网络 [dataset.size, networkinputs.size, 1]。因此,根据我对低分辨率的理解,lstm 网络根本不应该工作,但至少以某种方式产生了一些东西。
将数据集转换为列表也可能存在一些问题,因为我也更改了为网络提供数据的方式,但我认为问题的大部分是数据结构问题。
推荐阅读
- spring - 使用嵌入式 H2 数据库进行测试 - 导入数据正常,运行测试时,它会再次尝试初始化数据,然后找不到表
- react-native - 单击列表项时如何导航到另一个活动
- powershell - Powershell脚本未创建站点文件夹
- c# - ASP.NET Core Linq
- html - 电子邮件模板 HTML 从 ul 中删除项目符号
- javascript - OverwriteModelError:编译后无法覆盖“团队”模型
- javascript - 在我的情况下,通过 JavaScript 从 EntityFramework 数据库获取数据的最佳和最有效的方法是什么?
- vba - Add printer to system using Powershell within VBA?
- python - 如果它们为 None,则调用不带可选参数的函数
- excel - 计算行并进行最大排列