实现片段
from torch import nn
import torch
class u_dowm(nn.Module):
def __init__(self,in_channel,out_channel):
super(u_dowm, self).__init__()
self.convANDbn=nn.Sequential(
nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(),
nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(),
)
self.downsample=nn.Sequential(
nn.MaxPool2d(2,2)
)
def forward(self,x):
output=self.convANDbn(x) #output是与上采样的拼接的
out=self.downsample(output)
return output,out
class up(nn.Module):
def __init__(self,in_channel,out_channel):
super(up, self).__init__()
self.conv2bn=nn.Sequential(
nn.Conv2d(in_channels=in_channel,out_channels=2*out_channel,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(2*out_channel),
nn.ReLU(),
nn.Conv2d(in_channels=2*out_channel, out_channels=2 * out_channel, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(2 * out_channel),
nn.ReLU()
)
self.upSample=nn.Sequential(
nn.UpsamplingBilinear2d(scale_factor=2),
nn.Conv2d(in_channels=2*out_channel,out_channels=out_channel,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU()
)
def forward(self,x,out):
output=self.conv2bn(x)
output=self.upSample(output)
output=torch.cat((output,out),dim=1)
return output
class UNET(nn.Module):
def __init__(self,in_channel,out_channel):
super(UNET, self).__init__()
self.downSample1=u_dowm(in_channel=in_channel , out_channel=64)
self.downSample2=u_dowm(in_channel=64, out_channel=128)
self.downSample3=u_dowm(in_channel=128,out_channel=256)
self.downSample4=u_dowm(in_channel=256,out_channel=512)
self.upSample4=up(in_channel=512,out_channel=512)
self.upSample3=up(in_channel=1024,out_channel=256)
self.upSample2=up(in_channel=512,out_channel=128)
self.upSample1=up(in_channel=256,out_channel=64)
self.conv=nn.Sequential(
nn.Conv2d(in_channels=128,out_channels=64,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=out_channel, kernel_size=3, stride=1, padding=1),
)
def forward(self,x):
out_1,out1=self.downSample1(x)
out_2,out2=self.downSample2(out1)
out_3, out3 = self.downSample3(out2)
out_4, out4 = self.downSample4(out3)
out5=self.upSample4(out4,out_4)
out6=self.upSample3(out5,out_3)
out7=self.upSample2(out6,out_2)
out8=self.upSample1(out7,out_1)
out=self.conv(out8)
return out
实例化
x=torch.rand((1,1,224,224))
print(x.shape)
net=UNET(1,2)
output=net(x)
print(output.shape)
运行结果:
C:\ProgramData\Anaconda3\envs\torch37\python.exe E:/taidi/yolov5-master/yolov5-master/data/unet.py
torch.Size([1, 1, 224, 224])
torch.Size([1, 2, 224, 224])
Process finished with exit code 0
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)