PyTorch一小時(shí)掌握之神經(jīng)網(wǎng)絡(luò)分類篇
概述
對(duì)于 MNIST 手寫數(shù)據(jù)集的具體介紹, 我們?cè)?TensorFlow 中已經(jīng)詳細(xì)描述過, 在這里就不多贅述. 有興趣的同學(xué)可以去看看之前的文章: https://www.jb51.net/article/222183.htm
在上一節(jié)的內(nèi)容里, 我們用 PyTorch 實(shí)現(xiàn)了回歸任務(wù), 在這一節(jié)里, 我們將使用 PyTorch 來解決分類任務(wù).
導(dǎo)包
import torchvision import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import matplotlib.pyplot as plt
設(shè)置超參數(shù)
# 設(shè)置超參數(shù) n_epochs = 3 batch_size_train = 64 batch_size_test = 1000 learning_rate = 0.01 momentum = 0.5 log_interval = 10 random_seed = 1 torch.manual_seed(random_seed)
讀取數(shù)據(jù)
# 數(shù)據(jù)讀取
train_loader = torch.utils.data.DataLoader(
 torchvision.datasets.MNIST('./data/', train=True, download=True,
transform=torchvision.transforms.Compose([
 torchvision.transforms.ToTensor(),
 torchvision.transforms.Normalize(
  (0.1307,), (0.3081,))
])),
 batch_size=batch_size_train, shuffle=True)
 
test_loader = torch.utils.data.DataLoader(
 torchvision.datasets.MNIST('./data/', train=False, download=True,
transform=torchvision.transforms.Compose([
 torchvision.transforms.ToTensor(),
 torchvision.transforms.Normalize(
  (0.1307,), (0.3081,))
])),
 batch_size=batch_size_test, shuffle=True)
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
# 調(diào)試輸出
print(example_targets)
print(example_data.shape)
輸出結(jié)果:
tensor([7, 6, 7, 5, 6, 7, 8, 1, 1, 2, 4, 1, 0, 8, 4, 4, 4, 9, 8, 1, 3, 3, 8, 6,
2, 7, 5, 1, 6, 5, 6, 2, 9, 2, 8, 4, 9, 4, 8, 6, 7, 7, 9, 8, 4, 9, 5, 3,
1, 0, 9, 1, 7, 3, 7, 0, 9, 2, 5, 1, 8, 9, 3, 7, 8, 4, 1, 9, 0, 3, 1, 2,
3, 6, 2, 9, 9, 0, 3, 8, 3, 0, 8, 8, 5, 3, 8, 2, 8, 5, 5, 7, 1, 5, 5, 1,
0, 9, 7, 5, 2, 0, 7, 6, 1, 2, 2, 7, 5, 4, 7, 3, 0, 6, 7, 5, 1, 7, 6, 7,
2, 1, 9, 1, 9, 2, 7, 6, 8, 8, 8, 4, 6, 0, 0, 2, 3, 0, 1, 7, 8, 7, 4, 1,
3, 8, 3, 5, 5, 9, 6, 0, 5, 3, 3, 9, 4, 0, 1, 9, 9, 1, 5, 6, 2, 0, 4, 7,
3, 5, 8, 8, 2, 5, 9, 5, 0, 7, 8, 9, 3, 8, 5, 3, 2, 4, 4, 6, 3, 0, 8, 2,
7, 0, 5, 2, 0, 6, 2, 6, 3, 6, 6, 7, 9, 3, 4, 1, 6, 2, 8, 4, 7, 7, 2, 7,
4, 2, 4, 9, 7, 7, 5, 9, 1, 3, 0, 4, 4, 8, 9, 6, 6, 5, 3, 3, 2, 3, 9, 1,
1, 4, 4, 8, 1, 5, 1, 8, 8, 0, 7, 5, 8, 4, 0, 0, 0, 6, 3, 0, 9, 0, 6, 6,
9, 8, 1, 2, 3, 7, 6, 1, 5, 9, 3, 9, 3, 2, 5, 9, 9, 5, 4, 9, 3, 9, 6, 0,
3, 3, 8, 3, 1, 4, 1, 4, 7, 3, 1, 6, 8, 4, 7, 7, 3, 3, 6, 1, 3, 2, 3, 5,
9, 9, 9, 2, 9, 0, 2, 7, 0, 7, 5, 0, 2, 6, 7, 3, 7, 1, 4, 6, 4, 0, 0, 3,
2, 1, 9, 3, 5, 5, 1, 6, 4, 7, 4, 6, 4, 4, 9, 7, 4, 1, 5, 4, 8, 7, 5, 9,
2, 9, 4, 0, 8, 7, 3, 4, 2, 7, 9, 4, 4, 0, 1, 4, 1, 2, 5, 2, 8, 5, 3, 9,
1, 3, 5, 1, 9, 5, 3, 6, 8, 1, 7, 9, 9, 9, 9, 9, 2, 3, 5, 1, 4, 2, 3, 1,
1, 3, 8, 2, 8, 1, 9, 2, 9, 0, 7, 3, 5, 8, 3, 7, 8, 5, 6, 4, 1, 9, 7, 1,
7, 1, 1, 8, 6, 7, 5, 6, 7, 4, 9, 5, 8, 6, 5, 6, 8, 4, 1, 0, 9, 1, 4, 3,
5, 1, 8, 7, 5, 4, 6, 6, 0, 2, 4, 2, 9, 5, 9, 8, 1, 4, 8, 1, 1, 6, 7, 5,
9, 1, 1, 7, 8, 7, 5, 5, 2, 6, 5, 8, 1, 0, 7, 2, 2, 4, 3, 9, 7, 3, 5, 7,
6, 9, 5, 9, 6, 5, 7, 2, 3, 7, 2, 9, 7, 4, 8, 4, 9, 3, 8, 7, 5, 0, 0, 3,
4, 3, 3, 6, 0, 1, 7, 7, 4, 6, 3, 0, 8, 0, 9, 8, 2, 4, 2, 9, 4, 9, 9, 9,
7, 7, 6, 8, 2, 4, 9, 3, 0, 4, 4, 1, 5, 7, 7, 6, 9, 7, 0, 2, 4, 2, 1, 4,
7, 4, 5, 1, 4, 7, 3, 1, 7, 6, 9, 0, 0, 7, 3, 6, 3, 3, 6, 5, 8, 1, 7, 1,
6, 1, 2, 3, 1, 6, 8, 8, 7, 4, 3, 7, 7, 1, 8, 9, 2, 6, 6, 6, 2, 8, 8, 1,
6, 0, 3, 0, 5, 1, 3, 2, 4, 1, 5, 5, 7, 3, 5, 6, 2, 1, 8, 0, 2, 0, 8, 4,
4, 5, 0, 0, 1, 5, 0, 7, 4, 0, 9, 2, 5, 7, 4, 0, 3, 7, 0, 3, 5, 1, 0, 6,
4, 7, 6, 4, 7, 0, 0, 5, 8, 2, 0, 6, 2, 4, 2, 3, 2, 7, 7, 6, 9, 8, 5, 9,
7, 1, 3, 4, 3, 1, 8, 0, 3, 0, 7, 4, 9, 0, 8, 1, 5, 7, 3, 2, 2, 0, 7, 3,
1, 8, 8, 2, 2, 6, 2, 7, 6, 6, 9, 4, 9, 3, 7, 0, 4, 6, 1, 9, 7, 4, 4, 5,
8, 2, 3, 2, 4, 9, 1, 9, 6, 7, 1, 2, 1, 1, 2, 6, 9, 7, 1, 0, 1, 4, 2, 7,
7, 8, 3, 2, 8, 2, 7, 6, 1, 1, 9, 1, 0, 9, 1, 3, 9, 3, 7, 6, 5, 6, 2, 0,
0, 3, 9, 4, 7, 3, 2, 9, 0, 9, 5, 2, 2, 4, 1, 6, 3, 4, 0, 1, 6, 9, 1, 7,
0, 8, 0, 0, 9, 8, 5, 9, 4, 4, 7, 1, 9, 0, 0, 2, 4, 3, 5, 0, 4, 0, 1, 0,
5, 8, 1, 8, 3, 3, 2, 1, 2, 6, 8, 2, 5, 3, 7, 9, 3, 6, 2, 2, 6, 2, 7, 7,
6, 1, 8, 0, 3, 5, 7, 5, 0, 8, 6, 7, 2, 4, 1, 4, 3, 7, 7, 2, 9, 3, 5, 5,
9, 4, 8, 7, 6, 7, 4, 9, 2, 7, 7, 1, 0, 7, 2, 8, 0, 3, 5, 4, 5, 1, 5, 7,
6, 7, 3, 5, 3, 4, 5, 3, 4, 3, 2, 3, 1, 7, 4, 4, 8, 5, 5, 3, 2, 2, 9, 5,
8, 2, 0, 6, 0, 7, 9, 9, 6, 1, 6, 6, 2, 3, 7, 4, 7, 5, 2, 9, 4, 2, 9, 0,
8, 1, 7, 5, 5, 7, 0, 5, 2, 9, 5, 2, 3, 4, 6, 0, 0, 2, 9, 2, 0, 5, 4, 8,
9, 0, 9, 1, 3, 4, 1, 8, 0, 0, 4, 0, 8, 5, 9, 8])
torch.Size([1000, 1, 28, 28])
可視化展示
# 畫圖 (前6個(gè))
fig = plt.figure()
for i in range(6):
 plt.subplot(2, 3, i + 1)
 plt.tight_layout()
 plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
 plt.title("Ground Truth: {}".format(example_targets[i]))
 plt.xticks([])
 plt.yticks([])
plt.show()
輸出結(jié)果:

建立模型
# 創(chuàng)建model class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.conv2_drop = nn.Dropout2d() self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) return F.log_softmax(x) network = Net() optimizer = optim.SGD(network.parameters(), lr=learning_rate, momentum=momentum)
訓(xùn)練模型
# 訓(xùn)練
train_losses = []
train_counter = []
test_losses = []
test_counter = [i * len(train_loader.dataset) for i in range(n_epochs + 1)]
def train(epoch):
 network.train()
 for batch_idx, (data, target) in enumerate(train_loader):
  optimizer.zero_grad()
  output = network(data)
  loss = F.nll_loss(output, target)
  loss.backward()
  optimizer.step()
  if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
 epoch, batch_idx * len(data), len(train_loader.dataset),
  100. * batch_idx / len(train_loader), loss.item()))
train_losses.append(loss.item())
train_counter.append(
 (batch_idx * 64) + ((epoch - 1) * len(train_loader.dataset)))
torch.save(network.state_dict(), './model.pth')
torch.save(optimizer.state_dict(), './optimizer.pth')
def test():
 network.eval()
 test_loss = 0
 correct = 0
 with torch.no_grad():
  for data, target in test_loader:
output = network(data)
test_loss += F.nll_loss(output, target, size_average=False).item()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).sum()
 test_loss /= len(test_loader.dataset)
 test_losses.append(test_loss)
 print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
  test_loss, correct, len(test_loader.dataset),
  100. * correct / len(test_loader.dataset)))
for epoch in range(1, n_epochs + 1):
 train(epoch)
 test()
輸出結(jié)果:
Train Epoch: 1 [0/60000 (0%)] Loss: 2.297471
Train Epoch: 1 [6400/60000 (11%)] Loss: 1.934886
Train Epoch: 1 [12800/60000 (21%)] Loss: 1.242982
Train Epoch: 1 [19200/60000 (32%)] Loss: 0.979296
Train Epoch: 1 [25600/60000 (43%)] Loss: 1.277279
Train Epoch: 1 [32000/60000 (53%)] Loss: 0.721533
Train Epoch: 1 [38400/60000 (64%)] Loss: 0.759595
Train Epoch: 1 [44800/60000 (75%)] Loss: 0.469635
Train Epoch: 1 [51200/60000 (85%)] Loss: 0.422614
Train Epoch: 1 [57600/60000 (96%)] Loss: 0.417603Test set: Avg. loss: 0.1988, Accuracy: 9431/10000 (94%)
Train Epoch: 2 [0/60000 (0%)] Loss: 0.277207
Train Epoch: 2 [6400/60000 (11%)] Loss: 0.328862
Train Epoch: 2 [12800/60000 (21%)] Loss: 0.396312
Train Epoch: 2 [19200/60000 (32%)] Loss: 0.301772
Train Epoch: 2 [25600/60000 (43%)] Loss: 0.253600
Train Epoch: 2 [32000/60000 (53%)] Loss: 0.217821
Train Epoch: 2 [38400/60000 (64%)] Loss: 0.395815
Train Epoch: 2 [44800/60000 (75%)] Loss: 0.265737
Train Epoch: 2 [51200/60000 (85%)] Loss: 0.323627
Train Epoch: 2 [57600/60000 (96%)] Loss: 0.236692Test set: Avg. loss: 0.1233, Accuracy: 9622/10000 (96%)
Train Epoch: 3 [0/60000 (0%)] Loss: 0.500148
Train Epoch: 3 [6400/60000 (11%)] Loss: 0.338118
Train Epoch: 3 [12800/60000 (21%)] Loss: 0.452308
Train Epoch: 3 [19200/60000 (32%)] Loss: 0.374940
Train Epoch: 3 [25600/60000 (43%)] Loss: 0.323300
Train Epoch: 3 [32000/60000 (53%)] Loss: 0.203830
Train Epoch: 3 [38400/60000 (64%)] Loss: 0.379557
Train Epoch: 3 [44800/60000 (75%)] Loss: 0.334822
Train Epoch: 3 [51200/60000 (85%)] Loss: 0.361676
Train Epoch: 3 [57600/60000 (96%)] Loss: 0.218833Test set: Avg. loss: 0.0911, Accuracy: 9723/10000 (97%)
完整代碼
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
# 設(shè)置超參數(shù)
n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 100
random_seed = 1
torch.manual_seed(random_seed)
# 數(shù)據(jù)讀取
train_loader = torch.utils.data.DataLoader(
 torchvision.datasets.MNIST('./data/', train=True, download=True,
transform=torchvision.transforms.Compose([
 torchvision.transforms.ToTensor(),
 torchvision.transforms.Normalize(
  (0.1307,), (0.3081,))
])),
 batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(
 torchvision.datasets.MNIST('./data/', train=False, download=True,
transform=torchvision.transforms.Compose([
 torchvision.transforms.ToTensor(),
 torchvision.transforms.Normalize(
  (0.1307,), (0.3081,))
])),
 batch_size=batch_size_test, shuffle=True)
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
# 調(diào)試輸出
print(example_targets)
print(example_data.shape)
# 畫圖 (前6個(gè))
fig = plt.figure()
for i in range(6):
 plt.subplot(2, 3, i + 1)
 plt.tight_layout()
 plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
 plt.title("Ground Truth: {}".format(example_targets[i]))
 plt.xticks([])
 plt.yticks([])
plt.show()
# 創(chuàng)建model
class Net(nn.Module):
 def __init__(self):
  super(Net, self).__init__()
  self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
  self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
  self.conv2_drop = nn.Dropout2d()
  self.fc1 = nn.Linear(320, 50)
  self.fc2 = nn.Linear(50, 10)
 def forward(self, x):
  x = F.relu(F.max_pool2d(self.conv1(x), 2))
  x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
  x = x.view(-1, 320)
  x = F.relu(self.fc1(x))
  x = F.dropout(x, training=self.training)
  x = self.fc2(x)
  return F.log_softmax(x)
network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
 momentum=momentum)
# 訓(xùn)練
train_losses = []
train_counter = []
test_losses = []
test_counter = [i * len(train_loader.dataset) for i in range(n_epochs + 1)]
def train(epoch):
 network.train()
 for batch_idx, (data, target) in enumerate(train_loader):
  optimizer.zero_grad()
  output = network(data)
  loss = F.nll_loss(output, target)
  loss.backward()
  optimizer.step()
  if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
 epoch, batch_idx * len(data), len(train_loader.dataset),
  100. * batch_idx / len(train_loader), loss.item()))
train_losses.append(loss.item())
train_counter.append(
 (batch_idx * 64) + ((epoch - 1) * len(train_loader.dataset)))
torch.save(network.state_dict(), './model.pth')
torch.save(optimizer.state_dict(), './optimizer.pth')
def test():
 network.eval()
 test_loss = 0
 correct = 0
 with torch.no_grad():
  for data, target in test_loader:
output = network(data)
test_loss += F.nll_loss(output, target, size_average=False).item()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).sum()
 test_loss /= len(test_loader.dataset)
 test_losses.append(test_loss)
 print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
  test_loss, correct, len(test_loader.dataset),
  100. * correct / len(test_loader.dataset)))
for epoch in range(1, n_epochs + 1):
 train(epoch)
 test()
到此這篇關(guān)于PyTorch一小時(shí)掌握之神經(jīng)網(wǎng)絡(luò)分類篇的文章就介紹到這了,更多相關(guān)PyTorch神經(jīng)網(wǎng)絡(luò)分類內(nèi)容請(qǐng)搜索本站以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持本站!
版權(quán)聲明:本站文章來源標(biāo)注為YINGSOO的內(nèi)容版權(quán)均為本站所有,歡迎引用、轉(zhuǎn)載,請(qǐng)保持原文完整并注明來源及原文鏈接。禁止復(fù)制或仿造本網(wǎng)站,禁止在非maisonbaluchon.cn所屬的服務(wù)器上建立鏡像,否則將依法追究法律責(zé)任。本站部分內(nèi)容來源于網(wǎng)友推薦、互聯(lián)網(wǎng)收集整理而來,僅供學(xué)習(xí)參考,不代表本站立場,如有內(nèi)容涉嫌侵權(quán),請(qǐng)聯(lián)系alex-e#qq.com處理。
 關(guān)注官方微信
                    關(guān)注官方微信