首页 > 解决方案 > 实施 rv_continuous

问题描述

我正在尝试scipy.stats.truncnorm通过实现scipy.stats.rv_continuous. 我提供了_argcheck, _get_support,_pdf和的代码_rvs,但得到了错误

_parse_args() missing 4 required positional arguments: 'a', 'b', 'mu', and 'sig'

我怀疑它与shapes或实现有关_parse_args,但无法弄清楚如何解决它(我已经看到How do you use scipy.stats.rv_continuous?)。

我正在使用 scipy v 1.5.2 和 Python 3.8.5。

代码:

from scipy.stats import *
import scipy.stats

class truncgauss_gen(rv_continuous):
    ''' a and b are bounds of true support
    mu is mean
    sig is std dev
    '''
    def _argcheck(self, a, b, sig): return (a < b) and (sig > 0)
        
    def _get_support(self, a, b):   return a, b
        
    def _pdf(self, x, a, b, mu, sig):   return scipy.stats.truncnorm.pdf(x, (a - mu) / sig, (b - mu) / sig, ac=mu, scale=sig)

    def _rvs(self, a, b, mu, sig, size):    return scipy.stats.truncnorm.rvs((a - mu) / sig, (b - mu) / sig, ac=mu, scale=sig, size=size)
    
        
truncgauss = truncgauss_gen(name='truncgauss', momtype=1)

if __name__ == '__main__':
    print(scipy.__version__)    
    
    tg = truncgauss()
    dat = tg.rvs(a=-5.1, b=10.07, mu=2.3, n=10)
    print(dat)

追溯:

1.5.2
Traceback (most recent call last):
  File "testDistr.py", line 41, in <module>
    tg = truncgauss()
  File ".../opt/anaconda3/lib/python3.8/site-packages/scipy/stats/_distn_infrastructure.py", line 780, in __call__
    return self.freeze(*args, **kwds)
  File ".../opt/anaconda3/lib/python3.8/site-packages/scipy/stats/_distn_infrastructure.py", line 777, in freeze
    return rv_frozen(self, *args, **kwds)
  File ".../opt/anaconda3/lib/python3.8/site-packages/scipy/stats/_distn_infrastructure.py", line 424, in __init__
    shapes, _, _ = self.dist._parse_args(*args, **kwds)
TypeError: _parse_args() missing 4 required positional arguments: 'a', 'b', 'mu', and 'sig'

标签: pythonscipy.stats

解决方案


不太确定问题出在哪里,但将所有变量传递给每个函数,如下所示似乎可行。

class truncG_gen(rv_continuous):
    def _argcheck(self, a, b, mu, sig): return (a < b) and (sig > 0)
        
    def _get_support(self, a, b, mu, sig):  return a, b
        
    def _pdf(self, x, a, b, mu, sig):   return scipy.stats.truncnorm.pdf(x, (a - mu) / sig, (b - mu) / sig, loc=mu, scale=sig)

    def _cdf(self, x, a, b, mu, sig):   return scipy.stats.truncnorm.cdf(x, (a - mu) / sig, (b - mu) / sig, loc=mu, scale=sig)
    
    def _rvs(self, a, b, mu, sig, size=None, random_state=None):
        return scipy.stats.truncnorm.rvs((a - mu) / sig, (b - mu) / sig, loc=mu, scale=sig, size=size, random_state=random_state)

truncG = truncG_gen(name='truncG', momtype=1)
'''

推荐阅读