首页 > 解决方案 > xarray.groupby() 与参差不齐的数组

问题描述

我有一个半大型数据集,其中每个轨迹 [traj] 有不同数量的观察值 [obs]。因为观察的数量从 ~100 到 100000 不等,所以它们被组合成一个参差不齐/锯齿状的数组。如您所见buoy_diameter,一些变量是按轨迹定义idcountlonlattimeuvids

# define data with variable attributes
data_vars = {
             'id': (['traj'], [11,22,33], {'long_name': 'trajectory id'}),
             'count': (['traj'], [2,3,4], {'long_name': 'number of observations per trajectory'}),
             'buoy_diameter': (['traj'], [0.3,0.4,0.5], {'units': 'm/s', 'long_name': 'drogue length'}),
             
             'u': (['obs'], [-1.1,-1.2,           # traj 1
                             -2.1,-2.2,-2.3,-3.1, # traj 2
                             -3.2,-3.3,-3.4],     # traj 3
                          {'units': 'm/s', 'long_name':'meridional velocity'}),
             'v': (['obs'], [1.2,1.3,          # traj 1
                             2.1,2.2,2.3,      # traj 2
                             3.1,3.2,3.3,3.4], # traj 3
                          {'units': 'm/s', 'long_name':'longitidunal velocity'}),
             'ids': (['obs'], [0,0,      # traj 1
                               1,1,1,    # traj 2
                               2,2,2,2], # traj 3
                          {'long_name': 'index of trajectory per observations'}),
             }

# define coordinates
coords = {  
            'lon': (['obs'], [10.,12.,         # traj 1
                              30.,34.,35.,      # traj 2
                              60.,61.,62.,63.]), # traj 3
            'lat': (['obs'], [0.,1.,           # traj 1
                              12.,13.,14.,      # traj 2
                              20.,22.,23.,24.]), # traj 3
            'time': (['obs'], [0.,1.,          # traj 1
                               0.,1.,2.,        # traj 2
                               0.,1.,2.,3.]),    # traj 3
         }

# create dataset
ds = xr.Dataset(data_vars=data_vars,
                coords=coords, 
                attrs=attrs)

轨迹有一个关联id,所以我可以使用groupby(id),但是每个组都包含所有的观察结果,所以它们不能用于map()每个轨迹的函数。

gr = ds.groupby(id)
print(gr[11])
<xarray.Dataset>
Dimensions:        (traj: 1, obs: 9)
Coordinates:
    lon            (obs) float64 10.0 12.0 30.0 34.0 35.0 60.0 61.0 62.0 63.0
    lat            (obs) float64 0.0 1.0 12.0 13.0 14.0 20.0 22.0 23.0 24.0
    time           (obs) float64 0.0 1.0 0.0 1.0 2.0 0.0 1.0 2.0 3.0
Dimensions without coordinates: traj, obs
Data variables:
    id             (traj) int64 11
    count          (traj) float64 2.0
    buoy_diameter  (traj) float64 0.3
    u              (obs) float64 -1.1 -1.2 -2.1 -2.2 -2.3 -3.1 -3.2 -3.3 -3.4
    v              (obs) float64 1.2 1.3 2.1 2.2 2.3 3.1 3.2 3.3 3.4
    ids            (obs) int64 0 0 1 1 1 2 2 2 2

另一方面,我也有一个ids可以使用的每个观察值groupby('ids'),但正如您现在可以想象的那样,它包含所有轨迹。

gr = ds.groupby('ids')
print(traj_gr[0])
<xarray.Dataset>
Dimensions:        (traj: 3, obs: 2)
Coordinates:
    lon            (obs) float64 10.0 12.0
    lat            (obs) float64 0.0 1.0
    time           (obs) float64 0.0 1.0
Dimensions without coordinates: traj, obs
Data variables:
    id             (traj) int64 11 22 33
    count          (traj) float64 2.0 3.0 4.0
    buoy_diameter  (traj) float64 0.3 0.4 0.5
    u              (obs) float64 -1.1 -1.2
    v              (obs) float64 1.2 1.3
    ids            (obs) int64 0 0

我想获得的是DatasetGroupBy每组只包含一个轨迹和对该轨迹的观察。我可以使用额外的手动获取.sel()

for i, gr in ds.groupby('ids'):
    correct_gr = gr.sel(traj=[i])
    print(correct_gr)

第一组作为参考,包含每个轨迹变量的 1 个值和 2 个观察值。(第二组(traj:1,obs:3),第三组(traj:1,obs:4))。

<xarray.Dataset>
Dimensions:        (traj: 1, obs: 2)
Coordinates:
    lon            (obs) float64 10.0 12.0
    lat            (obs) float64 0.0 1.0
    time           (obs) float64 0.0 1.0
Dimensions without coordinates: traj, obs
Data variables:
    id             (traj) int64 11
    count          (traj) float64 2.0
    buoy_diameter  (traj) float64 0.3
    u              (obs) float64 -1.1 -1.2
    v              (obs) float64 1.2 1.3
    ids            (obs) int64 0 0

标签: python-xarray

解决方案


推荐阅读