當前位置:網站首頁>【Pytorch(四)】學習如何使用 PyTorch 讀取並處理數據集

【Pytorch(四)】學習如何使用 PyTorch 讀取並處理數據集

2022-01-28 10:28:06 zqwlearning

學習如何使用 PyTorch 讀取並處理數據集

在處理任何機器學習問題之前都需要讀取數據,並對數據進行預處理。處理數據樣本的代碼可能會變得混亂且難以維護,因此 PyTorch 將數據集代碼與模型訓練代碼相分離,從而獲得更好的可讀性和模塊化。

下面我們將以手寫數字0~9的數據集 MNIST 為例,學習如何在 PyTorch 中讀取和處理數據。

1. 准備數據集

MNIST數據集

2. 讀取並處理數據集 MNIST

下面我們來通過 PyTorch 讀取和處理 MNIST 數據集。在這一節中我們將把數據集讀取到 train_loader(訓練數據集)和 test_loader (測試數據集)中。

PyTorch 中與數據讀取和預處理相關的模塊包括 torchvision 和 torch.utils.data。我們首先導入相關包(torch & torchvision),並查看版本。其中,torch 是頂層的 PyTorch 包和張量庫;torchvision 是一個單獨的包,通過它可以便捷的訪問一些常用的數據集(如 MNIST、Fashion-MNIST、 Cifar 和 ImageNet 等),以及模型架構(如 VGG)和圖像轉換方法。

import torch  # top-level pytorch package and tensor library
import torchvision

print(torch.__version__)

print(torch.cuda.is_available())
print(torch.version.cuda)
1.7.1+cu110
True
11.0

數據讀取和預處理總結起來包括如下幾個步驟:

1) 提取原始數據 (extract data from the dataset)

2) 將提取出來的原始數據轉換為合適的格式 (transform it into the desirable format (Dataset object))

3) 將數據加載到合適的數據結構中 (load the data into a suitable data structure (DataLoader object))

batch_size_train = 128  # 設置訓練集的 batch size,即每批次將參與運算的樣本數
batch_size_test = 128  # 設置測試集 batch size
# 我們首先提取原始數據,即使用 PyTorch 的內置函數從網絡上獲得 MNIST 數據集。
# 數據集下載網址:http://yann.lecun.com/exdb/mnist/
# (大家可在上述網址閱讀 MNIST 的詳細信息。)

# 此處由於下載數據集可能會卡死,我們為大家准備好了提前下載好的數據集,即本篇
# 開頭讓大家下載並上傳到特定路徑的四個壓縮包。上傳到特定路徑是為了讓 PyTorch
# 能够找到。


######################### Please explain the following code #########################

# 下述代碼除了提取原始數據,還會對原始數據進行預處理 (transform)。經過轉換
# 後的數據被保存為一個 Dataset object,其中包含樣本及其對應的標簽。

# 請同學們查閱資料,在實驗報告中對以下代碼進行解釋,說明其每個參數對應的意義,
# 和代碼進行的具體操作。

train_set = torchvision.datasets.MNIST('./dataset_mnist', train=True, download=True,
                                       transform=torchvision.transforms.Compose([
                                           torchvision.transforms.ToTensor(),
                                           torchvision.transforms.Normalize(
                                               (0.1307,), (0.3081,)
                                           )
                                       ])
)

##################################### end ##########################################
# 接下來,請同學們仿照上述訓練集,對測試集進行相似處理,並將轉換後的測試集數據保存在 test_set 中。

############################ Please finish the code ################################

# test_set = XXX

test_set = torchvision.datasets.MNIST('./dataset_mnist', train=False, download=True,
                                       transform=torchvision.transforms.Compose([
                                           torchvision.transforms.ToTensor(),
                                           torchvision.transforms.Normalize(
                                               (0.1307,), (0.3081,)
                                           )
                                       ])
)


##################################### end ##########################################
# 完成數據轉換後,最後一步即將數據 (Dataset) 加載到合適的數據結構中,即 DataLoader。
# DataLoader 可以幫助我們便捷的對數據進行操作,例如我們可以方便的設置 batch_size
# (每一批的樣本個數), shuffle(是否隨機打亂樣本順序), num_workers(加載數據的時候
# 使用幾個子進程)等。
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size_test, shuffle=True)

我們的數據集已經准備好了,在開始使用 PyTorch 搭建神經網絡前,讓我們先來查看一下讀取到的數據集。

  1. 查看數據集整體情况
print(len(train_set))  # train_set 中的樣本總數
print(train_set.train_labels)  # train_set中的樣本標簽
print(train_set.train_labels.bincount())  # 查看每一個標簽有多少樣本
60000
tensor([5, 0, 4,  ..., 5, 6, 8])
tensor([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949])
print(train_set.classes)  # 查看 train_set 的樣本類別
print(len(train_set.classes))  # 查看train_set中有所少種類別
print(train_set.class_to_idx)  # 查看樣本類別和樣本標簽的對應關系
['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
10
{'0 - zero': 0, '1 - one': 1, '2 - two': 2, '3 - three': 3, '4 - four': 4, '5 - five': 5, '6 - six': 6, '7 - seven': 7, '8 - eight': 8, '9 - nine': 9}
  1. 查看 Dataset object 中的單個樣本
sample = next(iter(train_set))  # get an item from train_set

print("For each item in train_set: \n\n \t type: ", type(sample))  # tuple (image, label)
print("\t Length: ", len(sample), '\n')  # 2
For each item in train_set: 

 	 type:  <class 'tuple'>
	 Length:  2 
image, label = sample # unpack the sample

print("For each image: \n\n \t type: ", type(image)) # rank-3 tensor
print("\t shape: ", image.shape, '\n') # [channel, height, width] = [1, 28, 28] Note: 僅有3維!
print("For each label: \n\n \t type: ", type(label), '\n')
For each image: 

 	 type:  <class 'torch.Tensor'>
	 shape:  torch.Size([1, 28, 28]) 

For each label: 

 	 type:  <class 'int'> 
import matplotlib.pyplot as plt
import numpy as np

print("Let's check an image: \n ")
plt.imshow(image.squeeze(), cmap='gray')
print(f'label: {
      label}')
Let's check an image: 
 
label: 5

在這裏插入圖片描述

  1. 查看 DataLoader object 中一個批次的樣本
train_loader_plot = torch.utils.data.DataLoader(
    train_set, batch_size=40
)  # 假設一個批次有40個樣本

batch = next(iter(train_loader_plot))
print("type(batch): \t", type(batch))  # list [images, labels]
print("len(batch): \t", len(batch), "\n") # 2

images, labels = batch

print("type(images): \t", type(images))  # rank-4 tensor
print("images.shape: \t", images.shape)  # [batch_size, channel, height, width] = [10, 1, 28, 28]

print("type(labels): \t", type(labels))  # rank-1 tensor
print("labels.shape: \t", labels.shape)  # size=batch size
type(batch): 	 <class 'list'>
len(batch): 	 2 

type(images): 	 <class 'torch.Tensor'>
images.shape: 	 torch.Size([40, 1, 28, 28])
type(labels): 	 <class 'torch.Tensor'>
labels.shape: 	 torch.Size([40])
# 畫出第一個批次的樣本

grid = torchvision.utils.make_grid(images, nrow=10)  # make a grid of images (grid is a tensor)
plt.figure(figsize=(12,12))
plt.imshow(np.transpose(grid, (1,2,0))) # np.transpose permutes the dimensions
print(f'labels: {
      labels}')
labels: tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9, 4, 0, 9, 1,
        1, 2, 4, 3, 2, 7, 3, 8, 6, 9, 0, 5, 6, 0, 7, 6])

在這裏插入圖片描述

版權聲明
本文為[zqwlearning]所創,轉載請帶上原文鏈接,感謝
https://cht.chowdera.com/2022/01/202201281028060551.html

隨機推薦