python - 为函数返回 nan 数组但仅适用于某些实例?
问题描述
我正在学习数据科学课程,最近我们不得不编写梯度下降函数和成本函数来找到模型的最佳参数。我的函数在之前的项目中运行良好,并且完成了它的工作,为我的模型找到了 6 个最佳系数。
本周我必须在家庭作业中再次使用它,并通过观察当您继续向模型添加越来越多的参数时会发生什么来观察过拟合/欠拟合。我的函数用于提供 2 和 3 参数,但是当我尝试执行 4 和 5 时,突然间我将所有 nan 值返回到结果数组中,尽管没有更改函数定义中的任何内容,或者我的状态如何定义初始参数值。这是我的代码,我非常感谢一些反馈,因为这让我很沮丧。
我看不出有任何方法可以获取一些未定义的值或任何会导致它返回 nan 的东西,尤其是看到它适用于我之前的项目找到 6 个系数,并且在这个例子中找到 2 和 3 效果很好(我知道这是正确的,因为这是在以前的作业中使用的一个例子)。
代码截图,nan 的输出被截断,但它是 4x1 数组所有值 nan
复制并粘贴代码:
x=np.array([[2.9], [-1.5], [0.1], [-1.0], [2.1], [-4.0], [-2.0], [2.2], [0.2], [2.9], [1.5], [-2.5]])
y=np.array([[4.0], [-0.9], [0], [-1], [3.0], [-5.0], [-3.5], [2.6], [1.0], [3.5], [1.0], [-4.7]])
def cost_function(a, X, y):
m=len(y)
h=np.dot(X, a)
cost=(1/(2*m))*np.sum(np.square(h-y))
return cost
def gradient_descent(a, X, y, learning_rate=0.01, iteration=1000):
m=len(y)
a_history=[a]
J_history=[cost_function(a, X, y)]
x=0
while x < iteration:
h=np.dot(X, a)
a=a-((1/m)*learning_rate*np.dot((h-y).T, X)).T
a_history.append(a)
J=cost_function(a, X, y)
J_history.append(J)
x+=1
return a, J_history, a_history
#with two parameters
a=np.array([[0], [0]])
X=np.hstack([np.ones((12, 1), dtype=float), x])
a_values, J_history, a_history=gradient_descent(a, X, y)
print('2 coefficents: ',a_values)
#with three parameters
a=np.array([[0], [0], [0]])
X=np.hstack([np.ones((12, 1), dtype=float), x, x**2])
a_values, J_history, a_history=gradient_descent(a, X, y)
print('3 coefficents: ', a_values)
plt.scatter(x1, y)
x1.sort()
newy=[((a_values[2]*(x**2)) + (a_values[1]*x) + a_values[0]) for x in x1]
plt.plot(x, y, 'o') # create scatter plot.
plt.plot(x1, newy) #add line of best fit.
plt.title('3 Coefficent Model')
plt.show()
#4 coefficents
#This is where I'm getting stuck--for 4 coefficents, my gradient descent function is returning
# nan values.
a=np.array([[0], [0], [0], [0]])
X=np.hstack([np.ones((12, 1), dtype=float), x, x**2, x**3])
a_values, J_history, a_history=gradient_descent(a, X, y)
print(a_values)
解决方案
推荐阅读
- javascript - 在我刷新页面之前,来自图像的输入不会加载(角度 5)
- php - 有一个使用多个数据库的 codeigniter 站点配置?
- javascript - 组件每次渲染调用函数 10 次
- mysql - 挣扎于 SQL 语法 - 请问我错过了什么?(这里是 SQL 新手)
- c# - 使用 ITextSharp 获取 PDF 中图像的位置(X、Y、宽度、高度)
- python - 了解多元线性回归并使用 python 来完成这个?
- sql-server - 在级联递归表上删除 - SQL Server
- google-apps-script - 谷歌脚本复制最后一行,所有列来自
- c# - Android.Content.ActivityNotFoundException:
Xamarin C# - angular - Angular 6材料应用程序仅打印第一页