因为要做个简单的实验,这里把UNet多分类记录下。
1.数据准备以公共数据CamVid为例,一级目录
二级目录
2.数据加载文件data.py,脚本前面是数据增强部分,我自己的数据加载函数从128行开始,包括read_own_data、own_data_loader、own_data_test_loader函数以及ImageFolder加载类。
# -*- coding: utf-8 -*-
import torch
import torch.utils.data as data
from torch.autograd import Variable as V
from PIL import Image
import cv2
import numpy as np
import os
import scipy.misc as misc
def randomHueSaturationValue(image, hue_shift_limit=(-180, 180),
sat_shift_limit=(-255, 255),
val_shift_limit=(-255, 255), u=0.5):
if np.random.random() < u:
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
h, s, v = cv2.split(image)
hue_shift = np.random.randint(hue_shift_limit[0], hue_shift_limit[1]+1)
hue_shift = np.uint8(hue_shift)
h += hue_shift
sat_shift = np.random.uniform(sat_shift_limit[0], sat_shift_limit[1])
s = cv2.add(s, sat_shift)
val_shift = np.random.uniform(val_shift_limit[0], val_shift_limit[1])
v = cv2.add(v, val_shift)
image = cv2.merge((h, s, v))
#image = cv2.merge((s, v))
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
return image
def randomShiftScaleRotate(image, mask,
shift_limit=(-0.0, 0.0),
scale_limit=(-0.0, 0.0),
rotate_limit=(-0.0, 0.0),
aspect_limit=(-0.0, 0.0),
borderMode=cv2.BORDER_CONSTANT, u=0.5):
if np.random.random() < u:
height, width, channel = image.shape
angle = np.random.uniform(rotate_limit[0], rotate_limit[1])
scale = np.random.uniform(1 + scale_limit[0], 1 + scale_limit[1])
aspect = np.random.uniform(1 + aspect_limit[0], 1 + aspect_limit[1])
sx = scale * aspect / (aspect ** 0.5)
sy = scale / (aspect ** 0.5)
dx = round(np.random.uniform(shift_limit[0], shift_limit[1]) * width)
dy = round(np.random.uniform(shift_limit[0], shift_limit[1]) * height)
cc = np.math.cos(angle / 180 * np.math.pi) * sx
ss = np.math.sin(angle / 180 * np.math.pi) * sy
rotate_matrix = np.array([[cc, -ss], [ss, cc]])
box0 = np.array([[0, 0], [width, 0], [width, height], [0, height], ])
box1 = box0 - np.array([width / 2, height / 2])
box1 = np.dot(box1, rotate_matrix.T) + np.array([width / 2 + dx, height / 2 + dy])
box0 = box0.astype(np.float32)
box1 = box1.astype(np.float32)
mat = cv2.getPerspectiveTransform(box0, box1)
image = cv2.warpPerspective(image, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode,
borderValue=(
0, 0,
0,))
mask = cv2.warpPerspective(mask, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode,
borderValue=(
0, 0,
0,))
return image, mask
def randomHorizontalFlip(image, mask, u=0.5):
if np.random.random() < u:
image = cv2.flip(image, 1)
mask = cv2.flip(mask, 1)
return image, mask
def randomVerticleFlip(image, mask, u=0.5):
if np.random.random() < u:
image = cv2.flip(image, 0)
mask = cv2.flip(mask, 0)
return image, mask
def randomRotate90(image, mask, u=0.5):
if np.random.random() < u:
image=np.rot90(image)
mask=np.rot90(mask)
return image, mask
def default_loader(img_path, mask_path):
img = cv2.imread(img_path)
# print("img:{}".format(np.shape(img)))
img = cv2.resize(img, (448, 448))
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
mask = 255. - cv2.resize(mask, (448, 448))
img = randomHueSaturationValue(img,
hue_shift_limit=(-30, 30),
sat_shift_limit=(-5, 5),
val_shift_limit=(-15, 15))
img, mask = randomShiftScaleRotate(img, mask,
shift_limit=(-0.1, 0.1),
scale_limit=(-0.1, 0.1),
aspect_limit=(-0.1, 0.1),
rotate_limit=(-0, 0))
img, mask = randomHorizontalFlip(img, mask)
img, mask = randomVerticleFlip(img, mask)
img, mask = randomRotate90(img, mask)
mask = np.expand_dims(mask, axis=2)
#
# print(np.shape(img))
# print(np.shape(mask))
img = np.array(img, np.float32).transpose(2,0,1)/255.0 * 3.2 - 1.6
mask = np.array(mask, np.float32).transpose(2,0,1)/255.0
mask[mask >= 0.5] = 1
mask[mask <= 0.5] = 0
#mask = abs(mask-1)
return img, mask
def read_own_data(root_path, mode = 'train'):
images = []
masks = []
image_root = os.path.join(root_path, mode + '/images')
gt_root = os.path.join(root_path, mode + '/labels')
for image_name in os.listdir(gt_root):
image_path = os.path.join(image_root, image_name)
label_path = os.path.join(gt_root, image_name)
images.append(image_path)
masks.append(label_path)
return images, masks
def own_data_loader(img_path, mask_path):
img = cv2.imread(img_path)
img = cv2.resize(img, (512,512), interpolation = cv2.INTER_NEAREST)
mask = cv2.imread(mask_path, 0)
mask = cv2.resize(mask, (512,512), interpolation = cv2.INTER_NEAREST)
img = randomHueSaturationValue(img,
hue_shift_limit=(-30, 30),
sat_shift_limit=(-5, 5),
val_shift_limit=(-15, 15))
img, mask = randomShiftScaleRotate(img, mask,
shift_limit=(-0.1, 0.1),
scale_limit=(-0.1, 0.1),
aspect_limit=(-0.1, 0.1),
rotate_limit=(-0, 0))
img, mask = randomHorizontalFlip(img, mask)
img, mask = randomVerticleFlip(img, mask)
img, mask = randomRotate90(img, mask)
mask = np.expand_dims(mask, axis=2)
img = np.array(img, np.float32) / 255.0 * 3.2 - 1.6
mask = np.array(mask, np.float32)
img = np.array(img, np.float32).transpose(2, 0, 1)
mask = np.array(mask, np.float32).transpose(2, 0, 1)
return img, mask
def own_data_test_loader(img_path, mask_path):
img = cv2.imread(img_path)
img = cv2.resize(img, (512,512), interpolation = cv2.INTER_NEAREST)
mask = cv2.imread(mask_path, 0)
mask = cv2.resize(mask, (512,512), interpolation = cv2.INTER_NEAREST)
mask = np.expand_dims(mask, axis=2)
img = np.array(img, np.float32) / 255.0 * 3.2 - 1.6
mask = np.array(mask, np.float32)
#mask[mask >= 0.5] = 1
#mask[mask < 0.5] = 0
img = np.array(img, np.float32).transpose(2, 0, 1)
mask = np.array(mask, np.float32).transpose(2, 0, 1)
return img, mask
class ImageFolder(data.Dataset):
def __init__(self,root_path, mode='train'):
self.root = root_path
self.mode = mode
self.images, self.labels = read_own_data(self.root, self.mode)
def __getitem__(self, index):
if self.mode == 'test':
img, mask = own_data_test_loader(self.images[index], self.labels[index])
else:
img, mask = own_data_loader(self.images[index], self.labels[index])
img = torch.Tensor(img)
mask = torch.Tensor(mask)
return img, mask
def __len__(self):
assert len(self.images) == len(self.labels), 'The number of images must be equal to labels'
return len(self.images)
3.训练
里面用到了比较好的两个包segmentation_models_pytorch,pytorch_toolbelt。其中segmentation_models_pytorch里面包含了很多常见语义分割模型的实现,同时支持了二分类和多分类,直接pip install segmentation_models_pytorch就可以安装。另外注意,训练的类别数在一开始的地方改,那个n_classes = 12那里
# -*- coding: utf-8 -*-
import time
import warnings
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.swa_utils import AveragedModel, SWALR
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import DiceLoss, SoftCrossEntropyLoss, LovaszLoss
from pytorch_toolbelt import losses as L
from data import ImageFolder
from sklearn import metrics
warnings.filterwarnings('ignore')
torch.backends.cudnn.enabled = True
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
n_classes = 12
def cal_cm(y_true, y_pred):
y_true = y_true.reshape(1,-1).squeeze()
y_pred = y_pred.reshape(1,-1).squeeze()
cm = metrics.confusion_matrix(y_true,y_pred)
return cm
def iou_mean(pred, target, n_classes = n_classes):
#n_classes :the number of classes in your dataset,not including background
# for mask and ground-truth label, not probability map
ious = []
iousSum = 0
# pred = torch.from_numpy(pred)
pred = pred.view(-1)
# print(type(pred))
target = np.array(target.cpu())
target = torch.from_numpy(target)
# print(type(target))
target = target.view(-1)
# Ignore IoU for background class ("0")
for cls in range(1, n_classes): # This goes from 1:n_classes-1 -> class "0" is ignored
pred_inds = pred == cls
target_inds = target == cls
intersection = (pred_inds[target_inds]).long().sum().data.cpu().item() # Cast to long to prevent overflows
union = pred_inds.long().sum().data.cpu().item() + target_inds.long().sum().data.cpu().item() - intersection
if union == 0:
ious.append(float('nan')) # If there is no ground truth, do not include in evaluation
else:
ious.append(float(intersection) / float(max(union, 1)))
iousSum += float(intersection) / float(max(union, 1))
return iousSum/n_classes
def multi_acc(pred, label):
probs = torch.log_softmax(pred, dim = 1)
_, tags = torch.max(probs, dim = 1)
corrects = torch.eq(tags,label).int()
acc = corrects.sum()/corrects.numel()
return acc
def train(EPOCHES, BATCH_SIZE, data_root, channels, optimizer_name,
model_path, swa_model_path, loss, early_stop):
train_dataset = ImageFolder(data_root, mode='train')
val_dataset = ImageFolder(data_root, mode='val')
train_data_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size = BATCH_SIZE,
shuffle=True,
num_workers=0)
val_data_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size = BATCH_SIZE,
shuffle=True,
num_workers=0)
# 定义模型,优化器,损失函数
# model = smp.UnetPlusPlus(
# encoder_name="efficientnet-b7",
# encoder_weights="imagenet",
# in_channels=channels,
# classes=17,
# )
# model = smp.UnetPlusPlus(
# encoder_name="timm-resnest101e",
# encoder_weights="imagenet",
# in_channels=channels,
# classes=2,
# )
model = smp.Unet(
encoder_name="resnext50_32x4d", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7, resnet34
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
in_channels=channels, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=n_classes, # model output channels (number of classes in your dataset)
activation='softmax', #二分类需要换成sigmoid
)
model.to(DEVICE)
#加载预模型可以打开下面这句,model_path给预模型路径
# model.load_state_dict(torch.load(model_path))
if(optimizer_name == "sgd"):
optimizer = torch.optim.SGD(model.parameters(),
lr=1e-4, weight_decay=1e-3, momentum=0.9)
else:
optimizer = torch.optim.AdamW(model.parameters(),
lr=1e-3, weight_decay=1e-3)
# 余弦退火调整学习率
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=2, # T_0就是初始restart的epoch数目
T_mult=2, # T_mult就是重启之后因子,即每个restart后,T_0 = T_0 * T_mult
eta_min=1e-5 # 最低学习率
)
if(loss == "SoftCE_dice"): #mode: Loss mode 'binary', 'multiclass' or 'multilabel'
# 损失函数采用SoftCrossEntropyLoss+DiceLoss
# diceloss在一定程度上可以缓解类别不平衡,但是训练容易不稳定
# DiceLoss_fn = DiceLoss(mode='binary')
DiceLoss_fn = DiceLoss(mode='multiclass') #多分类改为multiclass
#Bceloss_fn = nn.BCELoss()
# 软交叉熵,即使用了标签平滑的交叉熵,会增加泛化性
SoftCrossEntropy_fn=SoftCrossEntropyLoss(smooth_factor=0.1) #用于多分类
loss_fn = L.JointLoss(first=DiceLoss_fn, second=SoftCrossEntropy_fn, first_weight=0.8, second_weight=0.2).cuda()
# loss_fn = smp.utils.losses.DiceLoss()
else:
# 损失函数采用SoftCrossEntropyLoss+LovaszLoss
# LovaszLoss是对基于子模块损失凸Lovasz扩展的mIoU损失的直接优化
# LovaszLoss_fn = LovaszLoss(mode='binary')
LovaszLoss_fn = LovaszLoss(mode='multiclass')
# 软交叉熵,即使用了标签平滑的交叉熵,会增加泛化性
SoftCrossEntropy_fn=SoftCrossEntropyLoss(smooth_factor=0.1) #这里我没有改,这里是多分类的,有需求就改下
loss_fn = L.JointLoss(first=LovaszLoss_fn, second=SoftCrossEntropy_fn,
first_weight=0.5, second_weight=0.5).cuda()
best_miou = 0
best_miou_epoch = 0
train_loss_epochs, val_mIoU_epochs, lr_epochs = [], [], []
for epoch in range(1, EPOCHES+1):
losses = []
start_time = time.time()
model.train()
for image, target in tqdm(train_data_loader, ncols=20, total=len(train_data_loader)):
image, target = image.to(DEVICE), target.to(DEVICE)
output = model(image)
target = torch.tensor(target, dtype=torch.int64)
loss = loss_fn(output, target)
losses.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
val_acc = []
val_iou = []
val_data_loader_num = iter(val_data_loader)
for val_img, val_mask in tqdm(val_data_loader_num,ncols=20,total=len(val_data_loader_num)):
val_img, label = val_img.to(DEVICE), val_mask.to(DEVICE)
predict = model(val_img)
label = label.squeeze(1)
acc = multi_acc(predict, label)
val_acc.append(acc.item())
predict = torch.argmax(predict, axis=1)
iou = iou_mean(predict, label, n_classes)
val_iou.append(iou)
train_loss_epochs.append(np.array(losses).mean())
val_mIoU_epochs.append(np.mean(val_iou))
lr_epochs.append(optimizer.param_groups[0]['lr'])
print('Epoch:' + str(epoch) + ' Loss:' + str(np.array(losses).mean()) + ' Val_Acc:'+ str(np.array(val_acc).mean()) + ' Val_IOU:' + str(np.mean(val_iou)) + ' Time_use:' + str((time.time()-start_time)/60.0))
if best_miou < np.stack(val_iou).mean(0).mean():
best_miou = np.stack(val_iou).mean(0).mean()
best_miou_epoch = epoch
torch.save(model.state_dict(), model_path)
print(" valid mIoU is improved. the model is saved.")
else:
print("")
if (epoch - best_miou_epoch) >= early_stop:
break
return train_loss_epochs, val_mIoU_epochs, lr_epochs
if __name__ == '__main__':
EPOCHES = 100
BATCH_SIZE = 4
loss = "SoftCE_dice"
# loss = "SoftCE_Lovasz"
channels = 3
optimizer_name = "adamw"
data_root = "./data/CamVid/"
model_path = "./weights/CamVid_" + loss + '.pth'
swa_model_path = model_path + "_swa.pth"
early_stop = 400
train_loss_epochs, val_mIoU_epochs, lr_epochs = train(EPOCHES, BATCH_SIZE, data_root, channels, optimizer_name, model_path, swa_model_path, loss,early_stop)
if(True):
import matplotlib.pyplot as plt
epochs = range(1, len(train_loss_epochs) + 1)
plt.plot(epochs, train_loss_epochs, 'r', label = 'train loss')
plt.plot(epochs, val_mIoU_epochs, 'b', label = 'val mIoU')
plt.title('train loss and val mIoU')
plt.legend()
plt.savefig("train loss and val mIoU.png",dpi = 300)
plt.figure()
plt.plot(epochs, lr_epochs, 'r', label = 'learning rate')
plt.title('learning rate')
plt.legend()
plt.savefig("learning rate.png", dpi = 300)
plt.show()
4.预测
# -*- coding: utf-8 -*-
import os
import glob
import time
import cv2
import numpy as np
import torch
import segmentation_models_pytorch as smp
from torch.optim.swa_utils import AveragedModel
from data import ImageFolder
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
def test_1(channels, model_path, output_dir, test_path):
# model = smp.UnetPlusPlus(
# encoder_name="resnet101",
# encoder_weights="imagenet",
# in_channels=4,
# classes=10,
# )
# model = smp.DeepLabV3Plus(
# encoder_name="resnet101",
# encoder_weights="imagenet",
# in_channels=in_channels,
# classes=1,
# )
model = smp.Unet(
encoder_name="resnext50_32x4d", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
in_channels=channels, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=12, # model output channels (number of classes in your dataset)
activation='softmax',
)
# 如果模型是SWA
if("swa" in model_path):
model = AveragedModel(model)
model.to(DEVICE);
model.load_state_dict(torch.load(model_path))
model.eval()
im_names = os.listdir(test_path)
for name in im_names:
full_path = os.path.join(test_path, name)
img = cv2.imread(full_path)
h, w , c = img.shape
#resize是因为训练的输入我resize成了512,后面有还原
img = cv2.resize(img, (512, 512), interpolation = cv2.INTER_NEAREST)
image = np.array(img, np.float32) / 255.0 * 3.2 - 1.6
image = np.array(image, np.float32).transpose(2, 0, 1)
image = np.expand_dims(image, axis=0)
image = torch.Tensor(image)
image = image.cuda()
output = model(image)
output = torch.argmax(output, axis=1).cpu().data.numpy()
output = output.squeeze()
output = cv2.resize(output, (w, h), interpolation = cv2.INTER_NEAREST)
save_path = os.path.join(output_dir, name)
cv2.imwrite(save_path, output)
if __name__ == "__main__":
data_root = "./data/CamVid/test/"
model_path = "./weights/CamVid_SoftCE_dice.pth"
output_dir = './data/CamVid/test_pre/'
test_1(3, model_path, output_dir, data_root)
结果
原图 标签
预测结果
注意颜色是随机的,不是分错了哈,这个结果明显还没有训练到收敛,我提前停了,继续训练效果还能提升。
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)