基于ResNet的猫十二分类

基于ResNet的猫十二分类,第1张

        在这次实战训练中,首先对下载的猫十二数据集进行预处理,使用了tensorflow构建resnet模型,在学习率调度上,使用了1周期调度,并且使用了动量优化和Nesterov加速梯度

1.导包
from tensorflow import keras
import tensorflow as tf
from keras.preprocessing import image
import random
from matplotlib import pyplot as plt
import cv2
from tqdm import tqdm
import numpy as np
import math
2.数据预处理

cat_12数据集包含3个部分,训练集cat_12_train,测试集cat_test,以及存储图片名称及标签的train_list.txt

(1)定义prepare_image函数从文件中分离路径和标签

def prepare_image(file_path):
    X_train = []
    y_train = []
 
    with open(file_path) as f:
        context = f.readlines()
    random.shuffle(context)
 
    for str in context:
        str = str.strip('\n').split('\t')
 
        X_train.append('./image/cat_12/' + str[0])
        y_train.append(str[1])
 
 

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/918213.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-05-16
下一篇 2022-05-16

发表评论

登录后才能评论

评论列表(0条)

保存