基于CNN的图像/语义分割算法主要有Unet FCN PSPnet DAnet DeepLabV3+,HRnet+OCR等,去年年底基于Transform的各类CV算法(如ViT,Swin等)在分割/分类任务上都表现了相比CNN更为优秀的分割精度。
这里就简单介绍一下基于Swin模块的Unet分割算法:来自慕尼黑工业大学的Swin-Unet
论文:https://arxiv.org/abs/2105.05537
代码:https://github.com/HuCaoFighting/Swin-Unet
首先我们看模型结构:
整个网络结构看起来非常的清楚,可以说基本上就相当于把Unet中的2D卷积换成了Swin模块。对于Swin提出的W-MSA和SW-MSA在前面Swinformer那一期介绍了一下。更详细的还是得看代码。Swin论文那里我认为为了讲故事这块结构写的的有点玄学了。
整体结构和算法部分下面我跟着代码一起详细介绍:
首先是数据增广:
Swin-Unet代码结构比较清晰清爽,整体逻辑非常清晰:
def random_rot_flip(image, label): #随机翻转 k = np.random.randint(0, 4) image = np.rot90(image, k) label = np.rot90(label, k) axis = np.random.randint(0, 2) image = np.flip(image, axis=axis).copy() label = np.flip(label, axis=axis).copy() return image, label def random_rotate(image, label): #正负20度旋转 angle = np.random.randint(-20, 20) image = ndimage.rotate(image, angle, order=0, reshape=False) label = ndimage.rotate(label, angle, order=0, reshape=False) return image, label
图像增广方面就用了两个,一个是图像和label同步进行随机翻转,一个是图像和label进行正负20度随机旋转
其他的就很常规了:
首先写了一个Synapse_dataset类,通过继承torch的Dataset类,复写Dataset中的__len__和__getitem__方法,其中__getitem__主要是读图像和label的numpy数组,利用上面的图像增广做同步矩阵变换之后转换成pytorch的torch.tensor后喂入模型,__getitem__主要是读到图像时同步为图像和label做对应的 *** 作。这里我为了博客轻量把具体实现代码去掉了。想看的同学可以去看这块源码。很简单。
class Synapse_dataset(Dataset): def __init__(self, base_dir, list_dir, split, transform=None): def __len__(self): return len(self.sample_list) def __getitem__(self, idx): return sample
最后老办法喂入torch的dataloader后通过epoch的for循环同步读取训练数据的image和label的tensor:
trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, worker_init_fn=worker_init_fn)
数据预处理说完了。接下来介绍网络实现步骤:
首先是transform的PatchEmbed结构:
整个结构基本上就是照搬Swin的PatchEmbed方法,直接通过一个2D卷积
表征位置信息(事实上目前很多基于Transform的算法都这么干的)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
class PatchEmbed(nn.Module): r""" Image to Patch Embedding Args: img_size (int): Image size. Default: 224. patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C if self.norm is not None: x = self.norm(x) return x
熟悉Unet结构的同学应该清楚整个Unet核心其实就三部分:
编码头:
对图像特征进行聚合,同时下采样,WH减半,channel同步增加(由于Swin输入多少输出多少,所以下采样功能是通过torch的linear层实现的)
解码头:
将图像上采样要原图大小方便进行像素点分类
跳连接:
网络层越深得到的特征图,有着更大的感受野,浅层卷积关注纹理特征,深层网络关注本质的那种特征,通过跳连接可以使特征向量同时具有深层和表层特征(cat方法),由于图像在上采样过程(CNN的图像分割一般通过2Dconv+双线性插值进行上采样)本身不增加新的信息,但是每一次下采样提炼特征的同时,也必然会损失一些边缘特征,而失去的特征并不能从上采样中找回,因此通过特征的拼接,来实现边缘特征的一个找回。
由于SwinBlock相比CNN比较特殊,它的输入和输出是一样的,下采样主要
上采样:
作者尝试了双线性插值/转置卷积/Patch expand三种方法,通过对比实验证明了其有效性:
Patch expand方法其实很简单,首先通过一个线性层把长采样到两倍,然后通过torch.view()通道数变成1/4,wh各增加2倍。cat后刚好和encode对齐
class PatchExpand(nn.Module): def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity() #输出feature的channel加倍 self.norm = norm_layer(dim // dim_scale) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution x = self.expand(x) B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) x = x.view(B,-1,C//4) #wh翻倍,channel减少4倍 x= self.norm(x) return x
跳连接个数:
如下表显示跳连接确实是work的
损失函数部分:
Swin-Unet的损失函数有任何的改进,是0.4的交叉熵+0.6的dice-loss构成
outputs = model(image_batch) loss_ce = ce_loss(outputs, label_batch[:].long()) loss_dice = dice_loss(outputs, label_batch, softmax=True) loss = 0.4 * loss_ce + 0.6 * loss_dice
效果:
Swin-Unet凭借Swin中MSA强大特征提取能力。相比一众算法展现了sota的效果:
总结:Swin-Unet只是在各个特征提取模块将Unet的2D卷积换成了Swin结构,在Swin结构和Unet结构上基本没有改变,损失函数也没有做变化。再次说明了Swin模块的强大特征提取能力(感觉创新不太够啊,不过代码挺清爽的)
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)