python-3.x - 返回 nan 的逻辑回归成本函数
问题描述
我最近学习了逻辑回归,我想实践一下。我目前正在使用来自 kaggle 的这个数据集。我尝试以这种方式定义成本函数(我做了所有必要的导入):
# Defining the hypothesis
sigmoid = lambda x: 1 / (1 + np.exp(-x))
predict = lambda trainset, parameters: sigmoid(trainset @ parameters)
# Defining the cost
def cost(theta):
#print(X.shape, y.shape, theta.shape)
preds = predict(X, theta.T)
errors = (-y * np.log(preds)) - ((1-y)*np.log(1-preds))
return np.mean(errors)
theta = []
for i in range(13):
theta.append(1)
theta = np.array([theta])
cost(theta)
当我运行这个单元格时,我得到:
/opt/venv/lib/python3.7/site-packages/ipykernel_launcher.py:9: RuntimeWarning: divide by zero encountered in log
if __name__ == '__main__':
/opt/venv/lib/python3.7/site-packages/ipykernel_launcher.py:9: RuntimeWarning: invalid value encountered in multiply
if __name__ == '__main__':
nan
当我在网上搜索时,我得到了将数据归一化的建议,然后尝试一下。所以这就是我的做法:
df = pd.read_csv("/home/jovyan/work/heart.csv")
df.head()
# The dataset is 303x14 in size (using df.shape)
length = df.shape[0]
# Output vector
y = df['target'].values
y = np.array([y]).T
# We name trainingset as X for convenience
trainingset = df.drop(['target'], axis = 1)
#trainingset = df.insert(0, 'bias', 1)
minmax_normal_trainset = (trainingset - trainingset.min())/(trainingset.max() - trainingset.min())
X = trainingset.values
我真的不知道除以零错误发生在哪里以及如何解决它。如果我在这个实现中犯了任何错误,请纠正我。如果之前有人问过这个问题,我很抱歉,但我能找到的只是标准化数据的提示。提前致谢!
解决方案
np.log(0)
引发divide by zero
错误。所以这是导致问题的部分:
errors = (-y * np.log(preds)) - ((1 - y) * np.log(1 - preds))
############## #################
preds
当 的绝对值x
大于 709 时可以为 0 或 1(因为浮点数学,至少在我的机器上),这就是为什么规范化x
到 0 和 1 之间可以解决问题的原因。
编辑:
您可能希望标准化到更大的范围(0, 1)
- 当前设置的 sigmoid 函数在该范围内几乎是线性的。也许使用:
minmax_normal_trainset = c * (trainingset - trainingset.mean())/(trainingset.stdev())
并调整c
以获得更好的收敛性。
推荐阅读
- python - 如何在 Python 中标记嵌套列表中的匹配项
- python - 轴错误中不包含标签 - 熊猫数据框
- java - Branch.io (Android SDK):“仅在启动器活动中初始化分支”和 GDPR
- vba - 当相同的对应值出现在 B 列中时,计算 A 列中值的时间差
- python - 如何让 Microsoft Edge(和其他 Windows 应用程序)干净地关闭?
- html - DNN 动态输入到模板
- google-bigquery - 如何定义在子查询中引用通配符表的 BQ 视图?
- javascript - 如何在 iOS 键盘(Web / Hybrid / PWA)上的“tel”输入中获得句点
- python - Python pandas 到 DataFrame 的对象列表
- php - 明确书面陈述的解释