【零基础-3】PaddlePaddle学习Bert

【零基础-3】PaddlePaddle学习Bert,第1张

【零基础-3】PaddlePaddle学习Bert

概要

【零基础-1】PaddlePaddle学习Bert_ 一只博客-CSDN博客https://blog.csdn.net/qq_42276781/article/details/121488335【零基础-2】PaddlePaddle学习Bert_ 一只博客-CSDN博客https://blog.csdn.net/qq_42276781/article/details/121523268

Cell 7
# 创建dataloader
def create_dataloader(dataset,
                      mode='train',
                      batch_size=1,
                      batchify_fn=None,
                      trans_fn=None):
    if trans_fn:
        dataset = dataset.map(trans_fn)

    shuffle = True if mode == 'train' else False
    if mode == 'train':
        batch_sampler = paddle.io.DistributedBatchSampler(
            dataset, batch_size=batch_size, shuffle=shuffle)
    else:
        batch_sampler = paddle.io.BatchSampler(
            dataset, batch_size=batch_size, shuffle=shuffle)

    return paddle.io.DataLoader(
        dataset=dataset,
        batch_sampler=batch_sampler,
        collate_fn=batchify_fn,
        return_list=True)
Snippet  1
def create_dataloader(dataset,
                      mode='train',
                      batch_size=1,
                      batchify_fn=None,
                      trans_fn=None):

create_dataloader,创建数据加载器,输入数据集dataset、模式mode(默认为训练集)、batchify_fn(未知,暂时理解成batchify_function,即batch化的函数)、trans_fn(转换样本的函数)。

Snippet 2
if trans_fn:
        dataset = dataset.map(trans_fn)

如果传入了trans_fn,就使用trans_fn将dataset进行一个转换,dataset.map的api文档如下

dataset — PaddleNLP 文档https://paddlenlp.readthedocs.io/zh/latest/source/paddlenlp.datasets.dataset.html?highlight=dataset.map#paddlenlp.datasets.dataset.MapDataset.map

Snippet 3
shuffle = True if mode == 'train' else False

如果是训练集,就打乱,否则不打乱,这里的语法相当于C、Java的三目运算符

shuffle = mode == 'train' ? true : false 
Snippet 4
if mode == 'train':
        batch_sampler = paddle.io.DistributedBatchSampler(
            dataset, batch_size=batch_size, shuffle=shuffle)
    else:
        batch_sampler = paddle.io.BatchSampler(
            dataset, batch_size=batch_size, shuffle=shuffle)

如果传入的是训练集,则调用paddle.io.DistributedBatchSampler处理得到batch_sampler,如果传入的不是训练集,则调用paddle.io.BatchSampler处理得到batch_sampler。

这里为什么要得到batch_sampler呢?

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5605981.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-15
下一篇 2022-12-15

发表评论

登录后才能评论

评论列表(0条)

保存