一台 linux 服务器上,单机多GPU
single machine, multiple GPUs
参考代码:Multi-GPU Computing with Pytorch (Draft)https://srijithr.gitlab.io/post/pytorchdist/
完整的一套训练测试在参考代码里有,需要稍微修改一下才能运行。
下面是从参考代码copy的一个简单的例子,稍微修改了一下才能运行。
下面的代码保存为ddp_1.py, 命令行运行 python -m torch.distributed.launch ddp_1.py
import os import argparse import torch.multiprocessing as mp import torch import torch.nn as nn import torch.distributed as dist import torch.optim as optim from torch.nn.parallel import DistributedDataParallel as DDP os.environ["CUDA_VISIBLE_DEVICES"] = "6,7" def setup(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' # initialize the process group dist.init_process_group("nccl", rank=rank, world_size=world_size) def cleanup(): dist.destroy_process_group() class ToyModel(nn.Module): def __init__(self): super(ToyModel, self).__init__() self.net1 = nn.Linear(10, 10) self.relu = nn.ReLU() self.net2 = nn.Linear(10, 5) def forward(self, x): return self.net2(self.relu(self.net1(x))) def demo_basic(rank, world_size): print(f"Running basic DDP example on rank {rank}.") setup(rank, world_size) # create model and move it to GPU with id rank model = ToyModel().to(rank) ddp_model = DDP(model, device_ids=[rank]) loss_fn = nn.MSELoss() optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) optimizer.zero_grad() outputs = ddp_model(torch.ones(200, 10)) labels = torch.randn(200, 5).to(rank) loss = loss_fn(outputs, labels) print("Loss is ",loss.item()) loss.backward() optimizer.step() cleanup() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-g', '--gpus', default=2, type=int, help='number of gpus per node') parser.add_argument('--epochs', default=2, type=int, metavar='N', help='number of total epochs to run') parser.add_argument("--local_rank", default=0, type=int, help="node rank for distributed training") args = parser.parse_args() world_size = args.gpus print("We have available ", torch.cuda.device_count(), "GPUs! but using ",world_size," GPUs") ######################################################### mp.spawn(demo_basic, args=(world_size,), nprocs=world_size, join=True) ######################################################### # need environment variables # python -m torch.distributed.launch ddp_1.py
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)