import os
import tensorflow as tf
from PIL import Image
# 源数据地址
cwd = 'C:\\Users\\xiaodeng\\Desktop\\UCMerced_LandUse\\Images'
# 生成record路径及文件名
train_record_path = r"C:\\Users\\xiaodeng\\Desktop\\train.tfrecords"
test_record_path = r"C:\\Users\\xiaodeng\\Desktop\\test.tfrecords"
# 分类
classes = {'agricultural','airplane','baseballdiamond',
'beach','buildings','chaparral','denseresidential',
'forest','freeway','golfcourse','harbor',
'intersection','mediumresidential','mobilehomepark','overpass',
'parkinglot','river','runway','sparseresidential','storagetanks','tenniscourt'}
def _byteslist(value):
"""二进制属性"""
return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
def _int64list(value):
"""整数属性"""
return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
#def create_train_record(cwd,classes):
"""创建训练集tfrecord"""
writer = tf.python_io.TFRecordWriter(train_record_path) # 创建一个writer
NUM = 1 # 显示创建过程(计数)
for index, name in enumerate(classes):
class_path = cwd + "/" + name + '/'
l = int(len(os.listdir(class_path)) * 0.7) # 取前70%创建训练集
for img_name in os.listdir(class_path)[:l]:
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((256, 256))# resize图片大小
img_raw = img.tobytes() # 将图片转化为原生bytes
example = tf.train.Example( # 封装到Example中
features=tf.train.Features(feature={
"label":_int64list(index), # label必须为整数类型属性
'img_raw':_byteslist(img_raw) # 图片必须为二进制属性
}))
writer.write(example.SerializeToString())
print('Creating train record in ',NUM)
NUM += 1
writer.close() # 关闭writer
print("Create train_record successful!")
#def create_test_record(cwd,classes):
"""创建测试tfrecord"""
writer = tf.python_io.TFRecordWriter(test_record_path)
NUM = 1
for index, name in enumerate(classes):
class_path = cwd + '/' + name + '/'
l = int(len(os.listdir(class_path)) * 0.7)
for img_name in os.listdir(class_path)[l:]: # 剩余30%作为测试集
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((256, 256))
img_raw = img.tobytes() # 将图片转化为原生bytes
# print(index,img_raw)
example = tf.train.Example(
features=tf.train.Features(feature={
"label":_int64list(index),
'img_raw':_byteslist(img_raw)
}))
writer.write(example.SerializeToString())
print('Creating test record in ',NUM)
NUM += 1
writer.close()
在发布逾一周年之际,TensorFlow 终于将迎来史上最重大更新:TensorFlow 1.0。对于不熟悉开源框架的读者,TensorFlow 是谷歌 2015 年底推出的深度学习框架,在开发者社区享有盛誉。去年,它已成为 GitHub 最受欢迎的机器学习开源项目。因其高度普及率,尤其是在 Python 生态圈中,TensorFlow 的功能变化会对全世界的机器学习开发者造成重大影响。
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)