@@ -300,44 +300,46 @@ def target_wrapper(*args, **kwargs):
300300 )
301301
302302
303- def _download_files_worker(files_to_download, other_params, chunks, bucket_type):
304- # For regional buckets, a new client must be created for each process.
305- # For zonal, the same is done for consistency.
303+ # --- Global Variables for Worker Process ---
304+ worker_loop = None
305+ worker_client = None
306+ worker_json_client = None
307+
308+
309+ def _worker_init(bucket_type):
310+ """Initializes a persistent event loop and client for each worker process."""
311+ global worker_loop, worker_client, worker_json_client
306312 if bucket_type == "zonal":
307- loop = asyncio.new_event_loop()
308- asyncio.set_event_loop(loop)
309- client = loop.run_until_complete(create_client())
310- try:
311- # download_files_using_mrd_multi_coro returns max latency of coros
312- result = download_files_using_mrd_multi_coro(
313- loop, client, files_to_download, other_params, chunks
314- )
315- finally:
316- tasks = asyncio.all_tasks(loop=loop)
317- for task in tasks:
318- task.cancel()
319- loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
320- loop.close()
321- return result
313+ worker_loop = asyncio.new_event_loop()
314+ asyncio.set_event_loop(worker_loop)
315+ worker_client = worker_loop.run_until_complete(create_client())
322316 else: # regional
323317 from google.cloud import storage
324318
325- json_client = storage.Client()
319+ worker_json_client = storage.Client()
320+
321+
322+ def _download_files_worker(files_to_download, other_params, chunks, bucket_type):
323+ if bucket_type == "zonal":
324+ # The loop and client are already initialized in _worker_init.
325+ # download_files_using_mrd_multi_coro returns max latency of coros
326+ return download_files_using_mrd_multi_coro(
327+ worker_loop, worker_client, files_to_download, other_params, chunks
328+ )
329+ else: # regional
326330 # download_files_using_json_multi_threaded returns max latency of threads
327331 return download_files_using_json_multi_threaded(
328- None, json_client , files_to_download, other_params, chunks
332+ None, worker_json_client , files_to_download, other_params, chunks
329333 )
330334
331335
332- def download_files_mp_mc_wrapper(files_names, params, chunks, bucket_type):
333- num_processes = params.num_processes
336+ def download_files_mp_mc_wrapper(pool, files_names, params, chunks, bucket_type):
334337 num_coros = params.num_coros # This is n, number of files per process
335338
336339 # Distribute filenames to processes
337340 filenames_per_process = [
338341 files_names[i : i + num_coros] for i in range(0, len(files_names), num_coros)
339342 ]
340-
341343 args = [
342344 (
343345 filenames,
@@ -348,10 +350,7 @@ def download_files_mp_mc_wrapper(files_names, params, chunks, bucket_type):
348350 for filenames in filenames_per_process
349351 ]
350352
351- ctx = multiprocessing.get_context("spawn")
352- with ctx.Pool(processes=num_processes) as pool:
353- results = pool.starmap(_download_files_worker, args)
354-
353+ results = pool.starmap(_download_files_worker, args)
355354 return max(results)
356355
357356
@@ -386,10 +385,16 @@ def test_downloads_multi_proc_multi_coro(
386385 logging.info("randomizing chunks")
387386 random.shuffle(chunks)
388387
388+ ctx = multiprocessing.get_context("spawn")
389+ pool = ctx.Pool(
390+ processes=params.num_processes,
391+ initializer=_worker_init,
392+ initargs=(params.bucket_type,),
393+ )
389394 output_times = []
390395
391396 def target_wrapper(*args, **kwargs):
392- result = download_files_mp_mc_wrapper(*args, **kwargs)
397+ result = download_files_mp_mc_wrapper(pool, *args, **kwargs)
393398 output_times.append(result)
394399 return output_times
395400
@@ -407,6 +412,8 @@ def target_wrapper(*args, **kwargs):
407412 ),
408413 )
409414 finally:
415+ pool.close()
416+ pool.join()
410417 publish_benchmark_extra_info(benchmark, params, true_times=output_times)
411418 publish_resource_metrics(benchmark, m)
412419
0 commit comments