首页 > 解决方案 > 使用 numpy 多项式或 astropy 多项式时如何实现 asdf 扩展?

问题描述

我想在 ASDF 文件中存储一些数据并想使用未实现的扩展。我想扩展其他扩展,所以我尝试从Astropy扩展开始。

我知道如何为 ASDF 编写一个有效的扩展。但是,关键问题是 ASDF 文件应该总是看起来像一个由天文多项式创建的 ASDF 文件。创建一个存储 numpy 多项式的新扩展不是我的目的。另一方面,asdf 文件应始终输出一个 numpy 多项式。

这就是我开始的工作:

import asdf
from astropy.modeling import models, fitting
from numpy.polynomial import Polynomial as P

# these 2 polynomials are equal
poly_np = P([0,0,0])
poly_astropy = models.Polynomial1D(degree=2)

# this is the usual way how to save an astropy polynomial
target = asdf.AsdfFile({'astropy_poly':poly_astropy})
# inline is just for readability...
target.write_to('poly_astropy.yaml',all_array_storage='inline')

# does not work since numpy polynomials are not 'known' by asdf
target = asdf.AsdfFile({'numpy_poly':poly_np})
target.write_to('poly_np.yaml',all_array_storage='inline')

我试图将polynomial.py中的 PolynomialType 类从 astropy 更改为接受类型“numpy.polynomial.polynomial.Polynomial”。但问题仍然是无法表示对象。那么我需要在哪里进行更改才能使我的 polynomial.py 正常工作?或者,也许我覆盖 astropy 类的方式是错误的?

import numpy as np
from numpy.polynomial import Polynomial as P
from numpy.testing import assert_array_equal

from asdf import yamlutil

from astropy import modeling
from astropy.io.misc.asdf.tags.transform.basic import TransformType

class PolynomialType_np(TransformType):
    name = "transform/polynomial"
    types = ['astropy.modeling.models.Polynomial1D',
             'astropy.modeling.models.Polynomial2D',
             'numpy.polynomial.polynomial.Polynomial']

    # from asdf file to np polynomial
    @classmethod
    def from_tree_transform(cls, node, ctx):
        coefficients = np.asarray(node['coefficients'])
        return P(coefficients)


    # from any polynomial to asdf
    @classmethod
    def to_tree_transform(cls, model, ctx):
        # np.polynomial added
        if isinstance(model, np.polynomial.polynomial.Polynomial):
            coefficients = p.coef
        elif isinstance(model, modeling.models.Polynomial1D):
            coefficients = np.array(model.parameters)
        elif isinstance(model, modeling.models.Polynomial2D):
            degree = model.degree
            coefficients = np.zeros((degree + 1, degree + 1))
            for i in range(degree + 1):
                for j in range(degree + 1):
                    if i + j < degree + 1:
                        name = 'c' + str(i) + '_' + str(j)
                        coefficients[i, j] = getattr(model, name).value
        node = {'coefficients': coefficients}
        return yamlutil.custom_tree_to_tagged_tree(node, ctx)

    # astropy classmethod updated with np.arrays
    @classmethod
    def assert_equal(cls, a, b):
        # TODO: If models become comparable themselves, remove this.
        TransformType.assert_equal(a, b)
        assert (isinstance(a, (modeling.models.Polynomial1D, modeling.models.Polynomial2D, np.polynomial.polynomial.Polynomial)) and
                isinstance(b, (modeling.models.Polynomial1D, modeling.models.Polynomial2D, np.polynomial.polynomial.Polynomial)))
        if (isinstance(a, (modeling.models.Polynomial1D, modeling.models.Polynomial2D)) and 
            isinstance(b, (modeling.models.Polynomial1D, modeling.models.Polynomial2D))):
            assert_array_equal(a.parameters, b.parameters)
        elif (isinstance(a, (modeling.models.Polynomial1D, modeling.models.Polynomial2D)) and
              isinstance(b, np.polynomial.polynomial.Polynomial)):
            assert_array_equal(a.parameters, b.coeff)
        elif (isinstance(b, (modeling.models.Polynomial1D, modeling.models.Polynomial2D)) and
              isinstance(a, np.polynomial.polynomial.Polynomial)):
            assert_array_equal(a.coeff, b.parameters)
        elif (isinstance(a, np.polynomial.polynomial.Polynomial) and
              isinstance(b, np.polynomial.polynomial.Polynomial)):
            assert_array_equal(a.coeff, b.coeff)

标签: pythonnumpyastropy

解决方案


以下是@Iguananaut 建议的两种解决方案:

解决方案 1

这是您PolynomialType强制覆盖注册表的解决方案。

# the code from above and then the following
from astropy.io.misc.asdf.extension import AstropyAsdfExtension
from astropy.io.misc.asdf.types import _astropy_asdf_types

_astropy_asdf_types.remove(
    astropy.io.misc.asdf.tags.transform.polynomial.PolynomialType)

#this will work now
target = asdf.AsdfFile({'numpy_poly':poly_np},extensions=AstropyAsdfExtension())
target.write_to('poly_np.yaml',all_array_storage='inline')

解决方案 2

这是您创建一个子类的解决方案,您可以在PolynomialType其中添加添加 numpy 多项式的功能。由于实际上没有必要将它们读取为 numpy 多项式,因此将它们读取为 astropy 多项式。

import numpy as np
from numpy.polynomial import Polynomial as P
from numpy.testing import assert_array_equal

import asdf
from asdf import yamlutil

from astropy import modeling
from astropy.io.misc.asdf.tags.transform.polynomial import PolynomialType

from astropy.io.misc.asdf.extension import AstropyAsdfExtension

class PolynomialTypeNumpy(PolynomialType):
    @classmethod
    def to_tree(cls, model, ctx):
        coefficients = model.coef
        node = {'coefficients': coefficients}
        return yamlutil.custom_tree_to_tagged_tree(node, ctx)

    # could/should add assert_equal from above


# And then this works.
target = asdf.AsdfFile({'numpy_poly':P([0,0,0])},
    extensions=AstropyAsdfExtension())
target.write_to('poly_np.yaml',all_array_storage='inline')


推荐阅读