首页 > 解决方案 > 由于使用 pyro 和 pytorch 的样本中存在多类分布,因此在 svi 步骤中出现错误

问题描述

我正在研究一个因果变分自动编码器,它使用类分割掩码、类标签和因果关系(0 或 1)作为输入。

由于 svi 步骤,使用大于 1 的批量大小时出现错误。我正在使用伯努林函数,因为我希望它学习图像中多个类的概率分布。我认为分类分布也符合此处的要求,但我也遇到了同样的错误。

当我尝试缩小产生问题的代码行时,我认为它在模型函数中:

one_vec2 = torch.ones([batch_size, self.lbl_shape[0]], **options)
class_labels = pyro.sample('class_labels', dist.Bernoulli(one_vec2*0.5), obs = lbls)      

错误:

ValueError                                Traceback (most recent call last)
<ipython-input-19-8cbc046dd2c1> in <module>()
      6 vae = Vae_Model1(lbl_sz, ch, img_sz).to(device)
      7 svi = SVI(vae.model, vae.guide, optimizer, loss = Trace_ELBO())
----> 8 train(svi, train_loader, USE_CUDA)

6 frames
/usr/local/lib/python3.6/dist-packages/pyro/util.py in check_site_shape(site, max_plate_nesting)
    320                 '- enclose the batched tensor in a with plate(...): context',
    321                 '- .to_event(...) the distribution being sampled',
--> 322                 '- .permute() data dimensions']))
    323 
    324     # Check parallel dimensions on the left of max_plate_nesting.

ValueError: at site "class_labels", invalid log_prob shape
  Expected [-1], actual [32, 21]
  Try one of the following fixes:
  - enclose the batched tensor in a with plate(...): context
  - .to_event(...) the distribution being sampled
  - .permute() data dimensions

目前批量大小为 32,lbl_shape[0] 为 21(VOC 数据集(背景和其他标签))

有人可以帮我解决这个问题吗?将不胜感激。谢谢

标签: pythonpytorchpyro.ai

解决方案


推荐阅读