首页 > 解决方案 > python中的简单感知器模型

问题描述

我在运行 fit 函数时遇到这种类型的错误。许多人说它在 python 2.7 中运行。我想知道如何在 python 3 中完成它。还有其他方法可以做到吗?

class Perceptron:

    def __init__(self):
        self.w=None
        self.b=None

    def model(self,x):
        return 1 if (np.dot(self.w,x)>=self.b) else 0

    def predict(self,X):
        Y=[]
        for x in X:
            result = self.model(x)
            Y.append(result)
        return np.array(Y)

    def fit(self, X, Y, epochs = 1, lr=1):
        self.w = np.ones(X.shape[1])
        self.b = 0

        accuracy = {}
        max_accuracy = 0

        wt_matrix = []

        for i in range(epochs):
            for x, y in zip(X,Y):
                y_pred = self.model(x)
                if y==1 and y_pred == 0:
                    self.w = self.w +lr* x
                    self.b = self.b + lr*1
                elif y==0 and y_pred== 1:
                    self.w = self.w-lr*x
                    self.b = self.b-lr*1
            wt_matrix.append(self.w)
            accuracy[i] =  accuracy_score(self.predict(X),Y)
            if(accuracy[i]>max_accuracy):
                max_accuracy = accuracy[i]
                chkptw=self.w
                chkptb=self.b
        self.w =chkptw
        self.b=chkptb

        print(max_accuracy)



        plt.plot(accuracy.values())
        plt.ylim([0,1])
        plt.show   

        return np.array(wt_matrix) 

这是代码:

wt_matrix = perceptron.fit(X_train,Y_train,100)

当我调用该函数时,它显示了这种类型的错误

TypeError                                 Traceback (most recent call last)
<ipython-input-76-8b850a516f0e> in <module>()

----> 1 wt_matrix = perceptron.fit(X_train,Y_train,100)


8 frames

/usr/local/lib/python3.6/dist-packages/numpy/core/_asarray.py in asarray(a, dtype, order)
     83 
     84     """
---> 85     return array(a, dtype, copy=False, order=order)
     86 
     87 

TypeError: float() argument must be a string or a number, not 'dict_values'

标签: pythonmachine-learningdeep-learningneural-networkperceptron

解决方案


这是一个简单的类型转换问题。改变

plt.plot(accuracy.values())

plt.plot(list(accuracy.values()))

推荐阅读