python - python中带有意外阴影的全局变量行为
问题描述
为什么在这段代码中,访问global
变量会产生错误?
版本 1:在没有全局变量的情况下按预期工作
import numpy as np
import matplotlib.pyplot as plt
# Initializations
scale = 2000
a, b, c, d = (np.random.randn() for i in range(4))
x = np.linspace(-np.math.pi, np.math.pi, scale)
y1 = np.sin(x)
y2 = a + b * x + c * x ** 2 + d * x ** 3
learning_rate = 1e-6
for i in range(scale):
y2 = a + b * x + c * x ** 2 + d * x ** 3
grad_loss = 2.0*(y2-y1)
grad_a = grad_loss.sum()
grad_b = (grad_loss * x).sum()
grad_c = (grad_loss * x ** 2).sum()
grad_d = (grad_loss * x ** 3).sum()
a -= learning_rate * grad_a
b -= learning_rate * grad_b
c -= learning_rate * grad_c
d -= learning_rate * grad_d
plt.plot(y1)
plt.plot(y2)
plt.show()
但是,对于函数内部的全局变量,行为是不同的
import numpy as np
import matplotlib.pyplot as plt
# Initializations
scale = 2000
a, b, c, d = (np.random.randn() for i in range(4))
x = np.linspace(-np.math.pi, np.math.pi, scale)
y1 = np.sin(x)
y2 = a + b * x + c * x ** 2 + d * x ** 3
def forward():
y2 = a + b * x + c * x ** 2 + d * x ** 3
#print(np.square(y2 - y1).sum())
learning_rate = 1e-6
def backward():
grad_loss = 2.0*(y2-y1)
grad_a = grad_loss.sum()
grad_b = (grad_loss * x).sum()
grad_c = (grad_loss * x ** 2).sum()
grad_d = (grad_loss * x ** 3).sum()
global a,b,c,d
a -= learning_rate * grad_a
b -= learning_rate * grad_b
c -= learning_rate * grad_c
d -= learning_rate * grad_d
for i in range(scale):
forward()
backward()
plt.plot(y1)
plt.plot(y2)
plt.show()
解决方案
在forward
函数中 y2 被更新。在您提供的代码中, y2 未声明为全局变量,因此未更新全局变量。添加global y2
,结果将是相同的:
import numpy as np
import matplotlib.pyplot as plt
# Initializations scale = 2000 a, b, c, d = (np.random.randn() for i in range(4)) x = np.linspace(-np.math.pi, np.math.pi, scale) y1 = np.sin(x) y2 = a + b * x + c * x ** 2 + d * x ** 3
def forward():
global y2
y2 = a + b * x + c * x ** 2 + d * x ** 3
#print(np.square(y2 - y1).sum())
learning_rate = 1e-6
def backward():
grad_loss = 2.0*(y2-y1)
grad_a = grad_loss.sum()
grad_b = (grad_loss * x).sum()
grad_c = (grad_loss * x ** 2).sum()
grad_d = (grad_loss * x ** 3).sum()
global a,b,c,d
a -= learning_rate * grad_a
b -= learning_rate * grad_b
c -= learning_rate * grad_c
d -= learning_rate * grad_d
for i in range(scale):
forward()
backward()
plt.plot(y1) plt.plot(y2) plt.show()
推荐阅读
- selenium - 无法从 testNG xml 中获取参数化值
- asp.net-mvc - 我们如何在 ASP.NET MVC 中临时限制禁用用户的访问?
- java - android中的自我更新应用程序
- docker - 如何进入正在运行的 docker 容器并检查机器学习训练结果?
- typescript - 选定的文件上传在 Angular 7 中不起作用
- sorting - YII2:搜索模型中的自定义排序
- java - connectionerror chaquo python-requests 尽管互联网已连接
- here-api - harp.gl:通过three.js添加的正确缩放地理参考3D几何(以米为单位)
- c# - 有没有办法把一个 int 变成一个字符串,然后再变成一个 int
- python - pymysql.err.OperationalError: (2013, 'Lost connection to MySQL server during query')