首页 > 解决方案 > 使用 numpy 多项式或 astropy 多项式时如何实现 asdf 扩展?


我想在 ASDF 文件中存储一些数据并想使用未实现的扩展。我想扩展其他扩展,所以我尝试从Astropy扩展开始。

我知道如何为 ASDF 编写一个有效的扩展。但是,关键问题是 ASDF 文件应该总是看起来像一个由天文多项式创建的 ASDF 文件。创建一个存储 numpy 多项式的新扩展不是我的目的。另一方面,asdf 文件应始终输出一个 numpy 多项式。


import asdf
from astropy.modeling import models, fitting
from numpy.polynomial import Polynomial as P

# these 2 polynomials are equal
poly_np = P([0,0,0])
poly_astropy = models.Polynomial1D(degree=2)

# this is the usual way how to save an astropy polynomial
target = asdf.AsdfFile({'astropy_poly':poly_astropy})
# inline is just for readability...

# does not work since numpy polynomials are not 'known' by asdf
target = asdf.AsdfFile({'numpy_poly':poly_np})

我试图将polynomial.py中的 PolynomialType 类从 astropy 更改为接受类型“numpy.polynomial.polynomial.Polynomial”。但问题仍然是无法表示对象。那么我需要在哪里进行更改才能使我的 polynomial.py 正常工作?或者,也许我覆盖 astropy 类的方式是错误的?

import numpy as np
from numpy.polynomial import Polynomial as P
from numpy.testing import assert_array_equal

from asdf import yamlutil

from astropy import modeling
from astropy.io.misc.asdf.tags.transform.basic import TransformType

class PolynomialType_np(TransformType):
    name = "transform/polynomial"
    types = ['astropy.modeling.models.Polynomial1D',

    # from asdf file to np polynomial
    def from_tree_transform(cls, node, ctx):
        coefficients = np.asarray(node['coefficients'])
        return P(coefficients)

    # from any polynomial to asdf
    def to_tree_transform(cls, model, ctx):
        # np.polynomial added
        if isinstance(model, np.polynomial.polynomial.Polynomial):
            coefficients = p.coef
        elif isinstance(model, modeling.models.Polynomial1D):
            coefficients = np.array(model.parameters)
        elif isinstance(model, modeling.models.Polynomial2D):
            degree = model.degree
            coefficients = np.zeros((degree + 1, degree + 1))
            for i in range(degree + 1):
                for j in range(degree + 1):
                    if i + j < degree + 1:
                        name = 'c' + str(i) + '_' + str(j)
                        coefficients[i, j] = getattr(model, name).value
        node = {'coefficients': coefficients}
        return yamlutil.custom_tree_to_tagged_tree(node, ctx)

    # astropy classmethod updated with np.arrays
    def assert_equal(cls, a, b):
        # TODO: If models become comparable themselves, remove this.
        TransformType.assert_equal(a, b)
        assert (isinstance(a, (modeling.models.Polynomial1D, modeling.models.Polynomial2D, np.polynomial.polynomial.Polynomial)) and
                isinstance(b, (modeling.models.Polynomial1D, modeling.models.Polynomial2D, np.polynomial.polynomial.Polynomial)))
        if (isinstance(a, (modeling.models.Polynomial1D, modeling.models.Polynomial2D)) and 
            isinstance(b, (modeling.models.Polynomial1D, modeling.models.Polynomial2D))):
            assert_array_equal(a.parameters, b.parameters)
        elif (isinstance(a, (modeling.models.Polynomial1D, modeling.models.Polynomial2D)) and
              isinstance(b, np.polynomial.polynomial.Polynomial)):
            assert_array_equal(a.parameters, b.coeff)
        elif (isinstance(b, (modeling.models.Polynomial1D, modeling.models.Polynomial2D)) and
              isinstance(a, np.polynomial.polynomial.Polynomial)):
            assert_array_equal(a.coeff, b.parameters)
        elif (isinstance(a, np.polynomial.polynomial.Polynomial) and
              isinstance(b, np.polynomial.polynomial.Polynomial)):
            assert_array_equal(a.coeff, b.coeff)

标签: pythonnumpyastropy


以下是@Iguananaut 建议的两种解决方案:

解决方案 1


# the code from above and then the following
from astropy.io.misc.asdf.extension import AstropyAsdfExtension
from astropy.io.misc.asdf.types import _astropy_asdf_types


#this will work now
target = asdf.AsdfFile({'numpy_poly':poly_np},extensions=AstropyAsdfExtension())

解决方案 2

这是您创建一个子类的解决方案,您可以在PolynomialType其中添加添加 numpy 多项式的功能。由于实际上没有必要将它们读取为 numpy 多项式,因此将它们读取为 astropy 多项式。

import numpy as np
from numpy.polynomial import Polynomial as P
from numpy.testing import assert_array_equal

import asdf
from asdf import yamlutil

from astropy import modeling
from astropy.io.misc.asdf.tags.transform.polynomial import PolynomialType

from astropy.io.misc.asdf.extension import AstropyAsdfExtension

class PolynomialTypeNumpy(PolynomialType):
    def to_tree(cls, model, ctx):
        coefficients = model.coef
        node = {'coefficients': coefficients}
        return yamlutil.custom_tree_to_tagged_tree(node, ctx)

    # could/should add assert_equal from above

# And then this works.
target = asdf.AsdfFile({'numpy_poly':P([0,0,0])},
