deep-learning - 如何使用带有 DeepLearning4j 的二维数组的 LSTM
问题描述
我正在尝试学习如何将 LSTM 与 deeplearning4j 库一起使用。
我创建了一个虚拟场景,我想根据我收集的数据获得输出(3 个类)。
如果有人好奇,我从这里(http://www.osservatoriodioropa.it/meteoropa/NOAAMO.TXT )得到数据:)
回到场景。我创建了 2 个矩阵,一个带有特征,另一个带有我想要输出的类,就像测试一样。
当我尝试分类器时,我得到了
Exception in thread "main" java.lang.IllegalStateException: 3D input expected to RNN layer expected, got 2
我认为是因为 RnnOutputLayer 需要一个 3d 矩阵,但我无法理解如何填充它。如何将 2d 矩阵转换为将前一个事件与新事件相关联的 3d 矩阵?数据是一个时间序列,我也想根据前几天来关联新一天的分类。(我知道数据可能不适合这种情况,并且有更好的方法来做到这一点,但这只是学习如何使用 LSTM,而不是如何分类这个特定的数据集)
这是到目前为止的代码
public class Test {
public static void main(String args[]) {
int events = 5;
int features = 6;
int classes = 3;
double[][] featureMatrix = new double[events][features];
double[][] labelMatrix = new double[events][classes];
for (int i = 0; i < events; i++) {
for (int f = 0; f < features; f++) {
featureMatrix[i][f] = getFeature(i, f);
}
for (int c = 0; c < classes; c++) {
labelMatrix[i][c] = getResult(i, c);
}
}
INDArray trainingIn = Nd4j.create(featureMatrix);
INDArray trainingOut = Nd4j.create(labelMatrix);
DataSet myData = new DataSet(trainingIn, trainingOut);
MultiLayerNetwork multiLayerNetwork = createModel(features,classes);
multiLayerNetwork.init();
multiLayerNetwork.fit(myData);
}
private static double getFeature(int i, int f) {
//dummy
return 1.;
}
private static double getResult(int i, int c) {
//dummy
return 1.;
}
public static MultiLayerNetwork createModel(int inputNum, int outputNum) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.trainingWorkspaceMode(ENABLED).inferenceWorkspaceMode(ENABLED)
.seed(123456)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new RmsProp.Builder().learningRate(0.05).rmsDecay(0.002).build())
.l2(0.0005)
.weightInit(WeightInit.XAVIER)
.activation(Activation.TANH)
.list()
.layer(new LSTM.Builder().name("1").nIn(inputNum).nOut(inputNum).build())
.layer(new LSTM.Builder().name("2").nIn(inputNum).nOut(inputNum).build())
.layer(new RnnOutputLayer.Builder().name("output").nIn(inputNum).nOut(outputNum)
.activation(Activation.IDENTITY).lossFunction(LossFunctions.LossFunction.MSE).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
return net;
}
}
解决方案
推荐阅读
- python - Python的if条件表达式中的大django查询集
- python - 如何使用 Python 和 re 从字符串中提取确切的单词?
- react-select - react-select(AsyncSelewhen 打字)不删除
- sas - 防止 SAS 自动删除字符串中的尾随空格
- python - Python - 替换文本中的缩写
- get - 通过 GM_xmlhttpRequest 访问另一个网站上的窗口属性
- python-3.x - 使用 google flow 2.0 时无法将用户重定向到 auth_uri
- puppeteer - 如何将 html 元素添加到当前页面?木偶/卡罗
- reactjs - 如何根据另一个 API 结果调用 API?
- python - 没有名为“apscheduler”的模块