python - 从头开始的神经网络 - 预测单个示例
问题描述
这是我从 Coursera Deep Learning Specialization 修改的神经网络,用于在包含扁平化训练数据数组的数据集上进行训练:
%reset -s -f
import numpy as np
import math
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def initialize_with_zeros(dim):
w = np.zeros(shape=(dim, 1))
b = 0
return w, b
X = np.array([[1,1,1,1],[1,0,1,0] , [1,1,1,0], [0,0,0,0], [0,1,0,0], [0,1,0,1]])
Y = np.array([[1,0,1,1,1,1]])
X = X.reshape(X.shape[0], -1).T
Y = Y.reshape(Y.shape[0], -1).T
print('X shape' , X.shape)
print('Y shape' , Y.shape)
b = 1
w, b = initialize_with_zeros(4)
def propagate(w, b, X, Y) :
m = X.shape[1]
A = sigmoid(np.dot(w.T, X) + b) # compute activation
cost = (- 1 / m) * np.sum(Y * np.log(A) + (1 - Y) * (np.log(1 - A))) # compute cost
dw = (1./m)*np.dot(X,((A-Y).T))
db = (1./m)*np.sum(A-Y, axis=1)
cost = np.squeeze(cost)
grads = {"dw": dw,
"db": db}
return grads, cost
propagate(w , b , X , Y)
learning_rate = .001
costs = []
def optimize(w , b, X , Y) :
for i in range(2):
grads, cost = propagate(w=w, b=b, X=X, Y=Y)
dw = grads["dw"]
db = grads["db"]
w = w - learning_rate*dw
b = b - learning_rate*db
if i % 100 == 0:
costs.append(cost)
return w , b
w , b = optimize(w , b , X , Y)
def predict(w, b, X):
m = 6
Y_prediction = np.zeros((1,m))
# w = w.reshape(X.shape[0], 1)
A = sigmoid(np.dot(w.T, X) + b)
for i in range(A.shape[1]):
if A[0, i] >= 0.5:
Y_prediction[0, i] = 1
else:
Y_prediction[0, i] = 0
return Y_prediction
predict(w , b, X)
这按预期工作,但我很难预测一个例子。
如果我使用:
predict(w , b, X[0])
返回错误:
ValueError: shapes (6,4) and (6,) not aligned: 4 (dim 1) != 6 (dim 0)
如何重新安排矩阵运算以预测单个实例?
解决方案
尝试
predict(w, b, X[:1])
看起来你的predict
函数应该是二维X
的,当只传递一个时,X
它应该有一个单一的第二维(即 shape=(6,1))而不是一个单一的维度(即 shape=(6,)) .
推荐阅读
- android - 尝试调用虚拟方法 'void android.widget.Editor$InsertionPointCursorController.hide()'
- html - html页面只有在重新加载后才能正确排列
- php - 如何在以下示例中使用 PHP foreach?
- python - 展开列表列表
- node.js - 无法让 vue cli 4 找到依赖项
- multithreading - 多线程合并排序堆栈溢出错误
- flutter - 在颤动中获取小部件的高度
- mysql - 交叉连接中返回的重复数据具有不同的结果#
- android - 将String设置为Edittext时\ n(新行)的两种方式数据绑定问题
- python - django for python 设置问题