當前位置:網站首頁>【Pytorch(四)】學習如何使用 PyTorch 讀取並處理數據集
【Pytorch(四)】學習如何使用 PyTorch 讀取並處理數據集
2022-01-28 10:28:06 【zqwlearning】
學習如何使用 PyTorch 讀取並處理數據集
在處理任何機器學習問題之前都需要讀取數據,並對數據進行預處理。處理數據樣本的代碼可能會變得混亂且難以維護,因此 PyTorch 將數據集代碼與模型訓練代碼相分離,從而獲得更好的可讀性和模塊化。
下面我們將以手寫數字0~9的數據集 MNIST 為例,學習如何在 PyTorch 中讀取和處理數據。
1. 准備數據集
2. 讀取並處理數據集 MNIST
下面我們來通過 PyTorch 讀取和處理 MNIST 數據集。在這一節中我們將把數據集讀取到 train_loader(訓練數據集)和 test_loader (測試數據集)中。
PyTorch 中與數據讀取和預處理相關的模塊包括 torchvision 和 torch.utils.data。我們首先導入相關包(torch & torchvision),並查看版本。其中,torch 是頂層的 PyTorch 包和張量庫;torchvision 是一個單獨的包,通過它可以便捷的訪問一些常用的數據集(如 MNIST、Fashion-MNIST、 Cifar 和 ImageNet 等),以及模型架構(如 VGG)和圖像轉換方法。
import torch # top-level pytorch package and tensor library
import torchvision
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.version.cuda)
1.7.1+cu110
True
11.0
數據讀取和預處理總結起來包括如下幾個步驟:
1) 提取原始數據 (extract data from the dataset)
2) 將提取出來的原始數據轉換為合適的格式 (transform it into the desirable format (Dataset object))
3) 將數據加載到合適的數據結構中 (load the data into a suitable data structure (DataLoader object))
batch_size_train = 128 # 設置訓練集的 batch size,即每批次將參與運算的樣本數
batch_size_test = 128 # 設置測試集 batch size
# 我們首先提取原始數據,即使用 PyTorch 的內置函數從網絡上獲得 MNIST 數據集。
# 數據集下載網址:http://yann.lecun.com/exdb/mnist/
# (大家可在上述網址閱讀 MNIST 的詳細信息。)
# 此處由於下載數據集可能會卡死,我們為大家准備好了提前下載好的數據集,即本篇
# 開頭讓大家下載並上傳到特定路徑的四個壓縮包。上傳到特定路徑是為了讓 PyTorch
# 能够找到。
######################### Please explain the following code #########################
# 下述代碼除了提取原始數據,還會對原始數據進行預處理 (transform)。經過轉換
# 後的數據被保存為一個 Dataset object,其中包含樣本及其對應的標簽。
# 請同學們查閱資料,在實驗報告中對以下代碼進行解釋,說明其每個參數對應的意義,
# 和代碼進行的具體操作。
train_set = torchvision.datasets.MNIST('./dataset_mnist', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,)
)
])
)
##################################### end ##########################################
# 接下來,請同學們仿照上述訓練集,對測試集進行相似處理,並將轉換後的測試集數據保存在 test_set 中。
############################ Please finish the code ################################
# test_set = XXX
test_set = torchvision.datasets.MNIST('./dataset_mnist', train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,)
)
])
)
##################################### end ##########################################
# 完成數據轉換後,最後一步即將數據 (Dataset) 加載到合適的數據結構中,即 DataLoader。
# DataLoader 可以幫助我們便捷的對數據進行操作,例如我們可以方便的設置 batch_size
# (每一批的樣本個數), shuffle(是否隨機打亂樣本順序), num_workers(加載數據的時候
# 使用幾個子進程)等。
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size_test, shuffle=True)
我們的數據集已經准備好了,在開始使用 PyTorch 搭建神經網絡前,讓我們先來查看一下讀取到的數據集。
- 查看數據集整體情况
print(len(train_set)) # train_set 中的樣本總數
print(train_set.train_labels) # train_set中的樣本標簽
print(train_set.train_labels.bincount()) # 查看每一個標簽有多少樣本
60000
tensor([5, 0, 4, ..., 5, 6, 8])
tensor([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949])
print(train_set.classes) # 查看 train_set 的樣本類別
print(len(train_set.classes)) # 查看train_set中有所少種類別
print(train_set.class_to_idx) # 查看樣本類別和樣本標簽的對應關系
['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
10
{'0 - zero': 0, '1 - one': 1, '2 - two': 2, '3 - three': 3, '4 - four': 4, '5 - five': 5, '6 - six': 6, '7 - seven': 7, '8 - eight': 8, '9 - nine': 9}
- 查看 Dataset object 中的單個樣本
sample = next(iter(train_set)) # get an item from train_set
print("For each item in train_set: \n\n \t type: ", type(sample)) # tuple (image, label)
print("\t Length: ", len(sample), '\n') # 2
For each item in train_set:
type: <class 'tuple'>
Length: 2
image, label = sample # unpack the sample
print("For each image: \n\n \t type: ", type(image)) # rank-3 tensor
print("\t shape: ", image.shape, '\n') # [channel, height, width] = [1, 28, 28] Note: 僅有3維!
print("For each label: \n\n \t type: ", type(label), '\n')
For each image:
type: <class 'torch.Tensor'>
shape: torch.Size([1, 28, 28])
For each label:
type: <class 'int'>
import matplotlib.pyplot as plt
import numpy as np
print("Let's check an image: \n ")
plt.imshow(image.squeeze(), cmap='gray')
print(f'label: {
label}')
Let's check an image:
label: 5
- 查看 DataLoader object 中一個批次的樣本
train_loader_plot = torch.utils.data.DataLoader(
train_set, batch_size=40
) # 假設一個批次有40個樣本
batch = next(iter(train_loader_plot))
print("type(batch): \t", type(batch)) # list [images, labels]
print("len(batch): \t", len(batch), "\n") # 2
images, labels = batch
print("type(images): \t", type(images)) # rank-4 tensor
print("images.shape: \t", images.shape) # [batch_size, channel, height, width] = [10, 1, 28, 28]
print("type(labels): \t", type(labels)) # rank-1 tensor
print("labels.shape: \t", labels.shape) # size=batch size
type(batch): <class 'list'>
len(batch): 2
type(images): <class 'torch.Tensor'>
images.shape: torch.Size([40, 1, 28, 28])
type(labels): <class 'torch.Tensor'>
labels.shape: torch.Size([40])
# 畫出第一個批次的樣本
grid = torchvision.utils.make_grid(images, nrow=10) # make a grid of images (grid is a tensor)
plt.figure(figsize=(12,12))
plt.imshow(np.transpose(grid, (1,2,0))) # np.transpose permutes the dimensions
print(f'labels: {
labels}')
labels: tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9, 4, 0, 9, 1,
1, 2, 4, 3, 2, 7, 3, 8, 6, 9, 0, 5, 6, 0, 7, 6])
版權聲明
本文為[zqwlearning]所創,轉載請帶上原文鏈接,感謝
https://cht.chowdera.com/2022/01/202201281028060551.html
邊欄推薦
猜你喜歡
隨機推薦
- uniapp上傳圖片及組件傳值
- 瑞利年金險資金保障安全嗎?收益高不高啊?
- 華為手機USB連不上電腦的解决方法
- Flutter 2,移動金融應用開發
- 關於st25系列NFC標簽簡單介紹及st25TV系列用於門禁讀取時的注意事項總結
- 關於用ffmpeg轉手機視頻發現視頻長寬倒了的問題
- 數組中的第k個最大的元素--優先級隊列、排序、堆、排序
- 單片機實例27——ADC0809A/D轉換器基本應用技術(硬件電路圖+匯編程序+C語言程序)
- Collection集合的學習
- 一場面試結束,某度員工從事Android 5年為何還是初級工程師?
- 3本書閱讀筆記【人月神話-Go語言實戰-研發能力持續成長路線】01
- PHP垃圾回收機制
- 【電子技術】什麼是LFSR?
- 死鎖?如何定比特到死鎖?如何修複死鎖?(jps和jstack兩個工具)
- 快樂寒假 22/01/20
- image
- 噴程序員?SURE?
- LDO分壓電阻計算小工具
- 面試之求一串字符串中每個字符的出現次數
- 【ISO15765_UDS&OBD診斷】-01-概述
- 【Mysql上分之路】第九篇:Mysql存儲引擎
- RHCE 第一次作業
- 2021.10.16我的第一篇博客:一切皆有可能!
- CTA-敏感行為-讀取IMEI
- 面試被問怎麼排查平時遇到的系統CPU飆高和頻繁GC,該怎麼回答?
- nuxt項目總結-綜合
- 自然語言處理學習筆記(一)
- C語言第一課
- XCTFre逆向(四):insanity
- 理解什麼是真正的並發數
- JVM腦圖
- 函數棧幀的創建與銷毀
- 構建神經網絡- 手寫字體識別案例
- 多模態生成模型ERNIE-VILG
- kotlin不容忽視的小細節
- 備戰一年,終於斬獲騰訊T3,我堅信成功是可以複制的