當前位置:網站首頁>*精度優化*優化策略1:網絡+SAM優化器

*精度優化*優化策略1:網絡+SAM優化器

2022-07-23 05:01:12夏天|여름이다

一:SAM優化器介紹:

SAM:Sharpness Awareness Minimization銳度感知最小化

SAM不是一個新的優化器,它與其他常見的優化器一起使用,比如SGD/Adam

論文:2020 Sharpness-Aware Minimization for Efficiently Improving Generalization

論文地址:https://arxiv.org/pdf/2010.01412v2.pdf

項目地址:GitHub - davda54/sam: SAM: Sharpness-Aware Minimization (PyTorch)

(依舊建議大家使用GPU去訓練,一般電腦cpu可以運行,但是非常卡,能卡出數據集,但是沒卡出結果。)

下載解壓後非常簡單,把sam.py文件直接複制到example文件夾下就可以直接跑train.py.

運行後會自動下載數據集,會進行批次訓練。

運行結果:(我改的epochs比較小,改大效果更好)

重要部分如下train.py:

import argparse
import torch

from model.wide_res_net import WideResNet#導入模型中的wide_res_net網絡
from model.smooth_cross_entropy import smooth_crossentropy#導入損失函數
from data.cifar import Cifar#導入數據集
from utility.log import Log#導入工具類日志文件
from utility.initialize import initialize#導入工具類初始化
from utility.step_lr import StepLR#導入工具類階梯學習率
from utility.bypass_bn import enable_running_stats, disable_running_stats#導入工具類繞過BN,啟用運行統計,禁用運行統計

import sys; sys.path.append("..")#導入sys.path中需要用到的XXX包,然後加載
from sam import SAM#引入SAM


if __name__ == "__main__":
    #創建解析器(arg對象)
    parser = argparse.ArgumentParser()
    #添加參數
    parser.add_argument("--adaptive", default=True, type=bool, help="True if you want to use the Adaptive SAM.")
    parser.add_argument("--batch_size", default=12, type=int, help="Batch size used in the training and validation loop.")
    parser.add_argument("--depth", default=16, type=int, help="Number of layers.")
    parser.add_argument("--dropout", default=0.0, type=float, help="Dropout rate.")
    parser.add_argument("--epochs", default=2, type=int, help="Total number of epochs.")
    parser.add_argument("--label_smoothing", default=0.1, type=float, help="Use 0.0 for no label smoothing.")
    parser.add_argument("--learning_rate", default=0.1, type=float, help="Base learning rate at the start of the training.")
    parser.add_argument("--momentum", default=0.9, type=float, help="SGD Momentum.")
    parser.add_argument("--threads", default=2, type=int, help="Number of CPU threads for dataloaders.")
    parser.add_argument("--rho", default=2.0, type=int, help="Rho parameter for SAM.")
    parser.add_argument("--weight_decay", default=0.0005, type=float, help="L2 weight decay.")
    parser.add_argument("--width_factor", default=8, type=int, help="How many times wider compared to normal ResNet.")
    #解析參數
    args = parser.parse_args()
    #初始化
    initialize(args, seed=42)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #定義數據集
    dataset = Cifar(args.batch_size, args.threads)
    #記錄日志
    log = Log(log_each=10)
    #定義模型
    model = WideResNet(args.depth, args.width_factor, args.dropout, in_channels=3, labels=10).to(device)
    #定義基礎優化器
    base_optimizer = torch.optim.SGD
    #定義第二個優化器SAM
    optimizer = SAM(model.parameters(), base_optimizer, rho=args.rho, adaptive=args.adaptive, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    #將optimizer作為參數傳遞給scheduler,每次通過調用scheduler.step()就會更新optimizer中每一個param_group[‘lr’],每過固定個epoch,學習率會按照gamma倍率進行衰减。
    scheduler = StepLR(optimizer, args.learning_rate, args.epochs)
    
    for epoch in range(args.epochs):
        model.train()
        log.train(len_dataset=len(dataset.train))

        for batch in dataset.train:
            inputs, targets = (b.to(device) for b in batch)

            # first forward-backward step
            enable_running_stats(model)
            predictions = model(inputs)
            loss = smooth_crossentropy(predictions, targets, smoothing=args.label_smoothing)
            loss.mean().backward()
            optimizer.first_step(zero_grad=True)

            # second forward-backward step
            disable_running_stats(model)
            smooth_crossentropy(model(inputs), targets, smoothing=args.label_smoothing).mean().backward()
            optimizer.second_step(zero_grad=True)

            with torch.no_grad():
                correct = torch.argmax(predictions.data, 1) == targets
                log(model, loss.cpu(), correct.cpu(), scheduler.lr())
                scheduler(epoch)

        model.eval()
        log.eval(len_dataset=len(dataset.test))

        with torch.no_grad():
            for batch in dataset.test:
                inputs, targets = (b.to(device) for b in batch)

                predictions = model(inputs)
                loss = smooth_crossentropy(predictions, targets)
                correct = torch.argmax(predictions, 1) == targets
                log(model, loss.cpu(), correct.cpu())

    log.flush()

二:把SAM應用到自己的項目上:

step1:把SAM的工具文件複制到自己的項目下

把utility文件夾複制到自己的項目下,

把sam.py複制到項目根目錄,在train.py裏導入包。

step2:把數據集改為自己的數據集

step3:把網絡改為自己的網絡

(我的項目是多個獨立的網絡,幾個網絡就寫幾遍)

step4:添加基礎優化器和SAM

    base_optimizer = torch.optim.SGD
    optimizer = SAM(model.parameters(), base_optimizer, rho=args.rho, adaptive=args.adaptive, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, args.learning_rate, args.epochs)
#一定要根據自己的項目去改相關參數等

step5:把損失函數改為自己原本的損失函數添加SAM工具類

            ...

            #opt.zero_grad()注釋掉原本的

            #添加SAM工具類裏的函數
            enable_running_stats(model_context)
            enable_running_stats(model_body)
            enable_running_stats(emotic_model)
            #我的項目是三個網絡。如果是一個網絡的話,寫一次
            #類似於enable_running_stats(model)

            ...
 
            loss.backward()
            opt.first_step(zero_grad=True)#在項目的loss反向傳播後先用優化器first step

             #添加SAM工具類裏的函數
            disable_running_stats(model_context)
            disable_running_stats(model_body)
            disable_running_stats(emotic_model)

            ...

            loss.backward()
            opt.second_step(zero_grad=True)#在項目的loss反向傳播後再用優化器second step

            # opt.step()注釋掉原本的

step6:添加SAM所需的超參數(可選,不改也不會出錯)


原項目:

修改後:

原項目:

 

 修改後:

#黃色框為修改比特置

 加入SAM優化器後,比原來精度提高了將近3%。

 

 

以上。(全是自己的理解,不正確望指正,感謝。)

版權聲明
本文為[夏天|여름이다]所創,轉載請帶上原文鏈接,感謝
https://cht.chowdera.com/2022/204/202207221752585451.html

隨機推薦