python - 这是在最新版本的 PyTorch(TensorFlow 到 PyTorch 转换)中计算内核的最佳方法吗?
问题描述
我正在尝试将 MMD-VAE 实现从 TensorFlow 转换为 PyTorch。我的大部分模型都构建得很好,但我只是想确保我正确转换了以下函数(一切正常,但我没有得到我期望的结果,所以我想也许我计算内核不正确因为我在 TensorFlow 中不是那么强)。
在 TensorFlow 中:
def compute_kernel(x, y):
x_size = tf.shape(x)[0]
y_size = tf.shape(y)[0]
dim = tf.shape(x)[1]
tiled_x = tf.tile(tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1]))
tiled_y = tf.tile(tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1]))
return tf.exp(-tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32))
我对 PyTorch 有什么:
def compute_kernel(x, y):
x_size = x.size(0)
y_size = y.size(0)
dim = x.size(1)
x = x.unsqueeze(1)
y = y.unsqueeze(0)
tiled_x = x.expand(x_size, y_size, dim)
tiled_y = y.expand(x_size, y_size, dim)
kernel_input = (tiled_x - tiled_y).pow(2).mean(2)/float(dim)
return torch.exp(-kernel_input)
谢谢你的帮助!!
解决方案
推荐阅读
- spring - BootRun not booting?
- javascript - Error message on javascript after publishing azure website
- wordpress - 图像替代文本未显示在 WordPress 上
- c++ - Boost geometry polygon inner representation as a STL list?
- javascript - How to enable ts-check in es6
- regex - Extract string between newline and variable data in Bash
- java - Spring security 将 ApplicationEventListener 添加到 ExpiredJwtException
- oracle-sqldeveloper - 在 Oracle SQL Developer 中可以弹出查询结果/脚本输出吗
- conv-neural-network - 在 python 中可视化 RGBN tiff 卫星图像
- docker - docker 在 docker localhost 网络问题