keras预训练模型的加载

keras预训练模型的加载,第1张


from keras.applications import ResNet50,VGG16,InceptionV3,MobileNet
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as np

res_path = 'G:\python_work1\learn\weights\keras_h5\resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
vgg_path = 'G:\python_work1\learn\weights\keras_h5\vgg16_weights_tf_dim_ordering_tf_kernels.h5'
inc_path = 'G:\python_work1\learn\weights\keras_h5\inception_v3_weights_tf_dim_ordering_tf_kernels (1).h5'
mbs_path = 'G:\python_work1\learn\weights\keras_h5\mobilenet_1_0_224_tf.h5'

#不包含最后的分类层的时候要让include_top = False,权重文件也要选不带top的
#当报错 original_keras_version = f.attrs['keras_version'].decode('utf8'),可能是h5py版本太高,重新安装2.10版本的

res = ResNet50(weights=res_path,include_top = False)
res.summary()


vgg = VGG16(weights =  vgg_path)
vgg.summary()


inc = InceptionV3(weights = inc_path)
inc.summary()


mbs = MobileNet(weights = mbs_path)
mbs.summary()

这里加载4种模型ResNet50,VGG16,InceptionV3,MobileNet,上面那些路径就是你自己把权重文件下载到自己电脑里的路径,根据自己的情况改

下载权重文件的时候要注意下载正确的权重文件:

        1.注意文件名尾部是否带notop,如果有说明没有分类头

        2.注意文件名的中间是‘tf’还是‘th’,'tf'代表tensorflow后端,'th'代表theano后端

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/877952.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-05-13
下一篇 2022-05-13

发表评论

登录后才能评论

评论列表(0条)

保存