對pytorch中不定長序列補(bǔ)齊的操作
第二種方法通常是在load一個(gè)batch數(shù)據(jù)時(shí), 在collate_fn中進(jìn)行補(bǔ)齊的.
以下給出兩種思路:
第一種思路是比較容易想到的, 就是對一個(gè)batch的樣本進(jìn)行遍歷, 然后使用np.pad對每一個(gè)樣本進(jìn)行補(bǔ)齊.
for unit in data: mask = np.zeros(max_length) s_len = len(unit[0]) # calculate the length of sequence in each unit mask[: s_len] = 1 unit[0] = np.pad(unit[0], (0, max_length - s_len), 'constant', constant_values=(0, 0)) mask_batch.append(mask)
但是這種方法在batch size很大的情況下會很慢, 因?yàn)槭褂胒or循環(huán)進(jìn)行了遍歷. 我在實(shí)際用的時(shí)候, 當(dāng)batch_size=128時(shí), 一個(gè)batch的加載時(shí)間甚至是一個(gè)batch訓(xùn)練時(shí)間的幾倍!
因此, 我想到如何并行地對序列進(jìn)行補(bǔ)齊. 第二種方法的思路就是使用torch中自帶的pad_sequence來并行補(bǔ)齊.
batch_sequence = list(map(lambda x: torch.tensor(x[findex]), x_data)) batch_data[feat] = torch.nn.utils.rnn.pad_sequence(batch_sequence).T
可以看到這里使用pad_sequence一次性對整個(gè)batch進(jìn)行補(bǔ)齊. 下面對這個(gè)函數(shù)進(jìn)行詳細(xì)說明.
pad_sequence詳解
from torch.utils.rnn import pad_sequence a = torch.ones(10) b = torch.ones(6) c = torch.ones(20) abc = pad_sequence([a,b,c]) # shape(20, 3)
注意這個(gè)函數(shù)接收的是一個(gè)元素為tensor的列表, 而不是tensor.
最終, 這個(gè)函數(shù)會將所有tensor轉(zhuǎn)換為tensor矩陣#shape(max_length, batch_size). 因此, 在使用完后通常還需要轉(zhuǎn)置一下.
補(bǔ)充:PyTorch中用于RNN變長序列填充函數(shù)的簡單使用
1、PyTorch中RNN變長序列的問題
RNN在處理變長序列時(shí)有它的優(yōu)勢。在分批處理變長序列問題時(shí),每個(gè)序列的長度往往不會完全相等,因此針對一個(gè)batch中序列長度不一的情況,需要對某些序列進(jìn)行PAD(填充)操作,使得一個(gè)batch內(nèi)的序列長度相等。
PyTorch中的pack_padded_sequence和pad_packed_sequence可處理上述問題,以下用一個(gè)示例演示這兩個(gè)函數(shù)的簡單使用方法。
2、填充函數(shù)簡介
“壓縮”函數(shù):用于將填充后的序列tensor進(jìn)行壓縮,方便RNN處理
pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)
(1)input->被“壓縮”的tensor,維度一般為[batch_size,_max_seq_len[,embedding_size]]或者[max_seq_len,batch_size[,embedding_size]]
若input維度為:[batch_size,_max_seq_len[,embedding_size]]
要將batch_first設(shè)置為True,這表示input的第一個(gè)維度為batch的數(shù)量
若input維度為:[max_seq_len,batch_size[,embedding_size]]
要將batch_first設(shè)置為False(默認(rèn)值),這表示input的第一個(gè)維度不是batch的數(shù)量
(2)lengths->lengths參數(shù)表示一個(gè)batch中序列真實(shí)長度,類型為列表,在例子中詳細(xì)說明
(3)batch_first->表示batch的數(shù)量是否在input的第一維度,默認(rèn)值為False
(4)enforce_sorted->input中的會自動按照lengths的情況進(jìn)行排序,默認(rèn)值為
“解壓”函數(shù):該函數(shù)與"壓縮函數(shù)"相對應(yīng),經(jīng)“壓縮函數(shù)”處理的輸入經(jīng)過RNN得到的最終結(jié)果可以利用該函數(shù)進(jìn)行“解壓”
pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None):
(1)sequence->壓縮函數(shù)處理過的input經(jīng)RNN后得到的結(jié)果
(2)batch_first->與“壓縮”函數(shù)中的batch_first一致
(3)padding_value->序列進(jìn)行填充時(shí)使用的索引,默認(rèn)為0
(4)total_length->暫略
3、PyTorch代碼示例
代碼如下(示例):
# Create by leslie_miao on 2020/11/1 import torch import torch.nn as nn d_model = 10 # 詞嵌入的維度 hidden_size = 20 # lstm隱藏層單元數(shù)量 layer_num = 1 # lstm層數(shù) # 輸入inputs,維度為[batch_size,max_seq_len]=[3,4],其中0代表填充 # 該input包含3個(gè)序列,每個(gè)序列的真實(shí)長度分別為: 4 3 2 inputs = torch.tensor([[1,2,3,4],[1,2,3,0],[1,2,0,0]]) embedding = nn.Embedding(5,d_model) # 獲取詞嵌入后的inputs 當(dāng)前inputs的維度為[batch_size,max_seq_len,d_model]=[3,4,10] inputs = embedding(inputs) # 查看inputs的維度 print(inputs.size()) # print: torch.Size([3, 4, 10]) # 利用“壓縮”函數(shù)對inputs進(jìn)行壓縮處理,[4,3,2]分別為inputs中序列的真實(shí)長度,batch_first=True表示inputs的第一維是batch_size inputs = nn.utils.rnn.pack_padded_sequence(inputs,lengths=[4,3,2],batch_first=True) # 查看經(jīng)“壓縮”函數(shù)處理過的inputs的維度 print(inputs[0].size()) # print: torch.Size([9, 10]) # 定義RNN網(wǎng)絡(luò) network = nn.LSTM(input_size=d_model,hidden_size=hidden_size,batch_first=True,num_layers=layer_num) # 初始化RNN相關(guān)門參數(shù) c_0 = torch.zeros((layer_num,3,hidden_size)) h_0 = torch.zeros((layer_num,3,hidden_size)) # [rnn層數(shù),batch_size,hidden_size] # inputs經(jīng)過RNN網(wǎng)絡(luò)后得到的結(jié)果outputs output,(h_n,c_n) = network(inputs,(h_0,c_0)) #查看未經(jīng)“解壓函數(shù)”處理的outputs維度 print(output[0].size()) # print: torch.Size([9, 20]) # 利用“解壓函數(shù)”對outputs進(jìn)行解壓操作,其中batch_first設(shè)置與“壓縮函數(shù)相同”,padding_value為0 output = nn.utils.rnn.pad_packed_sequence(output,batch_first=True,padding_value=0) # 查看經(jīng)“解壓函數(shù)”處理的outputs維度 print(output[0].size()) # print:torch.Size([3, 4, 20])
總結(jié)
介紹了PyTorch中兩個(gè)應(yīng)用于RNN變長序列填充的函數(shù)pack_padded_sequence和 pad_packed_sequence的簡單使用方法,歡迎指正交流!
版權(quán)聲明:本站文章來源標(biāo)注為YINGSOO的內(nèi)容版權(quán)均為本站所有,歡迎引用、轉(zhuǎn)載,請保持原文完整并注明來源及原文鏈接。禁止復(fù)制或仿造本網(wǎng)站,禁止在非maisonbaluchon.cn所屬的服務(wù)器上建立鏡像,否則將依法追究法律責(zé)任。本站部分內(nèi)容來源于網(wǎng)友推薦、互聯(lián)網(wǎng)收集整理而來,僅供學(xué)習(xí)參考,不代表本站立場,如有內(nèi)容涉嫌侵權(quán),請聯(lián)系alex-e#qq.com處理。
關(guān)注官方微信