Unet实现

Unet实现,第1张

 实现片段

from  torch import nn
import torch


class u_dowm(nn.Module):
    def __init__(self,in_channel,out_channel):
        super(u_dowm, self).__init__()
        self.convANDbn=nn.Sequential(
            nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(),
        )
        self.downsample=nn.Sequential(
            nn.MaxPool2d(2,2)
        )
    def forward(self,x):
        output=self.convANDbn(x)   #output是与上采样的拼接的
        out=self.downsample(output)
        return output,out


class up(nn.Module):
    def __init__(self,in_channel,out_channel):
        super(up, self).__init__()
        self.conv2bn=nn.Sequential(
            nn.Conv2d(in_channels=in_channel,out_channels=2*out_channel,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(2*out_channel),
            nn.ReLU(),
            nn.Conv2d(in_channels=2*out_channel, out_channels=2 * out_channel, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(2 * out_channel),
            nn.ReLU()
        )
        self.upSample=nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(in_channels=2*out_channel,out_channels=out_channel,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU()
        )
    def forward(self,x,out):
        output=self.conv2bn(x)
        output=self.upSample(output)
        output=torch.cat((output,out),dim=1)
        return output


class UNET(nn.Module):
    def __init__(self,in_channel,out_channel):
        super(UNET, self).__init__()
        self.downSample1=u_dowm(in_channel=in_channel , out_channel=64)
        self.downSample2=u_dowm(in_channel=64, out_channel=128)
        self.downSample3=u_dowm(in_channel=128,out_channel=256)
        self.downSample4=u_dowm(in_channel=256,out_channel=512)
        self.upSample4=up(in_channel=512,out_channel=512)
        self.upSample3=up(in_channel=1024,out_channel=256)
        self.upSample2=up(in_channel=512,out_channel=128)
        self.upSample1=up(in_channel=256,out_channel=64)
        self.conv=nn.Sequential(
            nn.Conv2d(in_channels=128,out_channels=64,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=out_channel, kernel_size=3, stride=1, padding=1),
        )
    def forward(self,x):
        out_1,out1=self.downSample1(x)
        out_2,out2=self.downSample2(out1)
        out_3, out3 = self.downSample3(out2)
        out_4, out4 = self.downSample4(out3)
        out5=self.upSample4(out4,out_4)
        out6=self.upSample3(out5,out_3)
        out7=self.upSample2(out6,out_2)
        out8=self.upSample1(out7,out_1)
        out=self.conv(out8)
        return out

实例化

x=torch.rand((1,1,224,224))
print(x.shape)
net=UNET(1,2)
output=net(x)
print(output.shape)

运行结果:

C:\ProgramData\Anaconda3\envs\torch37\python.exe E:/taidi/yolov5-master/yolov5-master/data/unet.py
torch.Size([1, 1, 224, 224])
torch.Size([1, 2, 224, 224])

Process finished with exit code 0

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/727632.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-26
下一篇 2022-04-26

发表评论

登录后才能评论

评论列表(0条)

保存