import os import cv2 import torch import torch.nn as nn from torchvision import models,transforms from torch.utils.data import DataLoader,Dataset import numpy as np from PIL import Image from torch.optim import lr_scheduler import copy
withopen(filelist) as f: lines=[_.strip() for _ in f] # 去除空白 np.random.shuffle(lines) # 随机打乱 for l in lines: img_path,label=l.split('\t') # 获取图片路径和标签 img=Image.fromarray(cv2.imdecode(np.fromfile(img_path,np.float32),1)) imgs.append(img) labels.append(label)
if kind=="test": self.imgs=self.load_origin_data() else: self.imgs,self.labels=self.load_origin_data()
def__getitem__(self, item): if self.mode=="test": return self.transform(self.imgs[item]) else: return self.transform(self.imgs[item]),torch.tensor(self.labels[item])
def__len__(self): returnlen(self.imgs)
defload_origin_data(self): filelist = './data/%s_split_list.txt' % self.mode imgs,labels=[],[] data_dir=os.getcwd()+"/data" if self.mode=='train'or self.mode=='val': withopen(filelist) as f: lines=[_.strip() for _ in f] if self.mode=='train': np.random.shuffle(lines) for l in lines: img_path,label=l.split('\t') img_path=os.path.join(data_dir,img_path) try: img=Image.fromarray(cv2.imdecode(np.fromfile(img_path,dtype=np.float32),1)) imgs.append(img) labels.append(int(label)) except Exception("The path %s"%img_path+" may be wrong") as e: print(e) continue return imgs,labels elif self.mode=="test": full_lines = os.listdir('data/cat_12_test/') lines = [line.strip() for line in full_lines] for img_path in lines: img_path = os.path.join(data_dir, "cat_12_test/", img_path) img = Image.open(img_path) imgs.append(img) return imgs
defget_Dataloader(): img_datasets = {x: myData(x) for x in ['train', 'val', 'test']} dataset_sizes = {x: len(img_datasets[x]) for x in ['train', 'val', 'test']}
print("{} Loss :{:.4f} Acc {:.4}".format(phase,epoch_loss,epoch_acc))
if phase=="val"and epoch_acc>best_acc: best_acc=epoch_acc best_model_wts=copy.deepcopy(model.state_dict()) print("Best val Acc : {:4f}".format(best_acc)) model.load_state_dict(best_model_wts) return model
三、迁移学习
迁移学习(Transfer Learning)就是利用预训练好的大模型参数去学习其他数据的分布。
这个过程我们一般不希望原始模型参数改变,因而一般需要做如下工作:
1 2
for param in model.parameters(): param.requires_grad=False