数据增强就是增强一个已有数据集,使得有更多的多样性。对于图片数据来说,就是改变图片的颜色和形状等等。比如常见的:
- 左右翻转,对于大多数数据集都可以使用;
- 上下翻转:部分数据集不适合使用;
- 图片切割:从图片中切割出一个固定的形状,
- 随机高宽比(e.g. [3/4, 4/3)
- 随机大小(e.g. [8%, 100%])
- 随机位置
- 改变图片的颜色
- 改变色调,饱和度,明亮度(e.g. [0.5, 1.5])
加载相关包。
import torch
import torchvision
import matplotlib
from torch import nn
from torchvision import transforms
from PIL import Image
from IPython import display
from matplotlib import pyplot as plt
选取一个狗的图片作为示例:
def set_figsize(figsize=(3.5, 2.5)):
display.set_matplotlib_formats('svg')
plt.rcParams['figure.figsize'] = figsize
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
r"""
展示一列图片
img: Image对象的列表
"""
figsize = (num_cols * scale, num_rows * scale)
fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
ax.imshow(img.numpy())
else:
ax.imshow(img)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.get_xaxis().set_label('x')
if titles:
ax.set_title(titles[i])
return axes
set_figsize()
img = Image.open('img/dog1.jpg')
plt.imshow(img);
2. 图片增广
def apply(img, aug, num_rows=2, num_clos=4, scale=1.5):
# 对图片应用图片增广
# img: Image object
# aug: 增广 *** 作
Y = [aug(img) for _ in range(num_clos * num_rows)]
d2l.show_images(Y, num_rows, num_clos, scale=scale)
2.1 图片水平翻转
class RandomHorizontalFlip(torch.nn.modules.module.Module):
r'''
RandomHorizontalFlip(p=0.5)
给图片一个一定概率的水平翻转 *** 作,如果是Tensor,要求形状为[..., H, W]
Args:
p: float, 图片翻转的概率,默认值0.5
'''
def __init__(self, p=0.5):
pass
示例。可以看到,有一般的几率对图片进行了水平翻转。
aug = transforms.RandomHorizontalFlip(0.5)
apply(img, aug)
2.2 图片上下翻转
class RandomVerticalFlip(torch.nn.modules.module.Module):
r'''
RandomVerticalFlip(p=0.5)
给图片一个一定概率的上下翻转 *** 作,如果是Tensor,要求形状为[..., H, W]
Args:
p: float, 图片翻转的概率,默认值0.5
'''
def __init__(self, p=0.5):
pass
示例。可以看到,有一般的几率对图片进行了上下翻转。
aug = transforms.RandomHorizontalFlip(0.5)
apply(img, aug)
2.3 图片旋转
class RandomRotation(torch.nn.modules.module.Module):
r'''
将图片旋转一定角度。
'''
def __init__(self,
degrees,
interpolation=<InterpolationMode.NEAREST: 'nearest'>,
expand=False,
center=None,
fill=0):
r"""
Args:
degrees: number or sequence, 可选择的角度范围(min, max),
如果是一个数字,则范围是(-degrees, +degrees)
interpolation: Default is ``InterpolationMode.NEAREST``.
expand: bool, 如果为True,则扩展输出,使其足够大来容纳整个旋转的图像
如果为False, 将输出图像与输入图像的大小相同。
center: sequence, 以左上角为原点的旋转中心,默认是图片中心。
fill: sequence or number: 旋转图像外部区域的像素填充值,默认0。
"""
pass
def forward(self, input):
r"""
Args:
img: PIL Image or Tensor, 被旋转的图片。
Return:
PIL Image or Tensor: 旋转后的图片。
"""
pass
使用实例:
aug = transforms.RandomRotation(degrees=(-90, 90), fill=128)
apply(img, aug)
2.4 中心裁切
class CenterCrop(torch.nn.modules.module.Module):
r'''
中心裁切。
'''
def __init__(self, size):
r"""
Args:
size: sequence or int, 裁切尺寸(H, W), 如果是int,尺寸为(size, size)
"""
pass
def forward(self, input):
r"""
Args:
img: PIL Image or Tensor, 被裁切的图片。
Return:
PIL Image or Tensor: 裁切后的图片。
"""
pass
实例:
aug = transforms.CenterCrop((200, 300))
apply(img, aug)
2.5 随机裁切
class RandomCrop(torch.nn.modules.module.Module):
r'''
随机裁切。
'''
def __init__(self, size):
r"""
Args:
size: sequence or int, 裁切尺寸(H, W), 如果是int,尺寸为(size, size)
padding: sequence or int, 填充大小,
如果值为 a , 四周填充a个像素
如果值为 (a, b), 左右填充a,上下填充b
如果值为 (a, b, c, d), 左上右下依次填充
pad_if_need: bool, 如果裁切尺寸大于原图片,则填充
fill: number or str or tuple: 填充像素的值
padding_mode: str, 填充类型。
`constant`: 使用 fill 填充
`edge`: 使用边缘的最后一个值填充在图像边缘。
`reflect`: 镜像填充
"""
pass
def forward(self, input):
r"""
Args:
img: PIL Image or Tensor, 被裁切的图片。
Return:
PIL Image or Tensor: 裁切后的图片。
"""
pass
示例:
aug = transforms.RandomCrop((200, 300))
apply(img, aug)
输出:
class RandomResizedCrop(torch.nn.modules.module.Module):
r'''
随机裁切, 并重设尺寸。
'''
def __init__(self, size, scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333)):
r"""
Args:
size: sequence or int, 需要输出的尺寸(H, W), 如果是int,尺寸为(size, size)
scale: tuple of float, 原始图片中裁切大小,百分比
ratio: tuple of float, resize前的裁切的纵横比范围
"""
pass
def forward(self, input):
r"""
Args:
img: PIL Image or Tensor, 被裁切的图片。
Return:
PIL Image or Tensor: 输出的图片。
"""
pass
示例:
aug = transforms.RandomResizedCrop((200, 200), scale=(0.2, 1))
apply(img, aug)
2. 7 修改图片颜色
class ColorJitter(torch.nn.modules.module.Module):
r'''
修改颜色。
'''
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
r"""
Args:
brightness: float or tuple of float (min, max), 亮度的偏移幅度,范围[max(0, 1 - brightness), 1 + brightness]
contrast: float or tuple of float (min, max), 对比度偏移幅度,范围[max(0, 1 - contrast), 1 + contrast]
saturation: float or tuple of float (min, max), 饱和度偏移幅度,范围[max(0, 1 - saturation), 1 + saturation]
hue: float or tuple of float (min, max), 色相偏移幅度,范围[-hue, hue]
"""
pass
def forward(self, input):
r"""
Args:
img: PIL Image or Tensor, 输入的图片。
Return:
PIL Image or Tensor: 输出的图片。
"""
pass
示例:
aug = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
apply(img, aug)
3. 训练数据集加载
train_augs = transforms.Compose([transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor()])
dataset = torchvision.datasets.CIFAR10(root="../data", train=is_train,
transform=augs, download=True)
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)