在跑AICITY2020_DMT_HST的代码时遇到的问题,记录一下
参考文章
问题描述 解决方法将不用的参数过滤掉
#原代码 # def load_param(self, model_path): # param_dict = torch.load(model_path) # if 'state_dict' in param_dict: # param_dict = param_dict['state_dict'] # for i in param_dict: # if 'fc' in i: # continue # embed() # self.state_dict()[i.replace('module.','')].copy_(param_dict[i]) #改的代码 def load_param(self, model_path): param_dict = torch.load(model_path) param_dict = {k: v for k, v in param_dict.items() if 'bn1.running_mean' not in k} if 'state_dict' in param_dict: param_dict = param_dict['state_dict'] for i in param_dict: # embed() if 'layer1.0.bn1.running_mean' in i: continue if 'fc' in i: continue if 'bn1.running_var' in i: continue if 'bn1.weight' in i: continue if 'bn1.bias' in i: continue if 'layer1.1.bn1.running_var' in i: continue if 'layer1.1.bn1.weight' in i: continue self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)