@@ -320,17 +320,16 @@ def test_eagle3_acceptance_length(
320320@dataclass
321321class MTPModelConfig :
322322 """Model configuration for MTP acceptance length tests."""
323+
323324 verifier : str
324325 expected_acceptance_length : float
325- expected_acceptance_lengths_per_pos : list [float ] = field (
326- default_factory = list )
326+ expected_acceptance_lengths_per_pos : list [float ] = field (default_factory = list )
327327 id : str = ""
328328 num_speculative_tokens : int = 1
329329 tensor_parallel_size : int = 1
330330 max_model_len : int = DEFAULT_MAX_MODEL_LEN
331331 gpu_memory_utilization : float = 0.7
332- excluded_backends : set [AttentionBackendEnum ] = field (
333- default_factory = set )
332+ excluded_backends : set [AttentionBackendEnum ] = field (default_factory = set )
334333 marks : list = field (default_factory = list )
335334 rtol : float | None = None
336335
@@ -364,8 +363,7 @@ class MTPModelConfig:
364363 for config in MTP_MODEL_CONFIGS
365364 ],
366365)
367- @pytest .mark .parametrize ("attention_backend" ,
368- get_attention_backend_params ())
366+ @pytest .mark .parametrize ("attention_backend" , get_attention_backend_params ())
369367def test_mtp_acceptance_length (
370368 model_config : MTPModelConfig ,
371369 attention_backend : str ,
@@ -381,8 +379,7 @@ def test_mtp_acceptance_length(
381379 """
382380 backend_enum = AttentionBackendEnum [attention_backend ]
383381 if backend_enum in model_config .excluded_backends :
384- pytest .skip (
385- f"{ attention_backend } incompatible with { model_config .id } " )
382+ pytest .skip (f"{ attention_backend } incompatible with { model_config .id } " )
386383
387384 num_spec_tokens = model_config .num_speculative_tokens
388385
@@ -403,35 +400,27 @@ def test_mtp_acceptance_length(
403400 trust_remote_code = True ,
404401 ) as vllm_runner :
405402 tokenizer = vllm_runner .llm .get_tokenizer ()
406- prompt_ids = get_mt_bench_prompts (
407- tokenizer , DEFAULT_NUM_PROMPTS )
403+ prompt_ids = get_mt_bench_prompts (tokenizer , DEFAULT_NUM_PROMPTS )
408404
409405 sampling_params = SamplingParams (
410406 temperature = 0 ,
411407 max_tokens = DEFAULT_OUTPUT_LEN ,
412408 )
413409 vllm_runner .llm .generate (
414- [TokensPrompt (prompt_token_ids = ids )
415- for ids in prompt_ids ],
410+ [TokensPrompt (prompt_token_ids = ids ) for ids in prompt_ids ],
416411 sampling_params = sampling_params ,
417412 )
418413
419414 metrics = vllm_runner .llm .get_metrics ()
420- results = extract_acceptance_metrics (
421- metrics , num_spec_tokens )
415+ results = extract_acceptance_metrics (metrics , num_spec_tokens )
422416
423417 actual = results ["acceptance_length" ]
424418 expected = model_config .expected_acceptance_length
425419 actual_per_pos = results ["acceptance_lengths_per_pos" ]
426- expected_per_pos = (
427- model_config .expected_acceptance_lengths_per_pos )
420+ expected_per_pos = model_config .expected_acceptance_lengths_per_pos
428421
429422 rel_error = abs (actual - expected ) / expected
430- rtol = (
431- model_config .rtol
432- if model_config .rtol is not None
433- else DEFAULT_RTOL
434- )
423+ rtol = model_config .rtol if model_config .rtol is not None else DEFAULT_RTOL
435424
436425 assert rel_error <= rtol , (
437426 f"MTP acceptance length regression for "
@@ -441,16 +430,14 @@ def test_mtp_acceptance_length(
441430 f" Relative error: { rel_error :.2%} "
442431 f"(tolerance: { rtol :.2%} )\n "
443432 f" Drafts: { results ['num_drafts' ]} , "
444- f"Accepted: { results ['num_accepted_tokens' ]} " )
445-
446- if (expected_per_pos
447- and len (expected_per_pos ) == len (actual_per_pos )):
448- rtol = (model_config .rtol
449- if model_config .rtol is not None
450- else DEFAULT_RTOL )
451- for pos , (act , exp ) in enumerate (
452- zip (actual_per_pos , expected_per_pos )
453- ):
433+ f"Accepted: { results ['num_accepted_tokens' ]} "
434+ )
435+
436+ if expected_per_pos and len (expected_per_pos ) == len (actual_per_pos ):
437+ rtol = (
438+ model_config .rtol if model_config .rtol is not None else DEFAULT_RTOL
439+ )
440+ for pos , (act , exp ) in enumerate (zip (actual_per_pos , expected_per_pos )):
454441 if exp > 0 :
455442 pos_err = abs (act - exp ) / exp
456443 assert pos_err <= rtol , (
@@ -459,19 +446,17 @@ def test_mtp_acceptance_length(
459446 f" Expected: { exp :.3f} \n "
460447 f" Actual: { act :.3f} \n "
461448 f" Error: { pos_err :.2%} "
462- f"(tolerance: { rtol :.2%} )" )
449+ f"(tolerance: { rtol :.2%} )"
450+ )
463451
464452 print (
465453 f"\n { model_config .id } "
466454 f"[tp={ model_config .tensor_parallel_size } , "
467455 f"backend={ attention_backend } ]: "
468456 f"acceptance_length={ actual :.3f} "
469457 f" (expected={ expected :.3f} , "
470- f"rel_error={ rel_error :.2%} )" )
471- print (
472- f" Per-position: "
473- f"{ [f'{ v :.3f} ' for v in actual_per_pos ]} " )
458+ f"rel_error={ rel_error :.2%} )"
459+ )
460+ print (f" Per-position: { [f'{ v :.3f} ' for v in actual_per_pos ]} " )
474461 if expected_per_pos :
475- print (
476- f" Expected: "
477- f"{ [f'{ v :.3f} ' for v in expected_per_pos ]} " )
462+ print (f" Expected: { [f'{ v :.3f} ' for v in expected_per_pos ]} " )
0 commit comments