python - 如何在一个简单的示例中使用 Python pytorch_forecasting 来预测时间序列?
问题描述
我想学习如何在一个简单的例子中使用pytorch_forecasting 。假设我们有一个只有 4 列的时间序列,即t,x(t),y(t),z(t)
. 这样的时间序列代表了粒子的轨迹。更具体地说:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
nt=os.cpu_count(); #number of threads that my CPU offers
def plotS(xx,clr): #scatterplot of several sets of points with different colors
xx=[( x if isinstance(x,np.ndarray) else
x.to_numpy() if isinstance(x,pd.DataFrame) else
x.numpy() if torch.is_tensor(x)
else 'Error!').astype('float32') for x in xx]
fig=plt.figure(figsize=(15,10));
ax=plt.axes(projection='3d');
for i in range(len(xx)):
ax.scatter3D(xx[i][:,0],xx[i][:,1],xx[i][:,2],c=clr[i],s=4) #c=color
x=np.concatenate(xx);
ax.set_box_aspect((np.ptp(x[:,0]),np.ptp(x[:,1]),np.ptp(x[:,2]))) #aspect ratio 1:1:1
plt.tight_layout(); plt.show()
X=pd.DataFrame({'t':np.linspace(0,12*np.pi,10**4)}); #creating our dataset
#X['x(t)'], X['y(t)'], X['z(t)'] = np.cos(X['t'])*2*X['t'], np.sin(X['t'])*X['t'], 3*X['t']; #easy example
X['x(t)'], X['y(t)'], X['z(t)'] = np.sin(X['t']), np.sin(2*X['t']), X['t']/10 #hard example
X0,X1 = X.iloc[:int(0.7*len(X)),:], X.iloc[int(0.7*len(X)):,:] #split data into train and test part
plotS([X0.iloc[:,1:],X1.iloc[:,1:]],['b','g']) #visualize both parts
从前 70% 的点(蓝色部分),我希望预测最后 30%(绿色部分)。我对此的笨拙尝试pytorch_forecasting
如下:
import torch, torchvision, pytorch_lightning, pytorch_forecasting
from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
X['t']=np.arange(len(X)); #set the time to be consecutive integers
X['groups']=0; #add a column, which marks all rows as coming from the same time-series
#X=X.drop([100,101,102,2000,2005,5123])
k,l=100,1; tt=(X['t'].max()-X['t'].min());
y0= X.to_numpy()[X['t']<0.7*tt,1:4];
y1= X.to_numpy()[X['t']>0.7*tt,1:4];
X0= TimeSeriesDataSet(data=X[X['t']<0.7*tt],time_idx='t',target=['x(t)','y(t)','z(t)'],
group_ids=['groups'], max_encoder_length=k, max_prediction_length=l, allow_missing_timesteps=True,
time_varying_known_reals=['t'],time_varying_unknown_reals=['x(t)','y(t)','z(t)']);
X1=X[X['t']>0.7*tt]; #X1.iloc[:,1:]=0;
X1= TimeSeriesDataSet.from_dataset(X0,X1,predict=False,stop_randomization=True);
XX0=X0.to_dataloader(train=True, batch_size=64,num_workers=nt);
XX1=X1.to_dataloader(train=False,batch_size=64,num_workers=nt);
ann= TemporalFusionTransformer.from_dataset(X0,learning_rate=0.001,hidden_size=32,lstm_layers=4);
trn=pytorch_lightning.Trainer(max_epochs=1,gpus=0);
trn.fit(ann,XX0,XX0);
y2=torch.cat(ann.predict(XX1),dim=1).numpy(); #cheating here: we use training values to predict
print(y1.shape,y2.shape);
plotS([y0,y1,y2[:,:3]],['b','g','r']);
结果相当不错:
然而,计算是作弊的,因为ann.predict(XX1)
它使用测试值来预测测试值。如何正确使用时间融合转换器,即仅从时间列预测测试值?
其次,我作弊是因为我将所有时间都转换为连续整数。如果我注释掉该行X['t']=np.arange(len(X))
,我会得到错误。我如何预测时间序列不是整数。
第三,如果我将代码更改为k,l=100,10;
,我会得到糟糕的结果:红色的预测变成了一条线。为什么?如何正确预测 l>1?
注意:我只是在学习 pytorch_forecasting,所以我很感激有关它的功能的一些提示。我确实阅读了文档,但没有简单的示例可以帮助我。
解决方案
推荐阅读
- excel - 使用与我看到的略有不同的 VBA 范围
- python - 如何在 5 秒后停止此循环?
- c# - 检查用户名是否已经在数据库中
- javascript - lit-html:连接字符串以使用 html``
- python - 如何从 javadoc 注释中删除 { 和 } 之间的 @link 标记及其内容?
- highcharts - Highchart 地图显示了一些已选择的状态,并具有选择其他状态的可行性
- javascript - 如何使用 fetch 发送数组?(Javascript)
- c# - 解释一些关于动态连接字符串的代码
- android - 在 genymotion 设备中运行时膨胀类 androidx.constraintlayout.widget.ConstraintLayout 时出错
- html - 如何摆脱自动填充html?