numpy - MLP 输出任何输入的所有训练输出的平均值
问题描述
我试图用 sigmoid 激活实现多层感知器。下面是代码:
import numpy as np
def sigmoid(x):
return 1.0/(1.0 + np.exp(-x))
def sigmoid_derivative(x):
return sigmoid(x) * (1.0 - sigmoid(x))
class MLP:
def __init__(self, layers, x_train, y_train):
self.layers = layers
self.inputs = x_train
self.outputs = y_train
def forward(self, input):
output = input
for layer in self.layers:
layer.activations = output
output = layer.feedforward(output)
return output
def backward(self, output, predicted):
error = np.multiply(2 * np.subtract(output, predicted), sigmoid_derivative(predicted))
for layer in self.layers[::-1]:
#recursively backpropagate the error
error = layer.backpropagate(error)
def train(self):
for i in range(1,500):
predicted = self.forward(self.inputs)
self.backward(self.outputs,predicted)
def test(self, input):
return self.forward(input)
class Layer:
def __init__(self, prevNodes, selfNodes):
self.weights = np.random.rand(prevNodes,selfNodes)
self.biases = np.zeros(selfNodes)
self.activations = np.array([])
def feedforward(self, input):
return sigmoid(np.dot(input, self.weights) + self.biases)
def backpropagate(self, error):
delPropagate = np.dot(error, self.weights.transpose())
dw = np.dot(self.activations.transpose(), error)
db = error.mean(axis=0) * self.activations.shape[0]
self.weights = self.weights + 0.1 * dw
self.biases = self.biases + 0.1 * db
return np.multiply(delPropagate ,sigmoid_derivative(self.activations))
layer1 = Layer(3,4)
layer2 = Layer(4,1)
x_train = np.array([[0,0,1],[0,1,1],[1,0,1],[1,1,1]])
y_train = np.array([[0],[1],[1],[0]])
x_test = np.array([[0,0,1]])
mlp = MLP([layer1,layer2], x_train, y_train)
mlp.train()
mlp.test(x_test)
然而问题是网络饱和输出任何输入的所有训练输出的平均值。例如,在上述情况下,y_train 平均值约为 0.5,无论我向网络提供什么“test_x”值,输出始终在 0.5 左右。
代码中的问题可能出在哪里。我在算法中遗漏了什么吗?帮助表示赞赏
解决方案
问题似乎在于迭代次数较少,将迭代次数从 500 增加到 50000 或将学习率更改为 0.5 也适用于较少的迭代次数。矩阵操作和所有数学似乎是一致的
推荐阅读
- google-sheets - 在查询语句中使用命名范围的 Google 表格不起作用
- android - 如何在 React-Native (Expo) 应用程序中显示应用程序图标徽章?
- mysql - 拒绝访问; 您需要(至少其中一项)超级特权
- elasticsearch - Elasticsearch:创建索引时设置的总字段限制
- android - Pusher Chatkit Android 获取指定房间的状态读数
- python - 尝试使用自定义对象加载模型时出现“ValueError:未知激活:激活”
- apache-spark - 是否为每个 spark 应用程序启动了 spark workers/jvm(重新)?
- php - 运行服务器而不从 Web 浏览器打开它
- excel - 我可以从另一个文件的宏中锁定对共享驱动器中多个 Excel 文件的代码(模块和用户窗体)的访问吗?
- gradle - 错误:缺少 JavaFX 运行时组件,需要使用 Gradle 示例运行此应用程序