看代码
import torch.nn as nn import torchvision.models as models import argparse parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') if __name__ == '__main__': model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name])) # print(model_names) parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet50)') args = parser.parse_args() model = models.__dict__[args.arch]() for name, param in model.named_parameters(): print(name,param.size())
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)