python - 如何在不影响 0 权重神经元的情况下计算 FLOPs 和 Params?
问题描述
我的 Prune 代码如下所示,运行此代码后,我将得到一个名为“pruned_model.pth”的文件。
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from cnn import net
ori_model = '/content/drive/My Drive/ECG_weight_prune/checkpoint_dir/model.pth'
save_path = '/content/drive/My Drive/ECG_weight_prune/checkpoint_dir/pruned_model.pth'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = net().to(device)
model.load_state_dict(torch.load(ori_model))
module = model.conv1
print(list(module.named_parameters()))
print(list(module.named_buffers()))
prune.l1_unstructured(module, name="weight", amount=0.3)
prune.l1_unstructured(module, name="bias", amount=3)
print(list(module.named_parameters()))
print(list(module.named_buffers()))
print(module.bias)
print(module.weight)
print(module._forward_pre_hooks)
prune.remove(module, 'weight')
prune.remove(module, 'bias')
print(list(module.named_parameters()))
print(model.state_dict())
torch.save(model.state_dict(), save_path)
结果是:
[('weight', Parameter containing:
tensor([[[-0.0000, -0.3137, -0.3221, ..., 0.5055, 0.3614, -0.0000]],
[[ 0.8889, 0.2697, -0.3400, ..., 0.8546, 0.2311, -0.0000]],
[[-0.2649, -0.1566, -0.0000, ..., 0.0000, 0.0000, 0.3855]],
...,
[[-0.2836, -0.0000, 0.2155, ..., -0.8894, -0.7676, -0.6271]],
[[-0.7908, -0.6732, -0.5024, ..., 0.2011, 0.4627, 1.0227]],
[[ 0.4433, 0.5048, 0.7685, ..., -1.0530, -0.8908, -0.4799]]],
device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([-0.7497, -1.3594, -1.7613, -2.0137, -1.1763, 0.4150, -1.6996, -1.5354,
0.4330, -0.9259, 0.4156, -2.3099, -0.4282, -0.5199, 0.1188, -1.1725,
-0.9064, -1.6639, -1.5834, -0.3655, -2.0727, -2.1078, -1.6431, -0.0694,
-0.5435, -1.9623, 0.5481, -0.8255, -1.5108, -0.4029, -1.9759, 0.0522,
0.0599, -2.2469, -0.5599, 0.1039, -0.4472, -1.1706, -0.0398, -1.9441,
-1.5310, -0.0837, -1.3250, -0.2098, -0.1919, 0.4600, -0.8268, -1.0041,
-0.8168, -0.8701, 0.3869, 0.1706, -0.0226, -1.2711, -0.9302, -2.0696,
-1.1838, 0.4497, -1.1426, 0.0772, -2.4356, -0.3138, 0.6297, 0.2022,
-0.4024, 0.0000, -1.2337, 0.2840, 0.4515, 0.2999, 0.0273, 0.0374,
0.1325, -0.4890, -2.3845, -1.9663, 0.2108, -0.1144, 0.0544, -0.2629,
0.0393, -0.6728, -0.9645, 0.3118, -0.5142, -0.4097, -0.0000, -1.5142,
-1.2798, 0.2871, -2.0122, -0.9346, -0.4931, -1.4895, -1.1401, -0.8823,
0.2210, 0.4282, 0.1685, -1.8876, -0.7459, 0.2505, -0.6315, 0.3827,
-0.3348, 0.1862, 0.0806, -2.0277, 0.2068, 0.3281, -1.8045, -0.0000,
-2.2377, -1.9742, -0.5164, -0.0660, 0.8392, 0.5863, -0.7301, 0.0778,
0.1611, 0.0260, 0.3183, -0.9097, -1.6152, 0.4712, -0.2378, -0.4972],
device='cuda:0', requires_grad=True))]
存在许多零权重。如何在不计算与这些零值相关的计算的情况下计算 FLOPs 和 Params?
我使用以下代码来计算 FLOPs 和 Params。
import torch
from cnn import net
from ptflops import get_model_complexity_info
ori_model = '/content/drive/My Drive/ECG_weight_prune/checkpoint_dir/model.pth'
pthfile = '/content/drive/My Drive/ECG_weight_prune/checkpoint_dir/pruned_model.pth'
model = net()
# model.load_state_dict(torch.load(ori_model))
model.load_state_dict(torch.load(pthfile))
# print(model.state_dict())
macs, params = get_model_complexity_info(model, (1, 260), as_strings=False,
print_per_layer_stat=True, verbose=True)
print('{:<30} {:<8}'.format('Computational complexity: ', macs))
print('{:<30} {:<8}'.format('Number of parameters: ', params))
ori_model 和 pthfile 的输出都是一样的,如下。
Warning: module Dropout2d is treated as a zero-op.
Warning: module Flatten is treated as a zero-op.
Warning: module net is treated as a zero-op.
net(
0.05 M, 100.000% Params, 0.001 GMac, 100.000% MACs,
(conv1): Conv1d(0.007 M, 13.143% Params, 0.0 GMac, 45.733% MACs, 1, 128, kernel_size=(50,), stride=(3,))
(conv2): Conv1d(0.029 M, 57.791% Params, 0.001 GMac, 50.980% MACs, 128, 32, kernel_size=(7,), stride=(1,))
(conv3): Conv1d(0.009 M, 18.619% Params, 0.0 GMac, 0.913% MACs, 32, 32, kernel_size=(9,), stride=(1,))
(fc1): Linear(0.004 M, 8.504% Params, 0.0 GMac, 0.404% MACs, in_features=32, out_features=128, bias=True)
(fc2): Linear(0.001 M, 1.299% Params, 0.0 GMac, 0.063% MACs, in_features=128, out_features=5, bias=True)
(bn1): BatchNorm1d(0.0 M, 0.515% Params, 0.0 GMac, 1.793% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(bn2): BatchNorm1d(0.0 M, 0.129% Params, 0.0 GMac, 0.114% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(dropout): Dropout2d(0.0 M, 0.000% Params, 0.0 GMac, 0.000% MACs, p=0.5, inplace=False)
(faltten): Flatten(0.0 M, 0.000% Params, 0.0 GMac, 0.000% MACs, )
)
Computational complexity: 1013472.0
Number of parameters: 49669
解决方案
您可以做的一件事是从 FLOPs 计算中排除低于某个阈值的权重。为此,您必须修改翻牌计数器功能。
我将在下面提供修改 fc 和 conv 层的示例。
def linear_flops_counter_hook(module, input, output):
input = input[0]
output_last_dim = output.shape[-1] # pytorch checks dimensions, so here we don't care much
# MODIFICATION HAPPENS HERE
num_zero_weights = (module.weight.data.abs() < 1e-9).sum()
zero_weights_factor = 1 - torch.true_divide(num_zero_weights, module.weight.data.numel())
module.__flops__ += int(np.prod(input.shape) * output_last_dim) * zero_weights_factor.numpy()
# MODIFICATION HAPPENS HERE
def conv_flops_counter_hook(conv_module, input, output):
# Can have multiple inputs, getting the first one
input = input[0]
batch_size = input.shape[0]
output_dims = list(output.shape[2:])
kernel_dims = list(conv_module.kernel_size)
in_channels = conv_module.in_channels
out_channels = conv_module.out_channels
groups = conv_module.groups
filters_per_channel = out_channels // groups
conv_per_position_flops = int(np.prod(kernel_dims)) * in_channels * filters_per_channel
active_elements_count = batch_size * int(np.prod(output_dims))
# MODIFICATION HAPPENS HERE
num_zero_weights = (conv_module.weight.data.abs() < 1e-9).sum()
zero_weights_factor = 1 - torch.true_divide(num_zero_weights, conv_module.weight.data.numel())
overall_conv_flops = conv_per_position_flops * active_elements_count * zero_weights_factor.numpy()
# MODIFICATION HAPPENS HERE
bias_flops = 0
if conv_module.bias is not None:
bias_flops = out_channels * active_elements_count
overall_flops = overall_conv_flops + bias_flops
conv_module.__flops__ += int(overall_flops)
请注意,我使用 1e-9 作为权重计数为零的阈值。
推荐阅读
- powerbi - Calendar/GenerateSeries DAX 函数给出错误“日历函数中的开始日期或结束日期不能为空白值”。
- r - R - 使用序列但排除特定数字的 For 循环
- python - 无法将 NumPy 数组转换为张量(不支持的对象类型 numpy.ndarray)
- groovy - 尝试使用 Groovy Spock Mock 模拟两个类:GroovyCastException
- selenium - Gitlab On-Premise Runner 无法使用 Selenium Chrome 服务运行 Codeception
- python - 如何将热图与面板中的曲线结合起来
- javascript - 我不知道如何将 json 响应放入 html 元素
- python - 用函数修改列表中的每个元素,然后修改列表中的所有元素
- html - 强大的表单单选按钮样式
- java - 如何使用 MVVM 模式在 UI 中更新服务器响应