将数据集按比例划分为 train、test、val。
对平行语料处理后如下图所示:
步骤:- 随机打乱数据集
- 划分数据集
- 划分平行语料
import os import random def data_split(config, file, train_ratio=0.98, shuffle=True): """ :param config: 数据文件所在的文件夹名 :param file: 要处理数据的文件名(全称) :param train_ratio: 训练集占比 :param shuffle: 是否打乱 :return: None """ with open(os.path.join(config, file), 'r', encoding='utf-8') as fp: # 用拼接config与file lines = fp.read().strip().split('n') n = len(lines) if shuffle: random.shuffle(lines) # 随机打乱数据集 train_len = int(n * train_ratio) val_len = int(n * (1 - train_ratio) / 2) train_data = lines[:train_len] # 训练集 val_data = lines[train_len:(train_len + val_len + 1)] # 验证集 test_data = lines[(train_len + val_len + 1):] # 测试集 train = 'train.txt' test = 'test.txt' val = 'val.txt' s_config = os.path.join(config, 'dataset') # 划分后数据集存放位置 with open(os.path.join(config, train), 'w', encoding='utf-8') as fp: fp.write("n".join(train_data)) para_divide(s_config, train) with open(os.path.join(config, test), 'w', encoding='utf-8') as fp: fp.write("n".join(test_data)) para_divide(s_config, test) with open(os.path.join(config, val), 'w', encoding='utf-8') as fp: fp.write("n".join(val_data)) para_divide(s_config, val) print('总共有数据:{}条'.format(n)) print('训练集:{}条'.format(len(train_data))) print('测试集:{}条'.format(len(test_data))) print('验证集:{}条'.format(len(val_data))) def para_divide(config, data): # 划分平行语料 f_data = open(os.path.join(config[:4], data), 'r', encoding='utf-8', errors='ignore') # config[:4]='data' en = open(os.path.join(config, (data[:-4]+'.en')), 'w', encoding='utf-8') # data[:-4]去除文件后缀 zh = open(os.path.join(config, (data[:-4]+'.zh')), 'w', encoding='utf-8') line = f_data.readline() while line: l, r = line.strip().split('t') # 按t划分 l = l.strip() + 'n' r = r.strip() + 'n' en.writelines(l) zh.writelines(r) line = f_data.readline() en.close() zh.close() f_data.close() def main(): data_split("data", "en-zh.txt") if __name__ == '__main__': main()
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)