您可以使用如下函数:
import tensorflow as tfdef split_tfrecord(tfrecord_path, split_size): with tf.Graph().as_default(), tf.Session() as sess: ds = tf.data.TFRecordDataset(tfrecord_path).batch(split_size) batch = ds.make_one_shot_iterator().get_next() part_num = 0 while True: try: records = sess.run(batch) part_path = tfrecord_path + '.{:03d}'.format(part_num) with tf.python_io.TFRecordWriter(part_path) as writer: for record in records: writer.write(record) part_num += 1 except tf.errors.OutOfRangeError: break
例如,要将文件
my_records.tfrecord分成100条记录,您可以执行以下 *** 作:
split_tfrecord(my_records.tfrecord, 100)
这将创建多个较小的记录文件
my_records.tfrecord.000,
my_records.tfrecord.001等等。
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)