c++ - DNN 精度低
问题描述
我最近一直在基于http://neuralnetworksanddeeplearning.com/实施 NN 。我为反向传播和 SGD 制作了整个算法,几乎与本书作者相同。问题是,虽然他在一个 epoch 后获得了 90% 左右的准确率,但在 5 个 epochs 后我获得了 30%,即使我有相同的超参数。你知道可能是什么原因吗?这是我的存储库。 https://github.com/PiPower/Deep-Neural-Network 以下是 Network.cpp 中实现的反向传播和 SGD 算法的一部分:
void Network::Train(MatrixD_Array& TrainingData, MatrixD_Array& TrainingLabels, int BatchSize,int epochs, double LearningRate)
{
assert(TrainingData.size() == TrainingLabels.size() && CostFunc != nullptr && CostFuncDer != nullptr && LearningRate > 0);
std::vector<long unsigned int > indexes;
for (int i = 0; i < TrainingData.size(); i++) indexes.push_back(i);
std::random_device rd;
std::mt19937 g(rd());
std::vector<Matrix<double>> NablaWeights;
std::vector<Matrix<double>> NablaBiases;
NablaWeights.resize(Layers.size());
NablaBiases.resize(Layers.size());
for (int i = 0; i < Layers.size(); i++)
{
NablaWeights[i] = Matrix<double>(Layers[i].GetInDim(), Layers[i].GetOutDim());
NablaBiases[i] = Matrix<double>(1, Layers[i].GetOutDim());
}
//---- Epoch iterating
for (int i = 0; i < epochs; i++)
{
cout << "Epoch number: " << i << endl;
shuffle(indexes.begin(), indexes.end(), g);
// Batch iterating
for (int batch = 0; batch < TrainingData.size(); batch = batch + BatchSize)
{
for (int i = 0; i < Layers.size(); i++)
{
NablaWeights[i].Clear();
NablaBiases[i].Clear();
}
int i = 0;
while( i < BatchSize && (i+batch)< TrainingData.size())
{
std::vector<Matrix<double>> ActivationOutput;
std::vector<Matrix<double>> Z_Output;
ActivationOutput.resize(Layers.size() + 1);
Z_Output.resize(Layers.size());
ActivationOutput[0] = TrainingData[indexes[i + batch]];
int index = 0;
// Pushing values through
for (auto layer : Layers)
{
Z_Output[index] = layer.Mul(ActivationOutput[index]);
ActivationOutput[index + 1] = layer.ApplyActivation(Z_Output[index]);
index++;
}
// ---- Calculating Nabla that will be later devided by batch size element wise
auto DeltaNabla = BackPropagation(ActivationOutput, Z_Output, TrainingLabels[indexes[i + batch]]);
for (int i = 0; i < Layers.size(); i++)
{
NablaWeights[i] = NablaWeights[i] + DeltaNabla.first[i];
NablaBiases[i] = NablaBiases[i] + DeltaNabla.second[i];
}
i++;
}
for (int g = 0; g < Layers.size(); g++)
{
Layers[g].Weights = Layers[g].Weights - NablaWeights[g] * LearningRate;
Layers[g].Biases = Layers[g].Biases - NablaBiases[g] * LearningRate;
}
// std::transform(NablaWeights.begin(), NablaWeights.end(), NablaWeights.begin(),[BatchSize, LearningRate](Matrix<double>& Weight) {return Weight * (LearningRate / BatchSize);});
//std::transform(NablaBiases.begin(), NablaBiases.end(), NablaBiases.begin(), [BatchSize, LearningRate](Matrix<double>& Bias) {return Bias * (LearningRate / BatchSize); });
}
}
}
std::pair<MatrixD_Array, MatrixD_Array> Network::BackPropagation( MatrixD_Array& ActivationOutput, MatrixD_Array& Z_Output,Matrix<double>& label)
{
MatrixD_Array NablaWeight;
MatrixD_Array NablaBias;
NablaWeight.resize(Layers.size());
NablaBias.resize(Layers.size());
auto zs = Layers[Layers.size() - 1].ActivationPrime(Z_Output[Z_Output.size() - 1]);
Matrix<double> Delta_L = Hadamard(CostFuncDer(ActivationOutput[ActivationOutput.size() - 1],label), zs);
NablaWeight[Layers.size() - 1] = Delta_L * ActivationOutput[ActivationOutput.size() - 2].Transpose();
NablaBias[Layers.size() - 1] = Delta_L;
for (int j = 2; j <= Layers.size() ; j++)
{
auto sp = Layers[Layers.size() - j].ActivationPrime(Z_Output[Layers.size() -j]);
Delta_L = Hadamard(Layers[Layers.size() - j+1 ].Weights.Transpose() * Delta_L, sp);
NablaWeight[Layers.size() - j] = Delta_L * ActivationOutput[ActivationOutput.size() -j-1].Transpose();
NablaBias[Layers.size() - j] = Delta_L;
}
return make_pair(NablaWeight, NablaBias);
}
解决方案
事实证明,mnist loader 无法正常工作。
推荐阅读
- phaser-framework - Phaser 3 Arcade 物理没有被激活
- java - Collectors.grouping 通过抛出 NullPointerException
- javascript - 我正在尝试调用事件,但我的可视代码说(“事件已弃用 ts(6385)”)
- mysql - 如何在sql中仅比较和使用日期时间格式的时间?我的插入以日期时间格式给出
- java - Java - 主要因素 - 提示修复?
- sql - SQL Developer - 根据不同的 ID 更新/插入来自不同表的列总和
- tensorflow - Resnet 152 tensorflow Keras 迁移学习的特征提取
- c# - 在 C# 中使用用户控件的自动完成文本框在表单重新启动之前不显示更新的结果
- google-apps-script - file.next().setTrashed(true) 无法从谷歌驱动器中删除文件
- pandas - 转换包含时间序列的数据框