mpi4py的wrapper

mpi4py的wrapper,第1张

mpi4py的wrapper mpi4py的wrapper

给mpi4py写了个wrapper。包括并行写入,对于numpy array的split并scatter,bcast和gather,基本完成。如果有新想法应该会持续更新,加入新功能

import h5py as h5
from mpi4py import MPI
import time
import numpy as np

mpi_comm = MPI.COMM_WORLD
mpi_size = mpi_comm.Get_size()
mpi_rank = mpi_comm.Get_rank()

def process_size(total_size, rank=mpi_rank, size=mpi_size):
    if rank < int(total_size % size):
        return int(total_size//size + 1)
    else:
        return int(total_size//size)

def ind_end(total_size, rank=mpi_rank, size=mpi_size):
    all_size = [int(total_size//size + 1)]* int(total_size % size)
    #print(total_size, all_size)
    all_size += [int(total_size//size)]* (total_size - int(total_size % size))
    #print(size, all_size)
    return np.cumsum(all_size)[rank]

def ind_start(total_size, rank=mpi_rank, size=mpi_size):
    return ind_end(total_size, rank=rank, size=size) - process_size(total_size, rank=rank, size=size)

def paralle_save_dataset(filename, key, data, axis=0):
    data = np.asarray(data)
    shp = list(data.shape)
    num = shp[axis]
    len_axis = mpi_comm.gather(num, root=0)
    if mpi_rank == 0:
        len_axis = sum(len_axis)
    len_axis = mpi_comm.bcast(len_axis, root=0)
    ist = ind_start(len_axis)
    ied = ind_end(len_axis)
    save_slice = [slice(None,None,None)]*len(shp)
    save_slice[axis] = slice(ist, ied, None)
    save_slice = tuple(save_slice)
    shp[axis] = len_axis
    if mpi_rank == 0:
        with h5.File(filename, 'a') as filein:
            filein.create_dataset(key, shape=shp, dtype=data.dtype)
    for ii in range(mpi_size):
        if ii == mpi_rank:
            for _ in range(10):
                try:
                    #raise IOError
                    with h5.File(filename, 'a') as filein:
                        filein[key][save_slice] = data
                    print('Rank %d save dataset '%s' %d to %d into %s!'%(mpi_rank, key, ist, ied, filename))
                    time.sleep(0.5)
                    break
                except IOError as e:
                    print('%s for rank %d, sleep 0.5 second!'%(e, mpi_rank))
                    time.sleep(0.5)
            else:
                raise IOError('Rank %d save dataset '%s' %d to %d into %s!'%(mpi_rank, key, ist, ied, filename))
        mpi_comm.barrier()

def paralle_save_multi_dataset(filename, key, data):
    for ii in range(mpi_size):
        if ii == mpi_rank:
            for _ in range(10):
                try:
                    #raise IOError
                    with h5.File(filename, 'a') as filein:
                        filein[key] = data
                    print('Rank %d save dataset '%s' into %s!'%(mpi_rank, key, filename))
                    time.sleep(0.5)
                    break
                except IOError as e:
                    print('%s for rank %d, sleep 0.5 second!'%(e, mpi_rank))
                    time.sleep(0.5)
            else:
                raise IOError('Rank %d cannot save %s into %s!'%(mpi_rank, key, filename))
        mpi_comm.barrier()

def split_uneven_array(data, root=0, axis=0):
    '''
    array_split and then scatter the splitted array
    '''
    if mpi_rank == root:
        data = np.asarray(data)
        data = np.array_split(data, mpi_size, axis=axis)
    new_data = mpi_comm.scatter(data, root=root)
    return new_data

def split_even_array(data, root=0, axis=0):
    '''
    array_split and then scatter the splitted array
    '''
    if mpi_rank == root:
        data = np.asarray(data)
        shp = list(data.shape)
        assert shp[axis]%mpi_size==0, 'Axis %d with length %d cannot exactly divided by mpi size %d!'%(axis, shp[axis], mpi_size)
        dtype = data.dtype
        data = np.array_split(data, mpi_size, axis=axis)
        data = np.asarray(data)
    else:
        dtype = None
        shp = None
    dtype = mpi_comm.bcast(dtype, root=root)
    shp = mpi_comm.bcast(shp, root=root)
    shp[axis] = process_size(shp[axis])
    new_data = np.empty(shp, dtype=dtype)
    mpi_comm.Scatter(data, new_data, root=root)
    #new_data = mpi_comm.scatter(data, root=root)
    return new_data

def split_array(data, root=0, axis=0):
    if mpi_rank == root:
        data = np.asarray(data)
        shp = list(data.shape)
        if shp[axis]%mpi_size==0:
            even = True
        else:
            even = False
    else:
        even = None
    even = mpi_comm.bcast(even, root=root)
    if even:
        print('Split and scatter as numpy array!')
        return split_even_array(data, root=root, axis=axis)
    else:
        print('Split and scatter as python object!')
        return split_uneven_array(data, root=root, axis=axis)

def bcast_array(data, root=0):
    if mpi_rank == root:
        data = np.asarray(data)
        dtype = data.dtype
        shp = data.shape
    else:
        dtype = None
        shp = None
    dtype = mpi_comm.bcast(dtype, root=root)
    shp = mpi_comm.bcast(shp, root=root)
    if mpi_rank != root:
        data = np.empty(shp, dtype=dtype)
    mpi_comm.Bcast(data, root=root)
    return data

def gather_array(data, root=0, axis=0, expand_dim=False, ascontiguous=True):
    data = np.asarray(data)
    shp = list(data.shape)
    if expand_dim:
        print('Gather as numpy array and expand axis=%d!'%axis)
        even = True
        new_shp = [mpi_size] + shp
    else:
        all_shp = mpi_comm.gather(shp, root=root)
        all_shp = mpi_comm.bcast(all_shp, root=root)
        shp0 = all_shp[0]
        even = True
        total_len = shp0[axis]
        for ii in all_shp[1:]:
            assert len(shp0) == len(ii), 'Data from different mpi process should have the same number of dimensions! Shapes are: %s'%all_shp
            shp1 = shp0.copy()
            shp2 = ii.copy()
            del shp1[axis]
            del shp2[axis]
            assert np.array_equal(shp1, shp2), 'Data from different mpi process should have the same shape except for the merge axis! Shapes are: %s'%all_shp
            if ii[axis] != shp0[axis]:
                even = False
            total_len += ii[axis]
        if even:
            print('Gather as numpy array!')
            new_shp = shp0.copy()
            del new_shp[axis]
            new_shp = [total_len] + new_shp
        else:
            print('Gather as python object!')
    if even:
        if mpi_rank == root:
            new_data = np.empty(new_shp, dtype=data.dtype)
        else:
            new_data = None
        mpi_comm.Gather(data, new_data, root=root)
        if mpi_rank == root:
            new_data = np.moveaxis(new_data, 0, axis)
            if ascontiguous:
                new_data = np.ascontiguousarray(new_data)
        return new_data
    else:
        new_data = mpi_comm.gather(data, root=root)
        if mpi_rank == root:
            new_data = np.concatenate(new_data, axis=axis)
        return new_data

if __name__ == '__main__':
    
    #if mpi_rank == 1:
    #    with h5.File('test.hdf5', 'w') as filein:
    #        pass
    #    a = np.random.rand(10, 2000, 800)
    #else:
    #    a = None
    #
    #
    #from timeit import timeit
    #def c1():
    #    b = split_even_array(a, root=1, axis=-1)
    #def c2():
    #    b = split_uneven_array(a, root=1, axis=-1)
    #
    #print(mpi_rank, timeit(c2, number=20), 2)
    #print(mpi_rank, timeit(c1, number=20), 1)
    #
    #
    #exit()
    
    #b = split_array(a, root=1, axis=-1)
    ##a = mpi_comm.bcast(a, root=1)
    #a = bcast_array(a, root=1)
    #print(mpi_rank, b.shape)
    #print(np.abs(a[...,a.shape[-1]//mpi_size*mpi_rank:a.shape[-1]//mpi_size*(mpi_rank+1)] - b).max())
    #paralle_save_dataset('test.hdf5', 'a', b, axis=-1)
    #if mpi_rank == 0:
    #    with h5.File('test.hdf5', 'r') as filein:
    #        print(np.abs(a - filein['a'][:]).max())
    
    
    #if mpi_rank == 0:
    #    a = np.random.rand(mpi_size, 30)
    #    with h5.File('test.hdf5', 'w') as filein:
    #        pass
    #else:
    #    a = None
    #a = mpi_comm.scatter(a, root=0)
    #paralle_save_multi_dataset('test.hdf5', '%d'%mpi_rank, a)


    axis = 1
    expand_dim = False
    #a = np.random.rand(10, 3, 20)
    np.random.seed(mpi_rank+1)
    #a = np.random.rand(10, 3, 20)
    a = np.random.rand(10, mpi_rank+1, 20)
    print(np.shape(a), mpi_rank)
    a = gather_array(a, root=1, axis=axis, expand_dim=expand_dim)
    print(np.shape(a), mpi_rank)
    if mpi_rank == 1:
        b = []
        for ii in range(mpi_size):
            np.random.seed(ii+1)
            #b.append(np.random.rand(10, 3, 20))
            b.append(np.random.rand(10, ii+1, 20))
            if expand_dim:
                b[-1] = np.expand_dims(b[-1], axis=axis)
        b = np.concatenate(b, axis=axis)
        print(np.abs(a - b).max())

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5670472.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-16
下一篇 2022-12-16

发表评论

登录后才能评论

评论列表(0条)

保存