Skip to content

Commit d5f268d

Browse files
hmellorlk-chen
authored andcommitted
Improve literal dataclass field conversion to argparse argument (vllm-project#17391)
Signed-off-by: Harry Mellor <[email protected]>
1 parent 82f87d2 commit d5f268d

File tree

4 files changed

+97
-18
lines changed

4 files changed

+97
-18
lines changed

tests/engine/test_arg_utils.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from vllm.config import PoolerConfig, config
1212
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
1313
get_type, is_not_builtin, is_type,
14-
nullable_kvs, optional_type)
14+
literal_to_kwargs, nullable_kvs,
15+
optional_type)
1516
from vllm.utils import FlexibleArgumentParser
1617

1718

@@ -71,6 +72,21 @@ def test_get_type(type_hints, type, expected):
7172
assert get_type(type_hints, type) == expected
7273

7374

75+
@pytest.mark.parametrize(("type_hints", "expected"), [
76+
({Literal[1, 2]}, {
77+
"type": int,
78+
"choices": [1, 2]
79+
}),
80+
({Literal[1, "a"]}, Exception),
81+
])
82+
def test_literal_to_kwargs(type_hints, expected):
83+
context = nullcontext()
84+
if expected is Exception:
85+
context = pytest.raises(expected)
86+
with context:
87+
assert literal_to_kwargs(type_hints) == expected
88+
89+
7490
@config
7591
@dataclass
7692
class DummyConfigClass:
@@ -81,11 +97,15 @@ class DummyConfigClass:
8197
optional_literal: Optional[Literal["x", "y"]] = None
8298
"""Optional literal with default None"""
8399
tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3))
84-
"""Tuple with default (1, 2, 3)"""
100+
"""Tuple with variable length"""
85101
tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2))
86-
"""Tuple with default (1, 2)"""
102+
"""Tuple with fixed length"""
87103
list_n: list[int] = field(default_factory=lambda: [1, 2, 3])
88-
"""List with default [1, 2, 3]"""
104+
"""List with variable length"""
105+
list_literal: list[Literal[1, 2]] = field(default_factory=list)
106+
"""List with literal choices"""
107+
literal_literal: Literal[Literal[1], Literal[2]] = 1
108+
"""Literal of literals with default 1"""
89109

90110

91111
@pytest.mark.parametrize(("type_hint", "expected"), [
@@ -111,6 +131,12 @@ def test_get_kwargs():
111131
# lists should work
112132
assert kwargs["list_n"]["type"] is int
113133
assert kwargs["list_n"]["nargs"] == "+"
134+
# lists with literals should have the correct choices
135+
assert kwargs["list_literal"]["type"] is int
136+
assert kwargs["list_literal"]["nargs"] == "+"
137+
assert kwargs["list_literal"]["choices"] == [1, 2]
138+
# literals of literals should have merged choices
139+
assert kwargs["literal_literal"]["choices"] == [1, 2]
114140

115141

116142
@pytest.mark.parametrize(("arg", "expected"), [

tests/test_config.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,47 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
from dataclasses import MISSING, Field, asdict, dataclass, field
4+
from typing import Literal, Union
45

56
import pytest
67

7-
from vllm.config import ModelConfig, PoolerConfig, get_field
8+
from vllm.config import ModelConfig, PoolerConfig, config, get_field
89
from vllm.model_executor.layers.pooler import PoolingType
910
from vllm.platforms import current_platform
1011

1112

13+
class TestConfig1:
14+
pass
15+
16+
17+
@dataclass
18+
class TestConfig2:
19+
a: int
20+
"""docstring"""
21+
22+
23+
@dataclass
24+
class TestConfig3:
25+
a: int = 1
26+
27+
28+
@dataclass
29+
class TestConfig4:
30+
a: Union[Literal[1], Literal[2]] = 1
31+
"""docstring"""
32+
33+
34+
@pytest.mark.parametrize(("test_config", "expected_error"), [
35+
(TestConfig1, "must be a dataclass"),
36+
(TestConfig2, "must have a default"),
37+
(TestConfig3, "must have a docstring"),
38+
(TestConfig4, "must use a single Literal"),
39+
])
40+
def test_config(test_config, expected_error):
41+
with pytest.raises(Exception, match=expected_error):
42+
config(test_config)
43+
44+
1245
def test_get_field():
1346

1447
@dataclass

vllm/config.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from importlib.util import find_spec
1818
from pathlib import Path
1919
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
20-
Optional, Protocol, TypeVar, Union, get_args)
20+
Optional, Protocol, TypeVar, Union, get_args, get_origin)
2121

2222
import torch
2323
from pydantic import BaseModel, Field, PrivateAttr
@@ -177,9 +177,19 @@ def config(cls: ConfigT) -> ConfigT:
177177
raise ValueError(
178178
f"Field '{f.name}' in {cls.__name__} must have a default value."
179179
)
180+
180181
if f.name not in attr_docs:
181182
raise ValueError(
182183
f"Field '{f.name}' in {cls.__name__} must have a docstring.")
184+
185+
if get_origin(f.type) is Union:
186+
args = get_args(f.type)
187+
literal_args = [arg for arg in args if get_origin(arg) is Literal]
188+
if len(literal_args) > 1:
189+
raise ValueError(
190+
f"Field '{f.name}' in {cls.__name__} must use a single "
191+
"Literal type. Please use 'Literal[Literal1, Literal2]' "
192+
"instead of 'Union[Literal1, Literal2]'.")
183193
return cls
184194

185195

@@ -3166,16 +3176,17 @@ def get_served_model_name(model: str,
31663176
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
31673177
"xgrammar", "guidance"]
31683178
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
3179+
GuidedDecodingBackend = Literal[GuidedDecodingBackendV0,
3180+
GuidedDecodingBackendV1]
31693181

31703182

31713183
@config
31723184
@dataclass
31733185
class DecodingConfig:
31743186
"""Dataclass which contains the decoding strategy of the engine."""
31753187

3176-
guided_decoding_backend: Union[
3177-
GuidedDecodingBackendV0,
3178-
GuidedDecodingBackendV1] = "auto" if envs.VLLM_USE_V1 else "xgrammar"
3188+
guided_decoding_backend: GuidedDecodingBackend = \
3189+
"auto" if envs.VLLM_USE_V1 else "xgrammar"
31793190
"""Which engine will be used for guided decoding (JSON schema / regex etc)
31803191
by default. With "auto", we will make opinionated choices based on request
31813192
contents and what the backend libraries currently support, so the behavior

vllm/engine/arg_utils.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,18 @@ def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
116116
return next((th for th in type_hints if is_type(th, type)), None)
117117

118118

119+
def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]:
120+
"""Convert Literal type hints to argparse kwargs."""
121+
type_hint = get_type(type_hints, Literal)
122+
choices = get_args(type_hint)
123+
choice_type = type(choices[0])
124+
if not all(isinstance(choice, choice_type) for choice in choices):
125+
raise ValueError(
126+
"All choices must be of the same type. "
127+
f"Got {choices} with types {[type(c) for c in choices]}")
128+
return {"type": choice_type, "choices": sorted(choices)}
129+
130+
119131
def is_not_builtin(type_hint: TypeHint) -> bool:
120132
"""Check if the class is not a built-in type."""
121133
return type_hint.__module__ != "builtins"
@@ -151,15 +163,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
151163
# Creates --no-<name> and --<name> flags
152164
kwargs[name]["action"] = argparse.BooleanOptionalAction
153165
elif contains_type(type_hints, Literal):
154-
# Creates choices from Literal arguments
155-
type_hint = get_type(type_hints, Literal)
156-
choices = sorted(get_args(type_hint))
157-
kwargs[name]["choices"] = choices
158-
choice_type = type(choices[0])
159-
assert all(type(c) is choice_type for c in choices), (
160-
"All choices must be of the same type. "
161-
f"Got {choices} with types {[type(c) for c in choices]}")
162-
kwargs[name]["type"] = choice_type
166+
kwargs[name].update(literal_to_kwargs(type_hints))
163167
elif contains_type(type_hints, tuple):
164168
type_hint = get_type(type_hints, tuple)
165169
types = get_args(type_hint)
@@ -191,6 +195,11 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
191195
raise ValueError(
192196
f"Unsupported type {type_hints} for argument {name}.")
193197

198+
# If the type hint was a sequence of literals, use the helper function
199+
# to update the type and choices
200+
if get_origin(kwargs[name].get("type")) is Literal:
201+
kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]}))
202+
194203
# If None is in type_hints, make the argument optional.
195204
# But not if it's a bool, argparse will handle this better.
196205
if type(None) in type_hints and not contains_type(type_hints, bool):

0 commit comments

Comments
 (0)