PointNet++网络详解
一、PointNet++改进思想
关于PointNet可以参考前一篇文档 。
前文中已经提到,PointNet并没有做局部特征提取,而是通过最大池化层获取全局的信息。这与当前主流的网络不符。在CNN中,有着感受野的概念,通过不断卷积获得的高维特征点对应着低层的一个区域。而在PointNet中,则没有这种局部特征融合的机制。
针对PointNet的不足,PointNet++应运而生。
PointNet++相较于PointNet,主要有以下几个改进项:
针对点云图点对数量的不规则,采用最远点采样 选取其中的N个点,既能保证每个数据能够有相同的形状,也能让其尽可能保留多的信息量。
通过构建球形搜索区域,获取子区域的点对,实现局部特征提取
提取多尺度特征,对不同子区域的特征进行提取与聚合。
提出基于距离差值的分层特征传播算法,将局部特征上采样传播给在特征融合过程中丢失的点中。
下面我们针对这些改进项进行一些比较细致的分析。
注:B 表示batch;N 表示num;C 和D 都表示特征维度(C是xyz)。
二、最远点采样FPS算法
最远点采样能够对全局点进行采样,在保证每个点云数据具有相同的点数量的同时,尽可能保留更多的信息量。
其中的输入 为:
xyz: 点云坐标数据,shape为 [B,N,3]
npoint: 需要提取的点云数量
输出 为:
centroid: 点云中心点索引 ,shape为 [B,npoint]
FPS(Farthest Point Sample)的核心思想如下:
对输入的每一批点云分别构建簇中心
构建距离矩阵 ,用于每次最远距迭代
在点云中随机选择一个点作为簇初始点
选择与该簇距离最远的点,加入簇,并将该点作为下一次迭代的点
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 def farthest_point_sample (xyz,npoint ): """ Input: xyz: pointcloud data, [B, N, 3] npoint: number of samples Return: centroids: sampled pointcloud index, [B, npoint] """ device = xyz.device B, N, C = xyz.shape centroids=torch.zeros(B,npoint,dtype=torch.long).to(device) distance=torch.ones(B,N).to(device)*1e10 batch_indices=torch.arange(B,dtype=torch.long).to(device) farthest=torch.randint(0 ,N,(B,),dtype=torch.long).to(device) for i in range (npoint): centroids[:,i]=farthest centroid=xyz[batch_indices,farthest,:].view(B,1 ,3 ) dist=torch.sum ((xyz-centroid)**2 ,-1 ) mask=dist<distance distance[mask]=dist[mask] farthest=torch.max (distance,-1 )[1 ] return centroid
关于距离更新算法:
1 2 3 4 5 mask=dist<distance distance[mask]=dist[mask] farthest=torch.max (distance,-1 )[1 ]
在初始化的时候,我们将distance初始化为1e10,那么在第一次更新时,就会将所有点距离进行更新。
且计算是会将自身计算进去的(自己到自己的距离是0),所以每更新一次矩阵,都有一个点的距离被更新为0。
distance在这里的作用,就相当于一个记录表,用来记录每次的状态变化。这样,每有一个点被加入,就有有一个0值被寻得,说明该点已经被使用,不再参与更新。
三、局部特征提取算法
在CNN中的局部特征一般是通过不同大小的卷积核点乘得到的,而在PointNet++中,作者也采用了这类的思想,用来提取子区域。
其核心思想为:
预设一个搜索半径 radius和子区域 的点数量k
在最远点采样中获取的簇中心构造球体 ,半径等于搜索半径
计算每个点离中心簇的距离,若该点落在球体内 ,则将其加入到簇中
若球体内的点小于子区域点数量k,则复制最近的点,直到满足条件,若大于,则选取前k个点。
现在每个中心都有k个点了,类似于CNN的k*k子区域
输入 为:
radius: 搜索半径
nsample: 采样点数量
xyz: 所有点的位置信息
new_xyz: 簇中心
输出 为:
一组簇点的索引,shape为 [B,S,nsample]
如何去获取各点的距离呢?这里采用了如下算法:
对于输入src,shape为[B,N,3];对于输入dst,shape为[B,S,3]
距离公式表示为:
d i s = ( x n − x m ) 2 + ( y n − y m ) 2 + ( z n − z m ) 2 = x n 2 + x m 2 − 2 x n x m + y n 2 + y m 2 − 2 y n y m + z n 2 + z m 2 − 2 z n z m = s r c 2 + d s t 2 − ( s r c T ∗ d s t ) dis=(x_n-x_m)^2+(y_n-y_m)^2+(z_n-z_m)^2
\\=x_n^2+x_m^2-2x_nx_m+y_n^2+y_m^2-2y_ny_m+\\
z_n^2+z_m^2-2z_nz_m
\\
=src^2+dst^2-(src^T*dst)
d i s = ( x n − x m ) 2 + ( y n − y m ) 2 + ( z n − z m ) 2 = x n 2 + x m 2 − 2 x n x m + y n 2 + y m 2 − 2 y n y m + z n 2 + z m 2 − 2 z n z m = sr c 2 + d s t 2 − ( sr c T ∗ d s t )
1 2 3 4 5 6 7 8 9 10 11 12 def square_distance (src,dst ): B,N,_=src.shape _,M,_=dst.shape dist=-2 *torch.matmul(src,dst.permute(0 ,2 ,1 )) dist+=torch.sum (src**2 ,-1 ).view(B,N,1 ) dist+=torch.sum (dst**2 ,-1 ).view(B,1 ,M) return dist
在计算中,需要先构建一个索引组。根据计算得到的距离张量,将超过搜索半径的距离点索引设置为最大值 。这样,我们就得到了实际 落在圆内的点。
接着再做升序排序,选取我们需要的nsample个点。当然,会出现点数不足的情况,所以我们复制最近的点,取最大值的位置做掩膜mask=group_idx==N,将掩膜位置修正为第一个点。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 def query_ball_point (radius,nsample,xyz,new_xyz ): """ Input: radius: local region radius nsample: max sample number in local region xyz: all points, [B, N, 3] new_xyz: query points, [B, S, 3] Return: group_idx: grouped points index, [B, S, nsample] """ device=xyz.device B,N,C=xyz.shape _,S,_=new_xyz.shape group_idx=torch.range (N,dtype=torch.long).to(device).view(1 ,1 ,N).repeat([B,S,1 ]) sqrdists=square_distance(new_xyz,xyz) group_idx[sqrdists>radius**2 ]=N group_idx=group_idx.sort(dim=-1 )[0 ][...,:nsample] group_first=group_idx[...,0 ].view(B,S,1 ).repeat([1 ,1 ,nsample]) mask=group_idx==N group_idx=group_first[mask] return group_idx
四、采样打组
在原文中,二和三被定义为Sampling layer和Grouping layer,张量维度的变换如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 shape: [B,N,d+C] -----> shape: [B,S,d+C] -----> shape: [B,S,K,d+C] ''' B: batch N: 点云总数 S: 采样簇数量 K: 簇中点云数量 d: 位置信息xyz C: 特征 '''
在此之前,需要定义一个函数,用于从索引中获取点。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 def index_points (points,idx ): """ Input: points: input points data, [B, N, C] idx: sample index data, [B, S] Return: new_points:, indexed points data, [B, S, C] """ device = points.device B = points.shape[0 ] view_shape = list (idx.shape) view_shape[1 :] = [1 ] * (len (view_shape) - 1 ) repeat_shape = list (idx.shape) repeat_shape[0 ] = 1 batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) new_points = points[batch_indices, idx, :] return new_points
其中有一点需要注意的是,关于tensor索引为一个矩阵的情况。
例如:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 b=[[1 ,2 ],[3 ,4 ]] i=[[4 ,3 ],[2 ,1 ]] points=torch.arange(25 ).view(5 ,5 ) ''' points: tensor([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24]]) ''' points[b,i] ''' tensor([[ 9, 13], [17, 21]]) '''
这种情况下,是对b和i做组合,也就是说,实际取得的点对为:
1 2 3 4 [[points[1 ,4 ], points[2 ,3 ]] [points[3 ,2 ], points[4 ,1 ]]]
实现的算法为:
输入 :
npoint: 簇中心数量
radius: 搜索半径
nsample: 簇内点数量
xyz: 位置信息
points: 全局点,主要是有其他维度时使用
returnfps: 是否返回最近点信息
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 def sample_and_group (npoint, radius, nsample, xyz, points, returnfps=False ): """ Input: npoint: the number of points radius: search radius nsample: the number of points which in cluster xyz: input points position data, [B, N, 3] points: input points data, [B, N, D] Return: new_xyz: sampled points position data, [B, npoint, nsample, 3] new_points: sampled points data, [B, npoint, nsample, 3+D] """ B,N,C=xyz.shape S=npoint fps_idx=farthest_point_sample(xyz,npoint) torch.cuda.empty_cache() new_xyz=index_points(xyz,fps_idx) idx=query_ball_point(radius,nsample,xyz,new_xyz) torch.cuda.empty_cache() grouped_xyz=index_points(xyz,idx) torch.cuda.empty_cache() grouped_xyz_norm=grouped_xyz-new_xyz.view(B,S,1 ,C) torch.cuda.empty_cache() if points is not None : grouped_points=index_points(points,idx) new_points=torch.cat([grouped_xyz_norm,grouped_points],dim=-1 ) else : new_points=grouped_xyz_norm if returnfps: return new_xyz,new_points,grouped_xyz,fps_idx else : return new_xyz,new_points
五、局部特征提取
PointNet++的局部特征提取与PointNet相同,都是通过一个max pool来实现的。与CNN不同,CNN是在做卷积加权求和,而PointNet++则是通过最大池化来完成。在网络中,作者使用了sampling layer+grouping layer+pointnet来完成整个流程。并将该过程称作set abstraction。SA采样能得到一个融合了局部特征的全局特征。
输入 :
xyz: N个点的位置
类似于CNN的卷积,多次SA操作后输入的N会变成npoint
points: 全部的数据
输出 :
new_xyz: 对原始数据进行采样后,融合了局部特征的新的xyz。shape: [B , C , npoint]
new_points: shape: [B , C+N , npoint]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 class PointNetSetAbstraction (nn.Module): def __init__ (self,npoint,radius,nsample,in_channel,mlp,group_all ): super (PointNetSetAbstraction, self).__init__() self.npoint=npoint self.radius=radius self.nsample=nsample self.mlp_convs=nn.ModuleList() self.mlp_bns=nn.ModuleList() last_channel=in_channel for out_channel in mlp: self.mlp_convs.append(nn.Conv2d(last_channel,out_channel,1 )) self.mlp_bns.append(nn.BatchNorm2d(out_channel)) last_channel=out_channel self.group_all=group_all def forward (self,xyz,points ): """ Input: xyz: input points position data, [B, C, N] points: input points data, [B, D, N] Return: new_xyz: sampled points position data, [B, C, S] new_points_concat: sample points feature data, [B, D', S] """ xyz = xyz.permute(0 , 2 , 1 ) if points is not None : points = points.permute(0 , 2 , 1 ) if self.group_all: new_xyz,new_points=sample_and_group_all(xyz,points) else : new_xyz,new_points=sample_and_group(self.npoint,self.radius,self.nsample,xyz,points) new_points=new_points.permute(0 ,3 ,2 ,1 ) for i,conv in enumerate (self.mlp_convs): bn=self.mlp_bns[i] new_points=F.relu(bn(conv(new_points))) new_points=torch.max (new_points,2 )[0 ] new_xyz=new_xyz.permute(0 ,2 ,1 ) return new_xyz,new_points
六、点云不均匀区域融合
作者在原文中提到:
Features learned in dense data may not generalize to sparsely sampled regions .
密集区特征与稀疏区特征可能会出现不适配,这是因为采样时在稀疏区域采用了最近点补全的方法,且受于尺度的影像,在稀疏区的点往往分布的很开,密集区则相对集中,这也会对结果造成较大的影像。
作者提出了两种特征融合的方法,分别是Multi-scale grouping(MSG 多尺度组合),Multiresolution grouping(MRG 多分辨率组合)。
关于尺度和分辨率,尺度就是观测事物的一种度量,例如看到一辆车,观察车窗和观察车身就是不同的尺度。在图像上的表现为感受野的不同,或者说不同尺寸的卷积核卷积后的尺度不同。而分辨率则是观察汽车,戴眼镜看和不戴眼镜看,都是汽车,但是有模糊和清楚之分。在图像上类似于同一层做池化。
对于多尺度组合MSG而言,就是选取不同半径的子区域(在图像上就是选择不同大小的卷积核)进行特征提取后堆叠。
其代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 class PointNetSetAbstractionMsg (nn.Module): def __init__ (self,npoint,radius_list,nasmple_list,in_channel,mlp_list ): super (PointNetSetAbstractionMsg, self).__init__() self.npoint=npoint self.radius_list=radius_list self.nsample_list=nasmple_list self.conv_block=nn.ModuleList() self.bn_block=nn.ModuleList() for idx,mlp in mlp_list: convs=nn.ModuleList() bns=nn.ModuleList() last_channel=in_channel+3 for output in mlp: convs.append(nn.Conv2d(last_channel,output,1 )) bns.append(nn.BatchNorm2d(output)) last_channel=output self.conv_block.append(convs) self.bn_block.append(bns) def forward (self,xyz,points ): ''' Input: xyz: input points position data, [B, C, N] points: input points data, [B, D, N] Return: new_xyz: sampled points position data, [B, C, S] new_points_concat: sample points feature data, [B, D', S] ''' xyz=xyz.permute(0 ,2 ,1 ) if points is not None : points=points.permute(0 ,2 ,1 ) B,N,C=xyz.shape S=self.npoint new_xyz=index_points(xyz,farthest_point_sample(xyz,S)) new_points_list=[] for i,radius in enumerate (self.radius_list): k=self.nsample_list[i] group_idx=query_ball_point(radius,k,xyz,new_xyz) group_xyz=index_points(xyz,group_idx) group_xyz-=new_xyz.view(B,S,1 ,C) if points is not None : group_points=index_points(points,group_idx) group_points=torch.cat([group_points,group_xyz],dim=-1 ) else : group_points=group_xyz group_points=group_points.permute(0 ,3 ,2 ,1 ) for j in range (len (self.conv_block[i])): conv=self.conv_block[i][j] bn=self.bn_block[i][j] group_points=F.relu(bn(conv(group_points))) new_points=torch.max (group_points,2 )[0 ] new_points_list.append(new_points) new_xyz=new_xyz.permute(0 ,2 ,1 ) new_points_concat=torch.cat(new_points_list,dim=1 ) return new_xyz,new_points_concat
七、点云上采样
在连续的SA层中,不断对原始点进行下采样而获得数量更少的特征点,但若是做分割任务,则需要把点云中的所有点都带上语义标签。若是用之前分类的思想,也就是对所有点做圆进行局部特征提取,实在是太耗费时间了。于是作者提出了基于上采样的方式,将已提取特征的点传递给其他点、
在本部分,作者提出一种基于反距离权重差值的特征传播算法。
其核心思想在于:
反距离插值 ,对每个点的k个临近点按照IDW进行差值。公式如下:
将插值得到的特征与SA阶段得到的特征通过skip-link连接后进行特征堆叠。
特征堆叠后输入到unit pointnet中进一步提取
输入 :
xyz1: 所有点对坐标
xyz2: 降采样后的点坐标
points1: SA层的点
points2: 降采样后的点
输出 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 class PointNetFeaturePropagation (nn.Module): def __init__ (self, in_channel, mlp ): super (PointNetFeaturePropagation, self).__init__() self.mlp_convs = nn.ModuleList() self.mlp_bns = nn.ModuleList() last_channel = in_channel for out_channel in mlp: self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1 )) self.mlp_bns.append(nn.BatchNorm1d(out_channel)) last_channel = out_channel def forward (self,xyz1,xyz2,points1,points2 ): """ Input: xyz1: input points position data, [B, C, N] xyz2: sampled input points position data, [B, C, S] points1: input points data, [B, D, N] points2: input points data, [B, D, S] Return: new_points: upsampled points data, [B, D', N] """ xyz1,xyz2=xyz1.permute(0 ,2 ,1 ),xyz2.permute(0 ,2 ,1 ) points2=points2.permute(0 ,2 ,1 ) B,N,C=xyz1.shape _,S,_=xyz2.shape if S==1 : interpolated_points=points2.repeat(1 ,N,1 ) else : dists=square_distance(xyz1,xyz2) print (dists.shape) dists,idx=dists.sort(dim=-1 ) dists,idx=dists[...,:3 ],idx[...,:3 ] dist_recip=1.0 /(dists+1e-8 ) norm=torch.sum (dist_recip,dim=2 ,keepdim=True ) weight=dist_recip/norm interpolated_points=torch.sum (index_points(points2,idx)* weight.view(B, N, 3 , 1 ), dim=2 ) if points1 is not None : points1=points1.permute(0 ,2 ,1 ) new_points=torch.cat([points1,interpolated_points],dim=-1 ) else : new_points=interpolated_points new_points=new_points.permute(0 ,2 ,1 ) for i,conv in enumerate (self.mlp_convs): bn=self.mlp_bns[i] new_points=F.relu(bn(conv(new_points))) return new_points
整个分类任务如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 class get_model (nn.Module): def __init__ (self,num_class,normal_channel=True ): super (get_model, self).__init__() in_channel=3 if normal_channel else 0 self.normal_channel=normal_channel self.sa1 = PointNetSetAbstractionMsg(512 , [0.1 , 0.2 , 0.4 ], [16 , 32 , 128 ], in_channel,[[32 , 32 , 64 ], [64 , 64 , 128 ], [64 , 96 , 128 ]]) self.sa2 = PointNetSetAbstractionMsg(128 , [0.2 , 0.4 , 0.8 ], [32 , 64 , 128 ], 320 ,[[64 , 64 , 128 ], [128 , 128 , 256 ], [128 , 128 , 256 ]]) self.sa3 = PointNetSetAbstraction(None , None , None , 640 + 3 , [256 , 512 , 1024 ], True ) self.fc1=nn.Linear(1024 ,512 ) self.bn1=nn.BatchNorm1d(512 ) self.drop1=nn.Dropout(0.4 ) self.fc2=nn.Linear(512 ,256 ) self.bn2=nn.BatchNorm1d(256 ) self.drop2=nn.Dropout(0.5 ) self.fc3=nn.Linear(256 ,num_class) def forward (self,xyz ): B,_,_=xyz.shape if self.normal_channel: norm=xyz[:,3 :,:] xyz=xyz[:,:3 ,:] else : norm=None l1_xyz,l1_points=self.sa1(xyz,norm) l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) x=l3_points.view(B,1024 ) x = self.drop1(F.relu(self.bn1(self.fc1(x)))) x = self.drop2(F.relu(self.bn2(self.fc2(x)))) x = self.fc3(x) x=F.log_softmax(x,-1 ) return x,l3_points class get_loss (nn.Module): def __init__ (self ): super (get_loss, self).__init__() def forward (self,pred,target,trans_feat=None ): if trans_feat: total_loss=trans_feat(pred,target) else : total_loss=F.nll_loss(pred,target) return total_loss
而分割任务则是使用了特征传递层的特征融合,代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 class get_model (nn.Module): def __init__ (self, num_classes, normal_channel=False ): super (get_model, self).__init__() if normal_channel: additional_channel = 3 else : additional_channel = 0 self.normal_channel = normal_channel self.sa1 = PointNetSetAbstractionMsg(512 , [0.1 , 0.2 , 0.4 ], [32 , 64 , 128 ], 3 +additional_channel, [[32 , 32 , 64 ], [64 , 64 , 128 ], [64 , 96 , 128 ]]) self.sa2 = PointNetSetAbstractionMsg(128 , [0.4 ,0.8 ], [64 , 128 ], 128 +128 +64 , [[128 , 128 , 256 ], [128 , 196 , 256 ]]) self.sa3 = PointNetSetAbstraction(npoint=None , radius=None , nsample=None , in_channel=512 + 3 , mlp=[256 , 512 , 1024 ], group_all=True ) self.fp3 = PointNetFeaturePropagation(in_channel=1536 , mlp=[256 , 256 ]) self.fp2 = PointNetFeaturePropagation(in_channel=576 , mlp=[256 , 128 ]) self.fp1 = PointNetFeaturePropagation(in_channel=150 +additional_channel, mlp=[128 , 128 ]) self.conv1 = nn.Conv1d(128 , 128 , 1 ) self.bn1 = nn.BatchNorm1d(128 ) self.drop1 = nn.Dropout(0.5 ) self.conv2 = nn.Conv1d(128 , num_classes, 1 ) def forward (self, xyz, cls_label ): B,C,N = xyz.shape if self.normal_channel: l0_points = xyz l0_xyz = xyz[:,:3 ,:] else : l0_points = xyz l0_xyz = xyz l1_xyz, l1_points = self.sa1(l0_xyz, l0_points) l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points) l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) cls_label_one_hot = cls_label.view(B,16 ,1 ).repeat(1 ,1 ,N) l0_points = self.fp1(l0_xyz, l1_xyz, torch.cat([cls_label_one_hot,l0_xyz,l0_points],1 ), l1_points) feat = F.relu(self.bn1(self.conv1(l0_points))) x = self.drop1(feat) x = self.conv2(x) x = F.log_softmax(x, dim=1 ) x = x.permute(0 , 2 , 1 ) return x, l3_points