gMLP的PyTorch实现与分析

gMLP的PyTorch实现与分析,第1张

gMLP的PyTorch实现与分析

 1、导入相关库函数

import torch
from torch import nn
from torch.nn import functional as F
from torchsummary import summary
from tensorboardX import SummaryWriter

 2、Spatial Gating Unit模块设计

        根据上图的Pseudo-code中的spatial_gating_uint(x)函数,以及文献中的公式(6),大致思路是:将输入x分块,分为Z1和Z2(对应于U,V),其中Z2经过线性映射,后输出U*V。公式如下:

class SpatialGatingUnit(nn.Module):  # [-1,256,256]
    def __init__(self, d_ffn, seq_len):
        super().__init__()
        self.norm = nn.LayerNorm(d_ffn)   # [-1,256,256]->[-1,256,512]
        self.spatial_proj = nn.Conv1d(seq_len, seq_len, kernel_size=1) # [-1,256,512]->[-1,256,512]
        nn.init.constant_(self.spatial_proj.bias, 1.0)  # 偏差

    def forward(self, x):
        # chunk(arr, size)接收两个参数,一个是原数组,一个是分块的大小size,默认值为1,
        # 原数组中的元素会按照size的大小从头开始分块,每一块组成一个新数组,如果最后元素个数不足size的大小,那么它们会组成一个快。
        u, v = x.chunk(2, dim=-1)
        v = self.norm(v)
        v = self.spatial_proj(v)
        out = u * v
        return out

        该模块的结构如下所示:

 3、gMLPBlock模块搭建

        即搭建下图结构:

 

class gMLPBlock(nn.Module):
    def __init__(self, d_model, d_ffn, seq_len):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.channel_proj1 = nn.Linear(d_model, d_ffn * 2) # (256, d_ffn * 2=1024)  [-1,256,1024]
        self.sgu = SpatialGatingUnit(d_ffn, seq_len)   #
        self.channel_proj2 = nn.Linear(d_ffn, d_model)

    def forward(self, x):
        residual = x
        x = self.norm(x)     # [-1,256,256]
        x = F.gelu(self.channel_proj1(x))  # GELU激活函数 [-1,256,256]
        x = self.sgu(x)   # [-1,256,256]
        x = self.channel_proj2(x)
        out = x + residual
        return out

         该模块的结构如下所示:

 

 4、gMLP模块搭建(方便堆叠num_layers)

class gMLP(nn.Module):
    def __init__(self, d_model=256, d_ffn=512, seq_len=256, num_layers=6):
        super().__init__()
        self.model = nn.Sequential(
            *[gMLPBlock(d_model, d_ffn, seq_len) for _ in range(num_layers)]
        )
        # [gMLPBlock(d_model=256, d_ffn=512, seq_len=256) for _ in range(num_layers)]

    def forward(self, x):
        return self.model(x)

 5、总体结构gMLPForImageClassification

class gMLPForImageClassification(gMLP):
    def __init__(
        self,
        image_size=256,
        patch_size=16,
        in_channels=3,
        num_classes=1000,
        d_model=256,
        d_ffn=512,
        seq_len=256,
        num_layers=6,
    ):
        num_patches = check_sizes(image_size, patch_size)  # num_patches=256
        super().__init__(d_model, d_ffn, seq_len, num_layers)
        self.patcher = nn.Conv2d(
            in_channels, d_model, kernel_size=patch_size, stride=patch_size
        )  # [2, 3, 256, 256] -> [2, 256, 16, 16]
        self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, x):
        # a = x.shape = [2,3,256,256]
        patches = self.patcher(x)
        batch_size, num_channels, _, _ = patches.shape  # [2,256,16,16]
        patches = patches.permute(0, 2, 3, 1)  # 将tensor的维度换位 [2,256,16,16]->[2,16,16,256]
        patches = patches.view(batch_size, -1, num_channels) # 转为(2,-1,256)  即为[2,256,256]
        # a = patches.shape
        embedding = self.model(patches)
        # a = embedding.shape = [2,256,256]

        embedding = embedding.mean(dim=1)
        out = self.classifier(embedding)
        return out

        总体结构如下:

 6、测试网络

# 测试gMLP
if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = gMLPForImageClassification(image_size=256,patch_size=16,in_channels=3,num_classes=1000,d_model=256,d_ffn=512,seq_len=256,num_layers=1,).to(device)
    summary(model, (3, 256, 256))  # [2,3,256,256]


    inputs = torch.Tensor(2, 3, 256, 256)
    inputs = inputs.to(device)
    print(inputs.shape)

    # 将model保存为graph
    with SummaryWriter(log_dir='logs', comment='model') as w:
        w.add_graph(model, (inputs,))
        print("success")

         以一个[2,3,256,256]大小的输入作为测试,得到,网络的架构图如上所示。

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1          [-1, 256, 16, 16]         196,864
         LayerNorm-2             [-1, 256, 256]             512
            Linear-3            [-1, 256, 1024]         263,168
         LayerNorm-4             [-1, 256, 512]           1,024
            Conv1d-5             [-1, 256, 512]          65,792
 SpatialGatingUnit-6             [-1, 256, 512]               0
            Linear-7             [-1, 256, 256]         131,328
         gMLPBlock-8             [-1, 256, 256]               0
            Linear-9                 [-1, 1000]         257,000
================================================================
Total params: 915,688
Trainable params: 915,688
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 7.01
Params size (MB): 3.49
Estimated Total Size (MB): 11.25
----------------------------------------------------------------

        网络的结构框图如下所示(一层gMLP):

 

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/4655187.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-11-06
下一篇 2022-11-06

发表评论

登录后才能评论

评论列表(0条)

保存