2222from common .capture_utils import export_vars
2323from transformers import AutoTokenizer
2424
25+ from ucm .logger import init_logger
26+
2527_test_functions = {}
28+ logger = init_logger (__name__ )
2629
2730
2831def _run_test_in_spawn_process (test_id , args , kwargs , result_queue , error_queue ):
@@ -85,15 +88,11 @@ def wrapper(*args, **kwargs):
8588 from vllm .distributed import cleanup_dist_env_and_memory
8689 from vllm .engine .arg_utils import EngineArgs
8790
88- from ucm .logger import init_logger
89-
9091 VLLM_AVAILABLE = True
9192except ImportError :
9293 VLLM_AVAILABLE = False
9394 pytest .skip ("vLLM not available" , allow_module_level = True )
9495
95- logger = init_logger (__name__ )
96-
9796
9897@contextlib .contextmanager
9998def build_llm_with_uc (
@@ -125,9 +124,7 @@ def build_llm_with_uc(
125124 "gpu_memory_utilization" : 0.3 , # Reduced to prevent OOM after Phase 1
126125 "max_num_batched_tokens" : max_num_batched_tokens ,
127126 "block_size" : 128 ,
128- "enforce_eager" : llm_kwargs .get (
129- "enforce_eager" , True
130- ), # Allow override via llm_kwargs, default True
127+ "enforce_eager" : llm_kwargs .get ("enforce_eager" , True ),
131128 "trust_remote_code" : True ,
132129 "enable_prefix_caching" : enable_prefix_caching ,
133130 "tensor_parallel_size" : tensor_parallel_size ,
@@ -142,7 +139,6 @@ def build_llm_with_uc(
142139 finally :
143140 logger .info ("LLM engine is exiting" )
144141 del llm
145- # Use vLLM's cleanup function to properly release distributed resources
146142 cleanup_dist_env_and_memory (shutdown_ray = False )
147143
148144
@@ -176,7 +172,6 @@ def _run_phase_in_subprocess(
176172 error_queue: Queue to put errors
177173 """
178174 try :
179- # Clear any GPU memory in subprocess before starting
180175 if torch .cuda .is_available ():
181176 torch .cuda .empty_cache ()
182177 torch .cuda .synchronize ()
@@ -185,20 +180,18 @@ def _run_phase_in_subprocess(
185180 torch .npu .synchronize ()
186181 gc .collect ()
187182
188- # Recreate SamplingParams in subprocess
189183 sampling_params = SamplingParams (** sampling_params_dict )
190184
191185 with build_llm_with_uc (
192186 model_path = model_path ,
193187 ucm_config = ucm_config ,
194188 enable_prefix_caching = enable_prefix_caching ,
195- gpu_memory_utilization = 0.3 , # Lower utilization for subprocess
189+ gpu_memory_utilization = 0.3 ,
196190 max_num_batched_tokens = max_num_batched_tokens ,
197191 enforce_eager = enforce_eager ,
198192 ) as llm :
199193 outputs = run_inference (llm , prompts , sampling_params , phase_description )
200194
201- # Return all outputs
202195 result_queue .put (outputs )
203196 except Exception as e :
204197 import traceback
@@ -316,7 +309,7 @@ def run_inference(
316309 generated_texts .append (generated_text )
317310
318311 if description :
319- print (f"[INFO] { description } completed" )
312+ logger . info (f"{ description } completed" )
320313
321314 return generated_texts
322315
@@ -377,13 +370,13 @@ def test_offline_accuracy_hbm_ssd_mixed(
377370
378371 try :
379372 test_prompt , standard_answers = load_prompt_from_file ()
380- print (
381- f"[INFO] Loaded prompt from prompt.json (length: { len (test_prompt )} chars)"
373+ logger . info (
374+ f"Loaded prompt from prompt.json (length: { len (test_prompt )} chars)"
382375 )
383376 if standard_answers :
384- print (f"[INFO] Standard answers: { standard_answers } " )
377+ logger . info (f"Standard answers: { standard_answers } " )
385378 else :
386- print (f"[INFO] No standard answers found in prompt.json" )
379+ logger . info (f"No standard answers found in prompt.json" )
387380 except Exception as e :
388381 pytest .skip (f"Failed to load prompt from prompt.json: { e } " )
389382
@@ -424,19 +417,19 @@ def test_offline_accuracy_hbm_ssd_mixed(
424417 ignore_eos = False ,
425418 )
426419
427- print (f"\n [INFO] ===== HBM + SSD Mixed Accuracy Test =====" )
428- print (f"[INFO] Model: { model_path } " )
429- print (f"[INFO] Full prompt length: { len (test_prompt )} chars" )
430- print (f"[INFO] Max tokens: { max_tokens } " )
431- print (f"[INFO] Temperature: 0.0 (deterministic)" )
432- print (f"[INFO] UCM storage: { ucm_storage_dir } " )
433- print (f"[INFO] Prompt split ratio: { prompt_split_ratio } " )
434- print (f"[INFO] Enforce eager: { enforce_eager } " )
435- print (f"[INFO] Max num batched tokens: { max_num_batched_tokens } " )
420+ logger . info (f"\n ===== HBM + SSD Mixed Accuracy Test =====" )
421+ logger . info (f"Model: { model_path } " )
422+ logger . info (f"Full prompt length: { len (test_prompt )} chars" )
423+ logger . info (f"Max tokens: { max_tokens } " )
424+ logger . info (f"Temperature: 0.0 (deterministic)" )
425+ logger . info (f"UCM storage: { ucm_storage_dir } " )
426+ logger . info (f"Prompt split ratio: { prompt_split_ratio } " )
427+ logger . info (f"Enforce eager: { enforce_eager } " )
428+ logger . info (f"Max num batched tokens: { max_num_batched_tokens } " )
436429
437430 # ===== Phase 1: Disable HBM PC, save KV cache to SSD and load (baseline) =====
438431 # Run Phase 1 in a separate subprocess to ensure GPU memory is fully released
439- print (f"\n [INFO] ===== Phase 1: Save KV Cache to SSD And Load (Baseline) =====" )
432+ logger . info (f"\n ===== Phase 1: Save KV Cache to SSD And Load (Baseline) =====" )
440433
441434 if torch .cuda .is_available ():
442435 torch .cuda .empty_cache ()
@@ -445,7 +438,7 @@ def test_offline_accuracy_hbm_ssd_mixed(
445438 torch .npu .empty_cache ()
446439 torch .npu .synchronize ()
447440 gc .collect ()
448- time .sleep (2 ) # Wait a bit for GPU memory to be released
441+ time .sleep (2 )
449442
450443 ctx = multiprocessing .get_context ("spawn" )
451444 result_queue = ctx .Queue ()
@@ -467,7 +460,7 @@ def test_offline_accuracy_hbm_ssd_mixed(
467460 [
468461 formatted_full_prompt ,
469462 formatted_full_prompt ,
470- ], # Phase 1: send same prompt twice
463+ ],
471464 sampling_params_dict ,
472465 False , # enable_prefix_caching=False for Phase 1
473466 enforce_eager ,
@@ -500,34 +493,27 @@ def test_offline_accuracy_hbm_ssd_mixed(
500493 f"Phase 1 failed in subprocess with exit code { process .exitcode } "
501494 )
502495
503- # Get results from subprocess
504496 if result_queue .empty ():
505497 raise RuntimeError ("Phase 1 subprocess completed but no result in queue" )
506498 phase1_outputs = result_queue .get ()
507499 phase1_1_output = phase1_outputs [0 ] # Phase 1.1: SSD save
508500 phase1_2_output = phase1_outputs [1 ] # Phase 1.2: SSD load
509- print (
510- f"[INFO] Phase 1 completed in subprocess, GPU memory should be fully released"
511- )
512- print (f"[INFO] Phase 1.1 output: { phase1_1_output } " )
513- print (f"[INFO] Phase 1.2 output: { phase1_2_output } " )
501+ logger .info (f"Phase 1 completed in subprocess, GPU memory should be fully released" )
502+ logger .info (f"Phase 1.1 output: { phase1_1_output } " )
503+ logger .info (f"Phase 1.2 output: { phase1_2_output } " )
514504
515505 # ===== Phase 2: Enable HBM PC, test HBM + SSD mixed hit =====
516506 # Run Phase 2 in a separate subprocess to ensure GPU memory is fully released
517- print (f"\n [INFO] ===== Phase 2: HBM + SSD Mixed Hit Test =====" )
507+ logger . info (f"\n ===== Phase 2: HBM + SSD Mixed Hit Test =====" )
518508
519- # Clear GPU memory in main process before starting subprocess
520- print (
521- f"[INFO] Clearing GPU memory in main process before starting Phase 2 subprocess..."
522- )
523509 if torch .cuda .is_available ():
524510 torch .cuda .empty_cache ()
525511 torch .cuda .synchronize ()
526512 elif hasattr (torch , "npu" ) and torch .npu .is_available ():
527513 torch .npu .empty_cache ()
528514 torch .npu .synchronize ()
529515 gc .collect ()
530- time .sleep (2 ) # Wait a bit for GPU memory to be released
516+ time .sleep (2 )
531517
532518 ctx = multiprocessing .get_context ("spawn" )
533519 result_queue_2 = ctx .Queue ()
@@ -574,28 +560,19 @@ def test_offline_accuracy_hbm_ssd_mixed(
574560 f"Phase 2 failed in subprocess with exit code { process_2 .exitcode } "
575561 )
576562
577- # Get results from subprocess
578563 if result_queue_2 .empty ():
579564 raise RuntimeError ("Phase 2 subprocess completed but no result in queue" )
580565 phase2_outputs = result_queue_2 .get ()
581- phase2_partial_output = phase2_outputs [
582- 0
583- ] # Output from partial prompt (for reference)
584- phase2_full_output = phase2_outputs [
585- 1
586- ] # Output from full prompt (this is what we compare)
587- print (
588- f"[INFO] Phase 2 completed in subprocess, GPU memory should be fully released"
589- )
590- print (f"[INFO] Phase 2.1 output: { phase2_partial_output } " )
591- print (f"[INFO] Phase 2.2 output: { phase2_full_output } " )
566+ phase2_partial_output = phase2_outputs [0 ]
567+ phase2_full_output = phase2_outputs [1 ]
568+ logger .info (f"Phase 2 completed in subprocess, GPU memory should be fully released" )
569+ logger .info (f"[INFO] Phase 2.1 output: { phase2_partial_output } " )
570+ logger .info (f"[INFO] Phase 2.2 output: { phase2_full_output } " )
592571
593- # ===== Compare outputs =====
594- print (f"\n [INFO] ===== Accuracy Test Results =====" )
572+ logger .info (f"\n [INFO] ===== Accuracy Test Results =====" )
595573
596574 def normalize_text (text : str ) -> str :
597575 """Normalize text for comparison by replacing similar punctuation."""
598- # Replace full-width punctuation with half-width for comparison
599576 text = text .replace ("," , "," )
600577 text = text .replace ("。" , "." )
601578 text = text .replace ("!" , "!" )
@@ -610,39 +587,39 @@ def normalize_text(text: str) -> str:
610587 phase1_2_output
611588 )
612589 if not phase1_identical :
613- print (f"\n [WARNING] ===== Phase 1: SSD Load Accuracy Test (Exact Match) =====" )
614- print (
615- f"[WARNING] Phase 1.1 (SSD save) output differs from Phase 1.2 (SSD load) output!"
590+ logger . warning (f"\n ===== Phase 1: SSD Load Accuracy Test (Exact Match) =====" )
591+ logger . warning (
592+ f"Phase 1.1 (SSD save) output differs from Phase 1.2 (SSD load) output!"
616593 )
617- print (f"[WARNING] Phase 1.1 output:\n { phase1_1_output } " )
618- print (f"[WARNING] Phase 1.2 output:\n { phase1_2_output } " )
594+ logger . warning (f"Phase 1.1 output:\n { phase1_1_output } " )
595+ logger . warning (f"Phase 1.2 output:\n { phase1_2_output } " )
619596 if phase1_normalized_identical :
620- print (
621- f"[INFO] But normalized outputs are identical (punctuation difference only)"
597+ logger . info (
598+ f"But normalized outputs are identical (punctuation difference only)"
622599 )
623600
624601 phase2_identical = phase1_1_output == phase2_full_output
625602 phase2_normalized_identical = normalize_text (phase1_1_output ) == normalize_text (
626603 phase2_full_output
627604 )
628605 if not phase2_identical :
629- print (
630- f"\n [WARNING] ===== Phase 2: HBM + SSD Mixed Accuracy Test (Exact Match) ====="
606+ logger . warning (
607+ f"\n ===== Phase 2: HBM + SSD Mixed Accuracy Test (Exact Match) ====="
631608 )
632- print (
633- f"[WARNING] Phase 1.1 (SSD save) output differs from Phase 2.2 (HBM + SSD mixed) output!"
609+ logger . warning (
610+ f"Phase 1.1 (SSD save) output differs from Phase 2.2 (HBM + SSD mixed) output!"
634611 )
635- print (f"[WARNING] Phase 1.1 output:\n { phase1_1_output } " )
636- print (f"[WARNING] Phase 2.2 output:\n { phase2_full_output } " )
612+ logger . warning (f"Phase 1.1 output:\n { phase1_1_output } " )
613+ logger . warning (f"Phase 2.2 output:\n { phase2_full_output } " )
637614 if phase2_normalized_identical :
638- print (
639- f"[INFO] But normalized outputs are identical (punctuation difference only)"
615+ logger . info (
616+ f"But normalized outputs are identical (punctuation difference only)"
640617 )
641- print (
642- f"[INFO] This is likely due to numerical precision differences in KV cache loading"
618+ logger . info (
619+ f"This is likely due to numerical precision differences in KV cache loading"
643620 )
644- print (f"[INFO] Normalized Phase 1.1: { normalize_text (phase1_1_output )} " )
645- print (f"[INFO] Normalized Phase 2.2: { normalize_text (phase2_full_output )} " )
621+ logger . info (f"Normalized Phase 1.1: { normalize_text (phase1_1_output )} " )
622+ logger . info (f"Normalized Phase 2.2: { normalize_text (phase2_full_output )} " )
646623
647624 # Assert outputs are identical (using normalized comparison for punctuation differences)
648625 # Note: Small numerical precision differences in KV cache loading can cause
@@ -663,13 +640,13 @@ def normalize_text(text: str) -> str:
663640 )
664641
665642 if phase2_identical :
666- print (f"\n [INFO] ✓ HBM + SSD mixed accuracy test passed: outputs are identical" )
643+ logger . info (f"\n ✓ HBM + SSD mixed accuracy test passed: outputs are identical" )
667644 else :
668- print (
669- f"\n [INFO] ✓ HBM + SSD mixed accuracy test passed: normalized outputs are identical"
645+ logger . info (
646+ f"\n ✓ HBM + SSD mixed accuracy test passed: normalized outputs are identical"
670647 )
671- print (
672- f"[INFO] Note: Punctuation difference detected (likely due to numerical precision in KV cache)"
648+ logger . info (
649+ f"Note: Punctuation difference detected (likely due to numerical precision in KV cache)"
673650 )
674651
675652 value_lists = {
0 commit comments