- 早停的目的与流程
- 早停策略
- pytorch使用示例
- 参考网站
目的:防止模型过拟合,由于深度学习模型可以无限迭代下去,因此希望在即将过拟合时、或训练效果微乎其微时停止训练。
流程如下:
- 将数据集切分为三部分:训练数据(数据量最多),验证数据(数据量最少,一般10%-20%左右即可),测试数据(数据量第二多)
- 模型通过训练集,得到训练集的 L o s s t r a i n Loss_{train} Losstrain
- 然后模型通过验证集,此时不是训练,不需要反向传播,得到验证集的 L o s s v a l i d Loss_{valid} Lossvalid
- 早停策略通过 L o s s t r a i n Loss_{train} Losstrain与 L o s s v a l i d Loss_{valid} Lossvalid来判断,是否需要中断训练
早停策略,我们都是拿着验证集和训练集来说事:
-
常用的策略:
♣ 如果训练集loss与验证集loss连续几次下降不明显,就早停
♣ 验证集loss连续n次不降反升则早停。(通常是3次)
-
根据泛化损失卡阈值的策略
♣ 将目前已有的验证集的最小loss记录下来,看当前的验证集loss与最小的loss之间的差距
♣ 通过公式: G L ( t ) = 100 ⋅ ( E v a ( t ) E o p t ( t ) − 1 ) {GL(t)} = 100 \cdot \big( \frac{E_{va}(t)}{E_{opt}(t)} - 1) GL(t)=100⋅(Eopt(t)Eva(t)−1)计算一个值,并称之为泛化损失
♣ 当这个泛化损失超过阈值的时候停止训练 -
根据度量进展卡阈值的策略:我们通常假设过拟合会出现在训练集loss很难下降的时候,此时模型继续强行下降loss会导致过拟合的风险,因此,
♣ 定一个迭代周期,为训练k次,判断本次迭代的时候平均训练loss比最小训练loss大多少
(公式: P k ( t ) = 1000 ⋅ ( ∑ t ′ = t − k + 1 t E t r ( t ′ ) k ⋅ m i n t ′ = t − k + 1 t E t r ( t ′ ) − 1 ) P_k(t) = 1000 \cdot \big( \frac{ \sum_{t' = t-k+1}^t E_{tr}(t') }{ k \cdot min_{t' = t-k+1}^t E_{tr}(t') } -1 \big) Pk(t)=1000⋅(k⋅mint′=t−k+1tEtr(t′)∑t′=t−k+1tEtr(t′)−1))
♣ 然后结合上面的泛化损失,计算 G L ( t ) P k ( t ) \frac{GL(t)}{P_k(t)} Pk(t)GL(t)
♣ 当这个值大于一个阈值时,停止训练
我们参考https://github.com/Bjarten/early-stopping-pytorch这个项目的早停策略
EarlyStopping
类在:https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py
结合深度学习的示例如下:
import torch
import torch.nn as nn
import os
from sklearn.datasets import make_regression
from torch.utils.data import Dataset, DataLoader
import numpy as np
class EarlyStopping: # 这个是别人写的工具类,大家可以把它放到别的地方
"""Early stops the training if validation loss doesn't improve after a given patience."""
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
Default: 7
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0
path (str): Path for the checkpoint to be saved to.
Default: 'checkpoint.pt'
trace_func (function): trace print function.
Default: print
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
self.path = path
self.trace_func = trace_func
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
'''Saves model when validation loss decrease.'''
if self.verbose:
self.trace_func(
f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), self.path)
self.val_loss_min = val_loss
class MyDataSet(Dataset): # 定义数据格式
def __init__(self, train_x, train_y, sample):
self.train_x = train_x
self.train_y = train_y
self._len = sample
def __getitem__(self, item: int):
return self.train_x[item], self.train_y[item]
def __len__(self):
return self._len
def get_data():
"""构造数据"""
sample = 20000
data_x, data_y = make_regression(n_samples=sample, n_features=100) # 生成数据集
train_data_x = data_x[:int(sample * 0.8)]
train_data_y = data_y[:int(sample * 0.8)]
valid_data_x = data_x[int(sample * 0.8):]
valid_data_y = data_y[int(sample * 0.8):]
train_loader = DataLoader(MyDataSet(train_data_x, train_data_y, len(train_data_x)), batch_size=10)
valid_loader = DataLoader(MyDataSet(valid_data_x, valid_data_y, len(valid_data_x)), batch_size=10)
return train_loader, valid_loader
class LinearRegressionModel(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearRegressionModel, self).__init__()
self.linear = nn.Linear(input_dim, output_dim) # 输入的个数,输出的个数
def forward(self, x):
out = self.linear(x)
return out
def main():
train_loader, valid_loader = get_data()
model = LinearRegressionModel(input_dim=100, output_dim=1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
early_stopping = EarlyStopping(patience=4, verbose=True) # 早停
# 开始训练模型
for epoch in range(1000):
# 正常的训练
print("迭代第{}次".format(epoch))
model.train()
train_loss_list = []
for train_x, train_y in train_loader:
optimizer.zero_grad()
outputs = model(train_x.float())
loss = criterion(outputs.flatten(), train_y.float())
loss.backward()
train_loss_list.append(loss.item())
optimizer.step()
print("训练loss:{}".format(np.average(train_loss_list)))
# 早停策略判断
model.eval()
with torch.no_grad():
valid_loss_list = []
for valid_x, valid_y in valid_loader:
outputs = model(valid_x.float())
loss = criterion(outputs.flatten(), valid_y.float())
valid_loss_list.append(loss.item())
avg_valid_loss = np.average(valid_loss_list)
print("验证集loss:{}".format(avg_valid_loss))
early_stopping(avg_valid_loss, model)
if early_stopping.early_stop:
print("此时早停!")
break
if __name__ == '__main__':
main()
参考网站
- 深度学习技巧之Early Stopping(早停法):https://www.datalearner.com/blog/1051537860479157
- early-stopping-pytorch:https://github.com/Bjarten/early-stopping-pytorch
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)