python - 当我尝试在感知器算法中更新权重向量时,为什么会出现 matmul 不匹配错误?
问题描述
我开始学习 ML 和神经网络,并开始为手写数字识别实施感知器算法。所以代码可以正常工作,直到我尝试更新我的权重(权重)向量。我想 numpy 数组向量大小存在一些问题,但我不知道如何解决这个问题。
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import scipy as sp
from sklearn.datasets import load_digits
#This function passes the values X,Y to the Perceptron algorith and plots the graph of accuracy after each itearation.
def digit(digit_to_recognize=5):
# loading the usps digits dataset from sklearn repository
n_example = 100
X, Y = load_digits(n_class=10, return_X_y=True)
plt.matshow(X[n_example,:].reshape(8,8));
plt.xticks([]);plt.yticks([]);
plt.title(Y[n_example])
plt.savefig("usps_example.png")
# transforming the 10-class labels into binary form
y = sp.sign((Y==digit_to_recognize)* 1.0 - .5)
_, acc = perceptron_train(X,y)
plt.figure(figsize=[12,4])
plt.plot(acc)
plt.xlabel("Iterations");plt.ylabel("Accuracy");
plt.savefig("learning_curve.png")
def perceptron_train(X,Y,iterations=100,eta=.01):
acc = sp.zeros(iterations)
# initialize weight vector
weights = sp.random.randn(X.shape[1]) * 1e-5
for it in sp.arange(iterations):
# indices of misclassified data
wrong = (sp.sign(X @ weights) != Y).nonzero()[0]
if wrong.shape[0] > 0:
# picking a random misclassified data point
i = sp.random.choice(wrong,1)
rand_ex = X[i]*Y[i]
# update weight vector
weights = weights + (eta/it)*rand_ex
# computing accuracy
acc[it] = sp.double(sp.sum(sp.sign(X @ weights)==Y))/X.shape[0]
# return weight vector and accuracy
return weights,acc
这是一个错误:
<ipython-input-128-f4e41796a9be> in perceptron_train(X, Y, iterations, eta)
16 weights = weights + (eta/it)*rand_ex
17 # compute accuracy
---> 18 acc[it] = sp.double(sp.sum(sp.sign(X @ weights)==Y))/X.shape[0]
19 # return weight vector and accuracy
20 return weights,acc
ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 1 is different from 64)
这是用伪代码编写的感知器算法本身: 感知器算法
解决方案
推荐阅读
- docker - 如何在现有的 docker swarm 中生成交互式容器?
- reactjs - React Table 如何添加/更改单个单元格
- r - 闪亮的 DT 表中的 unbindalll 问题
- c++ - 编译 woff2 时出现错误“需要 libbrotli”
- java - 如何从用户输入数组中调用某些信息
- firebase - Flutter - firebase_analytics 请求权限
- python - 使用 kivy-ios 构建并让 pyrebase/pycryptodome(x) 工作(解决方案)
- c# - Winform 使用通用方法将按钮文本复制到文本框
- flutter - Appbar 标题文本 fontSize 不变
- python - SequenceMatcher 在编辑距离和 difflib 中的应用区别?