- 获得box的八个角点坐标
- 在无限高的圆柱中随机采样
- Embedding和Encoding
paper: Improving 3D Object Detection with Channel-wise Transformer
code:https://github.com/hlsheng1/CT3D
在/pcdet/models/roi_heads/ct3d_head.py 中 130行左右
# corner #输入的rois大小为(batch_size,roi的个数,roi信息)(2,128,7) #其中7包括roi的x,y,z,l,w,h,方向角 #最后得到的corner_points是(B, 128, 2*2*2, 3),每个角点的坐标 corner_points, _ = self.get_global_grid_points_of_roi(rois) # (BxN, 2x2x2, 3) corner_points = corner_points.view(batch_size, num_rois, -1, corner_points.shape[-1]) # (2, 128, 2x2x2, 3)
下面看一下 self.get_global_grid_points_of_roi(rois)这个函数,同样在class CT3DHead这个类中
def get_global_grid_points_of_roi(self, rois): rois = rois.view(-1, rois.shape[-1]) # (256,7) batch_size_rcnn = rois.shape[0] # (256) # 得到了box八个角点到中心点的坐标距离 # 中心点的三个轴坐标只要在每个轴加上相应的坐标距离 # 就能得到每一个角点的坐标了。 local_roi_grid_points = self.get_corner_points(rois, batch_size_rcnn) #(256,8,3) #注意!!! #上面得到的只考虑了box的大小,并没有考虑角度 #如果直接拿来计算的话,得到的所有角点都不是真正的,而是平行于坐标轴的! #所以还要通过下面的函数将每个roi的角度考虑信息。 #最后得到真正的角点全局坐标 global_roi_grid_points = common_utils.rotate_points_along_z( local_roi_grid_points.clone(), rois[:, 6] ).squeeze(dim=1) #(256, 8 ,3) #得到中心点的全局坐标 global_center = rois[:, 0:3].clone() # (256, 3) # pdb.set_trace() #中心点加坐标距离最终得到每个box八个点的角点 global_roi_grid_points += global_center.unsqueeze(dim=1) #(256,8,3 ) #(256,8,3) (256,8,3) return global_roi_grid_points, local_roi_grid_points
看一下 self.get_corner_points(rois, batch_size_rcnn)这个函数,同样在class CT3DHead这个类中
@staticmethod def get_corner_points(rois, batch_size_rcnn): #得到一个(2, 2, 2) 的tensor,用来在后面求box八个角点的索引 ,这里还是很巧妙的,可以学习一下。 faked_features = rois.new_ones((2, 2, 2)) # nonzero() 返回每个这个tensor中非零元素的索引 # 例如这里建立了一个(2,2,2)的全是1的tensor,每一个元素都是非零的, #所以得到的非零元素的索引为 [0,0,0], [0,0,1], [0,1,0], [0,1,1] .....[1,1,1] 共8个 dense_idx = faked_features.nonzero() # (8, 3) [x_idx, y_idx, z_idx] dense_idx = dense_idx.repeat(batch_size_rcnn, 1, 1).float() # (256, 2x2x2, 3) # 取每一个RoI的长宽高 local_roi_size = rois.view(batch_size_rcnn, -1)[:, 3:6] # (256,3) #这一步是关键,求每个角点在每个轴上相对于roi中心点的距离 #例如:[0, 0, 1] * [l, w, h] - [l/2,w/2, h/2] # = [-l/2, -w/2, h/2] #即这个角点到中心点的坐标距离 #中心点坐标 在x轴减去一半长,在y轴减去一半宽,在z轴加上一般高。就是这个角点的坐标了。 roi_grid_points = dense_idx * local_roi_size.unsqueeze(dim=1) - (local_roi_size.unsqueeze(dim=1) / 2) # (B, 2x2x2, 3) return roi_grid_points #(256,8,3)
这个函数得到了box每一个角点到中心点的坐标距离,中心点的三个轴坐标只要在每个轴加上相应的坐标距离,就能得到每一个角点的坐标了。
在无限高的圆柱中随机采样论文中在RoI中采样点时,将RoI转换为一个高度无限高的圆柱体,并从中随机采样256个点作为RoI的表征。采样方法如下
圆柱体底面的半径为:
代码如下
num_sample = self.num_points # 这里是256个点 src = rois.new_zeros(batch_size, num_rois, num_sample, 4) #(2,128,256,4) for bs_idx in range(batch_size): # 每次循环一个batch_size # batch_dict[points]是所有batch_size中所有的点 # batch_dict[points][(batch_dict['points'][:, 0] == bs_idx)]是第bs_idx个batch中的点 # batch_dict[points][(batch_dict['points'][:, 0] == bs_idx)][:,1:5]是每个点的坐标和反射强度 #cur_points是一个batch中所有点的坐标+反射强度(194165,4) cur_points = batch_dict['points'][(batch_dict['points'][:, 0] == bs_idx)][:,1:5] #(194165,4) # 每个batch的roi box cur_batch_boxes = batch_dict['rois'][bs_idx] #(128,7) # 求半径公式如上图所示 (128) cur_radiis = torch.sqrt((cur_batch_boxes[:,3]/2) ** 2 + (cur_batch_boxes[:,4]/2) ** 2) * 1.2 # 所有点到roi box中心的距离, # 共128个RoI,每一个RoI 都计算19165个点到其中心的距离 dis = torch.norm((cur_points[:,:2].unsqueeze(0) - cur_batch_boxes[:,:2].unsqueeze(1).repeat(1,cur_points.shape[0],1)), dim = 2) # (128,19165) #过滤出半径内的点 point_mask = (dis <= cur_radiis.unsqueeze(-1)) # 遍历每一个roi for roi_box_idx in range(0, num_rois): # point_mask[roi_box_idx] 是第roi_box_idx个roi的mask #cur_roi_points 这里是(465,4),即这个圆柱roi中有465个点 cur_roi_points = cur_points[point_mask[roi_box_idx]] # 如果roi内部的点大于要求采样的数量的话,就随机取256个(源码中num_sample=256) # 如果roi内部的点个数等于0,就用0填充256个 # 如果roi内部点在0到256之间,采样所有点,剩余的用0填充 if cur_roi_points.shape[0] >= num_sample: random.seed(0) index = np.random.randint(cur_roi_points.shape[0], size=num_sample) cur_roi_points_sample = cur_roi_points[index] elif cur_roi_points.shape[0] == 0: cur_roi_points_sample = cur_roi_points.new_zeros(num_sample, 4) else: empty_num = num_sample - cur_roi_points.shape[0] add_zeros = cur_roi_points.new_zeros(empty_num, 4) add_zeros = cur_roi_points[0].repeat(empty_num, 1) cur_roi_points_sample = torch.cat([cur_roi_points, add_zeros], dim = 0) #这个roi的采样结束,记录到src中 src[bs_idx, roi_box_idx, :, :] = cur_roi_points_sample #采样结束,经过整理得到src,共b个batch,每个batch中128个roi,每个roi取256个点(x,y,z,r) src = src.view(batch_size * num_rois, -1, src.shape[-1]) # (b*128, 256, 4)Embedding和Encoding
论文的主要结构图如下所示:
得到了圆柱体内部的采样点,和box角点,中心点的信息。可以进行embedding了
如下所示
以一个roi举例,roi内部256个采样点,每一个点都计算自己与roi八个角点以及中心点的距离,再包括自己的反射强度,通过线形层升维。具体公式如下所示:对于点采样点pi, fi是对该点的embedding
Enbedding的代码如下:
# src是采样得到的roi内部的点 src = src.view(batch_size * num_rois, -1, src.shape[-1]) # (b*128, 256, 4) #corner_points是上一步得到的roi八个角点坐标(256,8,3) -->(256,24) corner_points = corner_points.view(batch_size * num_rois, -1) #将每个roi的中心点坐标concat到八个角点坐标上 # corner_add_center_points (256,24) -->(256,27) corner_add_center_points = torch.cat([corner_points, rois.view(-1, rois.shape[-1])[:,:3]], dim = -1) # 计算每个点与角点中心点得坐标距离 #pos_fea(b*roi, num_sample, 27) pos_fea = src[:,:,:3].repeat(1,1,9) - corner_add_center_points.unsqueeze(1).repeat(1,num_sample,1) # 27 维度 #roi的长宽高 lwh (b*roi, num_sample, 3) lwh = rois.view(-1, rois.shape[-1])[:,3:6].unsqueeze(1).repeat(1,num_sample,1) # (l*l + w*w + h*h) ** 0.5 diag_dist = (lwh[:,:,0]**2 + lwh[:,:,1]**2 + lwh[:,:,2]**2) ** 0.5 # pos_fea(256,256,27) 每一个点到角点,中心点的 球形坐标距离 pos_fea = self.spherical_coordinate(pos_fea, diag_dist = diag_dist.unsqueeze(-1)) # src(256,256,28) src = torch.cat([pos_fea, src[:,:,-1].unsqueeze(-1)], dim = -1) #src(256,256,256) src = self.up_dimension(src)
下面看一下 pos_fea = self.spherical_coordinate(pos_fea, diag_dist = diag_dist.unsqueeze(-1))这个函数:
def spherical_coordinate(self, src, diag_dist): # src(256,256,27)每个采样点到角点中心点的坐标距离 # diag_dist(256,256,1) assert (src.shape[-1] == 27) device = src.device #分别取每个采样点到角点中心点坐标距离 indices_x = torch.LongTensor([0,3,6,9,12,15,18,21,24]).to(device) # indices_y = torch.LongTensor([1,4,7,10,13,16,19,22,25]).to(device) # indices_z = torch.LongTensor([2,5,8,11,14,17,20,23,26]).to(device) src_x = torch.index_select(src, -1, indices_x) # (256, 256, 9) src_y = torch.index_select(src, -1, indices_y) # (256, 256, 9) src_z = torch.index_select(src, -1, indices_z) # (256, 256, 9) #把坐标距离转换成球形坐标距离?? dis = (src_x ** 2 + src_y ** 2 + src_z ** 2) ** 0.5 phi = torch.atan(src_y / (src_x + 1e-5)) the = torch.acos(src_z / (dis + 1e-5)) dis = dis / diag_dist #src(256,256,256)球形坐标距离?? src = torch.cat([dis, phi, the], dim = -1) return src
这里为什么要把坐标系转换成球形坐标,作者在github上解释说,换不换没有什么差别,所以在论文中就没有提到…
Tranformer如下所示
def forward(self, src, query_embed, pos_embed): # src (256,256,256) # query_embed (1,256) # pos_embed (256,256,256) 位置编码在Emcoder中q,k相加 bs, n, c = src.shape src = src.permute(1, 0, 2) pos_embed = pos_embed.permute(1, 0, 2) query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # (1,256,256) tgt = torch.zeros_like(query_embed) #(1,256,256) # memory (256,256,256) memory = self.encoder(src, src_key_padding_mask=None, pos=pos_embed) #因为在整体代码中,TransformerDecoder的初始化参数return_intermediate设置为True #因此,Decoder的输出包含了每层的结果,共有一层,shape是[1,num_querie,batch_size,hidden_dim] hs = self.decoder(tgt, memory, memory_key_padding_mask=None, pos=pos_embed, query_pos=query_embed) #(1,256,1,256) return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, n)
Encoder模块没有什么可说的,和DETR一样的,共有三层。encoder得到memory(256,256,256),256个roi,每个roi256个采样点,通道数为256。与输入相同
Decoder共一层,与DETR一样,输入的num_query = 1 ,代表每个roi里256个采样点生成一个box 。
k,v来自encoder,每一个roi得到一个box。
唯一不同的是,decoder在计算self-attention时候。公式由标准的
变成了论文中的:
再乘以V
代码如下
def attention(query, key, value): dim = query.shape[1] scores_1 = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5 scores_2 = torch.einsum('abcd, aced->abcd', key, scores_1) prob = torch.nn.functional.softmax(scores_2, dim=-1) output = torch.einsum('bnhm,bdhm->bdhn', prob, value) return output, prob
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)