(自学笔记,都是抄的,长期更新)
利用pytorch写一个数据类,可以加载数据图片并显示。在datasets文件夹下有两个子文件夹,分别是SegmentationClass
用以存放语义分割数据的mask.png,JPEGImages
用以存放语义分割的数据图片.jpg。
import os
import torch
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader
from PIL import Image
# 将图片转换为tensor
transform=transforms.Compose([
transforms.ToTensor()
])
# 将tensor转化为图片
show_tensor_pic=transforms.ToPILImage()
#保持图片最长的边进行缩放,默认缩放到(256,256)
def keep_png_size(path,size=(256,256)):
img=Image.open(path)
temp=max(img.size)
mask=Image.new('P',(temp,temp))
mask.paste(img,(0,0))
mask=mask.resize(size)
return mask
#保持图片最长的边进行缩放,默认缩放到(26,256)
def keep_jpg_size(path,size=(256,256)):
img=Image.open(path)
temp=max(img.size)
mask=Image.new('RGB',(temp,temp))
mask.paste(img,(0,0))
mask=mask.resize(size)
return mask
class MyDataset(Dataset):
def __init__(self,path):
self.path=path#输入图片路径
#保存所有mask的名字
self.name=os.listdir(os.path.join(path,'SegmentationClass'))
def __len__(self):
return len(self.name)
def __getitem__(self, item):
# 根据索引获取一张mask的名字
segment_name=self.name[item]
# 获取这张mask的路径
segment_path=os.path.join(self.path,'SegmentationClass',segment_name)
# 根据mask的名字取索引图片的名字
image_path=os.path.join(self.path,'JPEGImages',segment_name.replace('png','jpg'))
# 将mask压缩到一定的尺寸
segment_image=keep_png_size(segment_path)
# 将图片压缩到一定的尺寸
image=keep_jpg_size(image_path)
return transform(image),transform(segment_image)
if __name__ == '__main__':
#建立一个dataloader,2个图片为一组,随机打乱
data_loader=DataLoader(MyDataset('datasets'),batch_size=2,shuffle=True)
#读取一组图片和标签
img, seg=next(iter(data_loader))
#显示该组的第一张图片和标签
show_tensor_pic(img[0]).show()
show_tensor_pic(seg[0]).show()
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)