经常看到很多代码里定义函数的时候最后会有*args
,当时不知道什么意思和作用,但并不影响阅读,但并不影响阅读,就没有继续深入研究了。
后来一次看代码时候忽然有所感触,所以记录下来。
我们在使用过程中,往往有可能无法固定下来所有参数的数量。所以将不确定的参数全部用替代*args
。当调用时候的参数超过固定参数时,不会报错。
例如,定义函数sum:
def sum(a,b):
return a+b
然后,调用时:
result = sum(1,2,3)
会进行报错:
然后加入*args
参数重新定义函数sum:
def sum(a,b,*args):
return a+b
则不会报错。
常见的使用场景:从例子上看感觉并没有什么用,然后这样的好处之一就是可以统一格式,显得更简洁,更有规范:
例如有一系列作用相似的损失函数,需要用到很多相同但并不是完全一致的参数,那么增加参数*args
进行定义:
def compute_diversity(pred, *args):
if pred.shape[0] == 1:
return 0.0
dist = pdist(pred.reshape(pred.shape[0], -1))
diversity = dist.mean().item()
return diversity
def compute_ade(pred, gt, *args):
diff = pred - gt
dist = np.linalg.norm(diff, axis=2).mean(axis=1)
return dist.min()
def compute_fde(pred, gt, *args):
diff = pred - gt
dist = np.linalg.norm(diff, axis=2)[:, -1]
return dist.min()
def compute_mmade(pred, gt, gt_multi):
gt_dist = []
for gt_multi_i in gt_multi:
dist = compute_ade(pred, gt_multi_i)
gt_dist.append(dist)
gt_dist = np.array(gt_dist).mean()
return gt_dist
def compute_mmfde(pred, gt, gt_multi):
gt_dist = []
for gt_multi_i in gt_multi:
dist = compute_fde(pred, gt_multi_i)
gt_dist.append(dist)
gt_dist = np.array(gt_dist).mean()
return gt_dist
然后在调用的时候,可以直接生成一个字典类型:
stats_func = {'Diversity': compute_diversity, 'ADE': compute_ade,
'FDE': compute_fde, 'MMADE': compute_mmade, 'MMFDE': compute_mmfde}
然后调用损失函数计算的时候直接循环调用,调用的参数直接取并集,这样整个代码非常清晰整洁:
for stats in stats_names:
val = 0
for pred_i in pred:
val += stats_func[stats](pred_i, gt, gt_multi) / num_seeds
而不借助*args
,这种情况通常只会这样写,非常的繁琐:
compute_diversity_val = 0
compute_ade_val = 0
compute_fde_val = 0
compute_mmade_val = 0
compute_mmfde_val = 0
for pred_i in pred:
compute_diversity_val+=compute_diversity(pred) / num_seeds
compute_ade_val+=compute_ade(pred, gt) / num_seeds
compute_fde_val+= compute_fde(pred, gt) / num_seeds
compute_mmade_val+=compute_mmade(pred, gt, gt_multi)/ num_seeds
compute_mmfde_val+=compute_mmfde(pred, gt, gt_multi) / num_seeds
当然,*args
肯定也有其他作用,可以在函数内使用,方法很简单。
例如:
def sum(*args):
sum=0
for arg in args:
sum+= arg
return sum
print(sum(1,2,3,4))
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)