python - 如何在 RNN 中嵌入句子序列?
问题描述
我正在尝试制作一个 RNN 模型(在 Pytorch 中),它需要几个句子,然后将其分类为Class 0或Class 1。
为了这个问题,我们假设句子的 max_len 为 4,时间步长的 max_amount 为 5。因此,每个数据点都在表单上(0 是用于填充填充值的值):
x[1] = [
# Input features at timestep 1
[1, 48, 91, 0],
# Input features at timestep 2
[20, 5, 17, 32],
# Input features at timestep 3
[12, 18, 0, 0],
# Input features at timestep 4
[0, 0, 0, 0],
# Input features at timestep 5
[0, 0, 0, 0]
]
y[1] = [1]
当我每个目标只有一个句子时:我只是将每个单词传递给嵌入层,然后传递给 LSTM 或 GRU,但是当每个目标有一系列句子时,我有点卡住了怎么办?
如何构建可以处理句子的嵌入?
解决方案
最简单的方法是使用 2 种 LSTM。
准备玩具数据集
xi = [
# Input features at timestep 1
[1, 48, 91, 0],
# Input features at timestep 2
[20, 5, 17, 32],
# Input features at timestep 3
[12, 18, 0, 0],
# Input features at timestep 4
[0, 0, 0, 0],
# Input features at timestep 5
[0, 0, 0, 0]
]
yi = 1
x = torch.tensor([xi, xi])
y = torch.tensor([yi, yi])
print(x.shape)
# torch.Size([2, 5, 4])
print(y.shape)
# torch.Size([2])
然后,x
是输入的批次。这里batch_size
= 2。
嵌入输入
vocab_size = 1000
embed_size = 100
hidden_size = 200
embed = nn.Embedding(vocab_size, embed_size)
# shape [2, 5, 4, 100]
x = embed(x)
第一个词-LSTM是将每个序列编码成一个向量
# convert x into a batch of sequences
# Reshape into [2, 20, 100]
x = x.view(bs * 5, 4, 100)
wlstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
# get the only final hidden state of each sequence
_, (hn, _) = wlstm(x)
# hn shape [1, 10, 200]
# get the output of final layer
hn = hn[0] # [10, 200]
第二个seq-LSTM是将序列编码成单个向量
# Reshape hn into [bs, num_seq, hidden_size]
hn = hn.view(2, 5, 200)
# Pass to another LSTM and get the final state hn
slstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
_, (hn, _) = slstm(hn) # [1, 2, 200]
# Similarly, get the hidden state of the last layer
hn = hn[0] # [2, 200]
添加一些分类层
pred_linear = nn.Linear(hidden_size, 1)
# [2, 1]
output = torch.sigmoid(pred_linear(hn))
推荐阅读
- python - 获取在特定时间段内创建的文件名列表的最快方法是什么
- sftp - 从 HP-UX 服务器下载时的 SFTP“,u”(逗号和 u 字母)选项
- sql - sys.tables 与 COUNT
- php - 我在 PHP 中收到“警告:mysqli_fetch_array() 期望参数 1 为 mysqli_result”错误
- node.js - 我无法让 NPM Start 在 Visual Code Studio 中工作;它给了我一个我无法理解的错误
- laravel - 为什么 Laravel 没有超时?
- html - 如果我想要两个或三个或四个不同的段落都是单独的颜色怎么办?
- javascript - 在 iOS 上使用 Chrome 重新加载时的视口高度错误
- vue.js - 无法在 Vuejs 上下文中使用 chartjs 中的 API 调用显示图表
- android - 强制 Android WebView 在默认浏览器中打开外部链接