2021SC@SDUSC
- 生物大分子平台(7)
- 0 前言
- 1 preprocess
- 1.1 导入相应的库
- 库函数详细解读。
- 1.2 方法解读
- 导入相应的数据库(网上在线数据库)
- 下载并变换格式方法
- 调用命令行创建新文件夹
- 编译文件,解码文件(以及详细解读)
- 1.2 main函数解读
- 创建文件夹
- 下载原始数据集合
- 合并所有的文件到一个中
- 建立训练文件集合代码
- 调用文本数据处理函数
- parser模块使用
- SRC、TRG计算
本周继续阅读了transformer的代码实现。主要是preprocess部分的代码阅读。
1 preprocess 1.1 导入相应的库import os import argparse import logging import dill as pickle import urllib from tqdm import tqdm import sys import codecs import spacy import torch import tarfile import torchtext.data import torchtext.datasets from torchtext.datasets import TranslationDataset import transformer.Constants as Constants from learn_bpe import learn_bpe from apply_bpe import BPE库函数详细解读。
- os库是python中提供通用的、基本的 *** 作系统交互功能的库
- argparse 模块是 Python 内置的一个用于命令项选项与参数解析的模块,argparse 模块可以让人轻松编写用户友好的命令行接口。
- logging 用作记录日志
- dill将python用于序列化和反序列化python对象的pickle模块扩展到大多数内置python类型。比如嵌套函数类型的对象pickle不可以存储,但dill可以。dill提供和pickle相同的接口,使用时,“import dill as pickle”即可。
- urllib是Python中请求url连接的官方标准库
- tqdm能非常完美的支持和解决在使用Python处理比较耗时 *** 作的时候,为了便于观察处理进度,需要通过进度条将处理情况进行可视化展示,以便我们能够及时了解情况的问题,可以实时输出处理进度而且占用的CPU资源非常少,支持windows、Linux、mac等系统,支持循环处理、多进程、递归处理、还可以结合linux的命令来查看处理情况,等进度展示。
- sys模块包括了一组非常实用的服务,内含很多函数方法和变量,用来处理Python运行时配置以及资源,从而可以与前当程序之外的系统环境交互,如:python解释器。
- codecs专门用作编码转换,当我们要做编码转换的时候可以借助codecs很简单的进行编码转换
- spacy 一个NLP库
- torch 一个深度学习库
__author__ = "Yu-Hsiang Huang" _TRAIN_DATA_SOURCES = [ {"url": "http://data.statmt.org/wmt17/translation-task/" "training-parallel-nc-v12.tgz", "trg": "news-commentary-v12.de-en.en", "src": "news-commentary-v12.de-en.de"}, #{"url": "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz", # "trg": "commoncrawl.de-en.en", # "src": "commoncrawl.de-en.de"}, #{"url": "http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz", # "trg": "europarl-v7.de-en.en", # "src": "europarl-v7.de-en.de"} ] _VAL_DATA_SOURCES = [ {"url": "http://data.statmt.org/wmt17/translation-task/dev.tgz", "trg": "newstest2013.en", "src": "newstest2013.de"}] _TEST_DATA_SOURCES = [ {"url": "https://storage.googleapis.com/tf-perf-public/" "official_transformer/test_data/newstest2014.tgz", "trg": "newstest2014.en", "src": "newstest2014.de"}]下载并变换格式方法
def download_and_extract(download_dir, url, src_filename, trg_filename): src_path = file_exist(download_dir, src_filename) trg_path = file_exist(download_dir, trg_filename) if src_path and trg_path: sys.stderr.write(f"Already downloaded and extracted {url}.n") return src_path, trg_path compressed_file = _download_file(download_dir, url) sys.stderr.write(f"Extracting {compressed_file}.n") with tarfile.open(compressed_file, "r:gz") as corpus_tar: corpus_tar.extractall(download_dir) src_path = file_exist(download_dir, src_filename) trg_path = file_exist(download_dir, trg_filename) if src_path and trg_path: return src_path, trg_path raise OSError(f"Download/extraction failed for url {url} to path {download_dir}")调用命令行创建新文件夹
def mkdir_if_needed(dir_name): if not os.path.isdir(dir_name): os.makedirs(dir_name)编译文件,解码文件(以及详细解读)
def compile_files(raw_dir, raw_files, prefix): src_fpath = os.path.join(raw_dir, f"raw-{prefix}.src") trg_fpath = os.path.join(raw_dir, f"raw-{prefix}.trg") if os.path.isfile(src_fpath) and os.path.isfile(trg_fpath): sys.stderr.write(f"Merged files found, skip the merging process.n") return src_fpath, trg_fpath sys.stderr.write(f"Merge files into two files: {src_fpath} and {trg_fpath}.n") with open(src_fpath, 'w') as src_outf, open(trg_fpath, 'w') as trg_outf: for src_inf, trg_inf in zip(raw_files['src'], raw_files['trg']): sys.stderr.write(f' Input files: n' f' - SRC: {src_inf}, andn' f' - TRG: {trg_inf}.n') with open(src_inf, newline='n') as src_inf, open(trg_inf, newline='n') as trg_inf: cntr = 0 for i, line in enumerate(src_inf): cntr += 1 src_outf.write(line.replace('r', ' ').strip() + 'n') for j, line in enumerate(trg_inf): cntr -= 1 trg_outf.write(line.replace('r', ' ').strip() + 'n') assert cntr == 0, 'Number of lines in two files are inconsistent.' return src_fpath, trg_fpath def encode_file(bpe, in_file, out_file): sys.stderr.write(f"Read raw content from {in_file} and n" f"Write encoded content to {out_file}n") with codecs.open(in_file, encoding='utf-8') as in_f: with codecs.open(out_file, 'w', encoding='utf-8') as out_f: for line in in_f: out_f.write(bpe.process_line(line)) def encode_files(bpe, src_in_file, trg_in_file, data_dir, prefix): src_out_file = os.path.join(data_dir, f"{prefix}.src") trg_out_file = os.path.join(data_dir, f"{prefix}.trg") if os.path.isfile(src_out_file) and os.path.isfile(trg_out_file): sys.stderr.write(f"Encoded files found, skip the encoding process ...n") encode_file(bpe, src_in_file, src_out_file) encode_file(bpe, trg_in_file, trg_out_file) return src_out_file, trg_out_file
- 首先使用stderr部分输出对这个问题的错误异常处理
- 调用命令行部分实现对相应命令行的写 *** 作
sys.stderr.write(f' Input files: n' f' - SRC: {src_inf}, andn' f' - TRG: {trg_inf}.n')
- 打开文件读取文件
for i, line in enumerate(src_inf): cntr += 1 src_outf.write(line.replace('r', ' ').strip() + 'n') for j, line in enumerate(trg_inf): cntr -= 1 trg_outf.write(line.replace('r', ' ').strip() + 'n') assert cntr == 0, 'Number of lines in two files are inconsistent.'
- assert()断言函数,用于在调试过程中捕捉程序错误
assert cntr == 0, 'Number of lines in two files are inconsistent.'
- 编码文件 使用stderr输出错误异常,打开文件读取其中的内容
with codecs.open(in_file, encoding='utf-8') as in_f: with codecs.open(out_file, 'w', encoding='utf-8') as out_f: for line in in_f: out_f.write(bpe.process_line(line))
- 编码整个文件夹
- os.path.join()函数用于路径拼接文件路径,可以传入多个路径
def encode_files(bpe, src_in_file, trg_in_file, data_dir, prefix): src_out_file = os.path.join(data_dir, f"{prefix}.src") trg_out_file = os.path.join(data_dir, f"{prefix}.trg") if os.path.isfile(src_out_file) and os.path.isfile(trg_out_file): sys.stderr.write(f"Encoded files found, skip the encoding process ...n") encode_file(bpe, src_in_file, src_out_file) encode_file(bpe, trg_in_file, trg_out_file) return src_out_file, trg_out_file1.2 main函数解读 创建文件夹
mkdir_if_needed(opt.raw_dir) mkdir_if_needed(opt.data_dir)下载原始数据集合
raw_train = get_raw_files(opt.raw_dir, _TRAIN_DATA_SOURCES) raw_val = get_raw_files(opt.raw_dir, _VAL_DATA_SOURCES) raw_test = get_raw_files(opt.raw_dir, _TEST_DATA_SOURCES)合并所有的文件到一个中
train_src, train_trg = compile_files(opt.raw_dir, raw_train, opt.prefix + '-train') val_src, val_trg = compile_files(opt.raw_dir, raw_val, opt.prefix + '-val') test_src, test_trg = compile_files(opt.raw_dir, raw_test, opt.prefix + '-test')建立训练文件集合代码
opt.codes = os.path.join(opt.data_dir, opt.codes) if not os.path.isfile(opt.codes): sys.stderr.write(f"Collect codes from training data and save to {opt.codes}.n") learn_bpe(raw_train['src'] + raw_train['trg'], opt.codes, opt.symbols, opt.min_frequency, True) sys.stderr.write(f"BPE codes prepared.n") sys.stderr.write(f"Build up the tokenizer.n") with codecs.open(opt.codes, encoding='utf-8') as codes: bpe = BPE(codes, separator=opt.separator) sys.stderr.write(f"Encoding ...n") encode_files(bpe, train_src, train_trg, opt.data_dir, opt.prefix + '-train') encode_files(bpe, val_src, val_trg, opt.data_dir, opt.prefix + '-val') encode_files(bpe, test_src, test_trg, opt.data_dir, opt.prefix + '-test') sys.stderr.write(f"Done.n")调用文本数据处理函数
field = torchtext.data.Field( tokenize=str.split, lower=True, pad_token=Constants.PAD_WORD, init_token=Constants.BOS_WORD, eos_token=Constants.EOS_WORD)parser模块使用
- 介绍链接:https://blog.csdn.net/qq_34243930/article/details/106517985
parser = argparse.ArgumentParser() parser.add_argument('-lang_src', required=True, choices=spacy_support_langs) parser.add_argument('-lang_trg', required=True, choices=spacy_support_langs) parser.add_argument('-save_data', required=True) parser.add_argument('-data_src', type=str, default=None) parser.add_argument('-data_trg', type=str, default=None) parser.add_argument('-max_len', type=int, default=100) parser.add_argument('-min_word_count', type=int, default=3) parser.add_argument('-keep_case', action='store_true') parser.add_argument('-share_vocab', action='store_true')SRC、TRG计算
SRC = torchtext.data.Field( tokenize=tokenize_src, lower=not opt.keep_case, pad_token=Constants.PAD_WORD, init_token=Constants.BOS_WORD, eos_token=Constants.EOS_WORD) TRG = torchtext.data.Field( tokenize=tokenize_trg, lower=not opt.keep_case, pad_token=Constants.PAD_WORD, init_token=Constants.BOS_WORD, eos_token=Constants.EOS_WORD) MAX_LEN = opt.max_len MIN_FREQ = opt.min_word_count
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)