语义分割网络
语义分割是对图像在像素级别上进行分类的方法,在一张图像中,属于同一类的像素点都要被预测为相同的类。因此语义分割是从像素级别来理解图像。
注意,语义分割仅仅是把某一类划分出来,而针对每个个体没办法进行分割(实例分割)。
常见的语义分割网络有很多,如FCN、U-Net、SegNet、DeepLab等。
FCN
FCN(Fully Convolutional Networks)属于利用深度网络进行图片语义分割的开山之作,其主要思想为:
对于一般的CNN网络分类图像,如VGG和ResNet,在网络的最后是通过全连接层,通过softmax进行分类,但这只能表示整个图片的类别。FCN把最后几个全连接层都换成了卷积操作,得到和输入图像尺寸相当的特征映射,最后通过softmax获取每个像素点的分类信息,实现像素点的图像分割。
端到端像素级语义分割任务,需要输出分类结果尺寸和输入图像尺寸一致,面对池化造成的图面尺寸缩小,FCN采用反卷积(deconvolution)进行上采样,从而保证图像大小的一致。
为了更有效的利用特征映射的信息,FCN提出一种跨层连接结构,将低层和高层的目标位置信息的特征映射进行融合,即将低层位置强语义弱的信息跟高层位置弱语义强的信息进行融合,提升网络对图像分割的性能。
U-Net
U-Net基于FCN网络提出,能够适应较小的训练集。其采用大量弹性形变的方法对数据进行增强,让模型更好的学习形变不变形。在不同特征融合方式上,U-Net采用通道维度上的拼接融合代替FCN的逐点相加。
SegNet
SegNet的网络结构借鉴了自编码网络的思想,具有编码器网络和解码器网络。最后通过softmax分类器对每个像素点进行分类。网络在编码器处会执行卷积和最大池化,在解码器部分则会执行上采样和卷积。
基于PyTorch预训练好的语义分割网络实现VOC数据集分类
数据集
本次使用VOC2012数据集,来源于:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html。
数据集中存在20个类别的1个背景类:
Person: person
Animal: bird, cat, cow, dog, horse, sheep
Vehicle: aeroplane, bicycle, boat, bus, car, motorbike, train
Indoor: bottle, chair, dining table, potted plant, sofa, tv/monitor
在Annotations文件夹中,存放有对应图片的标记文件,以XML格式存储。
网络
在Pytorch中,提供训练好的fcn和deeplabv3网络,可以用作图像分割。
实现
1.1 导入模块
需要用到的模块主要是torchvision,直接pip install torchvision即可。
1 2 3 4 5 6 7 import numpy as npimport pandas as pdimport matplotlib.pyplot as pltimport PIL.Image as Imageimport torchfrom torchvision import transformsimport torchvisio
1.2 数据处理
我们加载torch中训练好的全卷积残差网络fcn_resnet101,设置预训练。如果是第一次加载需要在网络上下载参数。
由于该网络已经训练好了,所以我们不再进行训练,使用其评估模式eval。该该模式下,不启用 Batch Normalization 和 Dropout。即在测试过程中保证BN层均值方差不变,在Dropout层不随机舍弃神经元。
1 2 3 model=torchvision.models.segmentation.fcn_resnet101(pretrained=True ) model.eval ()
然后就需要把我们的图像读取进来啦,这里随机选用一张VOC2012数据。
对图片需要进行预处理:
数据格式转化为张量
RGB通道标准化
添加batch维度
1 2 3 4 5 6 7 8 9 image=Image.open (r"F:\VOCdevkit\VOC2012\JPEGImages\2007_002488.jpg" ) image_transf=transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485 ,0.456 ,0.406 ], std=[0.229 ,0.224 ,0.225 ]) ]) image_tensor=image_transf(image).unsqueeze(0 )
1.3 网络预测
1 output=model(image_tensor)["out" ]
1.4 结果可视化
输出的Tensor是结果分类的,为了方便可视化,需要做以下处理:
将Tensor重新转为图像
定义每一类对应的色彩,并将图像编码
1 2 3 4 5 6 7 8 9 outputarg=torch.argmax(output.squeeze(),dim=0 ).numpy() label_colors=np.array([(0 ,0 ,0 ),(128 ,0 ,0 ),(0 ,128 ,0 ),(128 ,128 ,0 ),(0 ,0 ,128 ),(128 ,0 ,128 ), (0 ,128 ,128 ),(128 ,128 ,128 ),(64 ,0 ,0 ),(192 ,0 ,0 ),(64 ,128 ,0 ),(192 ,128 ,0 ), (64 ,0 ,128 ),(192 ,0 ,128 ),(64 ,128 ,128 ),(192 ,128 ,128 ),(0 ,64 ,0 ),(128 ,64 ,0 ), (0 ,192 ,0 ),(128 ,192 ,0 ),(0 ,64 ,128 )])
图像编码的话,先生成三个维度的初始数据,接着获取输出类别的位置,并在该位置上为三个维度附上颜色值。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 def decode_segmaps (image,label_colors,nc=21 ): r=np.zeros_like(image).astype(np.uint8) g=np.zeros_like(image).astype(np.uint8) b=np.zeros_like(image).astype(np.uint8) for cla in range (0 ,nc): idx=(image==cla) r[idx]=label_colors[cla,0 ] g[idx]=label_colors[cla,1 ] b[idx]=label_colors[cla,2 ] rgbimage=np.stack([r,g,b],axis=2 ) return rgbimage
最后输出即可
1 2 3 4 5 6 7 8 9 10 11 outputrgb=decode_segmaps(outputarg,label_colors) plt.figure(figsize=(12 ,8 )) plt.subplot(1 ,2 ,1 ) plt.imshow(image) plt.axis("off" ) plt.subplot(1 ,2 ,2 ) plt.imshow(outputrgb) plt.axis("off" ) plt.subplots_adjust(wspace=0.05 ) plt.show()
结果
训练语义分割网络
基于VGG19搭建全卷积语义分割网络。
首先数据有一个标记集:
也有一个原始集:
针对一个图像,在训练阶段我们需要做的事情是:
将标记图像和原始图像对应的图片路径一一对应
将图像统一分为固定的尺寸 时,需要保持原始图像和其对应的标记图像从相同位置进行分割。
对原始图像进行标准化
对标记好的图像,其中的RGB值对应着一个类,需要将其转化为一个二维数据,其中每个位置的取值对应着图像在该像素点的类。
实现
1.1 导入模块
1 2 3 4 5 6 7 8 9 10 11 12 13 14 import numpy as npimport pandas as pdimport matplotlib.pyplot as pltimport PIL.Image as Imageimport torch.utils.data as Datafrom time import timeimport copyimport torchimport torch.nn as nnfrom torchvision import transformsimport torchvisionfrom torch.nn import functional as Fimport torch.optim as optimfrom torchsummary import summary
1.2 数据处理
全局信息,包括是否使用GPU,以及图像色带。
1 2 3 4 5 6 7 device=torch.device("cuda" if torch.cuda.is_available() else "cpu" ) colormap=[(0 ,0 ,0 ),(128 ,0 ,0 ),(0 ,128 ,0 ),(128 ,128 ,0 ),(0 ,0 ,128 ),(128 ,0 ,128 ), (0 ,128 ,128 ),(128 ,128 ,128 ),(64 ,0 ,0 ),(192 ,0 ,0 ),(64 ,128 ,0 ),(192 ,128 ,0 ), (64 ,0 ,128 ),(192 ,0 ,128 ),(64 ,128 ,128 ),(192 ,128 ,128 ),(0 ,64 ,0 ),(128 ,64 ,0 ), (0 ,192 ,0 ),(128 ,192 ,0 ),(0 ,64 ,128 )]
对于图像数据,我们主要进行以下几个工作:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 def img2lab (img,colormap ): cm2lbl=np.zeros(256 **3 ) for i,cm in enumerate (colormap): cm2lbl[((cm[0 ]*256 +cm[1 ])*256 +cm[2 ])]=i image=np.array(img,dtype="int64" ) ix=((image[:,:,0 ]*256 +image[:,:,1 ])*256 +image[:,:,2 ]) image2=cm2lbl[ix] return image2
1 2 3 4 5 6 7 8 9 10 11 def rand_crop (data,label,high,width ): im_width,im_high=data.size left=np.random.randint(0 ,im_width-width) top=np.random.randint(0 ,im_high-high) right=left+width bottom=top+high data=data.crop((left,top,right,bottom)) label=label.crop((left,top,right,bottom)) return data,label
1 2 3 4 5 6 7 8 9 10 11 def img_transforms (data,label,high,width,colormap ): data,label=rand_crop(data,label,high,width) data_tfs=transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.485 ,0.456 ,0.406 ],[0.229 ,0.224 ,0.225 ]) ]) data=data_tfs(data) label=torch.from_numpy(img2lab(label,colormap)) return data,label
1.3 读取文件
VOC2012的数据集路径保存在train.txt中,我们需要获取该文件,通过np.loadtxt保存路径信息。
1 2 3 4 5 6 7 8 9 def read_image_path (root=r"F:\VOCdevkit\VOC2012\ImageSets\Segmentation\train.txt" ): image=np.loadtxt(root,dtype=str ) n=len (image) data,label=[None ]*n,[None ]*n for i,fname in enumerate (image): data[i]=r"F:\VOCdevkit\VOC2012\JPEGImages\%s.jpg" %(fname) label[i]=r"F:\VOCdevkit\VOC2012\SegmentationClass\%s.png" %(fname) return data,label
接着我们需要定义一个Dataset类,继承自torch.utils.data.Dataset,作为DataLoade中的数据源。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 class MyDataset (Data.Dataset): def __init__ (self,data_root,high,width,imtransform,colormap ): self.data_root=data_root self.high=high self.width=width self.imtransform=imtransform self.cm=colormap data_list,label_list=read_image_path(data_root) self.data_list=self._filter (data_list) self.label_list=self._filter (label_list) def _filter (self,images ): imlist=[] for im in images: img=Image.open (im) if img.size[1 ]>self.high and img.size[0 ]>self.width: imlist.append(im) return imlist def __getitem__ (self,idx ): img=self.data_list[idx] lab=self.label_list[idx] img=Image.open (img) lab=Image.open (lab).convert("RGB" ) img,lab=self.imtransform(img,lab,self.high,self.width,self.cm) return img,lab def __len__ (self ): return len (self.data_list)
查看数据
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 high,width=320 ,480 voc_train=MyDataset(r"F:\VOCdevkit\VOC2012\ImageSets\Segmentation\train.txt" ,high,width,img_transforms,colormap) voc_val=MyDataset(r"F:\VOCdevkit\VOC2012\ImageSets\Segmentation\val.txt" ,high,width,img_transforms,colormap) train_loader=Data.DataLoader(voc_train,batch_size=4 ,shuffle=True ,pin_memory=True ) val_loader=Data.DataLoader(voc_val,batch_size=4 ,shuffle=True ,pin_memory=True ) for step,(bx,by) in enumerate (train_loader): if step>0 : break print ("bx.shape" ,bx.shape)print ("by.shape" ,by.shape)print ("bx" ,bx)print ("by" ,by)
1.4 可视化函数
本部分需要做的事情是将Tensor数据转化为图像数据,包括label和image的转化。
需要反标准化回去,并且将溢出的浮点抹去
1 2 3 4 5 def inv_normalize_img (data ): rgb_mean=np.array([0.485 ,0.456 ,0.406 ]) rgb_std=np.array([0.229 ,0.224 ,0.225 ]) data=data.astype("float32" )*rgb_std+rgb_mean return data.clip(0 ,1 )
将标签转化为RGB图像
1 2 3 4 5 6 7 8 9 def label2img (prelab,colormap ): h,w=prelab.shape prelab=prelab.reshape(h*w,-1 ) img=np.zeros((h*w,3 ),dtype="int32" ) for ii in range (len (colormap)): index=np.where(prelab==ii) img[index,:]=colormap[ii] return img.reshape(h,w,3 )
可视化图像
1 2 3 4 5 6 7 8 9 10 11 12 13 14 bx_numpy=bx.data.numpy() bx_numpy=bx_numpy.transpose(0 ,2 ,3 ,1 ) by_numpy=by.data.numpy() plt.figure(figsize=(16 ,6 )) for i in range (4 ): plt.subplot(2 ,4 ,i+1 ) plt.imshow(inv_normalize_img(bx_numpy[i])) plt.axis("off" ) plt.subplot(2 ,4 ,i+5 ) plt.imshow(label2img(by_numpy[i],colormap)) plt.axis("off" ) plt.subplots_adjust(wspace=0.1 ,hspace=0.1 ) plt.show()
1.5 网络构建
使用训练好的VGG19网络作为backbone,定义语义分割网络FCN8S。其核心在于:
将卷积后的结果反卷积
在不同层进行特征融合,提高辨识度
最终得到与原图大小相同的图片,利用softmax判别类别
FCN8S会在第五个最大池化层进行反卷积,得到大小为w/16的特征,融合将其加上第四个最大池化后的数据后进行处理,再次反卷积得到w/8的特征。最后通过分类器,将特征维度转换为类别数量,判断每个像素点在每个特征维度上的概率,即可实现图像分割。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 class FCN8S (nn.Module): def __init__ (self,num_class ): ''' :param num_class: 训练数据的类别 ''' super (FCN8S, self).__init__() self.num_class=num_class model_vgg19=torchvision.models.vgg19(pretrained=True ) self.backbone=model_vgg19.features self.relu=nn.ReLU(inplace=True ) self.deconv1=nn.ConvTranspose2d(512 ,512 ,kernel_size=3 ,stride=2 ,padding=1 ,dilation=1 ,output_padding=1 ) self.bn1=nn.BatchNorm2d(512 ) self.deconv2=nn.ConvTranspose2d(512 ,256 ,kernel_size=3 ,stride=2 ,padding=1 ,dilation=1 ,output_padding=1 ) self.bn2=nn.BatchNorm2d(256 ) self.deconv3=nn.ConvTranspose2d(256 ,128 ,kernel_size=3 ,stride=2 ,padding=1 ,dilation=1 ,output_padding=1 ) self.bn3=nn.BatchNorm2d(128 ) self.deconv4 = nn.ConvTranspose2d(128 , 64 , kernel_size=3 , stride=2 , padding=1 , dilation=1 , output_padding=1 ) self.bn4 = nn.BatchNorm2d(64 ) self.deconv5 = nn.ConvTranspose2d(64 , 32 , kernel_size=3 , stride=2 , padding=1 , dilation=1 , output_padding=1 ) self.bn5 = nn.BatchNorm2d(32 ) self.classifier=nn.Conv2d(32 ,num_class,kernel_size=1 ) self.layers={"4" :"maxpool_1" , "9" :"maxpool_2" , "18" :"maxpool_3" , "27" :"maxpool_4" , "36" :"maxpool_5" } def forward (self,x ): output={} for name,layer in self.backbone._modules.items(): x=layer(x) if name in self.layers: output[self.layers[name]]=x x5=output["maxpool_5" ] x4=output["maxpool_4" ] x3=output["maxpool_3" ] score=self.relu(self.deconv1(x5)) score=self.bn1(score+x4) score=self.relu(self.deconv2(score)) score=self.bn2(score+x3) score=self.bn3(self.relu(self.deconv3(score))) score=self.bn4(self.relu(self.deconv4(score))) score=self.bn5(self.relu(self.deconv5(score))) score=self.classifier(score) return score
1.6 网络训练
正常去训练就好,注意这里没有使用残差网络,所以可以保留最好的参数。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 def train_model (model,criterion,optimizer,traindataloader,valdataloader,num_epoch=25 ): since=time() best_model_wts=copy.deepcopy(model.state_dict()) best_loss=1e10 train_loss_all=[] train_acc_all=[] val_loss_all=[] val_acc_all=[] for epoch in range (num_epoch): print ("Epoch {}/{}" .format (epoch,num_epoch-1 )) print ("-" *10 ) train_loss=0.0 train_num=0 val_loss=0.0 val_num=0 model.train() for step,(bx,by) in enumerate (traindataloader): optimizer.zero_grad() bx=bx.float ().to(device) by=by.long().to(device) out=model(bx) out=F.log_softmax(out,dim=1 ) pre_lab=torch.argmax(out,1 ) loss=criterion(out,by) loss.backward() optimizer.step() train_loss+=loss.item()*len (by) train_num+=len (by) train_loss_all.append(train_loss/train_num) print ("{} Train Loss: {:.4f}" .format (epoch,train_loss_all[-1 ])) model.eval () for step,(bx,by) in enumerate (valdataloader): bx,by=bx.float ().to(device),by.long().to(device) out=model(bx) out=F.log_softmax(out,dim=1 ) pre_lab=torch.argmax(out,1 ) loss=criterion(out,by) val_loss+=loss.item()*len (by) val_num+=len (by) val_loss_all.append(val_loss/val_num) print ("{} Val Loss: {:.4f}" .format (epoch,val_loss_all[-1 ])) if val_loss_all[-1 ]<best_loss: best_loss=val_loss_all[-1 ] best_model_wts=copy.deepcopy(model.state_dict()) time_use=time()-since print ("Train and val complete in {:.0f}m {:.0f}s" .format (time_use//60 ,time_use%60 )) train_process=pd.DataFrame( data={"epoch" :range (num_epoch), "train_loss_all" :train_loss_all, "val_loss_all" :val_loss_all} ) model.load_state_dict(best_model_wts) return model,train_process lr=0.0003 criterion=nn.NLLLoss() optimizer=optim.Adam(fcn8s.parameters(),lr=lr,weight_decay=1e-4 ) fcn8s,train_process=train_model(fcn8s,criterion,optimizer,train_loader,val_loader,num_epoch=5 ) torch.save(fcn8s,"fcn8s.pkl" )
1.7 结果可视化
1 2 3 4 5 6 7 8 plt.figure(figsize=(10 ,6 )) plt.plot(train_process.epoch,train_process.train_loss_all,"ro-" ,label="Train Loss" ) plt.plot(train_process.epoch,train_process.val_loss_all,"bs-" ,label="Val Loss" ) plt.legend() plt.xlabel("epoch" ) plt.ylabel("Loss" ) plt.show()
查看在验证集上的效果
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 for step,(bx,by) in enumerate (val_loader): if step>0 : break fcn8s.eval () bx=bx.float ().to(device) by=by.long().to(device) out=fcn8s(bx) out=F.log_softmax(out,dim=1 ) pre_lab=torch.argmax(out,1 ) bx_numpy=bx.cpu().data.numpy() bx_numpy=bx_numpy.transpose(0 ,2 ,3 ,1 ) by_numpy=by.cpu().data.numpy() pre_lab_numpy=pre_lab.cpu().numpy() for i in range (4 ): plt.subplot(3 ,4 ,i+1 ) plt.imshow(inv_normalize_img(bx_numpy[i])) plt.axis("off" ) plt.subplot(3 ,4 ,i+5 ) plt.imshow(label2img(by_numpy[i],colormap)) plt.axis("off" ) plt.imshow(label2img(pre_lab_numpy[i],colormap)) plt.axis("off" ) plt.subplots_adjust(wspace=0.05 ,hspace=0.05 ) plt.show()