python - 使用预测数据对 MultiIndex DataFrame 进行 LSTM/RNN 预处理
问题描述
RNN 和 LSTM 需要为每个特征数据点定义序列。
预测数据(例如天气预报)的特征在于具有计算时间戳和预测时间戳(此处dt_calc
和dt_fore
)。此类数据可能会产生如下数据框:
data = pd.DataFrame([[2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [9, 8], [8, 9], [5, 4], [3, 3]],
index=pd.MultiIndex.from_tuples([
(pd.Timestamp('2019-07-02 00:00:00'), pd.Timestamp('2019-07-02 00:00:00'), 0),
(pd.Timestamp('2019-07-02 00:00:00'), pd.Timestamp('2019-07-02 01:00:00'), 0),
(pd.Timestamp('2019-07-02 00:00:00'), pd.Timestamp('2019-07-02 02:00:00'), 0),
(pd.Timestamp('2019-07-02 00:00:00'), pd.Timestamp('2019-07-02 03:00:00'), 0),
(pd.Timestamp('2019-07-02 00:00:00'), pd.Timestamp('2019-07-02 04:00:00'), 0),
(pd.Timestamp('2019-07-04 00:00:00'), pd.Timestamp('2019-07-04 00:00:00'), 0),
(pd.Timestamp('2019-07-04 00:00:00'), pd.Timestamp('2019-07-04 01:00:00'), 0),
(pd.Timestamp('2019-07-04 00:00:00'), pd.Timestamp('2019-07-04 02:00:00'), 0),
(pd.Timestamp('2019-07-04 00:00:00'), pd.Timestamp('2019-07-04 03:00:00'), 0),
(pd.Timestamp('2019-07-04 00:00:00'), pd.Timestamp('2019-07-04 04:00:00'), 0)
],
names=['dt_calc', 'dt_fore', 'positional_index']), columns=['temp', 'temp_2'])
对于 2 的序列长度,在 LSTM 或 RNN 中使用的数据集应如下所示:
data = pd.DataFrame([[[2, 4], [3, 5]], [[4, 6], [5, 7]], [[6, 8], [7, 9]], [[8, 10], [9, 11]], [[12, 9], [13, 8]], [[9, 8], [8, 9]], [[8, 5], [9, 4]], [[5, 3], [4, 3]]],
index=pd.MultiIndex.from_tuples([
(pd.Timestamp('2019-07-02 00:00:00'), pd.Timestamp('2019-07-02 01:00:00'), 0),
(pd.Timestamp('2019-07-02 00:00:00'), pd.Timestamp('2019-07-02 02:00:00'), 0),
(pd.Timestamp('2019-07-02 00:00:00'), pd.Timestamp('2019-07-02 03:00:00'), 0),
(pd.Timestamp('2019-07-02 00:00:00'), pd.Timestamp('2019-07-02 04:00:00'), 0),
(pd.Timestamp('2019-07-04 00:00:00'), pd.Timestamp('2019-07-04 01:00:00'), 0),
(pd.Timestamp('2019-07-04 00:00:00'), pd.Timestamp('2019-07-04 02:00:00'), 0),
(pd.Timestamp('2019-07-04 00:00:00'), pd.Timestamp('2019-07-04 03:00:00'), 0),
(pd.Timestamp('2019-07-04 00:00:00'), pd.Timestamp('2019-07-04 04:00:00'), 0)
],
names=['dt_calc', 'dt_fore', 'positional_index']), columns=['temp', 'temp_2'])
而这里的序列长度为 3:
data = pd.DataFrame([[[2, 4, 6], [3, 5, 7]], [[4, 6, 8], [5, 7, 9]], [[6, 8, 10], [7, 9, 11]], [[12, 9, 8], [13, 8, 9]], [[9, 8, 5], [8, 9, 4]], [[8, 5, 3], [9, 4, 3]]],
index=pd.MultiIndex.from_tuples([
(pd.Timestamp('2019-07-02 00:00:00'), pd.Timestamp('2019-07-02 02:00:00'), 0),
(pd.Timestamp('2019-07-02 00:00:00'), pd.Timestamp('2019-07-02 03:00:00'), 0),
(pd.Timestamp('2019-07-02 00:00:00'), pd.Timestamp('2019-07-02 04:00:00'), 0),
(pd.Timestamp('2019-07-04 00:00:00'), pd.Timestamp('2019-07-04 02:00:00'), 0),
(pd.Timestamp('2019-07-04 00:00:00'), pd.Timestamp('2019-07-04 03:00:00'), 0),
(pd.Timestamp('2019-07-04 00:00:00'), pd.Timestamp('2019-07-04 04:00:00'), 0)
],
names=['dt_calc', 'dt_fore', 'positional_index']), columns=['temp', 'temp_2'])
这个数据框可以很容易地转换为带有序列的 numpy 数组。
这个问题的重要性是要注意时间戳,因为在这种情况下,序列是由时间段定义的,而不是由索引定义的。
编辑:经过 Shubham Sharma 的一个很好的建议:我将概述另一个例子来阐明考虑时间戳的重要性。因为在 dt_fore 中的间隔不规则的情况下,它会出现以下输入:
data = pd.DataFrame([[2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [9, 8], [8, 9], [5, 4], [3, 3]],
index=pd.MultiIndex.from_tuples([
(pd.Timestamp('2019-07-02 00:00:00'), pd.Timestamp('2019-07-02 00:00:00'), 0),
(pd.Timestamp('2019-07-02 00:00:00'), pd.Timestamp('2019-07-02 01:00:00'), 0),
(pd.Timestamp('2019-07-02 00:00:00'), pd.Timestamp('2019-07-02 02:00:00'), 0),
(pd.Timestamp('2019-07-02 00:00:00'), pd.Timestamp('2019-07-02 03:00:00'), 0),
(pd.Timestamp('2019-07-02 00:00:00'), pd.Timestamp('2019-07-02 04:00:00'), 0),
(pd.Timestamp('2019-07-04 00:00:00'), pd.Timestamp('2019-07-04 00:00:00'), 0),
(pd.Timestamp('2019-07-04 00:00:00'), pd.Timestamp('2019-07-04 01:00:00'), 0),
(pd.Timestamp('2019-07-04 00:00:00'), pd.Timestamp('2019-07-04 02:00:00'), 0),
(pd.Timestamp('2019-07-04 00:00:00'), pd.Timestamp('2019-07-04 04:00:00'), 0),
(pd.Timestamp('2019-07-04 00:00:00'), pd.Timestamp('2019-07-04 05:00:00'), 0)
],
names=['dt_calc', 'dt_fore', 'positional_index']), columns=['temp', 'temp_2'])
这应该针对 n=2 的 LSTM/RNN 使用进行重组:
data = pd.DataFrame([[[2, 4], [3, 5]], [[4, 6], [5, 7]], [[6, 8], [7, 9]], [[8, 10], [9, 11]], [[12, 9], [13, 8]], [[9, 8], [8, 9]],[[5, 3], [4, 3]]],
index=pd.MultiIndex.from_tuples([
(pd.Timestamp('2019-07-02 00:00:00'), pd.Timestamp('2019-07-02 01:00:00'), 0),
(pd.Timestamp('2019-07-02 00:00:00'), pd.Timestamp('2019-07-02 02:00:00'), 0),
(pd.Timestamp('2019-07-02 00:00:00'), pd.Timestamp('2019-07-02 03:00:00'), 0),
(pd.Timestamp('2019-07-02 00:00:00'), pd.Timestamp('2019-07-02 04:00:00'), 0),
(pd.Timestamp('2019-07-04 00:00:00'), pd.Timestamp('2019-07-04 01:00:00'), 0),
(pd.Timestamp('2019-07-04 00:00:00'), pd.Timestamp('2019-07-04 02:00:00'), 0),
(pd.Timestamp('2019-07-04 00:00:00'), pd.Timestamp('2019-07-04 05:00:00'), 0)
],
names=['dt_calc', 'dt_fore', 'positional_index']), columns=['temp', 'temp_2'])
解决方案
我们可以定义一个生成器函数,它按列对数据帧进行分组,dt_calc
并使用大小窗口的滚动操作n
来聚合列以列出,从而产生序列。
def seq(n):
df = data.reset_index()
for g in df.groupby('dt_calc', sort=False).rolling(n):
yield g[data.columns].to_numpy().T if len(g) == n else []
pd.DataFrame(seq(2), index=data.index, columns=data.columns).dropna()
# n=2
temp temp_2
dt_calc dt_fore positional_index
2019-07-02 2019-07-02 01:00:00 0 [2, 4] [3, 5]
2019-07-02 02:00:00 0 [4, 6] [5, 7]
2019-07-02 03:00:00 0 [6, 8] [7, 9]
2019-07-02 04:00:00 0 [8, 10] [9, 11]
2019-07-04 2019-07-04 01:00:00 0 [12, 9] [13, 8]
2019-07-04 02:00:00 0 [9, 8] [8, 9]
2019-07-04 03:00:00 0 [8, 5] [9, 4]
2019-07-04 04:00:00 0 [5, 3] [4, 3]
# n=3
temp temp_2
dt_calc dt_fore positional_index
2019-07-02 2019-07-02 02:00:00 0 [2, 4, 6] [3, 5, 7]
2019-07-02 03:00:00 0 [4, 6, 8] [5, 7, 9]
2019-07-02 04:00:00 0 [6, 8, 10] [7, 9, 11]
2019-07-04 2019-07-04 02:00:00 0 [12, 9, 8] [13, 8, 9]
2019-07-04 03:00:00 0 [9, 8, 5] [8, 9, 4]
2019-07-04 04:00:00 0 [8, 5, 3] [9, 4, 3]
推荐阅读
- python - 我正在尝试将 .apply 应用于 pandas 中的列,但它抛出 TypeError: 'float' object is not subscriptable
- python - 当试图可视化由 numpy.meshgrid() 生成的“常规”网格点时,我看到了白色的垂直线。为什么会这样?
- python - 我需要选择一个 div 标签,它有一个带有 beautifulsoup 的特定子标签
- php - 检查字符串是否没有 6 位数字
- smarty - Smarty 中的 Foreach
- reactjs - 如何在酶和玩笑中使用模拟功能模拟点击事件?
- wordpress - Wordpress 在删除帖子后引用了错误的链接
- react-native - 尝试获取导航参考时反应本机博览会崩溃
- python - 如何全局使用python中的对象?
- javascript - 如何使组件仅对少数响应的用户可见/可用?