Friday, September 6, 2024

Load huge amount of data in process 0 and use it in other processes with shared_memory

It happens that when we create multiprocesses say a node with 8XA100 GPUs that each process has a copy of data, which easily goes beyond the memory limit of the whole node. Instead, we could let the rank 0 load data and the rest wait until it is ready. The code is given as follows,

import time
import pickle
import numpy as np
from multiprocessing import shared_memory

# dummy data to load
data = [{"url": "xxx", "exp": "yyy"}] * 20

# serialize
data = [pickle.dumps(d) for d in data]

# process 0
if rank % 8 == 0:
	shm_a = shared_memory.ShareableList(data, name='shared_data')
    print(pickle.loads(shm_a[0]))
# other processes
else:
    while True:
        try:
            shm_b = shared_memory.ShareableList(name="shared_data")
            print(f"Attached to shared memory with name: shared_data")
            print(pickle.loads(shm_b[0]))
            break
        except FileNotFoundError:
            print("Shared memory not found. It may not be created yet.")
            time.sleep(100)

No comments:

Post a Comment