给mpi4py写了个wrapper。包括并行写入,对于numpy array的split并scatter,bcast和gather,基本完成。如果有新想法应该会持续更新,加入新功能。
import h5py as h5 from mpi4py import MPI import time import numpy as np mpi_comm = MPI.COMM_WORLD mpi_size = mpi_comm.Get_size() mpi_rank = mpi_comm.Get_rank() def process_size(total_size, rank=mpi_rank, size=mpi_size): if rank < int(total_size % size): return int(total_size//size + 1) else: return int(total_size//size) def ind_end(total_size, rank=mpi_rank, size=mpi_size): all_size = [int(total_size//size + 1)]* int(total_size % size) #print(total_size, all_size) all_size += [int(total_size//size)]* (total_size - int(total_size % size)) #print(size, all_size) return np.cumsum(all_size)[rank] def ind_start(total_size, rank=mpi_rank, size=mpi_size): return ind_end(total_size, rank=rank, size=size) - process_size(total_size, rank=rank, size=size) def paralle_save_dataset(filename, key, data, axis=0): data = np.asarray(data) shp = list(data.shape) num = shp[axis] len_axis = mpi_comm.gather(num, root=0) if mpi_rank == 0: len_axis = sum(len_axis) len_axis = mpi_comm.bcast(len_axis, root=0) ist = ind_start(len_axis) ied = ind_end(len_axis) save_slice = [slice(None,None,None)]*len(shp) save_slice[axis] = slice(ist, ied, None) save_slice = tuple(save_slice) shp[axis] = len_axis if mpi_rank == 0: with h5.File(filename, 'a') as filein: filein.create_dataset(key, shape=shp, dtype=data.dtype) for ii in range(mpi_size): if ii == mpi_rank: for _ in range(10): try: #raise IOError with h5.File(filename, 'a') as filein: filein[key][save_slice] = data print('Rank %d save dataset '%s' %d to %d into %s!'%(mpi_rank, key, ist, ied, filename)) time.sleep(0.5) break except IOError as e: print('%s for rank %d, sleep 0.5 second!'%(e, mpi_rank)) time.sleep(0.5) else: raise IOError('Rank %d save dataset '%s' %d to %d into %s!'%(mpi_rank, key, ist, ied, filename)) mpi_comm.barrier() def paralle_save_multi_dataset(filename, key, data): for ii in range(mpi_size): if ii == mpi_rank: for _ in range(10): try: #raise IOError with h5.File(filename, 'a') as filein: filein[key] = data print('Rank %d save dataset '%s' into %s!'%(mpi_rank, key, filename)) time.sleep(0.5) break except IOError as e: print('%s for rank %d, sleep 0.5 second!'%(e, mpi_rank)) time.sleep(0.5) else: raise IOError('Rank %d cannot save %s into %s!'%(mpi_rank, key, filename)) mpi_comm.barrier() def split_uneven_array(data, root=0, axis=0): ''' array_split and then scatter the splitted array ''' if mpi_rank == root: data = np.asarray(data) data = np.array_split(data, mpi_size, axis=axis) new_data = mpi_comm.scatter(data, root=root) return new_data def split_even_array(data, root=0, axis=0): ''' array_split and then scatter the splitted array ''' if mpi_rank == root: data = np.asarray(data) shp = list(data.shape) assert shp[axis]%mpi_size==0, 'Axis %d with length %d cannot exactly divided by mpi size %d!'%(axis, shp[axis], mpi_size) dtype = data.dtype data = np.array_split(data, mpi_size, axis=axis) data = np.asarray(data) else: dtype = None shp = None dtype = mpi_comm.bcast(dtype, root=root) shp = mpi_comm.bcast(shp, root=root) shp[axis] = process_size(shp[axis]) new_data = np.empty(shp, dtype=dtype) mpi_comm.Scatter(data, new_data, root=root) #new_data = mpi_comm.scatter(data, root=root) return new_data def split_array(data, root=0, axis=0): if mpi_rank == root: data = np.asarray(data) shp = list(data.shape) if shp[axis]%mpi_size==0: even = True else: even = False else: even = None even = mpi_comm.bcast(even, root=root) if even: print('Split and scatter as numpy array!') return split_even_array(data, root=root, axis=axis) else: print('Split and scatter as python object!') return split_uneven_array(data, root=root, axis=axis) def bcast_array(data, root=0): if mpi_rank == root: data = np.asarray(data) dtype = data.dtype shp = data.shape else: dtype = None shp = None dtype = mpi_comm.bcast(dtype, root=root) shp = mpi_comm.bcast(shp, root=root) if mpi_rank != root: data = np.empty(shp, dtype=dtype) mpi_comm.Bcast(data, root=root) return data def gather_array(data, root=0, axis=0, expand_dim=False, ascontiguous=True): data = np.asarray(data) shp = list(data.shape) if expand_dim: print('Gather as numpy array and expand axis=%d!'%axis) even = True new_shp = [mpi_size] + shp else: all_shp = mpi_comm.gather(shp, root=root) all_shp = mpi_comm.bcast(all_shp, root=root) shp0 = all_shp[0] even = True total_len = shp0[axis] for ii in all_shp[1:]: assert len(shp0) == len(ii), 'Data from different mpi process should have the same number of dimensions! Shapes are: %s'%all_shp shp1 = shp0.copy() shp2 = ii.copy() del shp1[axis] del shp2[axis] assert np.array_equal(shp1, shp2), 'Data from different mpi process should have the same shape except for the merge axis! Shapes are: %s'%all_shp if ii[axis] != shp0[axis]: even = False total_len += ii[axis] if even: print('Gather as numpy array!') new_shp = shp0.copy() del new_shp[axis] new_shp = [total_len] + new_shp else: print('Gather as python object!') if even: if mpi_rank == root: new_data = np.empty(new_shp, dtype=data.dtype) else: new_data = None mpi_comm.Gather(data, new_data, root=root) if mpi_rank == root: new_data = np.moveaxis(new_data, 0, axis) if ascontiguous: new_data = np.ascontiguousarray(new_data) return new_data else: new_data = mpi_comm.gather(data, root=root) if mpi_rank == root: new_data = np.concatenate(new_data, axis=axis) return new_data if __name__ == '__main__': #if mpi_rank == 1: # with h5.File('test.hdf5', 'w') as filein: # pass # a = np.random.rand(10, 2000, 800) #else: # a = None # # #from timeit import timeit #def c1(): # b = split_even_array(a, root=1, axis=-1) #def c2(): # b = split_uneven_array(a, root=1, axis=-1) # #print(mpi_rank, timeit(c2, number=20), 2) #print(mpi_rank, timeit(c1, number=20), 1) # # #exit() #b = split_array(a, root=1, axis=-1) ##a = mpi_comm.bcast(a, root=1) #a = bcast_array(a, root=1) #print(mpi_rank, b.shape) #print(np.abs(a[...,a.shape[-1]//mpi_size*mpi_rank:a.shape[-1]//mpi_size*(mpi_rank+1)] - b).max()) #paralle_save_dataset('test.hdf5', 'a', b, axis=-1) #if mpi_rank == 0: # with h5.File('test.hdf5', 'r') as filein: # print(np.abs(a - filein['a'][:]).max()) #if mpi_rank == 0: # a = np.random.rand(mpi_size, 30) # with h5.File('test.hdf5', 'w') as filein: # pass #else: # a = None #a = mpi_comm.scatter(a, root=0) #paralle_save_multi_dataset('test.hdf5', '%d'%mpi_rank, a) axis = 1 expand_dim = False #a = np.random.rand(10, 3, 20) np.random.seed(mpi_rank+1) #a = np.random.rand(10, 3, 20) a = np.random.rand(10, mpi_rank+1, 20) print(np.shape(a), mpi_rank) a = gather_array(a, root=1, axis=axis, expand_dim=expand_dim) print(np.shape(a), mpi_rank) if mpi_rank == 1: b = [] for ii in range(mpi_size): np.random.seed(ii+1) #b.append(np.random.rand(10, 3, 20)) b.append(np.random.rand(10, ii+1, 20)) if expand_dim: b[-1] = np.expand_dims(b[-1], axis=axis) b = np.concatenate(b, axis=axis) print(np.abs(a - b).max())