python - Pytorch NN 中用于 GAN 的嵌入层
问题描述
我创建了一个嵌入层,为我的生成器使用预训练的 GloVe (glove.twitter.27B.25d.txt),它是生成推文的 GAN 的一部分。该层有 size Embedding(5119, 25)
,其中 5119 是我的词汇的大小,25 是嵌入单词的向量的大小。我的输入数据是大小为 [53, 20] 的火炬张量(单词索引对应于 vocab),其中 53 是填充推文的长度,20 是 batch_size。现在我不确定如何在我的生成器中实现嵌入层。
emb = nn.Embedding.from_pretrained(torch.FloatTensor(weights_matrix))
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
emb,
nn.Linear(lendata, batch_size),
nn.ReLU(),
nn.Linear(batch_size, 32),
nn.ReLU(),
nn.Linear(32, lendata),
)
def forward(self, x):
output = self.model(x)
print(output)
return output
generator = Generator()
latent_space_samples = torch.zeros((lendata, batch_size)).long()
generated_samples = generator(latent_space_samples)
我需要将真实样本和生成样本放在一起,以便将它们放入鉴别器中
all_samples = torch.cat((real_samples, generated_samples))
但是生成的样本现在具有以下大小:torch.Size([53, 20, 25]) 并且真实样本具有输入数据的大小 torch.Size([53, 20])。我理解我得到的错误,RuntimeError: Tensors must have same number of dimensions: got 2 and 3
但我不知道如何让输出具有相同的尺寸
我不明白发生了什么以及我应该怎么做才能正确地将其传递给鉴别器。我正在做的嵌入是否正确?我的生成器的代码好吗?
解决方案
推荐阅读
- java - 我如何在 Java 中过滤这个 ArrayList
- python - 无法找到从页面访问可下载元素的方法
- javascript - 查询以搜索特定 DOM 元素的 DOM
- c++ - 类没有合适的复制构造函数,我需要在传递给函数之前初始化吗?
- telegram - 如何将电报机器人添加到我不是管理员的电报组
- c - 计算 C 中出现的百分比
- ios - NSMutableAttributedString 文本到达右锚点时下降
- python - 如何定义一个可以再次定义来做某事的函数?
- javascript - Insert a new key value to the existing object array?
- mysql - 在 MySQL SET 中查找条目