Skip to content

Commit d147bb4

Browse files
author
Albert Cheng (Engrg-Hardware 1)
committed
[CI] Fix formatting for DeepSeek MTP acceptance test
Signed-off-by: Albert Cheng (Engrg-Hardware 1) <albecheng@login-lyris02.lyris.clusters.nvidia.com>
1 parent eb9e053 commit d147bb4

File tree

1 file changed

+24
-39
lines changed

1 file changed

+24
-39
lines changed

tests/v1/spec_decode/test_acceptance_length.py

Lines changed: 24 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -320,17 +320,16 @@ def test_eagle3_acceptance_length(
320320
@dataclass
321321
class MTPModelConfig:
322322
"""Model configuration for MTP acceptance length tests."""
323+
323324
verifier: str
324325
expected_acceptance_length: float
325-
expected_acceptance_lengths_per_pos: list[float] = field(
326-
default_factory=list)
326+
expected_acceptance_lengths_per_pos: list[float] = field(default_factory=list)
327327
id: str = ""
328328
num_speculative_tokens: int = 1
329329
tensor_parallel_size: int = 1
330330
max_model_len: int = DEFAULT_MAX_MODEL_LEN
331331
gpu_memory_utilization: float = 0.7
332-
excluded_backends: set[AttentionBackendEnum] = field(
333-
default_factory=set)
332+
excluded_backends: set[AttentionBackendEnum] = field(default_factory=set)
334333
marks: list = field(default_factory=list)
335334
rtol: float | None = None
336335

@@ -364,8 +363,7 @@ class MTPModelConfig:
364363
for config in MTP_MODEL_CONFIGS
365364
],
366365
)
367-
@pytest.mark.parametrize("attention_backend",
368-
get_attention_backend_params())
366+
@pytest.mark.parametrize("attention_backend", get_attention_backend_params())
369367
def test_mtp_acceptance_length(
370368
model_config: MTPModelConfig,
371369
attention_backend: str,
@@ -381,8 +379,7 @@ def test_mtp_acceptance_length(
381379
"""
382380
backend_enum = AttentionBackendEnum[attention_backend]
383381
if backend_enum in model_config.excluded_backends:
384-
pytest.skip(
385-
f"{attention_backend} incompatible with {model_config.id}")
382+
pytest.skip(f"{attention_backend} incompatible with {model_config.id}")
386383

387384
num_spec_tokens = model_config.num_speculative_tokens
388385

@@ -403,35 +400,27 @@ def test_mtp_acceptance_length(
403400
trust_remote_code=True,
404401
) as vllm_runner:
405402
tokenizer = vllm_runner.llm.get_tokenizer()
406-
prompt_ids = get_mt_bench_prompts(
407-
tokenizer, DEFAULT_NUM_PROMPTS)
403+
prompt_ids = get_mt_bench_prompts(tokenizer, DEFAULT_NUM_PROMPTS)
408404

409405
sampling_params = SamplingParams(
410406
temperature=0,
411407
max_tokens=DEFAULT_OUTPUT_LEN,
412408
)
413409
vllm_runner.llm.generate(
414-
[TokensPrompt(prompt_token_ids=ids)
415-
for ids in prompt_ids],
410+
[TokensPrompt(prompt_token_ids=ids) for ids in prompt_ids],
416411
sampling_params=sampling_params,
417412
)
418413

419414
metrics = vllm_runner.llm.get_metrics()
420-
results = extract_acceptance_metrics(
421-
metrics, num_spec_tokens)
415+
results = extract_acceptance_metrics(metrics, num_spec_tokens)
422416

423417
actual = results["acceptance_length"]
424418
expected = model_config.expected_acceptance_length
425419
actual_per_pos = results["acceptance_lengths_per_pos"]
426-
expected_per_pos = (
427-
model_config.expected_acceptance_lengths_per_pos)
420+
expected_per_pos = model_config.expected_acceptance_lengths_per_pos
428421

429422
rel_error = abs(actual - expected) / expected
430-
rtol = (
431-
model_config.rtol
432-
if model_config.rtol is not None
433-
else DEFAULT_RTOL
434-
)
423+
rtol = model_config.rtol if model_config.rtol is not None else DEFAULT_RTOL
435424

436425
assert rel_error <= rtol, (
437426
f"MTP acceptance length regression for "
@@ -441,16 +430,14 @@ def test_mtp_acceptance_length(
441430
f" Relative error: {rel_error:.2%} "
442431
f"(tolerance: {rtol:.2%})\n"
443432
f" Drafts: {results['num_drafts']}, "
444-
f"Accepted: {results['num_accepted_tokens']}")
445-
446-
if (expected_per_pos
447-
and len(expected_per_pos) == len(actual_per_pos)):
448-
rtol = (model_config.rtol
449-
if model_config.rtol is not None
450-
else DEFAULT_RTOL)
451-
for pos, (act, exp) in enumerate(
452-
zip(actual_per_pos, expected_per_pos)
453-
):
433+
f"Accepted: {results['num_accepted_tokens']}"
434+
)
435+
436+
if expected_per_pos and len(expected_per_pos) == len(actual_per_pos):
437+
rtol = (
438+
model_config.rtol if model_config.rtol is not None else DEFAULT_RTOL
439+
)
440+
for pos, (act, exp) in enumerate(zip(actual_per_pos, expected_per_pos)):
454441
if exp > 0:
455442
pos_err = abs(act - exp) / exp
456443
assert pos_err <= rtol, (
@@ -459,19 +446,17 @@ def test_mtp_acceptance_length(
459446
f" Expected: {exp:.3f}\n"
460447
f" Actual: {act:.3f}\n"
461448
f" Error: {pos_err:.2%} "
462-
f"(tolerance: {rtol:.2%})")
449+
f"(tolerance: {rtol:.2%})"
450+
)
463451

464452
print(
465453
f"\n{model_config.id} "
466454
f"[tp={model_config.tensor_parallel_size}, "
467455
f"backend={attention_backend}]: "
468456
f"acceptance_length={actual:.3f}"
469457
f" (expected={expected:.3f}, "
470-
f"rel_error={rel_error:.2%})")
471-
print(
472-
f" Per-position: "
473-
f"{[f'{v:.3f}' for v in actual_per_pos]}")
458+
f"rel_error={rel_error:.2%})"
459+
)
460+
print(f" Per-position: {[f'{v:.3f}' for v in actual_per_pos]}")
474461
if expected_per_pos:
475-
print(
476-
f" Expected: "
477-
f"{[f'{v:.3f}' for v in expected_per_pos]}")
462+
print(f" Expected: {[f'{v:.3f}' for v in expected_per_pos]}")

0 commit comments

Comments
 (0)