當前位置:網站首頁>Broadcasting機制

Broadcasting機制

2022-01-28 06:59:00 nefu_0iq

一、Broadcasting機制介紹

  1. Broadcasting(廣播)機制是numpy對不同shape的數組進行數值計算的一種方式。
  2. Numpy的通用函數中要求輸入的兩個tensor的shape一致。當兩個tensor的shape不一致時,會使用Broadcasting機制。

二、Broadcast機制規則

  1. 規則1:讓所有輸入的tensor向dim最大的tensor看齊,dim不足,前補1
  2. 規則2:將輸入的tensor的某一個shape為1的維度拉伸
  3. 規則3:如果存在兩個tensor的shape在任何維度均不匹配,且均沒有等於1的維度,則會報錯

三、實例說明

  1. Broadcasting理解
    1. 三部分代碼分別對應下圖的三個實例
a = torch.tensor([ [0,0,0],[10,10,10],[20,20,20],[30,30,30]]) # torch.Size([4, 3])
b = torch.tensor([ [0,1,2] , [0,1,2],[0,1,2],[0,1,2]]) # torch.Size([4, 3])
print("a + b == \n",a + b)
# 由於a和b的shape相同,不需要使用Broadcasting機制
# ---------------------------------------------------------------------
a = torch.tensor([ [0,0,0],[10,10,10],[20,20,20],[30,30,30]]) # torch.Size([4, 3])
b = torch.tensor([0,1,2]) # torch.Size([3])
print("a + b == \n",a + b)
# a.shape = [4,3] a.dim = 2; b.shape = [3], b.dim = 1, shape不同,需要使用Broadcasting機制
# 由於b.dim = 1 < a.dim = 2 , b向a看齊
# 相當於b.unsqueeze(0),使其shape變為[1,3],然後再使用expand(4,-1)使其shape變為[4,3]

# ---------------------------------------------------------------------
a = torch.tensor([[0],[10],[20],[30]]) # torch.Size([4, 1])
b = torch.tensor([0,1,2]) # torch.Size([3])
print("a + b == \n",a + b)
# a.shape = [4,1] a.dim = 2; b.shape = [3], b.dim = 1, shape不同,需要使用Broadcasting機制
# 由於b.dim = 1 < a.dim = 2 , b向a看齊
# 如果dim不同需要補齊所以b的shape變換為 [3] => [1,3]
# dim變換後,將shape為1的拉伸: a.shape變換為: [4,1] => [4,3] ; b.shape變換為: [1,3] => [4,3]

在這裏插入圖片描述
2. 簡單實際應用
1. 給你一個tensor,其shape含義為[class,students,scores]
2. 先需要將學生的所有成績+5分

a = torch.rand(4,32,8) # torch.Size([4, 32, 8])
b = torch.tensor(5) # torch.Size([])
c = a + b
  1. 錯誤實例
a = torch.rand(4,32,14,14)
b = torch.rand(2,32,14,14)
c = a + b
# RuntimeError: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 0
# 這裏出錯原因: 
# Dim 0 has dim,can NOT insert and expand to same
# Dim 0 has distinct dim, NOT size 1
# NOT broadcasting-able
  1. 多種方式理解
    1. 一下例子我們用[b,c,h,w]
    2. [4,3,32,32] + [32,32] : 對於所有的batch,channels,都疊加相同的picture_map,相當於疊加相同的base,使其平移。
    3. [4,3,32,32] + [3,1,1] :對於RGB來說,增加值
    4. [4,3,32,32] + [1,1,1,1] :對於所有像素點都增加一個值

版權聲明
本文為[nefu_0iq]所創,轉載請帶上原文鏈接,感謝
https://cht.chowdera.com/2022/01/202201280659003737.html

隨機推薦