VGGNet剪枝实战:使用VGGNet训练、稀疏训练、剪枝、微调等,剪枝出只有3M的模型(二)

news/2024/7/7 19:47:23

文章目录

  • 稀疏训练VGGNet
  • 剪枝
    • 导入库文件
    • 测试函数
    • 定义全局参数
    • BN通道排序
    • 制作Mask
    • 剪枝操作
  • 微调
    • 微调方法
    • 微调结果

稀疏训练VGGNet

新建train_sp.py脚本。稀疏化训练的过程和正常训练类似,不同的是在BN层中各权重加入稀疏因子,代码如下:

def updateBN(model,s=0.0001,epoch=1,epochs=1000):
    srtmp = s * (1 - 0.9 * epoch / epochs)
    for m in model.modules():
        if isinstance(m,nn.BatchNorm2d):
            m.weight.grad.data.add_(srtmp*torch.sign(m.weight.data))
            m.bias.grad.data.add_(s * 10 * torch.sign(m.bias.data))

加入到train函数中,如图:
在这里插入图片描述
s的设置需要根据数据集调整,可以通过观察tensorboard的map,gamma变化直方图等选择。我在本次训练种使用的是0.001.

训练完成后,就可以使用tensorboard观察训练结果,在根目录运行:

tensorboard --logdir .

然后看到如下信息:

TensorFlow installation not found - running with reduced feature set.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.13.0 at http://localhost:6006/ (Press CTRL+C to quit)

在浏览器中打开http://localhost:6006/就能看到。
在这里插入图片描述
在这里插入图片描述
蓝色的是正常训练,BN权重的分布情况。紫红色的是加入稀疏因子后BN权重的分布情况。
稀疏化训练结果:
在这里插入图片描述
在这里插入图片描述
结果基本上和正常训练一致!最终结果也是95.6%。

剪枝

新建prune.py脚本,这个脚本是剪枝脚本。

导入库文件

import os
import argparse
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms

from vgg import VGG
import numpy as np

测试函数

测试函数,用来测试剪枝后的模型ACC,代码如下:

# simple test model after Pre-processing prune (simple set BN scales to zeros)
def test():
    # 读取数据
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.48214436, 0.42969334, 0.33318862], std=[0.2642221, 0.23746745, 0.21696019])
    ])
    dataset_test = datasets.ImageFolder("data/val", transform=transform_test)

    test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, pin_memory=True, shuffle=False)
    model.eval()
    correct = 0
    for data, target in test_loader:
        data, target = data.to(DEVICE), target.to(DEVICE)
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
        correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset))

定义全局参数

if __name__ == '__main__':
    BATCH_SIZE=16
    percent=0.7
    DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model_path='checkpoints/vgg_sp/best.pth'
    save_name='pruned.pth'

BATCH_SIZE:测试函数的BatchSize。
percent:剪枝的比率。
DEVICE :如果有显卡则使用GPU,没有则使用cpu。
model_path:稀疏训练模型的路径。
save_name:剪枝后,模型的路径。

BN通道排序

    #加载稀疏训练的模型
    model = torch.load(model_path)
    print(model)
    total = 0 # 统计所有BN层的参数量
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            total += m.weight.data.shape[0] # 每个BN层权重w参数量
    bn = torch.zeros(total)
    index = 0
    for m in model.modules():
        #将各个BN层的参数值复制到bn中
        if isinstance(m, nn.BatchNorm2d):
            size = m.weight.data.shape[0]
            bn[index:(index + size)] = m.weight.data.abs().clone()
            index += size
    #对bn中的weight值排序
    y, i = torch.sort(bn)#
    thre_index = int(total * percent)
    thre = y[thre_index]#取bn排序后的第thresh_index索引值为bn权重的截断阈值

制作Mask

    pruned = 0 #统计BN层剪枝通道数
    cfg = []#统计保存通道数
    cfg_mask = []#BN层权重矩阵,剪枝的通道记为0,未剪枝通道记为1
    for k, m in enumerate(model.modules()):
        if isinstance(m, nn.BatchNorm2d):
            weight_copy = m.weight.data.clone()
            mask = weight_copy.abs().gt(thre).float().cuda()#阈值分离权重
            pruned = pruned + mask.shape[0] - torch.sum(mask)
            m.weight.data.mul_(mask)#更新BN层的权重,剪枝通道的权重值为0
            m.bias.data.mul_(mask)
            cfg.append(int(torch.sum(mask)))#记录未被剪枝的通道数量
            cfg_mask.append(mask.clone())
            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
                  format(k, mask.shape[0], int(torch.sum(mask))))
        elif isinstance(m, nn.MaxPool2d):
            cfg.append('M')

    pruned_ratio = pruned / total
    print('Pre-processing Successful!')
    test()
    # Make real prune
    print(cfg)

剪枝操作

    newmodel = VGG(cfg=cfg,num_classes=12)
    newmodel.cuda()
    layer_id_in_cfg = 0
    start_mask = torch.ones(3)
    end_mask = cfg_mask[layer_id_in_cfg]
    for [m0, m1] in zip(model.modules(), newmodel.modules()):
        if isinstance(m0, nn.BatchNorm2d):
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            m1.weight.data = m0.weight.data[idx1].clone()
            m1.bias.data = m0.bias.data[idx1].clone()
            m1.running_mean = m0.running_mean[idx1].clone()
            m1.running_var = m0.running_var[idx1].clone()
            layer_id_in_cfg += 1
            start_mask = end_mask.clone()
            if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
                end_mask = cfg_mask[layer_id_in_cfg]
        elif isinstance(m0, nn.Conv2d):
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            print('In shape: {:d} Out shape:{:d}'.format(idx0.shape[0], idx1.shape[0]))
            w = m0.weight.data[:, idx0, :, :].clone()
            w = w[idx1, :, :, :].clone()
            m1.weight.data = w.clone()
            # m1.bias.data = m0.bias.data[idx1].clone()
        elif isinstance(m0, nn.Linear):
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            m1.weight.data = m0.weight.data[:, idx0].clone()
    torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, save_name)
    print(newmodel)
    model = newmodel
    test()

剪枝后保存模型

微调

微调方法

微调方法和正常训练类似,加载剪枝后的模型和配置,然后训练、验证即可!

if __name__ == '__main__':
    # 创建保存模型的文件夹
    file_dir = 'checkpoints/vgg_pruned'
    if os.path.exists(file_dir):
        print('true')
        os.makedirs(file_dir, exist_ok=True)
    else:
        os.makedirs(file_dir)
    # 设置全局参数
    model_lr = 1e-4
    BATCH_SIZE = 16
    EPOCHS = 300
    DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    classes = 12
    resume = 'pruned.pth'

    # 设置模型
    model = torch.load(resume)
    model_ft=VGG(cfg=model['cfg'],num_classes=classes)
    model_ft.load_state_dict(model['state_dict'])
    model_ft.to(DEVICE)
    print(model_ft)

微调结果

Val set: Average loss: 0.2845, Accuracy: 457/482 (95%)

                           precision    recall  f1-score   support

              Black-grass       0.79      0.86      0.83        36
                 Charlock       1.00      1.00      1.00        42
                 Cleavers       1.00      0.96      0.98        50
         Common Chickweed       0.94      0.91      0.93        34
             Common wheat       0.93      1.00      0.97        42
                  Fat Hen       0.97      0.97      0.97        34
         Loose Silky-bent       0.88      0.78      0.83        46
                    Maize       0.96      1.00      0.98        45
        Scentless Mayweed       0.93      0.96      0.95        45
          Shepherds Purse       0.97      0.97      0.97        35
Small-flowered Cranesbill       1.00      0.97      0.99        36
               Sugar beet       1.00      1.00      1.00        37

                 accuracy                           0.95       482
                macro avg       0.95      0.95      0.95       482
             weighted avg       0.95      0.95      0.95       482

在这里插入图片描述

在这里插入图片描述


http://lihuaxi.xjx100.cn/news/1429323.html

相关文章

Mysql报错1194 - Table ‘‘ is marked as crashed and should be repaired的解决办法

本篇文章主要讲解&#xff1a;Mysql报错1194 - Table ‘’ is marked as crashed and should be repaired的解决办法。 日期&#xff1a;2023年8月9日 作者&#xff1a;任聪聪 具体现象 说明&#xff1a;执行sql语句查询或者检索相关数据时会出现如下报错内容&#xff1a; 11…

CSV文件编辑器——Modern CSV for mac

Modern CSV for Mac是一款功能强大、操作简单的CSV文件编辑器&#xff0c;适用于Mac用户快速、高效地处理和管理CSV文件。Modern CSV具有直观的用户界面&#xff0c;可以轻松导入、编辑和导出CSV文件。它支持各种功能&#xff0c;包括排序、过滤、查找和替换&#xff0c;使您能…

Qt事件过滤器

1 介绍 事件过滤器是一种机制&#xff0c;当某个QObject没有所需要的事件功能时&#xff0c;可将其委托给其它QObject&#xff0c;通过eventFilter成员函数来过滤实现功能。 2 主要构成 委托&#xff1a; ui->QObject1->installEventFilter(QObject2); eventFilter声明 …

NLP(六十五)LangChain中的重连(retry)机制

关于LangChain入门&#xff0c;读者可参考文章NLP&#xff08;五十六&#xff09;LangChain入门 。   本文将会介绍LangChain中的重连机制&#xff0c;并尝试给出定制化重连方案。   本文以LangChain中的对话功能&#xff08;ChatOpenAI&#xff09;为例。 LangChain中的重…

/proc directory in linux

Its zero-length files are neither binary nor text, yet you can examine and display themUnder Linux, everything is managed as a file; even devices are accessed as files (in the /dev directory). Although you might think that “normal” files are either text …

如何查看Linux内核中某个线程的CPU占用率

在Linux中&#xff0c;可以使用top、htop、ps等命令来查看进程和线程的CPU占用率。以下是使用top命令查看某个线程的CPU占用率的步骤&#xff1a; 打开终端并输入top命令。按下ShiftH&#xff0c;将显示所有线程。找到要查看的线程&#xff0c;并记下其PID&#xff08;进程ID&…

centos7安装phpipam1.4

by:铁乐与猫 date&#xff1a;2021-5-11 安装依赖 sudo yum install epel-release sudo yum install php-mcrypt安装 Apache, MySQL, PHP (LAMP) stack packages sudo yum install httpd mariadb-server php php-cli php-gd php-common php-ldap php-pdo php-pear php-snmp …

【ARM Cache 系列文章 9 番外篇 -- ARMv9 系列 Core 介绍】

文章目录 ARMv9 系列CoreARM Cortex-A510 介绍ARM Cortex-A715ARM Cortex-A720 ARMv9 系列Core 2021年5月Arm公布了其最新3款CPU和3款GPU核心设计&#xff0c;三款新CPU分别是旗舰核心Cortex-X2、高性能核心Cortex-A710、高能效核心Cortex-A510 CPU&#xff0c;三款新GPU核心则…