【python】函数参数*args的作用

【python】函数参数*args的作用,第1张

经常看到很多代码里定义函数的时候最后会有*args,当时不知道什么意思和作用,但并不影响阅读,但并不影响阅读,就没有继续深入研究了。
后来一次看代码时候忽然有所感触,所以记录下来。

作用:表示函数后续的传入的参数个数未知

我们在使用过程中,往往有可能无法固定下来所有参数的数量。所以将不确定的参数全部用替代*args。当调用时候的参数超过固定参数时,不会报错。

例如,定义函数sum:

def sum(a,b):
    return a+b

然后,调用时:

result = sum(1,2,3)

会进行报错:

然后加入*args参数重新定义函数sum:

def sum(a,b,*args):
    return a+b

则不会报错。

常见的使用场景:

从例子上看感觉并没有什么用,然后这样的好处之一就是可以统一格式,显得更简洁,更有规范:

例如有一系列作用相似的损失函数,需要用到很多相同但并不是完全一致的参数,那么增加参数*args进行定义:

def compute_diversity(pred, *args):
    if pred.shape[0] == 1:
        return 0.0
    dist = pdist(pred.reshape(pred.shape[0], -1))
    diversity = dist.mean().item()
    return diversity


def compute_ade(pred, gt, *args):
    diff = pred - gt
    dist = np.linalg.norm(diff, axis=2).mean(axis=1)
    return dist.min()


def compute_fde(pred, gt, *args):
    diff = pred - gt
    dist = np.linalg.norm(diff, axis=2)[:, -1]
    return dist.min()


def compute_mmade(pred, gt, gt_multi):
    gt_dist = []
    for gt_multi_i in gt_multi:
        dist = compute_ade(pred, gt_multi_i)
        gt_dist.append(dist)
    gt_dist = np.array(gt_dist).mean()
    return gt_dist


def compute_mmfde(pred, gt, gt_multi):
    gt_dist = []
    for gt_multi_i in gt_multi:
        dist = compute_fde(pred, gt_multi_i)
        gt_dist.append(dist)
    gt_dist = np.array(gt_dist).mean()
    return gt_dist

然后在调用的时候,可以直接生成一个字典类型:

    stats_func = {'Diversity': compute_diversity, 'ADE': compute_ade,
                  'FDE': compute_fde, 'MMADE': compute_mmade, 'MMFDE': compute_mmfde}

然后调用损失函数计算的时候直接循环调用,调用的参数直接取并集,这样整个代码非常清晰整洁:

            for stats in stats_names:
                val = 0
                for pred_i in pred:
                    val += stats_func[stats](pred_i, gt, gt_multi) / num_seeds

而不借助*args,这种情况通常只会这样写,非常的繁琐:

                compute_diversity_val = 0
                compute_ade_val = 0
                compute_fde_val = 0
                compute_mmade_val = 0
                compute_mmfde_val = 0
                for pred_i in pred:
                    compute_diversity_val+=compute_diversity(pred) / num_seeds
                    compute_ade_val+=compute_ade(pred, gt) / num_seeds
                    compute_fde_val+= compute_fde(pred, gt) / num_seeds
                    compute_mmade_val+=compute_mmade(pred, gt, gt_multi)/ num_seeds
                    compute_mmfde_val+=compute_mmfde(pred, gt, gt_multi) / num_seeds

当然,*args肯定也有其他作用,可以在函数内使用,方法很简单。

例如:

def sum(*args):
    sum=0
    for arg in args:
        sum+= arg
    return sum

print(sum(1,2,3,4))

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/942463.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-05-18
下一篇 2022-05-18

发表评论

登录后才能评论

评论列表(0条)

保存