使用Ray Tune自动调参

使用Ray Tune自动调参,第1张

文章目录
  • 前言
  • 一、Ray Tune是什么?
  • 二、使用步骤
    • 1.安装包
    • 2.引入库
    • 2.读入数据(与Ray Tune无关)
    • 3.构建神经网络模型(与Ray Tune无关)
    • 4.模型的训练和测试(与Ray Tune无关)
    • 5.构建“Trainable”
    • 6.超参搜索
  • 总结


前言

本文以PyTorch框架构建的卷积网络模型做分类任务为例介绍如何使用Ray Tune进行自动调参,相关代码引自官网文档。


一、Ray Tune是什么?

Ray Tune是一个用来实验执行和超参数调优的Python包,其中集成了网格搜索、随机搜索、贝叶斯优化搜索(BayesOptSearch)等搜索算法以及Optuna, Hyperopt等优化工具。Ray Tune调参的模型可以是基于PyTorch, XGBoost, TensorFlow或Keras等框架构建的模型。

二、使用步骤 1.安装包

可以只安装 ray 下的 tune 包:

$ pip install -U "ray[tune]"

或安装整个 ray 包:

$ pip install ray
2.引入库

代码如下(示例):

import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from filelock import FileLock
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms
import ray
from ray import tune
from ray.tune.schedulers import ASHAScheduler
2.读入数据(与Ray Tune无关)

示例代码如下:

def load_data(data_dir="./data"):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # We add FileLock here because multiple workers will want to
    # download data, and this may cause overwrites since
    # DataLoader is not threadsafe.
    with FileLock(os.path.expanduser("~/.data.lock")):
        trainset = torchvision.datasets.CIFAR10(
            root=data_dir, train=True, download=True, transform=transform)

        testset = torchvision.datasets.CIFAR10(
            root=data_dir, train=False, download=True, transform=transform)

    return trainset, testset
3.构建神经网络模型(与Ray Tune无关)

示例代码如下:

class Net(nn.Module):
    def __init__(self, l1=120, l2=84):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, l1)
        self.fc2 = nn.Linear(l1, l2)
        self.fc3 = nn.Linear(l2, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
4.模型的训练和测试(与Ray Tune无关)

示例代码如下:

EPOCH_SIZE = 512
TEST_SIZE = 256

def train(model, optimizer, train_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        # We set this just for the example to run quickly.
        if batch_idx * len(data) > EPOCH_SIZE:
            return
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()


def test(model, data_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            # We set this just for the example to run quickly.
            if batch_idx * len(data) > TEST_SIZE:
                break
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    return correct / total
5.构建“Trainable”

“Trainable”是一个需要在Tune.run()函数(见下文)运行时输入的参数,表示每一次的训练、调参及模型保存过程,可以用一个函数构建或用一个类(必须继承自tune.Trainable类)来构建。Trainable接受传入的参数config表示超参搜索空间,每次迭代tune.report()函数返回当前训练的结果,其余部分与正常的模型训练相同。本文以函数方式构建Trainable。

代码如下(示例):

def train_mnist(config):
    # Data Setup
    mnist_transforms = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    train_loader = DataLoader(
        datasets.MNIST("~/data", train=True, download=True, transform=mnist_transforms),
        batch_size=64,
        shuffle=True)
    test_loader = DataLoader(
        datasets.MNIST("~/data", train=False, transform=mnist_transforms),
        batch_size=64,
        shuffle=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = ConvNet()
    model.to(device)

    optimizer = optim.SGD(
        model.parameters(), lr=config["lr"], momentum=config["momentum"])
    for i in range(10):
        train(model, optimizer, train_loader)
        acc = test(model, test_loader)

        # Send the current training result back to Tune
        tune.report(mean_accuracy=acc)

        if i % 5 == 0:
            # This saves the model to the trial directory
            torch.save(model.state_dict(), "./model.pth")

在上面的代码中,每训练一个epoch后,tune.report()函数会返回当前的实验结果,以确定调参方向或者是否提前终止实验。

6.超参搜索

首先用户自定义超参搜索空间,示例代码如下:

# 定义超参搜索空间
search_space = {
    "lr": tune.sample_from(lambda spec: 10 ** (-10 * np.random.rand())),
    "momentum": tune.uniform(0.1, 0.9),
}

开始执行搜索过程,执行tune.run()函数即可,示例代码如下:

analysis = tune.run(
    train_mnist,
    num_samples=20,
    scheduler=ASHAScheduler(metric="mean_accuracy", mode="max"),
    config=search_space,
)

由于在上文中的train_mnist方法中,除config之外没有别的参数传入,因此在run()函数中直接传入train_mnist即可,若在train_mnist()函数中有用户自定义需要传入的其他参数,则使用tune.with_parameters()函数传入参数,示例代码如下:

analysis = tune.run(
    tune.with_parameters(
            train_mnist,
            parameter1,
            parameter2,
            ...
        ),
    num_samples=20,  # 不同的超参实验次数
    scheduler=ASHAScheduler(metric="mean_accuracy", mode="max"),
    config=search_space,  # 超参搜索空间
)

实验结果以及各类配置参数等都可以通过analysis获取,示例如下:

best_trial = analysis.best_trial  # Get best trial
best_config = analysis.best_config  # Get best trial's hyperparameters
best_logdir = analysis.best_logdir  # Get best trial's logdir
best_checkpoint = analysis.best_checkpoint  # Get best trial's best checkpoint
best_result = analysis.best_result  # Get best trial's last results

# 实验结果输出
print("Best trial is:", best_trial)
print("Best config is:", best_config)
print("Best logdir is:", best_logdir)
print("Best checkpoint is:", best_checkpoint)
print("Best result is:", best_result)

总结

本文借助官方文档的例子,结合自己的使用,简单介绍了Ray Tune的用法,相关更详细的介绍和用法(比如更多搜索算法的选择等)见官网文档。

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/726301.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-26
下一篇 2022-04-26

发表评论

登录后才能评论

评论列表(0条)

保存