model_analysis

参考链接:

python库

用于分析和统计深度学习模型的python库

  • torchsummary : 获得详细参数量
  • thop :获得总参数量,总浮点运算次数
  • torchstat

torchsummary

1
2
3
4
5
6
from torchsummary import summary
import torchvision.models as models

model = models.resnet18()
# 打印模型摘要信息
summary(model, (3, 244, 244), device='cpu')

输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 122, 122] 9,408
BatchNorm2d-2 [-1, 64, 122, 122] 128
ReLU-3 [-1, 64, 122, 122] 0
... ... ...
Conv2d-60 [-1, 512, 8, 8] 2,359,296
ReLU-62 [-1, 512, 8, 8] 0
BatchNorm2d-64 [-1, 512, 8, 8] 1,024
BasicBlock-66 [-1, 512, 8, 8] 0
Linear-68 [-1, 1000] 513,000
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.68
Forward/backward pass size (MB): 76.08
Params size (MB): 44.59
Estimated Total Size (MB): 121.36
----------------------------------------------------------------

thop

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from thop import profile
import torchvision.models as models
import torch

# 定义一个示例模型
model = models.resnet18()

# 创建一个随机输入
input_data = torch.randn(1, 3, 244, 244)

# 使用 thop 分析模型
flops, params = profile(model, inputs=(input_data,))

print(f"FLOPs: {flops}, Params: {params}")

输出

1
2
3
4
5
6
7
8
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
FLOPs: 2278887168.0, Params: 11689512.0

torchstat

输出

  • 理论计算次数

  • 理论耗时 一轮*n

    • 只输入模型结构 loss optim
    • 不输入数据,后台随机生成
    • 没有loss图像
  • 模型存储占用:模型参数量大小

  • 指标:

    • 结构匹配度
    • 算法适应度:理论耗时/理论计算次数 转换到 -> %
    • 并行程度

响应时间:

  • 3s内

跑一轮:

  • tqdm

进度条:跑完

存储

-