It happens that when we create multiprocesses say a node with 8XA100 GPUs that each process has a copy of data, which easily goes beyond the memory limit of the whole node. Instead, we could let the rank 0 load data and the rest wait until it is ready. The code is given as follows,
import time import pickle import numpy as np from multiprocessing import shared_memory # dummy data to load data = [{"url": "xxx", "exp": "yyy"}] * 20 # serialize data = [pickle.dumps(d) for d in data] # process 0 if rank % 8 == 0: shm_a = shared_memory.ShareableList(data, name='shared_data') print(pickle.loads(shm_a[0])) # other processes else: while True: try: shm_b = shared_memory.ShareableList(name="shared_data") print(f"Attached to shared memory with name: shared_data") print(pickle.loads(shm_b[0])) break except FileNotFoundError: print("Shared memory not found. It may not be created yet.") time.sleep(100)
No comments:
Post a Comment