machine-learning - 如何在 PYTORCH 中进行 2 层嵌套 FOR 循环?
问题描述
我正在学习在 Pytorch 中实现分解机。并且应该有一些特征交叉操作。比如我有三个特征[A,B,C],嵌入后是[vA,vB,vC],所以特征交叉是“[vA·vB],[vA·vC],[vB ·vc]”。
它可以通过 MATRIX OPERATIONS 来实现。但这仅给出最终结果,例如,单个值。
问题是,如何在不执行 FOR 循环的情况下获取以下所有 cross_vec:注意:“feature_emb”的大小为 [batch_size x feature_len x embedding_size]
g_feature = 0
for i in range(self.featurn_len):
for j in range(self.featurn_len):
if j <= i: continue
cross_vec = feature_emb[:,i,:] * feature_emb[:,j,:]
g_feature += torch.sum(cross_vec, dim=1)
解决方案
你可以
cross_vec = (feature_emb[:, None, ...] * feature_emb[..., None, :]).sum(dim=-1)
这应该给你corss_vec
的形状(batch_size, feature_len, feature_len)
。
或者,您可以使用torch.bmm
cross_vec = torch.bmm(feature_emb, feature_emb.transpose(1, 2))
推荐阅读
- c# - 在 .Net Windows 应用程序窗体中计算 DataGridView 中每个元素的频率
- javascript - 有人可以向我解释这段代码在 vue 中是如何工作的吗
- xml - 使用 Bash 从 XML 文件中删除文本
- javascript - 这个参数在哪里提供参数?
- c++ - 为什么此代码打印出 RValue 而不是 LValue
- vb.net - 如何使用 Bcrypt?
- turtle-graphics - 贪吃蛇游戏分数不加
- html - 如何使用相互堆叠的 HTML 元素实现组件的滚动
- java - 需要帮助将图像保存到 Android Studio 中的 SQLite 数据库
- java - Maven 不下载 pom.xml 和 .jar 文件