cat_12数据集包含3个部分,训练集cat_12_train,测试集cat_test,以及存储图片名称及标签的train_list.txt
2.数据预处理首先,定义一个prepare_image函数,取出文本文件中的图片路径与标签,并且打乱顺序
def prepare_image(file_path):
X_train = []
y_train = []
with open(file_path) as f:
context = f.readlines()
random.shuffle(context)
for str in context:
str = str.strip('\n').split('\t')
X_train.append('./cat_12/' + str[0])
y_train.append(str[1])
return X_train, y_train
再定义一个preprocess_image进行图片归一化 *** 作,将像素值限制在0-1之间。
# 数据归一化
def preprocess_image(image):
image = tf.io.read_file(image)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize_with_
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)