首页 > 解决方案 > __init__() 得到了一个意外的关键字参数“chi1”

问题描述

我是python新手。我正在尝试将参数导入我的类“ inverse_model ”。我调用一个函数“ get_model s”来做到这一点。但它给了我错误“ init () got an unexpected keyword argument 'zz'

我很感激帮助。请看下面的代码:

    def get_models(args):
    
    
    zz=torch.tensor(args.chi_Initialize)
    inverse_net = inverse_model(in_channels=len(args.chi),zz=zz,resolution_ratio=args.resolution_ratio,nonlinearity=args.nonlinearity)
    
    return inverse_net


class inverse_model(nn.Module):
    def __init__(self, in_channels,zz,resolution_ratio=6,nonlinearity="tanh"):
        super(inverse_model, self).__init__()
        self.in_channels = in_channels
        self.zz=zz
        self.resolution_ratio = resolution_ratio #vertical scale mismtach between seismic and EI
        self.activation =  nn.ReLU() if nonlinearity=="relu" else nn.Tanh() 

标签: pythonpytorch

解决方案


在默认参数之后,python 不允许有非默认参数。

将您的构造函数修改为

def __init__(self, in_channels,
                 chi1,chi2,chi3,chi4,chi5,chi6,chi7,chi8,chi9,
                 chi10,chi11,chi12, resolution_ratio=6,nonlinearity="tanh"):

更新答案:Revision 1(更新问题)

import torch
from torch import nn


class inverse_model(nn.Module):
    def __init__(self, in_channels, zz, resolution_ratio=6, nonlinearity="tanh"):
        super(inverse_model, self).__init__()
        self.in_channels = in_channels
        self.zz = zz
        self.resolution_ratio = resolution_ratio  # vertical scale mismtach between seismic and EI
        self.activation = nn.ReLU() if nonlinearity == "relu" else nn.Tanh()

def get_models(args):
    zz = torch.tensor(args.chi_Initialize)
    inverse_net = inverse_model(in_channels=len(args.chi), zz=zz, resolution_ratio=args.resolution_ratio,
                                nonlinearity=args.nonlinearity)

    return inverse_net

exit 0作为状态返回。


推荐阅读