python - 跟踪神经网络中的缓存
问题描述
我尝试实现基于 Coursera 深度学习课程的 L 层神经网络模型,但在将缓存添加到缓存列表时遇到问题。
caches = list.append(cache)
该课程建议使用 list.append() 函数。其余代码如下:
def L_model_forward(X, parameters):
"""
Implement forward propagation
Arguments:
X -- data, numpy array of shape (input size, number of examples)
parameters -- output of initialize_parameters_deep()
Returns:
AL -- last post-activation value
caches -- list of caches containing:
every cache of linear_activation_forward() (there are L-1 of them, indexed from 0 to L-1)
"""
caches = []
A = X
L = len(parameters) // 2 # number of layers in the neural network
# Implement [LINEAR -> RELU]*(L-1). Add "cache" to the "caches" list.
for l in range(1, L):
A_prev = A
### START CODE HERE ### (≈ 2 lines of code)
A, cache = linear_activation_forward(A_prev, parameters['W' + str(l)], parameters['b' + str(l)], activation = "relu")
caches = list.append(cache)
### END CODE HERE ###
# Implement LINEAR -> SIGMOID. Add "cache" to the "caches" list.
### START CODE HERE ### (≈ 2 lines of code)
AL, cache = linear_activation_forward(A, parameters['W' + str(L)], parameters['b' + str(L)], activation = "sigmoid")
caches = list.append(cache)
### END CODE HERE ###
assert(AL.shape == (1,X.shape[1]))
return AL, caches
当我运行代码时,这是错误:
TypeError:描述符“附加”需要一个“列表”对象但收到一个“元组”
解决方案
您不能使用 追加到列表中caches = list.append(cache)
。它必须由caches.append(cache)
and 附加,然后在caches
将元组缓存添加到列表中之前使用cache=list(cache)
,然后在缓存列表中使用caches.append(cache)
.
您的代码将如下所示:-
def L_model_forward(X, parameters):
caches = []
A = X
L = len(parameters) // 2 # number of layers in the neural network
for l in range(1, L):
A_prev = A
A, cache = linear_activation_forward(A_prev, parameters['W' + str(l)], parameters['b' + str(l)], activation = "relu")
caches.append(list(cache))
AL, cache = linear_activation_forward(A, parameters['W' + str(L)], parameters['b' + str(L)], activation = "sigmoid")
caches.append(list(cache))
assert(AL.shape == (1,X.shape[1]))
return AL, caches
推荐阅读
- python - 如何通过重复列表将熊猫数据框的行(而不是列)划分?
- c++ - 如何打印数组的最小数量及其使用函数调用它们的索引?
- cakephp - Configure::read 和控制器全局变量在 cakephp 2.10.12 中不起作用
- c# - 如何在一个 NUnit 项目中正确测试弱/强命名的内部类?
- laravel - 在laravel的select标签中选择供应商名称时如何显示其他数据
- ios - 文件提供程序扩展的显示名称
- linux - GCC 将 uint8_t 和 uint16_t 解释为已签名?
- python-3.x - 无法使用 matplotlib 绘制 3d 条形图
- python - 如何在网上抓取谷歌学者每年每篇论文的引用次数?
- css - React:如何以编程方式在反应组件中设置元素的宽度和高度