Python 3.8 introduced a new module multiprocessing.shared_memory that provides
shared memory for direct access across processes. My test shows that it
significantly reduces the memory usage, which also speeds up the program by
reducing the costs of copying and moving things around.1
In this test, I generated a 240MB
numpy.recarray
from a pandas.DataFrame with datetime, int and str typed columns. I used
numpy.recarray because it can preserve the dtype of each column, so that
later I can reconstruct the same array from the buffer of shared memory.
I performed a simple numpy.nansum on the numeric column of the data using two
methods. The first method uses multiprocessing.shared_memory where the 4
spawned processes directly access the data in the shared memory. The second
method passes the data to the spawned processes, which effectively means each
process will have a separate copy of the data.
A quick run of the test code below shows that the first method
based on shared_memory uses minimal memory (peak usage is 0.33MB) and is much
faster (2.09s) than the second one where the entire data is copied and passed
into each process (peak memory usage of 1.8G and takes 216s). More
importantly, the memory usage under the second method is consistently high.
frommultiprocessing.shared_memoryimportSharedMemoryfrommultiprocessing.managersimportSharedMemoryManagerfromconcurrent.futuresimportProcessPoolExecutor,as_completedfrommultiprocessingimportcurrent_process,cpu_count,Processfromdatetimeimportdatetimeimportnumpyasnpimportpandasaspdimporttracemallocimporttimedefwork_with_shared_memory(shm_name,shape,dtype):print(f'With SharedMemory: {current_process()=}')# Locate the shared memory by its nameshm=SharedMemory(shm_name)# Create the np.recarray from the buffer of the shared memorynp_array=np.recarray(shape=shape,dtype=dtype,buf=shm.buf)returnnp.nansum(np_array.val)defwork_no_shared_memory(np_array:np.recarray):print(f'No SharedMemory: {current_process()=}')# Without shared memory, the np_array is copied into the child processreturnnp.nansum(np_array.val)if__name__=="__main__":# Make a large data frame with date, float and character columnsa=[(datetime.today(),1,'string'),(datetime.today(),np.nan,'abc'),]*5000000df=pd.DataFrame(a,columns=['date','val','character_col'])# Convert into numpy recarray to preserve the dtypes (1)np_array=df.to_records(index=False,column_dtypes={'character_col':'S6'})deldfshape,dtype=np_array.shape,np_array.dtypeprint(f"np_array's size={np_array.nbytes/1e6}MB")# With shared memory# Start tracking memory usagetracemalloc.start()start_time=time.time()withSharedMemoryManager()assmm:# Create a shared memory of size np_arry.nbytesshm=smm.SharedMemory(np_array.nbytes)# Create a np.recarray using the buffer of shmshm_np_array=np.recarray(shape=shape,dtype=dtype,buf=shm.buf)# Copy the data into the shared memorynp.copyto(shm_np_array,np_array)# Spawn some processes to do some workwithProcessPoolExecutor(cpu_count())asexe:fs=[exe.submit(work_with_shared_memory,shm.name,shape,dtype)for_inrange(cpu_count())]for_inas_completed(fs):pass# Check memory usagecurrent,peak=tracemalloc.get_traced_memory()print(f"Current memory usage {current/1e6}MB; Peak: {peak/1e6}MB")print(f'Time elapsed: {time.time()-start_time:.2f}s')tracemalloc.stop()# Without shared memorytracemalloc.start()start_time=time.time()withProcessPoolExecutor(cpu_count())asexe:fs=[exe.submit(work_no_shared_memory,np_array)for_inrange(cpu_count())]for_inas_completed(fs):pass# Check memory usagecurrent,peak=tracemalloc.get_traced_memory()print(f"Current memory usage {current/1e6}MB; Peak: {peak/1e6}MB")print(f'Time elapsed: {time.time()-start_time:.2f}s')tracemalloc.stop()
A very important note about using multiprocessing.shared_memory, as at June
2020, is that the numpy.ndarray cannot have a dtype=dtype('O'). That is, the
dtype cannot be dtype(object). If it is, there will be a segmentation fault
when child processes try to access the shared memory and dereference it. It happens when the column contains strings.
To solve this problem, you need to specify the dtype in df.to_records(). For example:
Here, we specify that character_col contains strings of length 6. If it contains Unicode, we can use 'U6' instead. Longer strings will then be truncated at the specified length. As such, there won't be anymore segfault.