cntk - CNTK LSTM 输入形状
问题描述
我们正在使用 CNTK 的 C# API 构建 LSTM 网络,但根据 CNTK 文档的当前级别,很难确定输入的正确形状/尺寸。
我们有一个时间序列,在每个时间 t 都有一个值(一个数字),我们希望使用时间序列的前 744 个值的序列来使用 LSTM 进行预测。此外,我们是否想要制作一个包含 25 个序列的 minibatch,CNTK.InputVariable 的形状应该如下所示:
[0] 744
[1] 1
[2] 25
或者
[0] 1
[1] 744
[2] 25
…然后,如果我们不是在每个时间 t 有一个值,而是有两个值,那么 CNTK.InputVariable 的形状会是什么样子?
解决方案
如果你使用循环网络(LSTM、GRU),那么你需要知道什么是静态轴和动态轴。静态轴用于描述输入数据形式(在第一种情况下,它是秩为 1 且大小为 1: 的向量new int {1}
)。动态轴用于指定输入数据(在您的情况下)的序列(在您的情况下为可变长度 744 new int {1}
)。要指示动态轴应用于序列,请在输入参数dynamicAxes 中指定: new[] { Axis.DefaultBatchAxis() }
var inputDimension = 1; //for two values is 2 etc.
var inputShape = new { inputDimension };
var input = Variable.InputVariable(inputShape, DataType.Double, "input", new[] { Axis.DefaultBatchAxis() });
并确保正确准备小批量(创建一个小批量的示例):
var device = DeviceDescriptor.CPUDevice;
var inputDimension = 1;
var outputDimension = 1;
var minibatchSize = 25;
var oneMinibatchFeaturesData = new List<List<double[]>>(minibatchSize)
{
new List<double[]> //first sequence
{
new double[] { 23 },//t=1. Array.Length = inputDimension
new double[] { 25 },//t=2
//...
new double[] { 65 },//t=744
},
new List<double[]> //second seqeunce
{
new double[] { 76 }, //t=1
new double[] { 236 },//t=2
//...
new double[] { 87 }, //t=744
},
//...
new List<double[]> //twenty fifth sequence
{
new double[] { 9 }, //t=1
new double[] { 2 },//t=2
//...
new double[] { 90 }, //t=744
},
};
var oneMinibatchLabelsData = new List<double[]>(minibatchSize)
{
new double[] { 1 },//label of first sequence. Array.Length = outputDimension
new double[] { 5 },//label of second sequence
//...
new double[] { 3 }//label of twenty fifth sequence
};
var features = Value.CreateBatchOfSequences(new[] { inputDimension }, oneMinibatchFeaturesData.Select(sequence => sequence.SelectMany(value => value)), device);
var labels = Value.CreateBatch(new[] { outputDimension }, oneMinibatchLabelsData.SelectMany(value => value), device);
序列的长度可以是任意的。一个小批量可能包含不同长度的序列。
LSTM 很难在这种长度的序列上训练。如果您的序列长度始终为 744,那么您可能应该使用输入维度为 744 的简单 FNN。
推荐阅读
- java - Spring Boot 2 不适用于 Postgres
- sum - 给定数组中大小为 k 的子集的异或的总和
- python - python程序的执行时间
- pug - 如何从 mixins(Jade/Pug)中的二维数组获取数据?
- html - 如何使用 Oculus Go 触控板在 AFrame 中移动
- python - Jupyter Notebook 不会在 Gitlab 上显示图表
- c - 另一个 .c 中的引用队列
- php - PHP预定义函数 - 计数 - 在多维数组中返回一个奇怪的(对我来说)值?
- javascript - 使用 d3.js 将百分比的高度值转换为 svg 中的百分比宽度值
- postgresql - PostgreSQL 9.6 CentOS 7 LibreTime - postgresql.service 的作业失败