首页 > 解决方案 > 了解如何在 Python 中正确使用类

问题描述

我试图了解如何在 Python 中使用具体的类示例(尤其是在 Pytorch 中)。我有以下两个类:

class Context:

    def __init__(self):
        self._saved_tensors = ()
    def save_for_backward(self, *args):
        self._saved_tensors = args
    @property
    def saved_tensors(self):
        return self._saved_tensors



class MSE(Function):

@staticmethod
def forward(ctx, yhat, y):
   
    ctx.save_for_backward(yhat,y)

    return (1/(yhat.shape[0])*((yhat-y)**2)).sum()
@staticmethod
def backward(ctx, grad_output):

    yhat, y = ctx.saved_tensors
    out_weight =grad_output*(2/(yhat.shape[0])*(y-yhat)).sum()
    out_bias =(2/(yhat.shape[0])*(y-yhat)).sum()
    return out_weight,out_bias

当我将其应用于以下示例时:

ctx=LinearFunction()
#identity matrix of size 3
yhat=torch.eye(3)
y=torch.ones(3,3)
grad_output=1
loss=MSE(ctx,yhat,y)
loss.backward(ctx,grad_output)

它返回给我一个错误:

4 grad_output=1
5 loss=MSE(ctx,yhat,y)
6 loss.backward(ctx,grad_output)

<ipython-input-38-0c861407599f> in backward(ctx, grad_output)
   13 def backward(ctx, grad_output):
   15 yhat, y = ctx.saved_tensors
   16 out_weight = (2/(yhat.shape[0])*(y-yhat)).sum()
   17 out_bias =(2/(yhat.shape[0])*(y-yhat)).sum()

ValueError: not enough values to unpack (expected 2, got 0)

并且代码显示了一个指向第 6 行和第 15 行的箭头,表明错误来自那里。我知道代码需要 yhat 和 y 但我已经定义了它们。那么为什么编译器不使用它们呢?有人可以帮我解决这个问题吗?谢谢你。

标签: pythonclasspytorch

解决方案


推荐阅读