c++ - How to back propagate a Neural Network in C++?
问题描述
I have been trying to make a neural network in c++ and my back propagation code is not working the way I want it to. I have a text document that tells the network how to function. I have it have 2 input neurons, 1 hidden layer with 4 neurons and 2 output neurons. I have it learning to be an XOR gate right now. I have it so it takes the cost of the network, multiplys it by .55(scaling) and adding/subtracting that from the weights/bias' depending on how close the output is from the correct answer and weither or not the weights/bias' are + or -. Here's the code:
void Network::backProp(void)
{
double b = 0,a;
int loop,l;
for(loop=0;loop<4;loop++)
{
//Adds up the cost of the data
b = b + (pow(results[2*loop]-key[4*loop+2],2)+pow(results[2*loop+1]-key[4*loop+3],2));
}
a=.55*b;
if(b>.01)
{
for(l=0;l<4;l++)
{
if(round(results[2*l])!=key[4*l+2])
{
if(data[0] <= 0)
{
data[0] = data[0]+a; //(abs(data[0])/a);
}
else
{
data[0] = data[0]-a; //(abs(data[0])/a);
}
if(data[1] <= 0)
{
data[1] = data[1]+a; //(abs(data[1])/a);
}
else
{
data[1] = data[1]-a; //(abs(data[1])/a);
}
if(data[2] <= 0)
{
data[2] = data[2]+a; //(abs(data[2])/a);
}
else
{
data[2] = data[2]-a; //(abs(data[2])/a);
}
if(data[3] <= 0)
{
data[3] = data[3]+a; //(abs(data[3])/a);
}
else
{
data[3] = data[3]-a; //(abs(data[3])/a);
}
if(data[4] <= 0)
{
data[4] = data[4]+a; //(abs(data[4])/a);
}
else
{
data[4] = data[4]-a; //(abs(data[4])/a);
}
if(data[6] <= 0)
{
data[6] = data[6]+a; //(abs(data[6])/a);
}
else
{
data[6] = data[6]-a; //(abs(data[6])/a);
}
if(data[7] <= 0)
{
data[7] = data[7]+a; //(abs(data[7])/a);
}
else
{
data[7] = data[7]-a; //(abs(data[7])/a);
}
if(data[8] <= 0)
{
data[8] = data[8]+a; //(abs(data[8])/a);
}
else
{
data[8] = data[8]-a; //(abs(data[8])/a);
}
if(data[9] <= 0)
{
data[9] = data[9]+a; //(abs(data[9])/a);
}
else
{
data[9] = data[9]-a; //(abs(data[9])/a);
}
if(data[10] <= 0)
{
data[10] = data[10]+a; //(abs(data[10])/a);
}
else
{
data[10] = data[10]-a; //(abs(data[10])/a);
}
if(data[11] <= 0)
{
data[11] = data[11]+a; //(abs(data[11])/a);
}
else
{
data[11] = data[11]-a; //(abs(data[11])/a);
}
if(data[12] <= 0)
{
data[12] = data[12]+a; //(abs(data[12])/a);
}
else
{
data[12] = data[12]-a; //(abs(data[12])/a);
}
if(data[13] <= 0)
{
data[13] = data[13]+a; //(abs(data[13])/a);
}
else
{
data[13] = data[13]-a; //(abs(data[13])/a);
}
if(data[14] <= 0)
{
data[14] = data[14]+a; //(abs(data[14])/a);
}
else
{
data[14] = data[14]-a; //(abs(data[14])/a);
}
if(data[16] <= 0)
{
data[16] = data[16]+a; //(abs(data[16])/a);
}
else
{
data[16] = data[16]-a; //(abs(data[16])/a);
}
if(data[18] <= 0)
{
data[18] = data[18]+a; //(abs(data[18])/a);
}
else
{
data[18] = data[18]-a; //(abs(data[18])/a);
}
if(data[20] <= 0)
{
data[20] = data[20]+a; //(abs(data[20])/a);
}
else
{
data[20] = data[20]-a; //(abs(data[20])/a);
}
}
else
{
if(data[0] <= 0)
{
data[0] = data[0]-a; //(abs(data[0])/a);
}
else
{
data[0] = data[0]+a; //(abs(data[0])/a);
}
if(data[1] <= 0)
{
data[1] = data[1]-a; //(abs(data[1])/a);
}
else
{
data[1] = data[1]+a; //(abs(data[1])/a);
}
if(data[2] <= 0)
{
data[2] = data[2]-a; //(abs(data[2])/a);
}
else
{
data[2] = data[2]+a; //(abs(data[2])/a);
}
if(data[3] <= 0)
{
data[3] = data[3]-a; //(abs(data[3])/a);
}
else
{
data[3] = data[3]+a; //(abs(data[3])/a);
}
if(data[4] <= 0)
{
data[4] = data[4]-a; //(abs(data[4])/a);
}
else
{
data[4] = data[4]+a; //(abs(data[4])/a);
}
if(data[6] <= 0)
{
data[6] = data[6]-a; //(abs(data[6])/a);
}
else
{
data[6] = data[6]+a; //(abs(data[6])/a);
}
if(data[7] <= 0)
{
data[7] = data[7]-a; //(abs(data[7])/a);
}
else
{
data[7] = data[7]+a; //(abs(data[7])/a);
}
if(data[8] <= 0)
{
data[8] = data[8]-a; //(abs(data[8])/a);
}
else
{
data[8] = data[8]+a; //(abs(data[8])/a);
}
if(data[9] <= 0)
{
data[9] = data[9]-a; //(abs(data[9])/a);
}
else
{
data[9] = data[9]+a; //(abs(data[9])/a);
}
if(data[10] <= 0)
{
data[10] = data[10]-a; //(abs(data[10])/a);
}
else
{
data[10] = data[10]+a; //(abs(data[10])/a);
}
if(data[11] <= 0)
{
data[11] = data[11]-a; //(abs(data[11])/a);
}
else
{
data[11] = data[11]+a; //(abs(data[11])/a);
}
if(data[12] <= 0)
{
data[12] = data[12]-a; //(abs(data[12])/a);
}
else
{
data[12] = data[12]+a; //(abs(data[12])/a);
}
if(data[13] <= 0)
{
data[13] = data[13]-a; //(abs(data[13])/a);
}
else
{
data[13] = data[13]+a; //(abs(data[13])/a);
}
if(data[14] <= 0)
{
data[14] = data[14]-a; //(abs(data[14])/a);
}
else
{
data[14] = data[14]+a; //(abs(data[14])/a);
}
if(data[16] <= 0)
{
data[16] = data[16]-a; //(abs(data[16])/a);
}
else
{
data[16] = data[16]+a; //(abs(data[16])/a);
}
if(data[18] <= 0)
{
data[18] = data[18]-a; //(abs(data[18])/a);
}
else
{
data[18] = data[18]+a; //(abs(data[18])/a);
}
if(data[20] <= 0)
{
data[20] = data[20]-a; //(abs(data[20])/a);
}
else
{
data[20] = data[20]+a; //(abs(data[20])/a);
}
}
if(round(results[2*l+1])!=key[4*l+3])
{
if(data[0] <= 0)
{
data[0] = data[0]+a; //(abs(data[0])/a);
}
else
{
data[0] = data[0]-a; //(abs(data[0])/a);
}
if(data[1] <= 0)
{
data[1] = data[1]+a; //(abs(data[1])/a);
}
else
{
data[1] = data[1]-a; //(abs(data[1])/a);
}
if(data[2] <= 0)
{
data[2] = data[2]+a; //(abs(data[2])/a);
}
else
{
data[2] = data[2]-a; //(abs(data[2])/a);
}
if(data[3] <= 0)
{
data[3] = data[3]+a; //(abs(data[3])/a);
}
else
{
data[3] = data[3]-a; //(abs(data[3])/a);
}
if(data[4] <= 0)
{
data[4] = data[4]+a; //(abs(data[4])/a);
}
else
{
data[4] = data[4]-a; //(abs(data[4])/a);
}
if(data[5] <= 0)
{
data[5] = data[5]+a; //(abs(data[5])/a);
}
else
{
data[5] = data[5]-a; //(abs(data[5])/a);
}
if(data[7] <= 0)
{
data[7] = data[7]+a; //(abs(data[7])/a);
}
else
{
data[7] = data[7]-a; //(abs(data[7])/a);
}
if(data[8] <= 0)
{
data[8] = data[8]+a; //(abs(data[8])/a);
}
else
{
data[8] = data[8]-a; //(abs(data[8])/a);
}
if(data[9] <= 0)
{
data[9] = data[9]+a; //(abs(data[9])/a);
}
else
{
data[9] = data[9]-a; //(abs(data[9])/a);
}
if(data[10] <= 0)
{
data[10] = data[10]+a; //(abs(data[10])/a);
}
else
{
data[10] = data[10]-a; //(abs(data[10])/a);
}
if(data[11] <= 0)
{
data[11] = data[11]+a; //(abs(data[11])/a);
}
else
{
data[11] = data[11]-a; //(abs(data[11])/a);
}
if(data[12] <= 0)
{
data[12] = data[12]+a; //(abs(data[12])/a);
}
else
{
data[12] = data[12]-a; //(abs(data[12])/a);
}
if(data[13] <= 0)
{
data[13] = data[13]+a; //(abs(data[13])/a);
}
else
{
data[13] = data[13]-a; //(abs(data[13])/a);
}
if(data[15] <= 0)
{
data[15] = data[15]+a; //(abs(data[15])/a);
}
else
{
data[15] = data[15]-a; //(abs(data[15])/a);
}
if(data[17] <= 0)
{
data[17] = data[17]+a; //(abs(data[17])/a);
}
else
{
data[17] = data[17]-a; //(abs(data[17])/a);
}
if(data[19] <= 0)
{
data[19] = data[19]+a; //(abs(data[19])/a);
}
else
{
data[19] = data[19]-a; //(abs(data[19])/a);
}
if(data[21] <= 0)
{
data[21] = data[21]+a; //(abs(data[21])/a);
}
else
{
data[21] = data[21]-a; //(abs(data[21])/a);
}
}
else
{
if(data[0] <= 0)
{
data[0] = data[0]-a; //(abs(data[0])/a);
}
else
{
data[0] = data[0]+a; //(abs(data[0])/a);
}
if(data[1] <= 0)
{
data[1] = data[1]-a; //(abs(data[1])/a);
}
else
{
data[1] = data[1]+a; //(abs(data[1])/a);
}
if(data[2] <= 0)
{
data[2] = data[2]-a; //(abs(data[2])/a);
}
else
{
data[2] = data[2]+a; //(abs(data[2])/a);
}
if(data[3] <= 0)
{
data[3] = data[3]-a; //(abs(data[3])/a);
}
else
{
data[3] = data[3]+a; //(abs(data[3])/a);
}
if(data[4] <= 0)
{
data[4] = data[4]-a; //(abs(data[4])/a);
}
else
{
data[4] = data[4]+a; //(abs(data[4])/a);
}
if(data[5] <= 0)
{
data[5] = data[5]-a; //(abs(data[5])/a);
}
else
{
data[5] = data[5]+a; //(abs(data[5])/a);
}
if(data[7] <= 0)
{
data[7] = data[7]-a; //(abs(data[7])/a);
}
else
{
data[7] = data[7]+a; //(abs(data[7])/a);
}
if(data[8] <= 0)
{
data[8] = data[8]-a; //(abs(data[8])/a);
}
else
{
data[8] = data[8]+a; //(abs(data[8])/a);
}
if(data[9] <= 0)
{
data[9] = data[9]-a; //(abs(data[9])/a);
}
else
{
data[9] = data[9]+a; //(abs(data[9])/a);
}
if(data[10] <= 0)
{
data[10] = data[10]-a; //(abs(data[10])/a);
}
else
{
data[10] = data[10]+a; //(abs(data[10])/a);
}
if(data[11] <= 0)
{
data[11] = data[11]-a; //(abs(data[11])/a);
}
else
{
data[11] = data[11]+a; //(abs(data[11])/a);
}
if(data[12] <= 0)
{
data[12] = data[12]-a; //(abs(data[12])/a);
}
else
{
data[12] = data[12]+a; //(abs(data[12])/a);
}
if(data[13] <= 0)
{
data[13] = data[13]-a; //(abs(data[13])/a);
}
else
{
data[13] = data[13]+a; //(abs(data[13])/a);
}
if(data[15] <= 0)
{
data[15] = data[15]-a; //(abs(data[15])/a);
}
else
{
data[15] = data[15]+a; //(abs(data[15])/a);
}
if(data[17] <= 0)
{
data[17] = data[17]-a; //(abs(data[17])/a);
}
else
{
data[17] = data[17]+a; //(abs(data[17])/a);
}
if(data[19] <= 0)
{
data[19] = data[19]-a; //(abs(data[19])/a);
}
else
{
data[19] = data[19]+a; //(abs(data[19])/a);
}
if(data[21] <= 0)
{
data[21] = data[21]-a; //(abs(data[21])/a);
}
else
{
data[21] = data[21]+a; //(abs(data[21])/a);
}
}
}
}
}
I know it's a mess but this is what I came up with. I can post the rest of the code if that would help.
解决方案
这是您的代码的简化版本
void Network::backProp(void)
{
double b = 0,a;
int loop,l;
int inclusion1 [] = {0,1,2,3,4,6,7,8,9,10,11,12,13,14,16,18,20};
int inclusion2 [] = {0,1,2,3,4,5,7,8,9,10,11,12,13,15,17,19,21};
int j = 0;
for(loop=0;loop<4;loop++)
{
//Adds up the cost of the data
b = b + (pow(results[2*loop]-key[4*loop+2],2)+pow(results[2*loop+1]-key[4*loop+3],2));
}
a=.55*b;
if(b>.01)
{
for(l=0;l<4;l++)
{
for(j=0;j<17;j++)
{
if(round(results[2*l])!=key[4*l+2])
{
data[inclusion1[j]] = data[inclusion1[j]] - abs(data[inclusion1[j]])/data[inclusion1[j]]*a;
}
if(round(results[2*l+1])!=key[4*l+3])
{
data[inclusion2[j]] = data[inclusion2[j]] + abs(data[inclusion2[j]])/data[inclusion2[j]]*a;
}
}
}
}
}
我看到的基本问题是您的校正变量 bi 认为它的定义不准确
应该更加顺理成章
b = b + pow((pow(results[2*loop]-key[4*loop+2],2)+pow(results[2*loop+1]-key[4*loop+3],2)),1/2);
推荐阅读
- sql-server - 如何在 ASP.NET MVC 中执行 SQL 查询,并传递给视图
- java - 如何从外部项目的 pom 执行 maven liquibase pom?
- sql-server - 在 SSIS 中使用参数或变量设置连接超时?
- laravel - Laravel 侧边栏的 View Composer
- angular - Google Firestore 如何在后台工作?
- mysql - Mysql 中的 RowNum,分组依据(MySQL 5.x)
- python - 无法导入模块“lambda_function”:无法从“ctypes”导入名称“WinDLL”(/var/lang/lib/python3.7/ctypes/__init__.py
- python - Google Speech API 在 Python 子进程中不起作用
- c++ - 使用 Bazel 与库之间的循环依赖关系并改变这些库的 strip_include_prefix 属性
- encryption - 使用密码加密/解密数据,而数据库管理员无法解密