當前位置:網站首頁>Pytorch加載模型只導入部分層權重,即跳過指定網絡層的方法

Pytorch加載模型只導入部分層權重,即跳過指定網絡層的方法

2022-05-15 07:14:04m0_61899108

需求

Pytorch加載模型時,只導入部分層權重,跳過部分指定網絡層。(權重文件存儲為dict形式)

方法一

常見方法:加載權重時用if對網絡層進行篩選

'''
# model為定義的網絡結構:
class model(nn.Module):
    def __init__(self):
        super(model,self).__init__()
        ……

    def forward(self,x):
        ……
        return x
'''

model = model()  
# load存在的模型參數(權重文件),後綴名可能不同    
pretrained_dict = torch.load('model.pkl')
model_dict = model.state_dict()
# 關鍵在於下面這句,從model_dict中讀取key、value時,用if篩選掉不需要的網絡層 
pretrained_dict = {key: value for key, value in pretrained_dict.items() if (key in model_dict and 'Prediction' not in key)}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

方法二

不完全匹配,只加載權重中存在的參數,不匹配就跳過

# load_state_dict() 默認strict=True,需要完全匹配,否則報錯
# 修改為strict=False後,只匹配存在的參數
pretrained_dict = torch.load(weight_path)
model.load_state_dict(pretrained_dict, strict=False)

方法三

 不使用原有權重文件訓練,對原有權重文件進行拷貝,拷貝文件中只包含需要的網絡層,後續直接利用拷貝權重文件進行訓練。

    # 對原有權重文件進行拷貝,拷貝文件中只包含需要的網絡層,
    # 後續直接利用拷貝文件進行訓練。
    import pickle

    model = model()
    net = model
    path_weight = 'R-50.pkl'
    path_weight2 = 'R2-50.pkl'

    with open(path_weight,'rb') as f:
        obj=f.read()
    # 用pickle.loads()加載權重信息
    la_obj=pickle.loads(obj,encoding='latin1')
    # 用if進行篩選
    weights= {key: value for key, value in la_obj.items()}
              #if key in la_obj and 'backbone.bottom_up.stem.conv1.weight' not in key}
    # 使用print查看權重文件信息 
    print(weights)
    
    # 再深拷貝一份文件保存
    state_dict = copy.deepcopy(weights)
    with open(path_weight2,'wb') as f2:
        pickle.dump(state_dict, f2)

    # 可以寫入txt,便於查看信息
    path_weight2 = 'R2-101.txt'
    inf = str(state_dict)
    ff = open(path_weight2,'w')
    ff.write(inf)

下面是對載入參數的優化有特殊要求:參數固定、或者參數更新速度不同。

方法四

如果載入的這些參數中,有些參數不要求被更新,即固定不變,不參與訓練,需要手動設置這些參數的梯度屬性為Fasle,並且在optimizer傳參時篩選掉這些參數:

# 載入預訓練模型參數後...
for name, value in model.named_parameters():
    if name 滿足某些條件:
        value.requires_grad = False

# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)

方法五

如果載入的這些參數中,所有參數都更新,但要求一些參數和另一些參數的更新速度(學習率learning rate)不一樣,最好知道這些參數的名稱都有什麼:

# 載入預訓練模型參數後...
for name, value in model.named_parameters():
    print(name)
# 或
print(model.state_dict().keys())

假設該模型中有encoder,viewer和decoder兩部分,參數名稱分別是:

'encoder.visual_emb.0.weight',
'encoder.visual_emb.0.bias',
'viewer.bd.Wsi',
'viewer.bd.bias',
'decoder.core.layer_0.weight_ih',
'decoder.core.layer_0.weight_hh',

假設要求encode、viewer的學習率為1e-6, decoder的學習率為1e-4,那麼在將參數傳入優化器時:

ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
                              {'params':model.decoder.parameters()}
                              ],
                              lr=1e-4, momentum=0.9)

代碼的結果是除decoder參數的learning_rate=1e-4 外,其他參數的learning_rate=1e-6。
在傳入optimizer時,和一般的傳參方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,參數部分用了一個list, list的每個元素有params和lr兩個鍵值。如果沒有 lr則應用Adam的lr屬性。Adam的屬性除了lr, 其他都是參數所共有的(比如momentum)。
 

遇見的問題

torch.load 加載權重文件時報錯 Magic Number Error 

有時候使用 torch.load 加載比較古老的權重文件時可能報錯 Magic Number Error,這有可能是因為該文件使用 pickle 存儲並且編碼使用了 latin1,此時可以這樣加載:

若要進行篩選,同理可以在後面加上if進行判斷。

import pickle
with open(weights_path, 'rb') as f:
    obj = f.read()
# 用pickle進行load,編碼方式為latin1
weights = {key: weight_dict for key, weight_dict in pickle.loads(obj,encoding='latin1').items()}
# 同理,可以用if判斷進行篩選
# weights = {key: value for key, value in pickle.loads(obj,encoding='latin1').items() if (key in model_dict and 'Prediction' not in key)}
model.load_state_dict(weights) 

TypeError: a bytes-like object is required, not 'str'

python3和python2在套接字返回值解碼上有區別。

套接字就是 socket,用於描述 IP 地址和端口,應用程序通過套接字向網絡發出請求或者應答網絡請求,可以認為是計算機網絡的數據接口。目前套接字分為兩種:基於文件型和基於網絡型。

解决方法

使用函數 encode() 和decode():

  1. str 通過 encode() 函數編碼變為 bytes
  2. bytes 通過 decode() 函數編碼變為 str。(當我們從網絡或磁盤上讀取了字節流,則讀到的數據就是 bytes)

補充:

str --> bytes

# 聲明一個字符串s:
>>> s = 'abc'
>>> type(s)
<class 'str'>

# 四種轉換方式:
>>> b1 = s.encode()
>>> type(b1)
<class 'bytes'>
>>> b2 = str.encode(s)
>>> type(b2)
<class 'bytes'>
>>> b3 = s.encode(encoding='utf-8')
>>> type(b3)
<class 'bytes'>
>>> b4 = bytes(s,encoding='utf-8')
>>> type(b4)
<class 'bytes'>

bytes --> str

# 聲明一個bytes:
>>> b = b'abc'
>>> type(b)
<class 'bytes'>

# 三種轉換方式:
>>> s1 = bytes.decode(b)
>>> type(s1)
<class 'str'>
>>> s2 = b.decode()
>>> type(s2)
<class 'str'>
>>> s3 = str(b,encoding='utf-8')
>>> type(s3)
<class 'str'>

參考博客

Pytorch中只導入部分層權重的方法_汐夢聆海的博客-CSDN博客_pytorch加載部分權重

pytorch微調模型—只加載預訓練模型的某些層_農夫山泉2號的博客-CSDN博客

Pytorch加載模型不完全匹配 & 只加載部分參數權重 load_hxxjxw的博客-CSDN博客_pytorch加載模型不匹配跳過

pytorch載入預訓練模型後,只想訓練個別層怎麼辦?_慕白-的博客-CSDN博客_pytorch只訓練最後一層

PyTorch | 保存和加載模型 - 知乎 (zhihu.com)

torch.load加載權重時報錯 Magic Number Error - 仰望高端玩家的小清新 - 博客園 (cnblogs.com)

Python報錯:TypeError: a bytes-like object is required, not ‘str‘_程序媛三妹的博客-CSDN博客 

版權聲明
本文為[m0_61899108]所創,轉載請帶上原文鏈接,感謝
https://cht.chowdera.com/2022/135/202205142322539438.html

隨機推薦