python - 使用 NumPy 的 ReLU 导数
问题描述
import numpy as np
def relu(z):
return np.maximum(0,z)
def d_relu(z):
z[z>0]=1
z[z<=0]=0
return z
x=np.array([5,1,-4,0])
y=relu(x)
z=d_relu(y)
print("y = {}".format(y))
print("z = {}".format(z))
上面的代码打印出来:
y = [1 1 0 0]
z = [1 1 0 0]
代替
y = [5 1 0 0]
z = [1 1 0 0]
据我了解,我使用的函数调用应该只是按值传递,传递变量的副本。
为什么我的 d_relu 函数会影响 y 变量?
解决方案
Your first mistake is in assuming python passes objects by value... it doesn't - it's pass by assignment (similar to passing by reference, if you're familiar with this concept). However, only mutable objects, as the name suggests, can be modified in-place. This includes, among other things, numpy arrays.
You shouldn't have d_relu
modify z
inplace, because that's what it's doing right now, through the z[...] = ...
syntax. Try instead building a mask using broadcasted comparison and returning that instead.
def d_relu(z):
return (z > 0).astype(int)
This returns a fresh array instead of modifying z
in-place, and your code prints
y = [5 1 0 0]
z = [1 1 0 0]
If you're building a layered architecture, you can leverage the use of a computed mask during the forward pass stage:
class relu:
def __init__(self):
self.mask = None
def forward(self, x):
self.mask = x > 0
return x * self.mask
def backward(self, x):
return self.mask
Where the derivative is simply 1 if the input during feedforward if > 0, else 0.
推荐阅读
- python - Django 从数据库中删除模型
- c# - 浮点数乘以 0 是否总是得到 0?
- mysql - MySQL 5.7 referencing error
- javascript - standard browser EventDispatcher class to inherit from
- azure - 授予 READER 角色访问 Azure 中的订阅的权限在 Postman 中可以正常工作,但不能通过 Angular。为什么?
- c++ - Logging DJI error codes to ofstream
- c++ - 交换互斥锁
- javascript - how to open result in new tab with autocomplete search jquery?
- user-interface - 创建交互式时间线
- r - 将变量作为字符传递给函数