Skip to content

Commit f7c2e5e

Browse files
committed
finish accuracy test.
1 parent de61649 commit f7c2e5e

File tree

1 file changed

+57
-80
lines changed

1 file changed

+57
-80
lines changed

test/suites/E2E/test_uc_accuracy_offline.py

Lines changed: 57 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
from common.capture_utils import export_vars
2323
from transformers import AutoTokenizer
2424

25+
from ucm.logger import init_logger
26+
2527
_test_functions = {}
28+
logger = init_logger(__name__)
2629

2730

2831
def _run_test_in_spawn_process(test_id, args, kwargs, result_queue, error_queue):
@@ -85,15 +88,11 @@ def wrapper(*args, **kwargs):
8588
from vllm.distributed import cleanup_dist_env_and_memory
8689
from vllm.engine.arg_utils import EngineArgs
8790

88-
from ucm.logger import init_logger
89-
9091
VLLM_AVAILABLE = True
9192
except ImportError:
9293
VLLM_AVAILABLE = False
9394
pytest.skip("vLLM not available", allow_module_level=True)
9495

95-
logger = init_logger(__name__)
96-
9796

9897
@contextlib.contextmanager
9998
def build_llm_with_uc(
@@ -125,9 +124,7 @@ def build_llm_with_uc(
125124
"gpu_memory_utilization": 0.3, # Reduced to prevent OOM after Phase 1
126125
"max_num_batched_tokens": max_num_batched_tokens,
127126
"block_size": 128,
128-
"enforce_eager": llm_kwargs.get(
129-
"enforce_eager", True
130-
), # Allow override via llm_kwargs, default True
127+
"enforce_eager": llm_kwargs.get("enforce_eager", True),
131128
"trust_remote_code": True,
132129
"enable_prefix_caching": enable_prefix_caching,
133130
"tensor_parallel_size": tensor_parallel_size,
@@ -142,7 +139,6 @@ def build_llm_with_uc(
142139
finally:
143140
logger.info("LLM engine is exiting")
144141
del llm
145-
# Use vLLM's cleanup function to properly release distributed resources
146142
cleanup_dist_env_and_memory(shutdown_ray=False)
147143

148144

@@ -176,7 +172,6 @@ def _run_phase_in_subprocess(
176172
error_queue: Queue to put errors
177173
"""
178174
try:
179-
# Clear any GPU memory in subprocess before starting
180175
if torch.cuda.is_available():
181176
torch.cuda.empty_cache()
182177
torch.cuda.synchronize()
@@ -185,20 +180,18 @@ def _run_phase_in_subprocess(
185180
torch.npu.synchronize()
186181
gc.collect()
187182

188-
# Recreate SamplingParams in subprocess
189183
sampling_params = SamplingParams(**sampling_params_dict)
190184

191185
with build_llm_with_uc(
192186
model_path=model_path,
193187
ucm_config=ucm_config,
194188
enable_prefix_caching=enable_prefix_caching,
195-
gpu_memory_utilization=0.3, # Lower utilization for subprocess
189+
gpu_memory_utilization=0.3,
196190
max_num_batched_tokens=max_num_batched_tokens,
197191
enforce_eager=enforce_eager,
198192
) as llm:
199193
outputs = run_inference(llm, prompts, sampling_params, phase_description)
200194

201-
# Return all outputs
202195
result_queue.put(outputs)
203196
except Exception as e:
204197
import traceback
@@ -316,7 +309,7 @@ def run_inference(
316309
generated_texts.append(generated_text)
317310

318311
if description:
319-
print(f"[INFO] {description} completed")
312+
logger.info(f"{description} completed")
320313

321314
return generated_texts
322315

@@ -377,13 +370,13 @@ def test_offline_accuracy_hbm_ssd_mixed(
377370

378371
try:
379372
test_prompt, standard_answers = load_prompt_from_file()
380-
print(
381-
f"[INFO] Loaded prompt from prompt.json (length: {len(test_prompt)} chars)"
373+
logger.info(
374+
f"Loaded prompt from prompt.json (length: {len(test_prompt)} chars)"
382375
)
383376
if standard_answers:
384-
print(f"[INFO] Standard answers: {standard_answers}")
377+
logger.info(f"Standard answers: {standard_answers}")
385378
else:
386-
print(f"[INFO] No standard answers found in prompt.json")
379+
logger.info(f"No standard answers found in prompt.json")
387380
except Exception as e:
388381
pytest.skip(f"Failed to load prompt from prompt.json: {e}")
389382

@@ -424,19 +417,19 @@ def test_offline_accuracy_hbm_ssd_mixed(
424417
ignore_eos=False,
425418
)
426419

427-
print(f"\n[INFO] ===== HBM + SSD Mixed Accuracy Test =====")
428-
print(f"[INFO] Model: {model_path}")
429-
print(f"[INFO] Full prompt length: {len(test_prompt)} chars")
430-
print(f"[INFO] Max tokens: {max_tokens}")
431-
print(f"[INFO] Temperature: 0.0 (deterministic)")
432-
print(f"[INFO] UCM storage: {ucm_storage_dir}")
433-
print(f"[INFO] Prompt split ratio: {prompt_split_ratio}")
434-
print(f"[INFO] Enforce eager: {enforce_eager}")
435-
print(f"[INFO] Max num batched tokens: {max_num_batched_tokens}")
420+
logger.info(f"\n===== HBM + SSD Mixed Accuracy Test =====")
421+
logger.info(f"Model: {model_path}")
422+
logger.info(f"Full prompt length: {len(test_prompt)} chars")
423+
logger.info(f"Max tokens: {max_tokens}")
424+
logger.info(f"Temperature: 0.0 (deterministic)")
425+
logger.info(f"UCM storage: {ucm_storage_dir}")
426+
logger.info(f"Prompt split ratio: {prompt_split_ratio}")
427+
logger.info(f"Enforce eager: {enforce_eager}")
428+
logger.info(f"Max num batched tokens: {max_num_batched_tokens}")
436429

437430
# ===== Phase 1: Disable HBM PC, save KV cache to SSD and load (baseline) =====
438431
# Run Phase 1 in a separate subprocess to ensure GPU memory is fully released
439-
print(f"\n[INFO] ===== Phase 1: Save KV Cache to SSD And Load (Baseline) =====")
432+
logger.info(f"\n===== Phase 1: Save KV Cache to SSD And Load (Baseline) =====")
440433

441434
if torch.cuda.is_available():
442435
torch.cuda.empty_cache()
@@ -445,7 +438,7 @@ def test_offline_accuracy_hbm_ssd_mixed(
445438
torch.npu.empty_cache()
446439
torch.npu.synchronize()
447440
gc.collect()
448-
time.sleep(2) # Wait a bit for GPU memory to be released
441+
time.sleep(2)
449442

450443
ctx = multiprocessing.get_context("spawn")
451444
result_queue = ctx.Queue()
@@ -467,7 +460,7 @@ def test_offline_accuracy_hbm_ssd_mixed(
467460
[
468461
formatted_full_prompt,
469462
formatted_full_prompt,
470-
], # Phase 1: send same prompt twice
463+
],
471464
sampling_params_dict,
472465
False, # enable_prefix_caching=False for Phase 1
473466
enforce_eager,
@@ -500,34 +493,27 @@ def test_offline_accuracy_hbm_ssd_mixed(
500493
f"Phase 1 failed in subprocess with exit code {process.exitcode}"
501494
)
502495

503-
# Get results from subprocess
504496
if result_queue.empty():
505497
raise RuntimeError("Phase 1 subprocess completed but no result in queue")
506498
phase1_outputs = result_queue.get()
507499
phase1_1_output = phase1_outputs[0] # Phase 1.1: SSD save
508500
phase1_2_output = phase1_outputs[1] # Phase 1.2: SSD load
509-
print(
510-
f"[INFO] Phase 1 completed in subprocess, GPU memory should be fully released"
511-
)
512-
print(f"[INFO] Phase 1.1 output: {phase1_1_output}")
513-
print(f"[INFO] Phase 1.2 output: {phase1_2_output}")
501+
logger.info(f"Phase 1 completed in subprocess, GPU memory should be fully released")
502+
logger.info(f"Phase 1.1 output: {phase1_1_output}")
503+
logger.info(f"Phase 1.2 output: {phase1_2_output}")
514504

515505
# ===== Phase 2: Enable HBM PC, test HBM + SSD mixed hit =====
516506
# Run Phase 2 in a separate subprocess to ensure GPU memory is fully released
517-
print(f"\n[INFO] ===== Phase 2: HBM + SSD Mixed Hit Test =====")
507+
logger.info(f"\n===== Phase 2: HBM + SSD Mixed Hit Test =====")
518508

519-
# Clear GPU memory in main process before starting subprocess
520-
print(
521-
f"[INFO] Clearing GPU memory in main process before starting Phase 2 subprocess..."
522-
)
523509
if torch.cuda.is_available():
524510
torch.cuda.empty_cache()
525511
torch.cuda.synchronize()
526512
elif hasattr(torch, "npu") and torch.npu.is_available():
527513
torch.npu.empty_cache()
528514
torch.npu.synchronize()
529515
gc.collect()
530-
time.sleep(2) # Wait a bit for GPU memory to be released
516+
time.sleep(2)
531517

532518
ctx = multiprocessing.get_context("spawn")
533519
result_queue_2 = ctx.Queue()
@@ -574,28 +560,19 @@ def test_offline_accuracy_hbm_ssd_mixed(
574560
f"Phase 2 failed in subprocess with exit code {process_2.exitcode}"
575561
)
576562

577-
# Get results from subprocess
578563
if result_queue_2.empty():
579564
raise RuntimeError("Phase 2 subprocess completed but no result in queue")
580565
phase2_outputs = result_queue_2.get()
581-
phase2_partial_output = phase2_outputs[
582-
0
583-
] # Output from partial prompt (for reference)
584-
phase2_full_output = phase2_outputs[
585-
1
586-
] # Output from full prompt (this is what we compare)
587-
print(
588-
f"[INFO] Phase 2 completed in subprocess, GPU memory should be fully released"
589-
)
590-
print(f"[INFO] Phase 2.1 output: {phase2_partial_output}")
591-
print(f"[INFO] Phase 2.2 output: {phase2_full_output}")
566+
phase2_partial_output = phase2_outputs[0]
567+
phase2_full_output = phase2_outputs[1]
568+
logger.info(f"Phase 2 completed in subprocess, GPU memory should be fully released")
569+
logger.info(f"[INFO] Phase 2.1 output: {phase2_partial_output}")
570+
logger.info(f"[INFO] Phase 2.2 output: {phase2_full_output}")
592571

593-
# ===== Compare outputs =====
594-
print(f"\n[INFO] ===== Accuracy Test Results =====")
572+
logger.info(f"\n[INFO] ===== Accuracy Test Results =====")
595573

596574
def normalize_text(text: str) -> str:
597575
"""Normalize text for comparison by replacing similar punctuation."""
598-
# Replace full-width punctuation with half-width for comparison
599576
text = text.replace(",", ",")
600577
text = text.replace("。", ".")
601578
text = text.replace("!", "!")
@@ -610,39 +587,39 @@ def normalize_text(text: str) -> str:
610587
phase1_2_output
611588
)
612589
if not phase1_identical:
613-
print(f"\n[WARNING] ===== Phase 1: SSD Load Accuracy Test (Exact Match) =====")
614-
print(
615-
f"[WARNING] Phase 1.1 (SSD save) output differs from Phase 1.2 (SSD load) output!"
590+
logger.warning(f"\n===== Phase 1: SSD Load Accuracy Test (Exact Match) =====")
591+
logger.warning(
592+
f"Phase 1.1 (SSD save) output differs from Phase 1.2 (SSD load) output!"
616593
)
617-
print(f"[WARNING] Phase 1.1 output:\n{phase1_1_output}")
618-
print(f"[WARNING] Phase 1.2 output:\n{phase1_2_output}")
594+
logger.warning(f"Phase 1.1 output:\n{phase1_1_output}")
595+
logger.warning(f"Phase 1.2 output:\n{phase1_2_output}")
619596
if phase1_normalized_identical:
620-
print(
621-
f"[INFO] But normalized outputs are identical (punctuation difference only)"
597+
logger.info(
598+
f"But normalized outputs are identical (punctuation difference only)"
622599
)
623600

624601
phase2_identical = phase1_1_output == phase2_full_output
625602
phase2_normalized_identical = normalize_text(phase1_1_output) == normalize_text(
626603
phase2_full_output
627604
)
628605
if not phase2_identical:
629-
print(
630-
f"\n[WARNING] ===== Phase 2: HBM + SSD Mixed Accuracy Test (Exact Match) ====="
606+
logger.warning(
607+
f"\n===== Phase 2: HBM + SSD Mixed Accuracy Test (Exact Match) ====="
631608
)
632-
print(
633-
f"[WARNING] Phase 1.1 (SSD save) output differs from Phase 2.2 (HBM + SSD mixed) output!"
609+
logger.warning(
610+
f"Phase 1.1 (SSD save) output differs from Phase 2.2 (HBM + SSD mixed) output!"
634611
)
635-
print(f"[WARNING] Phase 1.1 output:\n{phase1_1_output}")
636-
print(f"[WARNING] Phase 2.2 output:\n{phase2_full_output}")
612+
logger.warning(f"Phase 1.1 output:\n{phase1_1_output}")
613+
logger.warning(f"Phase 2.2 output:\n{phase2_full_output}")
637614
if phase2_normalized_identical:
638-
print(
639-
f"[INFO] But normalized outputs are identical (punctuation difference only)"
615+
logger.info(
616+
f"But normalized outputs are identical (punctuation difference only)"
640617
)
641-
print(
642-
f"[INFO] This is likely due to numerical precision differences in KV cache loading"
618+
logger.info(
619+
f"This is likely due to numerical precision differences in KV cache loading"
643620
)
644-
print(f"[INFO] Normalized Phase 1.1: {normalize_text(phase1_1_output)}")
645-
print(f"[INFO] Normalized Phase 2.2: {normalize_text(phase2_full_output)}")
621+
logger.info(f"Normalized Phase 1.1: {normalize_text(phase1_1_output)}")
622+
logger.info(f"Normalized Phase 2.2: {normalize_text(phase2_full_output)}")
646623

647624
# Assert outputs are identical (using normalized comparison for punctuation differences)
648625
# Note: Small numerical precision differences in KV cache loading can cause
@@ -663,13 +640,13 @@ def normalize_text(text: str) -> str:
663640
)
664641

665642
if phase2_identical:
666-
print(f"\n[INFO] ✓ HBM + SSD mixed accuracy test passed: outputs are identical")
643+
logger.info(f"\n✓ HBM + SSD mixed accuracy test passed: outputs are identical")
667644
else:
668-
print(
669-
f"\n[INFO] ✓ HBM + SSD mixed accuracy test passed: normalized outputs are identical"
645+
logger.info(
646+
f"\n ✓ HBM + SSD mixed accuracy test passed: normalized outputs are identical"
670647
)
671-
print(
672-
f"[INFO] Note: Punctuation difference detected (likely due to numerical precision in KV cache)"
648+
logger.info(
649+
f"Note: Punctuation difference detected (likely due to numerical precision in KV cache)"
673650
)
674651

675652
value_lists = {

0 commit comments

Comments
 (0)