neural-network - Ordinal Pooling 神经网络的函数
问题描述
请我想创建一个计算 Ordinal Pooling 神经网络的函数,如下图所示:
这是我的功能:
def Ordinal_Pooling_NN(x):
wights = torch.tensor([0.6, 0.25, 0.10, 0.05])
top = torch.topk(x, 4, dim = 1)
wights = wights.repeat(x.shape[0], 1)
result = torch.sum(wights * (top.values), dim = 1 )
return result
但结果,我收到以下错误:
<ipython-input-112-ddf99c812d56> in Ordinal_Pooling_NN(x)
9 top = torch.topk(x, 4, dim = 1)
10 wights = wights.repeat(x.shape[0], 1)
---> 11 result = torch.sum(wights * (top.values), dim = 1 )
12 return result
RuntimeError: The size of tensor a (4) must match the size of tensor b (16) at non-singleton dimension 2
解决方案
您的实现实际上是正确的,我相信您没有使用 2D 张量来提供函数,输入必须具有批处理轴。例如,下面的代码将运行:
>>> Ordinal_Pooling_NN(torch.tensor([[1.9, 0.4, 1.3, 0.8]]))
tensor([1.5650])
请注意,您不需要重复权重张量,它会在计算逐点乘法时自动广播。您只需要以下内容:
def Ordinal_Pooling_NN(x):
w = torch.tensor([0.6, 0.25, 0.10, 0.05])
top = torch.topk(x, k=4, dim=1)
result = torch.sum(w*top.values, dim=1)
return result
推荐阅读
- swiftui - 使用 SwiftUI 和图表库合并两个 xy 图
- python - Kaggle 练习:列出问题编号。5. 使用“print”的语法错误
- python - Python Streamlit nltk 应用程序的 Heroku R10 和 H20 超时错误
- python - 即使我尝试先启动进程,进程也会在我调用函数后启动
- string - Groovy:将字符串标记为仅出现第 3 次分隔符
- python - 对数据框进行排序
- sql - 可以使用 sql 在月历中显示事件吗?
- python - 如果它们一起使用,如何替换令牌?
- vue.js - Nuxt JS 插件是否可以只运行一次?
- for-loop - gnuplot:在 for 循环中设置线条样式