【点云处理】PointNet++点云分类与分割

【点云处理】PointNet++点云分类与分割,第1张

PointNet++网络详解

一、PointNet++改进思想

关于PointNet可以参考前一篇文档。


前文中已经提到,PointNet并没有做局部特征提取,而是通过最大池化层获取全局的信息。


这与当前主流的网络不符。


在CNN中,有着感受野的概念,通过不断卷积获得的高维特征点对应着低层的一个区域。


而在PointNet中,则没有这种局部特征融合的机制。


针对PointNet的不足,PointNet++应运而生。


PointNet++相较于PointNet,主要有以下几个改进项:

  • 针对点云图点对数量的不规则,采用最远点采样选取其中的N个点,既能保证每个数据能够有相同的形状,也能让其尽可能保留多的信息量。


  • 通过构建球形搜索区域,获取子区域的点对,实现局部特征提取
  • 提取多尺度特征,对不同子区域的特征进行提取与聚合。


  • 提出基于距离差值的分层特征传播算法,将局部特征上采样传播给在特征融合过程中丢失的点中。


下面我们针对这些改进项进行一些比较细致的分析。


注:B表示batch;N表示num;C和D都表示特征维度(C是xyz)。





二、最远点采样FPS算法

最远点采样能够对全局点进行采样,在保证每个点云数据具有相同的点数量的同时,尽可能保留更多的信息量。


其中的输入为:

  • xyz: 点云坐标数据,shape为 [B,N,3]
  • npoint: 需要提取的点云数量

输出为:

  • centroid: 点云中心点索引,shape为 [B,npoint]

FPS(Farthest Point Sample)的核心思想如下:

  • 对输入的每一批点云分别构建簇中心
  • 构建距离矩阵,用于每次最远距迭代
  • 在点云中随机选择一个点作为簇初始点
  • 选择与该簇距离最远的点,加入簇,并将该点作为下一次迭代的点
def farthest_point_sample(xyz,npoint):
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """

    device = xyz.device
    B, N, C = xyz.shape

    # 构建中心簇 , 大小为: [ batch , npoint ]
    centroids=torch.zeros(B,npoint,dtype=torch.long).to(device)
    # 构建距离矩阵
    distance=torch.ones(B,N).to(device)*1e10
    # 对batch细分
    batch_indices=torch.arange(B,dtype=torch.long).to(device)
    # 最远点初始化 随机选择一个点
    farthest=torch.randint(0,N,(B,),dtype=torch.long).to(device)

    for i in range(npoint):
        centroids[:,i]=farthest
        # 获取当前采样点坐标值
        centroid=xyz[batch_indices,farthest,:].view(B,1,3)
        # 计算当前点与其他点的距离
        dist=torch.sum((xyz-centroid)**2,-1)

        # 获取满足条件的逻辑矩阵
        mask=dist<distance
        # 选择距离最近的点来更新距离
        distance[mask]=dist[mask]
        farthest=torch.max(distance,-1)[1] # 获得最远点的索引
    return centroid

关于距离更新算法:

# 获取满足条件的逻辑矩阵
mask=dist<distance
# 选择距离最近的点来更新距离
distance[mask]=dist[mask]
farthest=torch.max(distance,-1)[1] # 获得最远点的索引

在初始化的时候,我们将distance初始化为1e10,那么在第一次更新时,就会将所有点距离进行更新。


且计算是会将自身计算进去的(自己到自己的距离是0),所以每更新一次矩阵,都有一个点的距离被更新为0。


distance在这里的作用,就相当于一个记录表,用来记录每次的状态变化。


这样,每有一个点被加入,就有有一个0值被寻得,说明该点已经被使用,不再参与更新。




三、局部特征提取算法

在CNN中的局部特征一般是通过不同大小的卷积核点乘得到的,而在PointNet++中,作者也采用了这类的思想,用来提取子区域。


其核心思想为:

  • 预设一个搜索半径radius和子区域的点数量k
  • 在最远点采样中获取的簇中心构造球体,半径等于搜索半径
  • 计算每个点离中心簇的距离,若该点落在球体内,则将其加入到簇中
  • 若球体内的点小于子区域点数量k,则复制最近的点,直到满足条件,若大于,则选取前k个点。


  • 现在每个中心都有k个点了,类似于CNN的k*k子区域

输入为:

  • radius: 搜索半径
  • nsample: 采样点数量
  • xyz: 所有点的位置信息
  • new_xyz: 簇中心

输出为:

  • 一组簇点的索引,shape为 [B,S,nsample]

如何去获取各点的距离呢?这里采用了如下算法:

对于输入src,shape为[B,N,3];对于输入dst,shape为[B,S,3]

距离公式表示为:
d i s = ( x n − x m ) 2 + ( y n − y m ) 2 + ( z n − z m ) 2 = x n 2 + x m 2 − 2 x n x m + y n 2 + y m 2 − 2 y n y m + z n 2 + z m 2 − 2 z n z m = s r c 2 + d s t 2 − ( s r c T ∗ d s t ) dis=(x_n-x_m)^2+(y_n-y_m)^2+(z_n-z_m)^2 \=x_n^2+x_m^2-2x_nx_m+y_n^2+y_m^2-2y_ny_m+\ z_n^2+z_m^2-2z_nz_m \ =src^2+dst^2-(src^T*dst) dis=(xnxm)2+(ynym)2+(znzm)2=xn2+xm22xnxm+yn2+ym22ynym+zn2+zm22znzm=src2+dst2(srcTdst)

def square_distance(src,dst):
    # 计算各点间的距离
    B,N,_=src.shape
    _,M,_=dst.shape
    # shape: [B,N,M]
    dist=-2*torch.matmul(src,dst.permute(0,2,1))
    # shape: [B,N,M]+[B,N,1]
    dist+=torch.sum(src**2,-1).view(B,N,1)
    # shape: [B,N,M]+[B,1,M]
    dist+=torch.sum(dst**2,-1).view(B,1,M)
    return dist

在计算中,需要先构建一个索引组。


根据计算得到的距离张量,将超过搜索半径的距离点索引设置为最大值。


这样,我们就得到了实际落在圆内的点。


接着再做升序排序,选取我们需要的nsample个点。


当然,会出现点数不足的情况,所以我们复制最近的点,取最大值的位置做掩膜mask=group_idx==N,将掩膜位置修正为第一个点。


def query_ball_point(radius,nsample,xyz,new_xyz):
    """
    Input:
        radius: local region radius
        nsample: max sample number in local region
        xyz: all points, [B, N, 3]
        new_xyz: query points, [B, S, 3]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    device=xyz.device
    B,N,C=xyz.shape
    _,S,_=new_xyz.shape
    group_idx=torch.range(N,dtype=torch.long).to(device).view(1,1,N).repeat([B,S,1])
    # 得到一组[ B , S , N ] 的数据,即new_xyz与xyz中每个点的距离
    sqrdists=square_distance(new_xyz,xyz)
    # (x1-x2)**2+(y1-y2)**2+(z1-z2)**2>r**2
    # 这部分点已经超过了搜索半径了,所以令其索引等于最大值N(之前有效的最大值是N-1)
    group_idx[sqrdists>radius**2]=N
    # tensor.sort会返回一个value和一个index
    # 做了个升序排序,nsample就是我们需要的簇间点
    group_idx=group_idx.sort(dim=-1)[0][...,:nsample]
    # 针对半径点不足的情况,使用第一个点进行替代
    # shape: [B,S,nsample]
    group_first=group_idx[...,0].view(B,S,1).repeat([1,1,nsample])
    mask=group_idx==N
    group_idx=group_first[mask]
    # [B,S,nsample]
    return group_idx


四、采样打组

在原文中,二和三被定义为Sampling layer和Grouping layer,张量维度的变换如下:

# input
shape: [B,N,d+C]
----->
# sampling layer
shape: [B,S,d+C]
----->
# grouping layer
shape: [B,S,K,d+C]

'''
B: batch
N: 点云总数
S: 采样簇数量
K: 簇中点云数量
d: 位置信息xyz
C: 特征
'''

在此之前,需要定义一个函数,用于从索引中获取点。


def index_points(points,idx):
    """
    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S]
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    # [ B , 1 ]
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    # [ 1 , S ]
    repeat_shape[0] = 1
    # [ B , S ]
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    # row: batch_indices
    # col: idx
    new_points = points[batch_indices, idx, :]
    return new_points

其中有一点需要注意的是,关于tensor索引为一个矩阵的情况。


例如:

# row
b=[[1,2],[3,4]]
# col
i=[[4,3],[2,1]]

# poi
points=torch.arange(25).view(5,5)

'''
points:
tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24]])
'''

points[b,i]
'''
tensor([[ 9, 13],
        [17, 21]])
'''

这种情况下,是对b和i做组合,也就是说,实际取得的点对为:

[[points[1,4], # 9
points[2,3]] # 13
[points[3,2], # 17
points[4,1]]] # 21

实现的算法为:

输入:

  • npoint: 簇中心数量
  • radius: 搜索半径
  • nsample: 簇内点数量
  • xyz: 位置信息
  • points: 全局点,主要是有其他维度时使用
  • returnfps: 是否返回最近点信息
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
    """
    Input:
        npoint: the number of points
        radius: search radius
        nsample: the number of points which in cluster
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, npoint, nsample, 3]
        new_points: sampled points data, [B, npoint, nsample, 3+D]
    """
    B,N,C=xyz.shape
    S=npoint
   
    # sampling layer
    # 数据的形状为: [ B , npoint , C ]
    fps_idx=farthest_point_sample(xyz,npoint)
    torch.cuda.empty_cache() # 清空显存
    # 获取当前点对
    new_xyz=index_points(xyz,fps_idx)
    
    
    # groupling layer
    # [ B , npoint , nsample ]
    idx=query_ball_point(radius,nsample,xyz,new_xyz)
    torch.cuda.empty_cache() # 清空显存
    # [ B , npoint , nsample , C]
    grouped_xyz=index_points(xyz,idx)
    torch.cuda.empty_cache() # 清空显存
    
    # 中心化
    # 主要是减去中心点的坐标
    grouped_xyz_norm=grouped_xyz-new_xyz.view(B,S,1,C)
    torch.cuda.empty_cache() # 清空显存

    # 其他维度特征融合
    if points is not None:
        grouped_points=index_points(points,idx)
        new_points=torch.cat([grouped_xyz_norm,grouped_points],dim=-1)
    else:
        new_points=grouped_xyz_norm
    if returnfps:
        return new_xyz,new_points,grouped_xyz,fps_idx
    else:
        return new_xyz,new_points


五、局部特征提取

PointNet++的局部特征提取与PointNet相同,都是通过一个max pool来实现的。


与CNN不同,CNN是在做卷积加权求和,而PointNet++则是通过最大池化来完成。


在网络中,作者使用了sampling layer+grouping layer+pointnet来完成整个流程。


并将该过程称作set abstraction


SA采样能得到一个融合了局部特征的全局特征。



输入:

  • xyz: N个点的位置
    • 类似于CNN的卷积,多次SA *** 作后输入的N会变成npoint
  • points: 全部的数据

输出:

  • new_xyz: 对原始数据进行采样后,融合了局部特征的新的xyz。


    shape: [B , C , npoint]

  • new_points: shape: [B , C+N , npoint]

class PointNetSetAbstraction(nn.Module):

    def __init__(self,npoint,radius,nsample,in_channel,mlp,group_all):
        super(PointNetSetAbstraction, self).__init__()
        self.npoint=npoint
        self.radius=radius
        self.nsample=nsample
        self.mlp_convs=nn.ModuleList()
        self.mlp_bns=nn.ModuleList()
        last_channel=in_channel
        for out_channel in mlp:

            self.mlp_convs.append(nn.Conv2d(last_channel,out_channel,1))
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel=out_channel
        self.group_all=group_all

    def forward(self,xyz,points):
        """
              Input:
                  xyz: input points position data, [B, C, N]
                  points: input points data, [B, D, N]
              Return:
                  new_xyz: sampled points position data, [B, C, S]
                  new_points_concat: sample points feature data, [B, D', S]
              """
        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)
            

        if self.group_all:
            new_xyz,new_points=sample_and_group_all(xyz,points)
        else:
		 new_xyz,new_points=sample_and_group(self.npoint,self.radius,self.nsample,xyz,points)

        # new_xyz: 带有位置的采样点数据,形状为: [ B , npoint ,C ]
        # new_points: 采样点数据(聚类后的) [ B , npoint, nsample ,C+D ]
        new_points=new_points.permute(0,3,2,1) # [ B , C+D , nsample , npoint ]

        # 这步是一个PointNet
        for i,conv in enumerate(self.mlp_convs):
            bn=self.mlp_bns[i]
            new_points=F.relu(bn(conv(new_points)))
        
        new_points=torch.max(new_points,2)[0] # [ B , C+D , npoint]
       
        new_xyz=new_xyz.permute(0,2,1)
        
        return new_xyz,new_points


六、点云不均匀区域融合

作者在原文中提到:

Features learned in dense data may not generalize to sparsely sampled regions.

密集区特征与稀疏区特征可能会出现不适配,这是因为采样时在稀疏区域采用了最近点补全的方法,且受于尺度的影像,在稀疏区的点往往分布的很开,密集区则相对集中,这也会对结果造成较大的影像。


作者提出了两种特征融合的方法,分别是Multi-scale grouping(MSG 多尺度组合),Multiresolution grouping(MRG 多分辨率组合)。



关于尺度和分辨率,尺度就是观测事物的一种度量,例如看到一辆车,观察车窗和观察车身就是不同的尺度。


在图像上的表现为感受野的不同,或者说不同尺寸的卷积核卷积后的尺度不同。


而分辨率则是观察汽车,戴眼镜看和不戴眼镜看,都是汽车,但是有模糊和清楚之分。


在图像上类似于同一层做池化。


对于多尺度组合MSG而言,就是选取不同半径的子区域(在图像上就是选择不同大小的卷积核)进行特征提取后堆叠。


其代码如下:

class PointNetSetAbstractionMsg(nn.Module):
    def __init__(self,npoint,radius_list,nasmple_list,in_channel,mlp_list):
        super(PointNetSetAbstractionMsg, self).__init__()
        self.npoint=npoint
        self.radius_list=radius_list
        self.nsample_list=nasmple_list
        self.conv_block=nn.ModuleList()
        self.bn_block=nn.ModuleList()
        for idx,mlp in mlp_list:
            convs=nn.ModuleList()
            bns=nn.ModuleList()
            last_channel=in_channel+3
            for output in mlp:
                convs.append(nn.Conv2d(last_channel,output,1))
                bns.append(nn.BatchNorm2d(output))
                last_channel=output
            self.conv_block.append(convs)
            self.bn_block.append(bns)

    def forward(self,xyz,points):
        '''
        Input:
            xyz: input points position data, [B, C, N]
            points: input points data, [B, D, N]
        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        '''
        # xyz是坐标点位置特征
        xyz=xyz.permute(0,2,1) # [B,N,C]
        if points is not None:
            # 提取到的额外特征
            points=points.permute(0,2,1) # [B,N,D]

        B,N,C=xyz.shape
        S=self.npoint
        # 采样后的坐标点
        new_xyz=index_points(xyz,farthest_point_sample(xyz,S))

        new_points_list=[]
        # 多尺度特征提取
        for i,radius in enumerate(self.radius_list):
            k=self.nsample_list[i]
            group_idx=query_ball_point(radius,k,xyz,new_xyz)
            group_xyz=index_points(xyz,group_idx)
            # 中心化
            group_xyz-=new_xyz.view(B,S,1,C)
            if points is not None:
                group_points=index_points(points,group_idx)
                group_points=torch.cat([group_points,group_xyz],dim=-1)
                
            else:
                group_points=group_xyz

            group_points=group_points.permute(0,3,2,1) # [B,D,K,S]
           
            for j in range(len(self.conv_block[i])):
                conv=self.conv_block[i][j]
                bn=self.bn_block[i][j]
                group_points=F.relu(bn(conv(group_points)))
           
            # [B,D',S]
            new_points=torch.max(group_points,2)[0]
            
            new_points_list.append(new_points)

        new_xyz=new_xyz.permute(0,2,1)
        # 多尺度特征融合
        new_points_concat=torch.cat(new_points_list,dim=1)

        return new_xyz,new_points_concat


七、点云上采样

在连续的SA层中,不断对原始点进行下采样而获得数量更少的特征点,但若是做分割任务,则需要把点云中的所有点都带上语义标签。


若是用之前分类的思想,也就是对所有点做圆进行局部特征提取,实在是太耗费时间了。


于是作者提出了基于上采样的方式,将已提取特征的点传递给其他点、

在本部分,作者提出一种基于反距离权重差值的特征传播算法。


其核心思想在于:

  • 反距离插值,对每个点的k个临近点按照IDW进行差值。


    公式如下:

    • f ( j ) ( x ) = ∑ i = 1 k w i ( x ) f i ( j ) ∑ i = 1 k w i ( x ) f^{(j)}(x)=\frac{\sum_{i=1}^k w_i(x)f_i^{(j)}}{\sum_{i=1}^kw_i(x)} f(j)(x)=i=1kwi(x)i=1kwi(x)fi(j)

      • w i ( x ) = 1 d ( x , x i ) p w_i(x)=\frac {1}{d(x,x_i)^p} wi(x)=d(x,xi)p1
  • 将插值得到的特征与SA阶段得到的特征通过skip-link连接后进行特征堆叠。


  • 特征堆叠后输入到unit pointnet中进一步提取

输入:

  • xyz1: 所有点对坐标
  • xyz2: 降采样后的点坐标
  • points1: SA层的点
  • points2: 降采样后的点

输出:

  • skip-link后的特征点
class PointNetFeaturePropagation(nn.Module):
    
    def __init__(self, in_channel, mlp):
        super(PointNetFeaturePropagation, self).__init__()
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm1d(out_channel))
            last_channel = out_channel
            
    def forward(self,xyz1,xyz2,points1,points2):
        """
        Input:
            xyz1: input points position data, [B, C, N]
            xyz2: sampled input points position data, [B, C, S]
            points1: input points data, [B, D, N]
            points2: input points data, [B, D, S]
        Return:
            new_points: upsampled points data, [B, D', N]
        """
        xyz1,xyz2=xyz1.permute(0,2,1),xyz2.permute(0,2,1)
        points2=points2.permute(0,2,1)

        B,N,C=xyz1.shape
        _,S,_=xyz2.shape

        if S==1:
            # 此时仅有一个采样点
            interpolated_points=points2.repeat(1,N,1)
            # 上采样,把当前点的特征copy N 次
        else:
            dists=square_distance(xyz1,xyz2)
            # 距离张量
            print(dists.shape)
            dists,idx=dists.sort(dim=-1)
            dists,idx=dists[...,:3],idx[...,:3]
            # 反距离权重法
            dist_recip=1.0/(dists+1e-8)
            # 为了让权重归一
            norm=torch.sum(dist_recip,dim=2,keepdim=True)
            weight=dist_recip/norm
            interpolated_points=torch.sum(index_points(points2,idx)* weight.view(B, N, 3, 1), dim=2)

        if points1 is not None:
            points1=points1.permute(0,2,1)
            # 跟原先位置的点做skip-link
            new_points=torch.cat([points1,interpolated_points],dim=-1)
        else:
            new_points=interpolated_points
        new_points=new_points.permute(0,2,1)
        for i,conv in enumerate(self.mlp_convs):
            bn=self.mlp_bns[i]
            new_points=F.relu(bn(conv(new_points)))
        return new_points

整个分类任务如下:

class get_model(nn.Module):
    def __init__(self,num_class,normal_channel=True):
        super(get_model, self).__init__()
        in_channel=3 if normal_channel else 0
        self.normal_channel=normal_channel
        
        # SA层
        self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [16, 32, 128], in_channel,[[32, 32, 64], [64, 64, 128], [64, 96, 128]])
        self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 320,[[64, 64, 128], [128, 128, 256], [128, 128, 256]])
        self.sa3 = PointNetSetAbstraction(None, None, None, 640 + 3, [256, 512, 1024], True)
		# SA返回来的是全局特征
        
        # 最后的判别层
        self.fc1=nn.Linear(1024,512)
        self.bn1=nn.BatchNorm1d(512)
        self.drop1=nn.Dropout(0.4)
        self.fc2=nn.Linear(512,256)
        self.bn2=nn.BatchNorm1d(256)
        self.drop2=nn.Dropout(0.5)
        self.fc3=nn.Linear(256,num_class)

    def forward(self,xyz):
        B,_,_=xyz.shape
        if self.normal_channel:
            norm=xyz[:,3:,:] # 这个是特征
            xyz=xyz[:,:3,:] # 这个是位置
        else:
            norm=None

        l1_xyz,l1_points=self.sa1(xyz,norm) # return [B,C+D,npoint]
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        
        # 线性展平
        x=l3_points.view(B,1024)
        # 预测类别
        x = self.drop1(F.relu(self.bn1(self.fc1(x))))
        x = self.drop2(F.relu(self.bn2(self.fc2(x))))
        x = self.fc3(x)
        x=F.log_softmax(x,-1)
        return x,l3_points

class get_loss(nn.Module):
    def __init__(self):
        super(get_loss, self).__init__()

    def forward(self,pred,target,trans_feat=None):
        if trans_feat:
            total_loss=trans_feat(pred,target)
        else:
            total_loss=F.nll_loss(pred,target)
        return total_loss

而分割任务则是使用了特征传递层的特征融合,代码如下:

class get_model(nn.Module):
    def __init__(self, num_classes, normal_channel=False):
        super(get_model, self).__init__()
        if normal_channel:
            additional_channel = 3
        else:
            additional_channel = 0
        self.normal_channel = normal_channel
        
        # SA层
        self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [32, 64, 128], 3+additional_channel, [[32, 32, 64], [64, 64, 128], [64, 96, 128]])
        self.sa2 = PointNetSetAbstractionMsg(128, [0.4,0.8], [64, 128], 128+128+64, [[128, 128, 256], [128, 196, 256]])
        self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=512 + 3, mlp=[256, 512, 1024], group_all=True)
        
        # FP层
        self.fp3 = PointNetFeaturePropagation(in_channel=1536, mlp=[256, 256])
        self.fp2 = PointNetFeaturePropagation(in_channel=576, mlp=[256, 128])
        self.fp1 = PointNetFeaturePropagation(in_channel=150+additional_channel, mlp=[128, 128])
        
        # MLP层
        self.conv1 = nn.Conv1d(128, 128, 1)
        self.bn1 = nn.BatchNorm1d(128)
        self.drop1 = nn.Dropout(0.5)
        self.conv2 = nn.Conv1d(128, num_classes, 1)

    def forward(self, xyz, cls_label):
        # Set Abstraction layers
        B,C,N = xyz.shape
        if self.normal_channel:
            l0_points = xyz
            l0_xyz = xyz[:,:3,:]
        else:
            l0_points = xyz
            l0_xyz = xyz
        l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        # Feature Propagation layers
        l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
        l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
        # 获取类别的one-hot编码
        cls_label_one_hot = cls_label.view(B,16,1).repeat(1,1,N)
        l0_points = self.fp1(l0_xyz, l1_xyz, torch.cat([cls_label_one_hot,l0_xyz,l0_points],1), l1_points)
        # FC layers
        feat = F.relu(self.bn1(self.conv1(l0_points)))
        x = self.drop1(feat)
        x = self.conv2(x)
        x = F.log_softmax(x, dim=1)
        x = x.permute(0, 2, 1)
        return x, l3_points

欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/langs/567401.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-09
下一篇 2022-04-09

发表评论

登录后才能评论

评论列表(0条)

保存