Python激活函数

AlexNet参数及可视化

Python示例代码

    import torch
    import torchvision.models as models

    # 加载预训练的AlexNet
    model = models.alexnet(weights=models.AlexNet_Weights.DEFAULT)
    model.eval()

    def analyze_alexnet(model):
        total_params = 0
        print(f"{'Layer':<15} {'Type':<15} {'Params':<10} {'Shape':<25}")
        print("-" * 65)
        
        # 卷积层
        conv_params = 0
        for i, layer in enumerate(model.features):
            if isinstance(layer, torch.nn.Conv2d):
                params = sum(p.numel() for p in layer.parameters())
                conv_params += params
                total_params += params
                print(f"{f'conv{(i//3)+1}':<15} {'Conv2d':<15} {params:<10,} {str(tuple(layer.weight.shape)):<25}")
        
        # 全连接层  
        fc_params = 0
        fc_layers = [layer for layer in model.classifier if isinstance(layer, torch.nn.Linear)]
        names = ['fc1', 'fc2', 'fc3']
        
        for name, layer in zip(names, fc_layers):
            params = sum(p.numel() for p in layer.parameters())
            fc_params += params
            total_params += params
            print(f"{name:<15} {'Linear':<15} {params:<10,} {str(tuple(layer.weight.shape)):<25}")
        
        print("-" * 65)
        print(f"{'Total':<15} {'-':<15} {total_params:<10,} {'-':<25}")
        
        return total_params, conv_params, fc_params

    # 执行分析
    total, conv, fc = analyze_alexnet(model)

    print(f"\n参数分布:")
    print(f"卷积层: {conv:,} ({conv/total*100:.1f}%)")
    print(f"全连接层: {fc:,} ({fc/total*100:.1f}%)")
    print(f"总参数: {total:,} ({total/1e6:.1f}M)")

    # 测试输入输出
    test_input = torch.randn(1, 3, 224, 224)
    output = model(test_input)
    print(f"\n输入: {test_input.shape} -> 输出: {output.shape}")
        

    运行结果
    Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
    100%|██████████| 233M/233M [36:29<00:00, 112kB/s]    
    Layer           Type            Params     Shape                    
    -----------------------------------------------------------------
    conv1           Conv2d          23,296     (64, 3, 11, 11)          
    conv2           Conv2d          307,392    (192, 64, 5, 5)          
    conv3           Conv2d          663,936    (384, 192, 3, 3)         
    conv3           Conv2d          884,992    (256, 384, 3, 3)         
    conv4           Conv2d          590,080    (256, 256, 3, 3)         
    fc1             Linear          37,752,832 (4096, 9216)             
    fc2             Linear          16,781,312 (4096, 4096)             
    fc3             Linear          4,097,000  (1000, 4096)             
    -----------------------------------------------------------------
    Total           -               61,100,840 -                        

    参数分布:
    卷积层: 2,469,696 (4.0%)
    全连接层: 58,631,144 (96.0%)
    总参数: 61,100,840 (61.1M)

    输入: torch.Size([1, 3, 224, 224]) -> 输出: torch.Size([1, 1000])
    ​