NMT平行语料划分数据集

NMT平行语料划分数据集,第1张

NMT平行语料划分数据集 目标:

将数据集按比例划分为 train、test、val。

对平行语料处理后如下图所示:

步骤:
  1. 随机打乱数据集
  2. 划分数据集
  3. 划分平行语料
代码如下:
import os
import random


def data_split(config, file, train_ratio=0.98, shuffle=True):
    """
    :param config: 数据文件所在的文件夹名
    :param file: 要处理数据的文件名(全称)
    :param train_ratio: 训练集占比
    :param shuffle: 是否打乱
    :return: None
    """
    with open(os.path.join(config, file), 'r', encoding='utf-8') as fp:  # 用拼接config与file
        lines = fp.read().strip().split('n')
    n = len(lines)
    if shuffle:
        random.shuffle(lines)  # 随机打乱数据集

    train_len = int(n * train_ratio)
    val_len = int(n * (1 - train_ratio) / 2)

    train_data = lines[:train_len]  # 训练集
    val_data = lines[train_len:(train_len + val_len + 1)]  # 验证集
    test_data = lines[(train_len + val_len + 1):]  # 测试集

    train = 'train.txt'
    test = 'test.txt'
    val = 'val.txt'
    s_config = os.path.join(config, 'dataset')  # 划分后数据集存放位置

    with open(os.path.join(config, train), 'w', encoding='utf-8') as fp:
        fp.write("n".join(train_data))
        para_divide(s_config, train)
    with open(os.path.join(config, test), 'w', encoding='utf-8') as fp:
        fp.write("n".join(test_data))
        para_divide(s_config, test)
    with open(os.path.join(config, val), 'w', encoding='utf-8') as fp:
        fp.write("n".join(val_data))
        para_divide(s_config, val)

    print('总共有数据:{}条'.format(n))
    print('训练集:{}条'.format(len(train_data)))
    print('测试集:{}条'.format(len(test_data)))
    print('验证集:{}条'.format(len(val_data)))


def para_divide(config, data):  # 划分平行语料
    f_data = open(os.path.join(config[:4], data), 'r', encoding='utf-8', errors='ignore')  # config[:4]='data'
    en = open(os.path.join(config, (data[:-4]+'.en')), 'w', encoding='utf-8')  # data[:-4]去除文件后缀
    zh = open(os.path.join(config, (data[:-4]+'.zh')), 'w', encoding='utf-8')

    line = f_data.readline()
    while line:
        l, r = line.strip().split('t')  # 按t划分

        l = l.strip() + 'n'
        r = r.strip() + 'n'

        en.writelines(l)
        zh.writelines(r)

        line = f_data.readline()

    en.close()
    zh.close()
    f_data.close()


def main():
    data_split("data", "en-zh.txt")


if __name__ == '__main__':
    main()

欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/zaji/5689516.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-17
下一篇 2022-12-17

发表评论

登录后才能评论

评论列表(0条)

保存