python - 如何在 mxnet 中使用类似 torch.nn.functional.conv2d() 的函数?
问题描述
我想用输入数据和内核做一些卷积计算。
在torch中,我可以写一个函数:
import torch
def torch_conv_func(x, num_groups):
batch_size, num_channels, height, width = x.size()
conv_kernel = torch.ones(num_channels, num_channels, 1, 1)
return torch.nn.functional.conv2d(x, conv_kernel)
它工作得很好,现在我需要在 MXnet 中重建,所以我写了这个:
from mxnet import nd
from mxnet.gluon import nn
def mxnet_conv_func(x, num_groups):
batch_size, num_channels, height, width = x.shape
conv_kernel = nd.ones((num_channels, num_channels, 1, 1))
return nd.Convolution(x, conv_kernel)
我得到了错误
mxnet.base.MXNetError: Required parameter kernel of Shape(tuple) is not presented, in operator Convolution(name="")
如何解决?
解决方案
您缺少mxnet.nd.Convolution
. 你可以这样做:
from mxnet import nd
def mxnet_convolve(x):
B, C, H, W = x.shape
weight = nd.ones((C, C, 1, 1))
return nd.Convolution(x, weight, no_bias=True, kernel=(1,1), num_filter=C)
x = nd.ones((16, 3, 32, 32))
mxnet_convolve(x)
由于您没有使用偏差,因此您需要设置no_bias
为True
. 此外,mxnet 要求您使用kernel
andnum_filter
参数指定内核尺寸。
推荐阅读
- java - 控制器和检票口页面的设计模式帮助
- python-3.x - 如何从 Python 3 中的脚本连接 Google Datastore
- javascript - 如何在 for 循环中运行 mongoose 方法,因为 mongoose 函数是异步的
- python - 从不同的功能单击按钮时获取 tkinter 条目的值
- sql - 如何使用pymssql通过python检索Query Explorer SQL Server的插入/更新语句的确切响应
- python - 找到正确的 Python 类型提示,例如,内置函数 map() 的签名
- php - 进行数据库播种时出现 InvalidArgumentException
- python - 每次创建一个新文件
- npm - 在哪里手动下载 npm 模块?
- azure-web-app-service - Azure 中容器的 Web 应用程序限制