python-3.x - 如何解决将 GPyTorch 与 SpectralMixture Kernel 一起使用时遇到的错误?
问题描述
我正在使用 GPyTorch 来拟合高斯过程回归模型(主要用于学习过程)。在遵循他们的教程时,我正在尝试使用SpectralMixtureKernel
. 但是,我收到以下错误。但首先是代码(与他们的教程基本相同,但为方便起见,在此处复制):
class ExactGPModel(gpytorch.models.ExactGP):
def __init__(self,train_x,train_y,likelihood):
super(ExactGPModel, self).__init__(train_x,train_y,likelihood)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.SpectralMixtureKernel(num_mixtures=4)
self.covar_module.initialize_from_data(train_x, train_y)
def forward(self,x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x,covar_x)
熊猫数据框转换为torch.tensor
以下
train_x = torch.tensor(train_x.values.astype(np.float32))
train_y = torch.tensor(train_y.values.astype(np.float32))
test_x = torch.tensor(test_x.values.astype(np.float32))
test_y = torch.tensor(test_y.values.astype(np.float32))
然后
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPModel(train_x,train_y, likelihood)
运行最后一行后,我收到以下错误:
Traceback (most recent call last):
File "<ipython-input-195-e3bc37af324c>", line 1, in <module>
model = ExactGPModel(train_x,train_y, likelihood)
File "<ipython-input-186-323eff9c5819>", line 7, in __init__
self.covar_module.initialize_from_data(train_x, train_y)
File "/anaconda3/envs/py36/lib/python3.6/site-packages/gpytorch/kernels/spectral_mixture_kernel.py", line 163, in initialize_from_data
self.raw_mixture_scales.data.normal_().mul_(max_dist).abs_().pow_(-1)
RuntimeError: output with shape [4, 1, 1] doesn't match the broadcast shape [4, 1, 33]
任何解决此问题的帮助将不胜感激。
谢谢。
解决方案
推荐阅读
- php - Laravel 7关系查询数据不起作用
- python - 切换到 Python/Selenium-wire 后无法与 iframe 交互
- typescript - 如何设置美化使其不删除括号中的多行?
- go - 使用 golang 中的 crypto/ssh 将 ssh 后的用户切换到服务器
- arrays - 函数在 C 中分配和修改许多(多个)数组
- python - 从 shell 终端 SSH 到另一台服务器并使用 python paramiko 执行命令
- gitlab - GitLab 社区版 - docker 支持
- python - multiprocessing.Pool 产生太多线程
- python - 一个页面中的 Django 多个 chartJs 聊天
- python - 如何编码以修复检测到的致命错误:使用 auto-py-to-exe 将 python 代码编译为 exe 文件后无法执行脚本 BoxDetection?