首页 > 解决方案 > 如何在 PostgreSQL 中使用数组和循环加速自定义窗口函数?

问题描述

我目前正在学习 UDF,并在下面编写了 PostgreSQL UDF 来计算平均偏差 (MAD)。它是任何窗口的平均值和当前值之间的平均绝对差。在 python pandas/numpy 中,要找到 MAD,我们可以这样写:

series_mad = abs(series - series.mean()).mean()

其中 series 是一组数字,而 series_mad 是表示该系列的 MAD 的单个数值。

我正在尝试使用 Windows 和 UDF 在 PostgreSQL 中编写它。到目前为止,这就是我所拥有的:

CREATE TYPE misc_tuple AS (
    arr_store numeric[],
    ma_period integer
);

CREATE OR REPLACE FUNCTION mad_func(prev misc_tuple, curr numeric, ma_period integer)
    RETURNS misc_tuple AS $$
    BEGIN
        IF curr is null THEN
            RETURN (null::numeric[], -1);
        ELSEIF prev.arr_store is null THEN
            RETURN (ARRAY[curr]::numeric[], ma_period);
        ELSE
            -- accumulate new values in array
            prev.arr_store := array_append(prev.arr_store, curr);
            RETURN prev;
        END IF;
    END;
    $$ LANGUAGE plpgsql;

CREATE OR REPLACE FUNCTION mad_final(prev misc_tuple)
    RETURNS numeric AS $$
    DECLARE
        total_len integer;
        count numeric;
        mad_val numeric;
        mean_val numeric;
    BEGIN
        count := 0;
        mad_val := 0;
        mean_val := 0;
        total_len := array_length(prev.arr_store, 1);
        -- first loop to find the mean of the series
        FOR i IN greatest(1,total_len-prev.ma_period+1)..total_len
        LOOP 
            mean_val := mean_val + prev.arr_store[i];
            count := count + 1;
        END LOOP;
        mean_val := mean_val/NULLIF(count,0);
        -- second loop to subtract mean from each value 
        FOR i IN greatest(1,total_len-prev.ma_period+1)..total_len
        LOOP 
            mad_val := mad_val + abs(prev.arr_store[i]-mean_val);
        END LOOP;
        RETURN mad_val/NULLIF(count, 0);
    END;
    $$ LANGUAGE plpgsql;

CREATE OR REPLACE AGGREGATE mad(numeric, integer) (
    SFUNC = mad_func,
    STYPE = misc_tuple,
    FINALFUNC = mad_final
);

这就是我测试性能的方式:

-- find rolling 12-period MAD
SELECT x,
       mad(x, 12) OVER (ROWS 12-1 PRECEDING)
FROM generate_series(0,1000000) as g(x);

目前,在我的桌面(i5 4670、3.4 GHz、16 GB RAM)上大约需要 45-50 秒。我仍在学习 UDF,所以我不确定我还能对我的函数做些什么来让它更快。我还有其他一些类似的 UDF - 但不使用数组的 UDF 在相同的 1m 行上花费 <15 秒。我的猜测可能是我没有有效地循环数组,或者可以对 UDF 中的 2 个循环做一些事情。

我可以对此 UDF 进行任何更改以使其更快吗?

标签: sqlpostgresqluser-defined-functionswindow-functions

解决方案


您的示例代码不起作用,类型定义中有一个额外的逗号,并且您cnt在其中一个函数中使用了未定义的变量。

为什么将 12 指定为聚合本身和 ROWS PRECEDING 的参数?这似乎是多余的。

您与 numpy 的比较似乎不太恰当,因为这不是滑动窗口功能。

我还有一些其他类似的 UDF - 但是那些不使用数组并且它们在相同的 1m 行上花费 <15 秒

它们也用作滑动窗口功能吗?也是用plpgsql写的?你能展示一个和它的用法吗?

pl/pgsql 通常不是一种非常有效的语言,尤其是在处理大型数组时。尽管在您的使用中,数组永远不会变得非常大,所以我认为这不会是一个特别的问题。

一种提高效率的方法是用 C 而不是 pl/pgsql 编写代码,使用 INTERNAL 数据类型而不是 SQL 复合类型。

另一种改进这种特殊用法的方法(大量窗口,每个窗口都很小)可能是为此聚合实现MINVFUNC 函数和朋友,这样它就不必为每一行从头开始重新启动聚合。

这是一个示例反函数,它根本不会改变输出,但确实将运行时间减少了大约一半:

CREATE OR REPLACE FUNCTION mad_invfunc(prev misc_tuple, curr numeric, ma_period integer)
    RETURNS misc_tuple AS $$
    BEGIN
            -- remove prev value
            prev.arr_store := prev.arr_store[2:];
            RETURN prev;
    END;
    $$ LANGUAGE plpgsql;
CREATE OR REPLACE AGGREGATE mad(numeric, integer) (
    SFUNC = mad_func,
    STYPE = misc_tuple,
    FINALFUNC = mad_final,
    MSFUNC = mad_func,
    MSTYPE = misc_tuple,
    MFINALFUNC = mad_final,
    MINVFUNC = mad_invfunc
);

如果我将类型从任何地方更改numericdouble precision任何地方,它们都会再次将运行时间缩短一半。因此,虽然阵列上的循环可能效率不高,但当仅使用 12 个成员窗口时,它们并不是主要瓶颈。


推荐阅读