關(guān)于Pytorch中怎么自定義Dataset數(shù)據(jù)集類、怎樣使用DataLoader迭代加載數(shù)據(jù),這篇官方文檔已經(jīng)說得很清楚了,這里就不在贅述。
創(chuàng)新互聯(lián)科技有限公司專業(yè)互聯(lián)網(wǎng)基礎(chǔ)服務(wù)商,為您提供德陽機房服務(wù)器托管,高防服務(wù)器租用,成都IDC機房托管,成都主機托管等互聯(lián)網(wǎng)服務(wù)。現(xiàn)在的問題:有的時候,特別對于NLP任務(wù)來說,輸入的數(shù)據(jù)可能不是定長的,比如多個句子的長度一般不會一致,這時候使用DataLoader加載數(shù)據(jù)時,不定長的句子會被胡亂切分,這肯定是不行的。
解決方法是重寫DataLoader的collate_fn,具體方法如下:
# 假如每一個樣本為: sample = { # 一個句子中各個詞的id 'token_list' : [5, 2, 4, 1, 9, 8], # 結(jié)果y 'label' : 5, } # 重寫collate_fn函數(shù),其輸入為一個batch的sample數(shù)據(jù) def collate_fn(batch): # 因為token_list是一個變長的數(shù)據(jù),所以需要用一個list來裝這個batch的token_list token_lists = [item['token_list'] for item in batch] # 每個label是一個int,我們把這個batch中的label也全取出來,重新組裝 labels = [item['label'] for item in batch] # 把labels轉(zhuǎn)換成Tensor labels = torch.Tensor(labels) return { 'token_list': token_lists, 'label': labels, } # 在使用DataLoader加載數(shù)據(jù)時,注意collate_fn參數(shù)傳入的是重寫的函數(shù) DataLoader(trainset, batch_size=4, shuffle=True, num_workers=4, collate_fn=collate_fn)
網(wǎng)頁標題:PytorchDataLoader變長數(shù)據(jù)處理方式-創(chuàng)新互聯(lián)
當前鏈接:http://jinyejixie.com/article12/ggogc.html
成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供標簽優(yōu)化、小程序開發(fā)、軟件開發(fā)、微信小程序、服務(wù)器托管、移動網(wǎng)站建設(shè)
聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請盡快告知,我們將會在第一時間刪除。文章觀點不代表本網(wǎng)站立場,如需處理請聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時需注明來源: 創(chuàng)新互聯(lián)