概要
【零基础-1】PaddlePaddle学习Bert_ 一只博客-CSDN博客https://blog.csdn.net/qq_42276781/article/details/121488335【零基础-2】PaddlePaddle学习Bert_ 一只博客-CSDN博客https://blog.csdn.net/qq_42276781/article/details/121523268
Cell 7# 创建dataloader def create_dataloader(dataset, mode='train', batch_size=1, batchify_fn=None, trans_fn=None): if trans_fn: dataset = dataset.map(trans_fn) shuffle = True if mode == 'train' else False if mode == 'train': batch_sampler = paddle.io.DistributedBatchSampler( dataset, batch_size=batch_size, shuffle=shuffle) else: batch_sampler = paddle.io.BatchSampler( dataset, batch_size=batch_size, shuffle=shuffle) return paddle.io.DataLoader( dataset=dataset, batch_sampler=batch_sampler, collate_fn=batchify_fn, return_list=True)Snippet 1
def create_dataloader(dataset, mode='train', batch_size=1, batchify_fn=None, trans_fn=None):
create_dataloader,创建数据加载器,输入数据集dataset、模式mode(默认为训练集)、batchify_fn(未知,暂时理解成batchify_function,即batch化的函数)、trans_fn(转换样本的函数)。
Snippet 2if trans_fn: dataset = dataset.map(trans_fn)
如果传入了trans_fn,就使用trans_fn将dataset进行一个转换,dataset.map的api文档如下
dataset — PaddleNLP 文档https://paddlenlp.readthedocs.io/zh/latest/source/paddlenlp.datasets.dataset.html?highlight=dataset.map#paddlenlp.datasets.dataset.MapDataset.map
Snippet 3shuffle = True if mode == 'train' else False
如果是训练集,就打乱,否则不打乱,这里的语法相当于C、Java的三目运算符
shuffle = mode == 'train' ? true : falseSnippet 4
if mode == 'train': batch_sampler = paddle.io.DistributedBatchSampler( dataset, batch_size=batch_size, shuffle=shuffle) else: batch_sampler = paddle.io.BatchSampler( dataset, batch_size=batch_size, shuffle=shuffle)
如果传入的是训练集,则调用paddle.io.DistributedBatchSampler处理得到batch_sampler,如果传入的不是训练集,则调用paddle.io.BatchSampler处理得到batch_sampler。
这里为什么要得到batch_sampler呢?
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)