timm库使用

timm库使用,第1张


一、timm库简介

PyTorch Image Models,简称timm,是一个巨大的PyTorch代码集合,整合了常用的models、layers、utilities、optimizers、schedulers、data-loaders/augmentations和reference training/validation scripts。



二、安装

pip install timm

三、使用

  1. 查看所有模型
    model_list = timm.list_models()
    print(model_list)
  1. 查看具有预训练参数的模型
    model_pretrain_list = timm.list_models(pretrained=True)
    print(model_pretrain_list)
  1. 检索特定模型
    model_resnet = timm.list_models('*resnet*')
    print(model_resnet)
  1. 创建模型
    x = torch.randn((1, 3, 256, 512))
    modle_mobilenetv2 = timm.create_model('mobilenetv2_100', pretrained=True)
    out = modle_mobilenetv2(x)
    # print(out.shape)
    # torch.Size([1, 1000])
  1. 创建模型–改变输出类别数
    x = torch.randn((1, 3, 256, 512))
    modle_mobilenetv2 = timm.create_model('mobilenetv2_100', pretrained=True, num_classes=100)
    out = modle_mobilenetv2(x)
    # print(out.shape)
    # torch.Size([1, 100])
  1. 创建模型–改变输入通道数
    x = torch.randn((1, 10, 256, 512))
    modle_mobilenetv2 = timm.create_model('mobilenetv2_100', pretrained=True, in_chans=10)
    out = modle_mobilenetv2(x)
    # print(out.shape)
    # torch.Size([1, 1000])
  1. 创建模型–只提取特征
    x = torch.randn((1, 3, 256, 512))
    modle_mobilenetv2 = timm.create_model('mobilenetv2_100', pretrained=True, features_only=True)
    out = modle_mobilenetv2(x)
    # for o in out:
    #     print(o.shape)
    # torch.Size([1, 16, 128, 256])
    # torch.Size([1, 24, 64, 128])
    # torch.Size([1, 32, 32, 64])
    # torch.Size([1, 96, 16, 32])
    # torch.Size([1, 320, 8, 16])

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/569593.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-09
下一篇 2022-04-09

发表评论

登录后才能评论

评论列表(0条)

保存