python - 使用 numpy 多项式或 astropy 多项式时如何实现 asdf 扩展?
问题描述
我想在 ASDF 文件中存储一些数据并想使用未实现的扩展。我想扩展其他扩展,所以我尝试从Astropy扩展开始。
我知道如何为 ASDF 编写一个有效的扩展。但是,关键问题是 ASDF 文件应该总是看起来像一个由天文多项式创建的 ASDF 文件。创建一个存储 numpy 多项式的新扩展不是我的目的。另一方面,asdf 文件应始终输出一个 numpy 多项式。
这就是我开始的工作:
import asdf
from astropy.modeling import models, fitting
from numpy.polynomial import Polynomial as P
# these 2 polynomials are equal
poly_np = P([0,0,0])
poly_astropy = models.Polynomial1D(degree=2)
# this is the usual way how to save an astropy polynomial
target = asdf.AsdfFile({'astropy_poly':poly_astropy})
# inline is just for readability...
target.write_to('poly_astropy.yaml',all_array_storage='inline')
# does not work since numpy polynomials are not 'known' by asdf
target = asdf.AsdfFile({'numpy_poly':poly_np})
target.write_to('poly_np.yaml',all_array_storage='inline')
我试图将polynomial.py中的 PolynomialType 类从 astropy 更改为接受类型“numpy.polynomial.polynomial.Polynomial”。但问题仍然是无法表示对象。那么我需要在哪里进行更改才能使我的 polynomial.py 正常工作?或者,也许我覆盖 astropy 类的方式是错误的?
import numpy as np
from numpy.polynomial import Polynomial as P
from numpy.testing import assert_array_equal
from asdf import yamlutil
from astropy import modeling
from astropy.io.misc.asdf.tags.transform.basic import TransformType
class PolynomialType_np(TransformType):
name = "transform/polynomial"
types = ['astropy.modeling.models.Polynomial1D',
'astropy.modeling.models.Polynomial2D',
'numpy.polynomial.polynomial.Polynomial']
# from asdf file to np polynomial
@classmethod
def from_tree_transform(cls, node, ctx):
coefficients = np.asarray(node['coefficients'])
return P(coefficients)
# from any polynomial to asdf
@classmethod
def to_tree_transform(cls, model, ctx):
# np.polynomial added
if isinstance(model, np.polynomial.polynomial.Polynomial):
coefficients = p.coef
elif isinstance(model, modeling.models.Polynomial1D):
coefficients = np.array(model.parameters)
elif isinstance(model, modeling.models.Polynomial2D):
degree = model.degree
coefficients = np.zeros((degree + 1, degree + 1))
for i in range(degree + 1):
for j in range(degree + 1):
if i + j < degree + 1:
name = 'c' + str(i) + '_' + str(j)
coefficients[i, j] = getattr(model, name).value
node = {'coefficients': coefficients}
return yamlutil.custom_tree_to_tagged_tree(node, ctx)
# astropy classmethod updated with np.arrays
@classmethod
def assert_equal(cls, a, b):
# TODO: If models become comparable themselves, remove this.
TransformType.assert_equal(a, b)
assert (isinstance(a, (modeling.models.Polynomial1D, modeling.models.Polynomial2D, np.polynomial.polynomial.Polynomial)) and
isinstance(b, (modeling.models.Polynomial1D, modeling.models.Polynomial2D, np.polynomial.polynomial.Polynomial)))
if (isinstance(a, (modeling.models.Polynomial1D, modeling.models.Polynomial2D)) and
isinstance(b, (modeling.models.Polynomial1D, modeling.models.Polynomial2D))):
assert_array_equal(a.parameters, b.parameters)
elif (isinstance(a, (modeling.models.Polynomial1D, modeling.models.Polynomial2D)) and
isinstance(b, np.polynomial.polynomial.Polynomial)):
assert_array_equal(a.parameters, b.coeff)
elif (isinstance(b, (modeling.models.Polynomial1D, modeling.models.Polynomial2D)) and
isinstance(a, np.polynomial.polynomial.Polynomial)):
assert_array_equal(a.coeff, b.parameters)
elif (isinstance(a, np.polynomial.polynomial.Polynomial) and
isinstance(b, np.polynomial.polynomial.Polynomial)):
assert_array_equal(a.coeff, b.coeff)
解决方案
以下是@Iguananaut 建议的两种解决方案:
解决方案 1
这是您PolynomialType
强制覆盖注册表的解决方案。
# the code from above and then the following
from astropy.io.misc.asdf.extension import AstropyAsdfExtension
from astropy.io.misc.asdf.types import _astropy_asdf_types
_astropy_asdf_types.remove(
astropy.io.misc.asdf.tags.transform.polynomial.PolynomialType)
#this will work now
target = asdf.AsdfFile({'numpy_poly':poly_np},extensions=AstropyAsdfExtension())
target.write_to('poly_np.yaml',all_array_storage='inline')
解决方案 2
这是您创建一个子类的解决方案,您可以在PolynomialType
其中添加添加 numpy 多项式的功能。由于实际上没有必要将它们读取为 numpy 多项式,因此将它们读取为 astropy 多项式。
import numpy as np
from numpy.polynomial import Polynomial as P
from numpy.testing import assert_array_equal
import asdf
from asdf import yamlutil
from astropy import modeling
from astropy.io.misc.asdf.tags.transform.polynomial import PolynomialType
from astropy.io.misc.asdf.extension import AstropyAsdfExtension
class PolynomialTypeNumpy(PolynomialType):
@classmethod
def to_tree(cls, model, ctx):
coefficients = model.coef
node = {'coefficients': coefficients}
return yamlutil.custom_tree_to_tagged_tree(node, ctx)
# could/should add assert_equal from above
# And then this works.
target = asdf.AsdfFile({'numpy_poly':P([0,0,0])},
extensions=AstropyAsdfExtension())
target.write_to('poly_np.yaml',all_array_storage='inline')
推荐阅读
- android - 检测wifi列表Flutter
- swift - 没有这样的模块“AWSAppSync”
- forms - 如何在opencart 3.0.2的注册页面上合并名字和姓氏字段?
- python-3.7 - 使用正则表达式模式从文本文件中提取 url
- xamarin.forms - 加载列表项后如何从配置文件保存的模型中设置值
- python-2.7 - 使用python将数据插入SQLite3数据库时如何避免插入重复数据?
- javascript - Javascript 中是否有 Function.caller 的生产安全版本?
- javascript - 在 ESRI 地图中禁用双击放大
- jquery - 表 TD 使用 jquery.ui re-sizable 事件以百分比调整大小
- java - 在 Spring AOP 中,joinPoint.proceed 和 method.invoke 有什么区别?