python-3.x - 代码优化:Torch.Tensor 中的计算
问题描述
我目前正在实现一个计算自定义交叉熵损失的函数。函数的定义如下图。
我的代码如下,
output = output.permute(0, 2, 3, 1)
target = target.permute(0, 2, 3, 1)
batch, height, width, channel = output.size()
total_loss = 0.
for b in range(batch): # for each batch
o = output[b]
t = target[b]
loss = 0.
for w in range(width):
for h in range(height): # for every pixel([h,w]) in the image
sid_t = t[h][w][0]
sid_o_candi = o[h][w]
part1 = 0. # to store the first sigma
part2 = 0. # to store the second sigma
for k in range(0, sid_t):
p = torch.sum(sid_o_candi[k:]) # to get Pk(w,h)
part1 += torch.log(p + 1e-12).item()
for k in range(sid_t, intervals):
p = torch.sum(sid_o_candi[k:]) # to get Pk(w,h)
part2 += torch.log(1-p + 1e-12).item()
loss += part1 + part2
loss /= width * height * (-1)
total_loss += loss
total_loss /= batch
return torch.tensor(total_loss, dtype=torch.float32)
我想知道这些代码是否可以进行任何优化。
解决方案
我不确定sid_t = t[h][w][0]
每个像素是否相同。如果是这样,您可以摆脱所有for loop
提高计算速度的损失。
不要使用
.item()
,因为它会返回一个丢失grad_fn
轨道的 Python 值。然后你不能loss.backward()
用来计算梯度。
如果sid_t = t[h][w][0]
不一样,这里有一些修改可以帮助你摆脱至少 1 for-loop
:
batch, height, width, channel = output.size()
total_loss = 0.
for b in range(batch): # for each batch
o = output[b]
t = target[b]
loss = 0.
for w in range(width):
for h in range(height): # for every pixel([h,w]) in the image
sid_t = t[h][w][0]
sid_o_candi = o[h][w]
part1 = 0. # to store the first sigma
part2 = 0. # to store the second sigma
sid1_cumsum = sid_o_candi[:sid_t].flip(dim=(0,)).cumsum(dim=0).flip(dims=(0,))
part1 = torch.sum(torch.log(sid1_cumsum + 1e-12))
sid2_cumsum = sid_o_candi[sid_t:intervals].flip(dim=(0,)).cumsum(dim=0).flip(dims=(0,))
part2 = torch.sum(torch.log(1 - sid2_cumsum + 1e-12))
loss += part1 + part2
loss /= width * height * (-1)
total_loss += loss
total_loss /= batch
return torch.tensor(total_loss, dtype=torch.float32)
这个怎么运作:
x = torch.arange(10);
print(x)
x_flip = x.flip(dims=(0,));
print(x_flip)
x_inverse_cumsum = x_flip.cumsum(dim=0).flip(dims=(0,))
print(x_inverse_cumsum)
# output
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])
tensor([45, 45, 44, 42, 39, 35, 30, 24, 17, 9])
希望能帮助到你。
推荐阅读
- python - 无法在 Pycharm 中捕获密码
- c# - 如何在 linq to sql 中使包含方法不区分大小写
- javascript - Zxing 问题:无法让 MatrixToImageConfig 使用 javascript
- handlebars.js - 子文件夹中的零件 - 车把 (Fabricator)
- javascript - 如何外部化 d3.js 转换配置
- javascript - 这是实现cookie的正确方法吗?
- python - 使用 Pandas 自相关图 - 如何限制 x 轴以使其更具可读性?
- c# - 来自 cmd.exe 的不同 dotnet 测试结果
- node.js - NodeJS中是否有类似的“Java同步方法”?
- c - 函数中的 Sigaction 声明