torchvision.transforms处理模块用法详解

torchvision.transforms处理模块用法详解,第1张

torchvision.transforms处理模块用法详解
    • 常用方法介绍
    • 应用实例
      • 处理单张张量图像示例
      • 处理多张张量图像示例

torchvision.transforms是pytorch中的图像预处理包(图像变换),一般用Compose把多个步骤整合到一起。可单独处理张量图像的变换,也可以接受成批的张量图像。

常用方法介绍

transforms包中可实现图像的各种变换,如图像形状、颜色处理,还可进行模糊、曝光、类型转换等 *** 作。
常用方法(示例),更多方法见pytorch官方文档:

MethodParametersFunction
CenterCrop(size)若给定一个参数,则图片crop为(size,size),两个则为(h,w)在中心位置裁剪成给定图像大小
Grayscale([num_output_channels])( int ) - (1 或 3) 输出图像所需的通道数将图像转换为灰度
RandomCrop(size[, padding, pad_if_needed, …])size期望输出大小,padding边框填充(可选)随机位置裁剪给定图像
Resize (size)(size为序列或整数),若h>w,图像则调整为(size * height / width, size).将输入图像调整为给定大小
Normalize(mean, std[, inplace])(std[1],…,std[n])ntorch.*Tensoroutput[channel] = (input[channel] - mean[channel]) / std[channel]用均值和标准差对张量图像进行归一化
ToTensor()将 [0, 255] 范围内的 PIL 图像或 numpy.ndarray (H x W x C) 转换为 [0.0, 1.0] 范围内形状 (C x H x W) 的 torch.FloatTensor将 PIL Image or numpy.ndarray 转换为 tensor.
应用实例 处理单张张量图像示例

其中transforms方法调用工作原理:
transforms.Resize(256):将图片的(h,w)均转换为(256,256)
transforms.CenterCrop(224):从中心位置将图像裁剪为(224,224)
transforms.ToTensor():可以将图像(H x W x C)从灰度范围从0-255变换到形状为 (C x H x W)的 0-1之间,并转成tensor数据类型
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]):把图像灰度范围从把(0,1)变换到(-1,1);对每个通道而言,Normalize执行:image=(image-mean)/std;其中mean和std分别通过(0.5,0.5,0.5)和(0.5,0.5,0.5)进行指定。原来的0-1最小值0则变成(0-0.5)/0.5=-1,而最大值1则变成(1-0.5)/0.5=1.

import os
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

'''
torchvision.transforms是pytorch中的图像预处理包(图像变换),一般用Compose把多个步骤整合到一起
处理单张张量图像示例
'''
def main():
    # 实例化transforms对象,用Compose整合多个 *** 作
    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    # load image
    img_path = "data/test_data/1.png"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    print("图片的size:{}".format(img.size))
    plt.imshow(img)
    # [C, H, W]
    img = data_transform(img)
    print("转换后的图片size:{}".format(img.size()))
    print("转换后的图片类型:{}".format(img.type()))
    # expand batch dimension  [N, C, H, W]
    img = torch.unsqueeze(img, dim=0)
    print("增加图片维度后的size:{}".format(img.size()))
    plt.title("tulip")
    plt.show()

if __name__ == '__main__':
    main()

output:
图片的size:(700, 497)
转换后的图片size:torch.Size([3, 224, 224])
转换后的图片类型:torch.FloatTensor
增加图片维度后的size:torch.Size([1, 3, 224, 224])
处理多张张量图像示例

实例化自定义数据集时加入transform转换即可

data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}

    # 实例化训练数据集
    train_dataset = MyDataSet(images_path=train_images_path,
                              images_class=train_images_label,
                              transform=data_transform["train"])

    # 实例化验证数据集
    val_dataset = MyDataSet(images_path=val_images_path,
                            images_class=val_images_label,
                            transform=data_transform["val"])

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/716028.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-25
下一篇 2022-04-25

发表评论

登录后才能评论

评论列表(0条)

保存