2021SC@SDUSC
我们继续分析Field类中的其他类函数。首先看preprocess()函数。
def preprocess(self, x): if self.sequential and isinstance(x, str): x = self.tokenize(x.rstrip('n')) if self.lower: x = Pipeline(str.lower)(x) if self.sequential and self.use_vocab and self.stop_words is not None: x = [w for w in x if w not in self.stop_words] if self.preprocessing is not None: return self.preprocessing(x) else: return x
这个函数首先判断序列x是否为顺序的字符串类型的数据,如果x满足条件,就被标记。
然后判断序列x是否为小写,如果是的话,就把x传递给用户提供的“预处理”管道。
如果序列x是顺序的序列且是使用Vocab对象的,并且预处理步骤中有需要丢弃的令牌,那么就对x进行数据的清洗。
最后返回预处理后的x或者x。
再看process()函数。
def process(self, batch, device=None): padded = self.pad(batch) tensor = self.numericalize(padded, device=device) return tensor
process函数来处理一系列的例子来创建一个torch.Tensor。对批处理进行pad、数字化和后处理,然后创建一个张量。
参数batch(list(object)):来自一批示例的对象列表。返回的tensor是给定输入的处理对象和自定义后处理管道。
在Field类中还有以下三个函数,接下来我会逐个解释他们的作用。
def pad(self, minibatch): minibatch = list(minibatch) if not self.sequential: return minibatch if self.fix_length is None: max_len = max(len(x) for x in minibatch) else: max_len = self.fix_length + ( self.init_token, self.eos_token).count(None) - 2 padded, lengths = [], [] for x in minibatch: if self.pad_first: padded.append( [self.pad_token] * max(0, max_len - len(x)) + ([] if self.init_token is None else [self.init_token]) + list(x[-max_len:] if self.truncate_first else x[:max_len]) + ([] if self.eos_token is None else [self.eos_token])) else: padded.append( ([] if self.init_token is None else [self.init_token]) + list(x[-max_len:] if self.truncate_first else x[:max_len]) + ([] if self.eos_token is None else [self.eos_token]) + [self.pad_token] * max(0, max_len - len(x))) lengths.append(len(padded[-1]) - max(0, max_len - len(x))) if self.include_lengths: return (padded, lengths) return padded
首先是pad函数,用来填充一批示例。如果提供了fix_length,那么就填充这批例子中最长的那个长度。如果这些属性不是none,就突出init_token,附加eos_token。如果include_lengths和sequential是true,也就是说,如果返回带填充的minibatch和的元组,一个包含每个示例长度的列表,或者只是一个填充的minibatch且是顺序序列,就返回填充列表和包含每个示例长度的列表的一个元组,否则就返回这个填充的list,如果序列不是顺序的,就不返回填充序列。
def build_vocab(self, *args, **kwargs): counter = Counter() sources = [] for arg in args: if isinstance(arg, Dataset): sources += [getattr(arg, name) for name, field in arg.fields.items() if field is self] else: sources.append(arg) for data in sources: for x in data: if not self.sequential: x = [x] try: counter.update(x) except TypeError: counter.update(chain.from_iterable(x)) specials = list(OrderedDict.fromkeys( tok for tok in [self.unk_token, self.pad_token, self.init_token, self.eos_token] + kwargs.pop('specials', []) if tok is not None)) self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)
接下来是build_vocab函数,用来从一个或多个数据集为该字段构造Vocab对象。位置参数表示数据集对象或其他可迭代数据源,从中构造表示此字段可能值集的Vocab对象。如果提供了Dataset对象,则使用该字段对应的所有列;单独的列也可以直接提供。剩下的关键字参数表示传递给Vocab的构造函数。
def numericalize(self, arr, device=None): if self.include_lengths and not isinstance(arr, tuple): raise ValueError("Field has include_lengths set to True, but " "input data is not a tuple of " "(data batch, batch lengths).") if isinstance(arr, tuple): arr, lengths = arr lengths = torch.tensor(lengths, dtype=self.dtype, device=device) if self.use_vocab: if self.sequential: arr = [[self.vocab.stoi[x] for x in ex] for ex in arr] else: arr = [self.vocab.stoi[x] for x in arr] if self.postprocessing is not None: arr = self.postprocessing(arr, self.vocab) else: if self.dtype not in self.dtypes: raise ValueError( "Specified Field dtype {} can not be used with " "use_vocab=False because we do not know how to numericalize it. " "Please raise an issue at " "https://github.com/pytorch/text/issues".format(self.dtype)) numericalization_func = self.dtypes[self.dtype] if not self.sequential: arr = [numericalization_func(x) if isinstance(x, str) else x for x in arr] if self.postprocessing is not None: arr = self.postprocessing(arr, None) var = torch.tensor(arr, dtype=self.dtype, device=device) if self.sequential and not self.batch_first: var.t_() if self.sequential: var = var.contiguous() if self.include_lengths: return var, lengths return var
然后实numericalize函数,即实现数值化的函数,将一批使用该字段的示例转换为一个变量。如果字段包含include_length_true,则返回值中将包含一个长度张量。
arr:标记化和填充示例的列表,或标记化和填充示例的列表的元组和每个示例if self的长度列表。
device:一个“token.device”的字符串或实例,指定要在哪个设备上创建变量。如果保持默认值,张量将在cpu上创建。
分析完以上,我们回到dataset类中,看其他的函数,之后只选择重要的函数来分析。
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)