python - 如何在张量流 LinearRegressor 中强制偏差为零?
问题描述
我正在使用tensorflow LinearRegressor API 来解决回归问题(https://www.tensorflow.org/api_docs/python/tf/estimator/LinearRegressor)。我知道我的模型中的偏差正好是 0。
如何强制LinearRegressor 学习 0 的偏差?
这是一个最小的例子:
import tensorflow as tf
import numpy as np
from sklearn.linear_model import SGDRegressor
用 2 个特征(+ 0 的偏差)模拟一些数据 y = 0 + 2*x1 + 3*x2 + 噪声
np.random.seed(5332)
n = 1000
weights = np.array([
[2],
[3],
])
bias = 0
x = np.random.randn(n, np.shape(weights)[0])
y = (bias + np.matmul(x, weights) + np.random.randn(n, 1)).ravel()
在 sklearn 中,我会使用 fit_intercept=False 将偏差强制为 0:
ols = SGDRegressor(tol=0.000001, fit_intercept=False)
ols.fit(x, y)
print("True weights: {}".format(weights.ravel()))
print("Learned weights: {}".format(np.round(ols.coef_), 3))
print("True bias: {}".format([bias]))
print("Learned bias: {}".format(np.round(ols.intercept_), 3))
输出:
True weights: [2 3]
Learned weights: [2. 3.]
True bias: [0]
Learned bias: [0.]
在张量流中,我做了以下事情:
column = tf.feature_column.numeric_column('x', shape=np.shape(x)[1])
ols = tf.estimator.LinearRegressor(
feature_columns=[column],
optimizer=tf.train.GradientDescentOptimizer(0.0001)
)
train_input = tf.estimator.inputs.numpy_input_fn(
x={"x": x},
y=y,
shuffle=False,
num_epochs=100,
batch_size=int(len(y) / 20)
)
ols.train(train_input)
print("True weights: {}".format(weights.ravel()))
print("Learned weights: {}".format(np.round(ols.get_variable_value('linear/linear_model/x/weights').flatten(), 3)))
print("True bias: {}".format([bias]))
print("Learned bias: {}".format(np.round(ols.get_variable_value('linear/linear_model/bias_weights').flatten(), 3)))
输出:
True weights: [2 3]
Learned weights: [1.993 2.998]
True bias: [0]
Learned bias: [-0.067]
但是学习到的偏差应该是:[0],我该如何执行呢?
解决方案
我猜 tf.keras.constraints 就是您要搜索的内容。
推荐阅读
- html - CSS Flip Card 动画仅在 Firefox 中无法正常工作
- node.js - 如何在现有的 NodeJS 后端启用 ocsp 装订?
- python - 如何在python中正确导入arch
- postgresql - odoo.conf 导致 Postgres 身份验证失败
- opencv - 为什么模型的 opengl 投影没有给出与 opencv 投影完全相同的结果?
- swift - 这可以改变barbutton的“Y”位置吗?
- jsf - 如何使用 Primefaces 为热键下划线?
- excel - 将小时分钟转换为数字的公式
- mule - Mulesoft - Netsuite 集成 - 如何传递 recrodRef 值?
- c# - 配置 CORS 网页后返回错误但也有结果