tensorflow - pyTorch 中 tensorflow 的时间分布等价物
问题描述
是否有 tensorflow.keras.layers.Timedistributed for pytorch 的等效实现?
我正在尝试构建类似 Timedistributed(Resnet50()) 的东西。
解决方案
在这个主题上感谢 miguelvr 。
您可以使用此代码,它是为模仿 Timeditributed 包装器而开发的 PyTorch 模块。
import torch.nn as nn
class TimeDistributed(nn.Module):
def __init__(self, module, batch_first=False):
super(TimeDistributed, self).__init__()
self.module = module
self.batch_first = batch_first
def forward(self, x):
if len(x.size()) <= 2:
return self.module(x)
# Squash samples and timesteps into a single axis
x_reshape = x.contiguous().view(-1, x.size(-1)) # (samples * timesteps, input_size)
y = self.module(x_reshape)
# We have to reshape Y
if self.batch_first:
y = y.contiguous().view(x.size(0), -1, y.size(-1)) # (samples, timesteps, output_size)
else:
y = y.view(-1, x.size(1), y.size(-1)) # (timesteps, samples, output_size)
return y
推荐阅读
- excel - Excel VBA:引用命名范围,就像引用图表或表格一样
- assembly - 分段错误(x86 程序集)
- c# - 如何在 AzureDevops 中为 Dockerfile 设置环境变量(ASPNETCORE_ENVIRONMENT)?
- google-chrome - 当使用 Spring 将浏览器更改请求从 HTTPs 重定向到 HTTP
- javascript - 如果其他条件在赛普拉斯
- postgresql - Helidon 应用程序中的 Db health Ping 子句出现异常
- javascript - 滚动错误时的主动导航 - Vanilla JS
- windows-10 - Windows Defender Win32/Persistence.DQ!ml,它是什么?
- ontology - obo、bio2rdf 和 bioportal 有什么区别
- flutter - 当一个小部件被按下时,如何插入另一个小部件来代替另一个?