五月综合激情婷婷六月,日韩欧美国产一区不卡,他扒开我内裤强吻我下面视频 ,无套内射无矿码免费看黄,天天躁,日日躁,狠狠躁

新聞動態(tài)

PyTorch 如何檢查模型梯度是否可導(dǎo)

發(fā)布日期:2022-03-18 16:36 | 文章來源:源碼之家

一、PyTorch 檢查模型梯度是否可導(dǎo)

當(dāng)我們構(gòu)建復(fù)雜網(wǎng)絡(luò)模型或在模型中加入復(fù)雜操作時(shí),可能會需要驗(yàn)證該模型或操作是否可導(dǎo),即模型是否能夠優(yōu)化,在PyTorch框架下,我們可以使用torch.autograd.gradcheck函數(shù)來實(shí)現(xiàn)這一功能。

首先看一下官方文檔中關(guān)于該函數(shù)的介紹:

可以看到官方文檔中介紹了該函數(shù)基于何種方法,以及其參數(shù)列表,下面給出幾個(gè)例子介紹其使用方法,注意:

Tensor需要是雙精度浮點(diǎn)型且設(shè)置requires_grad = True

第一個(gè)例子:檢查某一操作是否可導(dǎo)

from torch.autograd import gradcheck
import torch
import torch.nn as nn
 
inputs = torch.randn((10, 5), requires_grad=True, dtype=torch.double)
linear = nn.Linear(5, 3)
linear = linear.double()
test = gradcheck(lambda x: linear(x), inputs)
print("Are the gradients correct: ", test)

輸出為:

Are the gradients correct: True

第二個(gè)例子:檢查某一網(wǎng)絡(luò)模型是否可導(dǎo)

from torch.autograd import gradcheck
import torch
import torch.nn as nn 
# 定義神經(jīng)網(wǎng)絡(luò)模型
class Net(nn.Module):
 
 def __init__(self):
  super(Net, self).__init__()
  self.net = nn.Sequential(
nn.Linear(15, 30),
nn.ReLU(),
nn.Linear(30, 15),
nn.ReLU(),
nn.Linear(15, 1),
nn.Sigmoid()
  )
 
 def forward(self, x):
  y = self.net(x)
  return y
 
net = Net()
net = net.double()
inputs = torch.randn((10, 15), requires_grad=True, dtype=torch.double)
test = gradcheck(net, inputs)
print("Are the gradients correct: ", test)

輸出為:

Are the gradients correct: True

二、Pytorch求導(dǎo)

1.標(biāo)量對矩陣求導(dǎo)

驗(yàn)證:

>>>import torch
>>>a = torch.tensor([[1],[2],[3.],[4]]) # 4*1列向量
>>>X = torch.tensor([[1,2,3],[5,6,7],[8,9,10],[5,4,3.]],requires_grad=True)  #4*3矩陣,注意,值必須要是float類型
>>>b = torch.tensor([[2],[3],[4.]]) #3*1列向量
>>>f = a.view(1,-1).mm(X).mm(b)  # f = a^T.dot(X).dot(b)
>>>f.backward()
>>>X.grad#df/dX = a.dot(b^T)
tensor([[ 2.,  3.,  4.],
 [ 4.,  6.,  8.],
 [ 6.,  9., 12.],
 [ 8., 12., 16.]])
>>>a.grad b.grad# a和b的requires_grad都為默認(rèn)(默認(rèn)為False),所以求導(dǎo)時(shí),沒有梯度
(None, None)
>>>a.mm(b.view(1,-1))  # a.dot(b^T)
 tensor([[ 2.,  3.,  4.],
 [ 4.,  6.,  8.],
 [ 6.,  9., 12.],
 [ 8., 12., 16.]])

2.矩陣對矩陣求導(dǎo)

驗(yàn)證:

>>>A = torch.tensor([[1,2],[3,4.]])  #2*2矩陣
>>>X =  torch.tensor([[1,2,3],[4,5.,6]],requires_grad=True)  # 2*3矩陣
>>>F = A.mm(X)
>>>F
tensor([[ 9., 12., 15.],
 [19., 26., 33.]], grad_fn=<MmBackward>)
>>>F.backgrad(torch.ones_like(F)) # 注意括號里要加上這句
>>>X.grad
tensor([[4., 4., 4.],
 [6., 6., 6.]])

注意:

requires_grad為True的數(shù)組必須是float類型

進(jìn)行backgrad的必須是標(biāo)量,如果是向量,必須在后面括號里加上torch.ones_like(X)

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持本站。

美國服務(wù)器租用

版權(quán)聲明:本站文章來源標(biāo)注為YINGSOO的內(nèi)容版權(quán)均為本站所有,歡迎引用、轉(zhuǎn)載,請保持原文完整并注明來源及原文鏈接。禁止復(fù)制或仿造本網(wǎng)站,禁止在非maisonbaluchon.cn所屬的服務(wù)器上建立鏡像,否則將依法追究法律責(zé)任。本站部分內(nèi)容來源于網(wǎng)友推薦、互聯(lián)網(wǎng)收集整理而來,僅供學(xué)習(xí)參考,不代表本站立場,如有內(nèi)容涉嫌侵權(quán),請聯(lián)系alex-e#qq.com處理。

相關(guān)文章

實(shí)時(shí)開通

自選配置、實(shí)時(shí)開通

免備案

全球線路精選!

全天候客戶服務(wù)

7x24全年不間斷在線

專屬顧問服務(wù)

1對1客戶咨詢顧問

在線
客服

在線客服:7*24小時(shí)在線

客服
熱線

400-630-3752
7*24小時(shí)客服服務(wù)熱線

關(guān)注
微信

關(guān)注官方微信
頂部