python - 如何细分/优化 xarray 数据集中的维度?
问题描述
摘要:我有一个以最初不可用维度的方式收集的数据集。我想将本质上是一大块未区分的数据并为其添加维度,以便对其进行查询、子集化等。这是以下问题的核心。
这是我拥有的一个 xarray 数据集:
<xarray.Dataset>
Dimensions: (chain: 1, draw: 2000, rows: 24000)
Coordinates:
* chain (chain) int64 0
* draw (draw) int64 0 1 2 3 4 5 6 7 ... 1993 1994 1995 1996 1997 1998 1999
* rows (rows) int64 0 1 2 3 4 5 6 ... 23994 23995 23996 23997 23998 23999
Data variables:
obs (chain, draw, rows) float64 4.304 3.985 4.612 ... 6.343 5.538 6.475
Attributes:
created_at: 2019-12-27T17:16:13.847972
inference_library: pymc3
inference_library_version: 3.8
这里的rows
维度对应于我需要恢复到数据的多个子维度。特别是,这 24,000 行对应于来自 240 个条件的 100 个样本(这 100 个样本位于连续块中)。这些条件是gate
、input
、growth medium
和的组合od
。
我想结束这样的事情:
<xarray.Dataset>
Dimensions: (chain: 1, draw: 2000, gate: 1, input: 4, growth_medium: 3, sample: 100, rows: 24000)
Coordinates:
* chain (chain) int64 0
* draw (draw) int64 0 1 2 3 4 5 6 7 ... 1993 1994 1995 1996 1997 1998 1999
* rows *MultiIndex*
* gate (gate) int64 'AND'
* input (input) int64 '00', '01', '10', '11'
* growth_medium (growth_medium) 'standard', 'rich', 'slow'
* sample (sample) int64 0 1 2 3 4 5 6 7 ... 95 96 97 98 99
Data variables:
obs (chain, draw, gate, input, growth_medium, samples) float64 4.304 3.985 4.612 ... 6.343 5.538 6.475
Attributes:
created_at: 2019-12-27T17:16:13.847972
inference_library: pymc3
inference_library_version: 3.8
我有一个 pandas 数据框,它指定门、输入和生长介质的值——每一行给出一组门、输入和生长介质的值,以及一个索引,它指定rows
对应的 100 组的位置(在 中)样品出现。目的是该数据框是标记数据集的指南。
我查看了有关“重塑和重组数据”的 xarray 文档,但我不知道如何组合这些操作来完成我需要的操作。我怀疑我需要以某种方式将这些与 结合起来GroupBy
,但我不明白如何。谢谢!
后来:我有一个解决这个问题的方法,但是太恶心了,我希望有人能解释我有多错,还有什么更优雅的方法是可能的。
因此,首先,我将原始数据中的所有数据提取Dataset
为原始 numpy 形式:
foo = qm.idata.posterior_predictive['obs'].squeeze('chain').values.T
foo.shape # (24000, 2000)
然后我根据需要重新塑造它:
bar = np.reshape(foo, (240, 100, 2000))
这给了我大致想要的形状:有 240 种不同的实验条件,每个都有 100 个变体,对于这些变体中的每一个,我的数据集中都有 2000 个 Monte Carlo 样本。
现在,我从 Pandas 中提取了 240 个实验条件的信息DataFrame
:
import pandas as pd
# qdf is the original dataframe with the experimental conditions and some
# extraneous information in other columns
new_df = qdf[['gate', 'input', 'output', 'media', 'od_lb', 'od_ub', 'temperature']]
idx = pd.MultiIndex.from_frame(new_df)
最后,我DataArray
从 numpy 数组和 pandas重新组装了 a MultiIndex
:
xr.DataArray(bar, name='obs', dims=['regions', 'conditions', 'draws'],
coords={'regions': idx, 'conditions': range(100), 'draws': range(2000)})
结果DataArray
有这些坐标,正如我所希望的:
Coordinates:
* regions (regions) MultiIndex
- gate (regions) object 'AND' 'AND' 'AND' 'AND' ... 'AND' 'AND' 'AND'
- input (regions) object '00' '10' '10' '10' ... '01' '01' '11' '11'
- output (regions) object '0' '0' '0' '0' '0' ... '0' '0' '0' '1' '1'
- media (regions) object 'standard_media' ... 'high_osm_media_five_percent'
- od_lb (regions) float64 0.0 0.001 0.001 ... 0.0001 0.0051 0.0051
- od_ub (regions) float64 0.0001 0.0051 0.0051 2.0 ... 0.0003 2.0 2.0
- temperature (regions) int64 30 30 37 30 37 30 37 ... 37 30 37 30 37 30 37
* conditions (conditions) int64 0 1 2 3 4 5 6 7 ... 92 93 94 95 96 97 98 99
* draws (draws) int64 0 1 2 3 4 5 6 ... 1994 1995 1996 1997 1998 1999
不过,这太可怕了,而且我必须穿透所有漂亮的xarray
抽象层才能达到这一点,这似乎是错误的。尤其是因为这似乎不是科学工作流程的一个不寻常部分:获取相对原始的数据集以及需要与数据组合的元数据电子表格。那么我做错了什么?更优雅的解决方案是什么?
解决方案
给定起始数据集,类似于:
<xarray.Dataset>
Dimensions: (draw: 2, row: 24)
Coordinates:
* draw (draw) int32 0 1
* row (row) int32 0 1 2 3 4 5 6 7 8 9 ... 14 15 16 17 18 19 20 21 22 23
Data variables:
obs (draw, row) int32 0 1 2 3 4 5 6 7 8 ... 39 40 41 42 43 44 45 46 47
您可以连接几个纯 xarray 命令来细分维度(以相同的形状获取数据,但使用多索引),甚至重塑数据集。要细分维度,可以使用以下代码:
multiindex_ds = ds.assign_coords(
dim_0=["a", "b", "c"], dim_1=[0,1], dim_2=range(4)
).stack(
dim=("dim_0", "dim_1", "dim_2")
).reset_index(
"row", drop=True
).rename(
row="dim"
)
multiindex_ds
其输出是:
<xarray.Dataset>
Dimensions: (dim: 24, draw: 2)
Coordinates:
* draw (draw) int32 0 1
* dim (dim) MultiIndex
- dim_0 (dim) object 'a' 'a' 'a' 'a' 'a' 'a' ... 'c' 'c' 'c' 'c' 'c' 'c'
- dim_1 (dim) int64 0 0 0 0 1 1 1 1 0 0 0 0 1 1 1 1 0 0 0 0 1 1 1 1
- dim_2 (dim) int64 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3
Data variables:
obs (draw, dim) int32 0 1 2 3 4 5 6 7 8 ... 39 40 41 42 43 44 45 46 47
此外,多索引可以被取消堆叠,有效地重塑数据集:
reshaped_ds = multiindex_ds.unstack("dim")
reshaped_ds
输出:
<xarray.Dataset>
Dimensions: (dim_0: 3, dim_1: 2, dim_2: 4, draw: 2)
Coordinates:
* draw (draw) int32 0 1
* dim_0 (dim_0) object 'a' 'b' 'c'
* dim_1 (dim_1) int64 0 1
* dim_2 (dim_2) int64 0 1 2 3
Data variables:
obs (draw, dim_0, dim_1, dim_2) int32 0 1 2 3 4 5 ... 42 43 44 45 46 47
我认为仅此一项并不能完全满足您的需求,因为您想将一个维度转换为二维,其中一个是多索引。不过,所有的构建块都在这里。
例如,您可以按照此步骤(包括取消堆叠)regions
和conditions
然后按照此步骤(现在不取消堆叠)转换regions
为多索引。另一种选择是从一开始就使用所有维度,将它们取消堆叠,然后再次堆叠它们,将它们留conditions
在最终的多索引之外。
详细解答
答案结合了几个完全不相关的命令,要查看每个命令在做什么可能会很棘手。
assign_coords
第一步是创建新的维度和坐标并将它们添加到数据集中。这是必要的,因为下一个方法需要维度和坐标已经存在于数据集中。
assign_coords
在产生以下数据集后立即停止:
<xarray.Dataset>
Dimensions: (dim_0: 3, dim_1: 2, dim_2: 4, draw: 2, row: 24)
Coordinates:
* draw (draw) int32 0 1
* row (row) int32 0 1 2 3 4 5 6 7 8 9 ... 14 15 16 17 18 19 20 21 22 23
* dim_0 (dim_0) <U1 'a' 'b' 'c'
* dim_1 (dim_1) int32 0 1
* dim_2 (dim_2) int32 0 1 2 3
Data variables:
obs (draw, row) int32 0 1 2 3 4 5 6 7 8 ... 39 40 41 42 43 44 45 46 47
stack
数据集现在包含 3 个维度,最多可包含 24 个元素,但是,由于当前数据相对于这 24 个元素是平坦的,我们必须将它们堆叠成一个 24 个元素的多索引以使其形状兼容。
我发现assign_coords
后面stack
是最自然的解决方案,但是,另一种可能性是生成类似于上面的方式的多索引并直接使用多索引调用assign_coords
,从而使堆栈变得不必要。
此步骤将所有 3 个新维度合并为一个维度:
<xarray.Dataset>
Dimensions: (dim: 24, draw: 2, row: 24)
Coordinates:
* draw (draw) int32 0 1
* row (row) int32 0 1 2 3 4 5 6 7 8 9 ... 14 15 16 17 18 19 20 21 22 23
* dim (dim) MultiIndex
- dim_0 (dim) object 'a' 'a' 'a' 'a' 'a' 'a' ... 'c' 'c' 'c' 'c' 'c' 'c'
- dim_1 (dim) int64 0 0 0 0 1 1 1 1 0 0 0 0 1 1 1 1 0 0 0 0 1 1 1 1
- dim_2 (dim) int64 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3
Data variables:
obs (draw, row) int32 0 1 2 3 4 5 6 7 8 ... 39 40 41 42 43 44 45 46 47
请注意,现在我们根据需要有 2 个尺寸为 24 的尺寸。
reset_index
现在我们在数据集中有我们的最终维度作为坐标,我们希望这个新坐标是用于索引变量的坐标obs
。set_index
似乎是正确的选择,但是,我们的每个坐标索引本身(与set_index
文档中的示例不同,x
索引同时索引x
和a
坐标),这意味着set_index
在这种特殊情况下不能使用。使用的方法是reset_index
删除坐标row
而不删除维度row
。
在下面的输出中,可以看出现在row
是一个没有坐标的维度:
<xarray.Dataset>
Dimensions: (dim: 24, draw: 2, row: 24)
Coordinates:
* draw (draw) int32 0 1
* dim (dim) MultiIndex
- dim_0 (dim) object 'a' 'a' 'a' 'a' 'a' 'a' ... 'c' 'c' 'c' 'c' 'c' 'c'
- dim_1 (dim) int64 0 0 0 0 1 1 1 1 0 0 0 0 1 1 1 1 0 0 0 0 1 1 1 1
- dim_2 (dim) int64 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3
Dimensions without coordinates: row
Data variables:
obs (draw, row) int32 0 1 2 3 4 5 6 7 8 ... 39 40 41 42 43 44 45 46 47
rename
当前的数据集几乎是最后一个,唯一的问题是obs
变量仍然具有row
维度而不是所需的维度:dim
. 它看起来不像是预期的用途,rename
但它可以用于dim
吸收, row
产生所需的最终结果(multiindex_ds
如上所述)。
再次,set_index
似乎是选择的方法,但是,如果使用 , 而不是rename(row="dim")
,set_index(row="dim")
则多索引将折叠为由元组组成的索引:
<xarray.Dataset>
Dimensions: (draw: 2, row: 24)
Coordinates:
* draw (draw) int32 0 1
* row (row) object ('a', 0, 0) ('a', 0, 1) ... ('c', 1, 2) ('c', 1, 3)
Data variables:
obs (draw, row) int32 0 1 2 3 4 5 6 7 8 ... 39 40 41 42 43 44 45 46 47
推荐阅读
- java - JSch ssh响应中的转义码/随机字符
- python - discord bot 在生产中离线而不是在本地
- javascript - Vue.JS - 组合变量
- python - 为 pandas loc 构建查询(布尔)函数
- c++ - 在 Microsoft MFP(媒体基础平台)上绘图
- php - PHP HTTP 426 - file_get_contents 与 curl
- linux - 无法在 shell 命令中打印临时变量
- php - 将数据批量导入 mysql 数据库导致 504 网关超时。(PHP 和 CI)
- php - 通过 PHP 检查 client_ip 时,我仍然通过 Tor 收到我的 IP 地址
- android - java.lang.IllegalArgumentException:服务未注册