首页 > 解决方案 > 并行化沿 xarray 数据集的维度工作的自定义函数

问题描述

我有一个函数aggregate_xr_dataset(),它采用具有 4 个维度的 xarray - regionsyxtime2 个变量 -ts_varvalues_var聚合每个区域内的数据(沿regions维度)。生成的 xarray 数据集具有 3 个维度 - timeregionstypes

下面显示的代码是函数内部执行的更简单的版本,但输入和输出是相同的。

我想并行化这个函数,使它在每个区域上并行工作,并将结果写入一个新的 xarray 数据集。我已经研究过,xr.apply_ufunc()但不能完全弄清楚如何将它用于我的案例。有可能xr.apply_ufunc()吗?

import xarray as xr
import numpy as np

# aggregation function 
def aggregate_xr_dataset(ds, n_types = 2):
    
    #prepare a resultant xr dataset 
    regions = ds['regions'].values
    time_steps = ds['time'].values
    types = [f'type_{i}' for i in range(n_types)] 

    data = np.zeros((len(time_steps), len(regions), n_types))
    
    res_ts_var_da = xr.DataArray(data, [('time', time_steps),
                                            ('regions', regions),
                                            ('types', types)])
    
    data = np.zeros((len(regions), len(types)))

    res_values_var_da = xr.DataArray(data, [('regions', regions),
                                                ('types', types)])

    # Aggregation in every region...
    for region in regions:
        # Get ts and values of current region 
        regional_ds = ds.sel(regions = region)
        
        regional_ts_da = regional_ds['ts_var']
        regional_value_da = regional_ds['values_var']

        # Preprocess - stack and transpose dimensions 
        regional_ts_da = regional_ts_da.stack(x_y = ['x', 'y']) 
        regional_ts_da = regional_ts_da.transpose(transpose_coords= True) 

        regional_value_da = regional_value_da.stack(x_y = ['x', 'y'])
        regional_value_da = regional_value_da.transpose(transpose_coords= True)

        # Aggregation
        ## values 
        res_values_var_da.loc[region, types[0]] = regional_value_da[dict(x_y=slice(None, 3))].sum()
        res_values_var_da.loc[region, types[1]] = regional_value_da[dict(x_y=slice(3, 9))].sum()
        ## ts 
        res_ts_var_da.loc[:, region, types[0]] = regional_ts_da[dict(x_y=slice(None, 3))].sum()
        res_ts_var_da.loc[:, region, types[1]] = regional_ts_da[dict(x_y=slice(3, 9))].sum()

    # Create resulting dataset 
    res_ds = xr.Dataset({"values": res_values_var_da,
                             "ts": res_ts_var_da}) 
    
    return res_ds


# xarray dataset 
regions = ['reg_1','reg_2']

x_locations = np.arange(3)
y_locations = np.arange(3)

time = np.arange(3)

## time series data
ts_var = np.array([ [[[1, 1, 1] for i in range(3)] for i in range(3)],
                            
                    [[[1, 1, 1] for i in range(3)] for i in range(3)]
                   
                  ])

ts_var_da = xr.DataArray(ts_var, 
                        coords=[regions, y_locations, x_locations, time], 
                        dims=['regions', 'y', 'x','time'])

## values data
values_var = np.array([  [[1, 1, 1] for i in range(3)],

                         [[1, 1, 1] for i in range(3)]
                ])
values_var_da = xr.DataArray(values_var, 
                            coords=[regions, y_locations, x_locations], 
                            dims=['regions', 'y', 'x'])

ds = xr.Dataset({'ts_var': ts_var_da,
                'values_var': values_var_da})

# Function call 
aggregate_xr_dataset(ds)

标签: pythonpython-xarray

解决方案


推荐阅读