python - 我的 WaveNet 中的张量维数与 PyTorch cross_entropy 函数不兼容
问题描述
我一直在做一个关于制作我自己的 WaveNet 实现的项目,因为 Deepmind 在 2016 年初用 Python 交付。
预处理包括mu法编码和一种热编码。模型本身运行良好,我的问题在于训练期间使用的损失函数 torch.nn.functional.cross_entropy,可在此处找到:https ://pytorch.org/docs/stable/nn.functional.html
特别是,我的输出和我的目标张量之间的关系,即
input_tensor.shape = tensor([1, 256, 225332]) # [batch_size, sample_size, audio_length]
output.shape = tensor([1, 256, 225332])
根据 F.cross_entropy,我必须有 output = (N, C) 和 target = input_tensor = (N)。我的主管告诉我要做到以下几点:
output = output.T.reshape(-1, 256) = tensor([225332, 256])
target = input_tensor.T.long() = tensor([225332, 256, 1]) # This needs to be 1-dimensional, help?
对于任何对显式代码感兴趣的人,如下所示: 注意 - 接收字段没有被填充,所以仅出于调试目的,我将其减去,虽然我知道这不是自然的。
>>> output.T.reshape(-1, 256).shape
torch.Size([225332, 256])
>>> input_tensor[:, :, model.input_size - model.output_size:].T.shape
torch.Size([225332, 256, 1])
>>> loss = F.cross_entropy(output.T.reshape(-1, 256), input_tensor[:, :, model.input_size - model.output_size:].T.long().to(device))
Traceback (most recent call last):
File "C:\Program Files\JetBrains\PyCharm Community Edition 2020.3.3\plugins\python-ce\helpers\pydev\_pydevd_bundle\pydevd_exec2.py", line 3, in Exec
exec(exp, global_vars, local_vars)
File "<input>", line 1, in <module>
File "C:\Users\JaQtae\anaconda3\envs\CortiGit\lib\site-packages\torch\nn\functional.py", line 2693, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File "C:\Users\JaQtae\anaconda3\envs\CortiGit\lib\site-packages\torch\nn\functional.py", line 2388, in nll_loss
ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: 1D target tensor expected, multi-target not supported
有点像 ML 和 AI 培训的新手,尤其是 PyTorch 库。
希望有任何关于我应该如何解决这个问题的建议。
培训:
model = Wavenet(layers=3,blocks=2,output_size=32).to(device)
model.apply(initialize) # Initialize causalconv1d() with xavier_uniform_ weights and bias of 0.
model.train()
optimizer = optim.Adam(model.parameters(), lr=0.0003)
for i, batch in tqdm(enumerate(train_loader)):
mu_enc_my_x = encode_mu_law(x=batch, mu=256)
input_tensor = one_hot_encoding(mu_enc_my_x)
input_tensor = input_tensor.to(device)
output = model(input_tensor)
# TODO: Inspect input/output formats, maybe something wrong....
loss = F.cross_entropy(output.T.reshape(-1, 256), input_tensor[:,:,model.input_size - model.output_size:].long().to(device)) # subtract receptive field instead of pad it, workaround for quick debugging of loss-issue.
print("\nLoss:", loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 1000 == 0:
print("\nSaving model")
torch.save(model.state_dict(), "wavenet.pt")
目的是让我的损失函数正常工作,以便我可以生成声音文件。当前具有我糟糕损失函数的那些显然返回纯噪声。
如果有帮助,我的完整模型。
"""
Wavenet model
Sources:
https://github.com/kan-bayashi/PytorchWaveNetVocoder/blob/master/wavenet_vocoder/nets/wavenet.py
https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/wavenet.py
https://github.com/Dankrushen/Wavenet-PyTorch/blob/master/wavenet/models.py
https://github.com/vincentherrmann/pytorch-wavenet
"""
from torch import nn
import torch
#TODO: Add local and global conditioning
def initialize(m):
"""
Initialize CNN with Xavier_uniform weight and 0 bias.
"""
if isinstance(m, torch.nn.Conv1d):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0.0)
class CausalConv1d(torch.nn.Module):
"""
Causal Convolution for WaveNet
Causality can be introduced with padding as (kernel_size - 1) * dilation (see Keras documentation)
or it can be introduced as follows according to Golbin.
https://github.com/golbin/WaveNet/blob/05545339096c3a1d9909d96fb19da4fbae28d8c6/wavenet/networks.py#L38
Else, look at the following article, several ways to implement it using PyTorch:
https://github.com/pytorch/pytorch/issues/1333
- Jakob
"""
def __init__(self, in_channels, out_channels, kernel_size, dilation = 1, bias = True):
super(CausalConv1d, self).__init__()
# padding=1 for same size(length) between input and output for causal convolution
self.dilation = dilation
self.kernel_size = kernel_size
self.in_channels = in_channels
self.out_channels = out_channels
self.padding = padding = (kernel_size-1) * dilation # kernelsize = 2, -1 * dilation = 1, = 1. - Jakob.
self.conv = torch.nn.Conv1d(in_channels, out_channels,
kernel_size, padding=padding, dilation=dilation,
bias=bias) # Fixed for WaveNet but not sure
def forward(self, x):
output = self.conv(x)
if self.padding != 0:
output = output[:, :, :-self.padding]
return output
class Wavenet(nn.Module):
def __init__(self,
layers=3,
blocks=2,
dilation_channels=32,
residual_block_channels=512,
skip_connection_channels=512,
output_channels=256,
output_size=32,
kernel_size=3
):
super(Wavenet, self).__init__()
self.layers = layers
self.blocks = blocks
self.dilation_channels = dilation_channels
self.residual_block_channels = residual_block_channels
self.skip_connection_channels = skip_connection_channels
self.output_channels = output_channels
self.kernel_size = kernel_size
self.output_size = output_size
# initialize dilation variables
receptive_field = 1
init_dilation = 1
# List of layers and connections
self.dilations = []
self.residual_convs = nn.ModuleList()
self.filter_conv_layers = nn.ModuleList()
self.gate_conv_layers = nn.ModuleList()
self.skip_convs = nn.ModuleList()
# First convolutional layer
self.first_conv = CausalConv1d(in_channels=self.output_channels,
out_channels=residual_block_channels,
kernel_size = 2)
# Building the Modulelists for the residual blocks
for b in range(blocks):
additional_scope = kernel_size - 1
new_dilation = 1
for i in range(layers):
# dilations of this layer
self.dilations.append((new_dilation, init_dilation))
# dilated convolutions
self.filter_conv_layers.append(nn.Conv1d(in_channels=residual_block_channels, out_channels=dilation_channels, kernel_size=kernel_size, dilation=new_dilation))
self.gate_conv_layers.append(nn.Conv1d(in_channels=residual_block_channels, out_channels=dilation_channels, kernel_size=kernel_size, dilation=new_dilation))
# 1x1 convolution for residual connection
self.residual_convs.append(nn.Conv1d(in_channels=dilation_channels, out_channels=residual_block_channels, kernel_size=1))
# 1x1 convolution for skip connection
self.skip_convs.append(nn.Conv1d(in_channels=dilation_channels,
out_channels=skip_connection_channels,
kernel_size=1))
# Update receptive field and dilation
receptive_field += additional_scope
additional_scope *= 2
init_dilation = new_dilation
new_dilation *= 2
# Last two convolutional layers
self.last_conv_1 = nn.Conv1d(in_channels=skip_connection_channels,
out_channels=skip_connection_channels,
kernel_size=1)
self.last_conv_2 = nn.Conv1d(in_channels=skip_connection_channels,
out_channels=output_channels,
kernel_size=1)
#Calculate model receptive field and the required input size for the given output size
self.receptive_field = receptive_field
self.input_size = receptive_field + output_size - 1
def forward(self, input):
# Feed first convolutional layer with input
x = self.first_conv(input)
# Initialize skip connection
skip = 0
# Residual block
for i in range(self.blocks * self.layers):
(dilation, init_dilation) = self.dilations[i]
# Residual connection bypassing dilated convolution block
residual = x
# input to dilated convolution block
filter = self.filter_conv_layers[i](x)
filter = torch.tanh(filter)
gate = self.gate_conv_layers[i](x)
gate = torch.sigmoid(gate)
x = filter * gate
# Feed into 1x1 convolution for skip connection
s = self.skip_convs[i](x)
#Adding skip & Match size with decreasing dimensionality of x
if skip is not 0:
skip = skip[:, :, -s.size(2):]
skip = s + skip # Sum all skip connections
# Feed into 1x1 convolution for residual connection
x = self.residual_convs[i](x)
#Adding Residual & Match size with decreasing dimensionality of x
x = x + residual[:, :, dilation * (self.kernel_size - 1):]
# print(x.shape)
x = torch.relu(skip)
#Last conv layers
x = torch.relu(self.last_conv_1(x))
x = self.last_conv_2(x)
soft = torch.nn.Softmax(dim=1)
x = soft(x)
return x
编辑:为清晰起见添加了火车代码片段和完整模型
解决方案
推荐阅读
- angularjs - 如何触发元素点击角度js?
- javascript - Firebase 存储错误:发生未知错误
- android - Android Architecture Navigation 组件和 Proguard 导致 java.lang.ClassNotFoundException
- ibm-mobilefirst - IBM MobileFirst - IPV6 上的 Android 连接错误
- android - 矢量可绘制图像视图上的洪水填充算法
- gulp - gulp watch 任务引发的 UNKNOWN 错误
- python - 如何通过计算过滤查询集?
- ruby - 按百分比选择数组中的项目
- autoconf - 为 OpenMPI 1.3 定义宏
- php - PHP MySQL双重提交使用两个SESSION