python - 使用 Theano 获取 w_0 和 w_1 参数
问题描述
我有一个问题,我必须创建一个数据集,然后,我必须使用 Theano 来获取以下模型的 w_0 和 w_1 参数:
y = log(1 + w_0 * |x|) + (w_1 * |x|)
数据集已创建,我已经计算了 w_0 和 w_1 值,但使用以下代码使用 numpy,但我已经彻底研究但不知道如何使用 theano 计算 w_0 和 w_1 值。我如何使用 theano 计算这些值?这将是很大的帮助谢谢 :) 我正在使用的代码:
import numpy as np
import math
import theano as t
#code to generate datasets
trX = np.linspace(-1, 1, 101)
trY = np.linspace(-1, 1, 101)
for i in range(len(trY)):
trY[i] = math.log(1 + 0.5 * abs(trX[i])) + trX[i] / 3 + np.random.randn() * 0.033
#code that produce w0 w1 and i want to compute it with theano
X = np.column_stack((np.ones(101, dtype=trX.dtype), trX))
print(X.shape)
Xplus = np.linalg.pinv(X) #pseudo-inverse of X
w_opt = Xplus @ trY #The @ symbol denotes matrix multiplication
print(w_opt)
x = abs(trX) #abs is a built in function to return positive values in a array
y= trY
for i in range(len(trX)):
y[i] = math.log(1 + w_opt[0] * x[i]) + (w_opt[1] * x[i])
解决方案
早上好希娜马利克,
使用梯度下降算法和正确的模型选择,应该可以解决这个问题。此外,您应该为每个参数创建 2 个共享变量 (w & c)。
X = T.scalar()
Y = T.scalar()
def model(X, w, c):
return X * w + c
w = theano.shared(np.asarray(0., dtype = theano.config.floatX))
c = theano.shared(np.asarray(0., dtype = theano.config.floatX))
y = model(X, w, c)
learning_rate=0.01
cost = T.mean(T.sqr(y - Y))
gradient_w = T.grad(cost = cost, wrt = w)
gradient_c = T.grad(cost = cost, wrt = c)
updates = [[w, w - gradient_w * learning_rate], [c, c - gradient_c * learning_rate]]
train = theano.function(inputs = [X, Y], outputs = cost, updates = updates)
coste=[] #Variable para almacenar los datos de coste para poder representarlos gráficamente
for i in range(101):
for x, y in zip(trX, trY):
cost_i = train(x, y)
coste.append(cost_i)
w0=float(w.get_value())
w1=float(c.get_value())
print(w0,w1)
我也在 StackOverFlow 的“西班牙语”版本中回复了相同或非常相似的主题:转到解决方案
我希望这可以帮助你
此致
推荐阅读
- java - 如何通过 YAML 加载基于抽象类的对象的数组列表
- vba - 无法使用 Word VBA 在 powershell 脚本中访问 CustomDocumentProperties
- google-cloud-platform - 在大存储桶上设置冷线规则后,云存储 503 错误
- mysql - 如何使用触发器创建 2 个档案?
- javascript - 将模式弹出窗口中的数据传递到 vue.js 中查看
- python - 在将数据帧处理为 csv 时,在每个数据行之后添加空行
- optimization - 如何找出具有二进制变量的函数是否是线性的?
- python - Python和Flask,在html页面上显示mysql DB中的数据
- java - 将平面列表转换为具有每个节点深度的树结构
- php - Laravel 8 通过外部 API 进行登录身份验证,无需数据库,但保留默认的 Laravel 用户和身份验证功能