图像分割之Swin-Unet分享

图像分割之Swin-Unet分享,第1张

图像分割之Swin-Unet分享

基于CNN的图像/语义分割算法主要有Unet FCN PSPnet DAnet DeepLabV3+,HRnet+OCR等,去年年底基于Transform的各类CV算法(如ViT,Swin等)在分割/分类任务上都表现了相比CNN更为优秀的分割精度。

这里就简单介绍一下基于Swin模块的Unet分割算法:来自慕尼黑工业大学的Swin-Unet

论文:https://arxiv.org/abs/2105.05537
代码:https://github.com/HuCaoFighting/Swin-Unet

首先我们看模型结构:

整个网络结构看起来非常的清楚,可以说基本上就相当于把Unet中的2D卷积换成了Swin模块。对于Swin提出的W-MSA和SW-MSA在前面Swinformer那一期介绍了一下。更详细的还是得看代码。Swin论文那里我认为为了讲故事这块结构写的的有点玄学了。

整体结构和算法部分下面我跟着代码一起详细介绍:

首先是数据增广:

Swin-Unet代码结构比较清晰清爽,整体逻辑非常清晰:

def random_rot_flip(image, label):  #随机翻转
    k = np.random.randint(0, 4)
    image = np.rot90(image, k)
    label = np.rot90(label, k)
    axis = np.random.randint(0, 2)
    image = np.flip(image, axis=axis).copy()
    label = np.flip(label, axis=axis).copy()
    return image, label


def random_rotate(image, label):  #正负20度旋转
    angle = np.random.randint(-20, 20)
    image = ndimage.rotate(image, angle, order=0, reshape=False)
    label = ndimage.rotate(label, angle, order=0, reshape=False)
    return image, label

图像增广方面就用了两个,一个是图像和label同步进行随机翻转,一个是图像和label进行正负20度随机旋转

其他的就很常规了:
首先写了一个Synapse_dataset类,通过继承torch的Dataset类,复写Dataset中的__len__和__getitem__方法,其中__getitem__主要是读图像和label的numpy数组,利用上面的图像增广做同步矩阵变换之后转换成pytorch的torch.tensor后喂入模型,__getitem__主要是读到图像时同步为图像和label做对应的 *** 作。这里我为了博客轻量把具体实现代码去掉了。想看的同学可以去看这块源码。很简单。

class Synapse_dataset(Dataset):
    def __init__(self, base_dir, list_dir, split, transform=None):  
    def __len__(self):
        return len(self.sample_list)
    def __getitem__(self, idx):
        return sample

最后老办法喂入torch的dataloader后通过epoch的for循环同步读取训练数据的image和label的tensor:

trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True,
                             worker_init_fn=worker_init_fn)

数据预处理说完了。接下来介绍网络实现步骤:

首先是transform的PatchEmbed结构:

整个结构基本上就是照搬Swin的PatchEmbed方法,直接通过一个2D卷积
表征位置信息(事实上目前很多基于Transform的算法都这么干的)

self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding

    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], 
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

熟悉Unet结构的同学应该清楚整个Unet核心其实就三部分:

编码头:
对图像特征进行聚合,同时下采样,WH减半,channel同步增加(由于Swin输入多少输出多少,所以下采样功能是通过torch的linear层实现的)

解码头:
将图像上采样要原图大小方便进行像素点分类

跳连接:
网络层越深得到的特征图,有着更大的感受野,浅层卷积关注纹理特征,深层网络关注本质的那种特征,通过跳连接可以使特征向量同时具有深层和表层特征(cat方法),由于图像在上采样过程(CNN的图像分割一般通过2Dconv+双线性插值进行上采样)本身不增加新的信息,但是每一次下采样提炼特征的同时,也必然会损失一些边缘特征,而失去的特征并不能从上采样中找回,因此通过特征的拼接,来实现边缘特征的一个找回。

由于SwinBlock相比CNN比较特殊,它的输入和输出是一样的,下采样主要

上采样:
作者尝试了双线性插值/转置卷积/Patch expand三种方法,通过对比实验证明了其有效性:

Patch expand方法其实很简单,首先通过一个线性层把长采样到两倍,然后通过torch.view()通道数变成1/4,wh各增加2倍。cat后刚好和encode对齐

class PatchExpand(nn.Module):
    def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity()  #输出feature的channel加倍
        self.norm = norm_layer(dim // dim_scale)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        x = self.expand(x)
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)
        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)
        x = x.view(B,-1,C//4)   #wh翻倍,channel减少4倍
        x= self.norm(x)

        return x

跳连接个数:
如下表显示跳连接确实是work的

损失函数部分:

Swin-Unet的损失函数有任何的改进,是0.4的交叉熵+0.6的dice-loss构成

outputs = model(image_batch)
            loss_ce = ce_loss(outputs, label_batch[:].long())
            loss_dice = dice_loss(outputs, label_batch, softmax=True)
            loss = 0.4 * loss_ce + 0.6 * loss_dice

效果:

Swin-Unet凭借Swin中MSA强大特征提取能力。相比一众算法展现了sota的效果:


总结:Swin-Unet只是在各个特征提取模块将Unet的2D卷积换成了Swin结构,在Swin结构和Unet结构上基本没有改变,损失函数也没有做变化。再次说明了Swin模块的强大特征提取能力(感觉创新不太够啊,不过代码挺清爽的)

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5658264.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-16
下一篇 2022-12-16

发表评论

登录后才能评论

评论列表(0条)

保存