|
33 | 33 | import warnings
|
34 | 34 | import weakref
|
35 | 35 | from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser,
|
36 |
| - ArgumentTypeError) |
| 36 | + ArgumentTypeError, _ArgumentGroup) |
37 | 37 | from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
38 | 38 | from collections import UserDict, defaultdict
|
39 | 39 | from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
|
40 | 40 | Iterable, Iterator, KeysView, Mapping)
|
41 | 41 | from concurrent.futures.process import ProcessPoolExecutor
|
42 | 42 | from dataclasses import dataclass, field
|
43 | 43 | from functools import cache, lru_cache, partial, wraps
|
| 44 | +from gettext import gettext as _gettext |
44 | 45 | from types import MappingProxyType
|
45 | 46 | from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
|
46 | 47 | Optional, Sequence, Tuple, Type, TypeVar, Union, cast,
|
|
70 | 71 | from vllm.logger import enable_trace_function_call, init_logger
|
71 | 72 |
|
72 | 73 | if TYPE_CHECKING:
|
| 74 | + from argparse import Namespace |
| 75 | + |
73 | 76 | from vllm.config import ModelConfig, VllmConfig
|
74 | 77 |
|
75 | 78 | logger = init_logger(__name__)
|
@@ -1323,16 +1326,78 @@ def add_arguments(self, actions):
|
1323 | 1326 | super().add_arguments(actions)
|
1324 | 1327 |
|
1325 | 1328 |
|
| 1329 | +class _FlexibleArgumentGroup(_ArgumentGroup): |
| 1330 | + |
| 1331 | + def __init__(self, parser: FlexibleArgumentParser, *args, **kwargs): |
| 1332 | + self._parser = parser |
| 1333 | + super().__init__(*args, **kwargs) |
| 1334 | + |
| 1335 | + def add_argument(self, *args: Any, **kwargs: Any): |
| 1336 | + if sys.version_info < (3, 13): |
| 1337 | + deprecated = kwargs.pop('deprecated', False) |
| 1338 | + action = super().add_argument(*args, **kwargs) |
| 1339 | + object.__setattr__(action, 'deprecated', deprecated) |
| 1340 | + if deprecated and action.dest not in \ |
| 1341 | + self._parser.__class__._deprecated: |
| 1342 | + self._parser._deprecated.add(action) |
| 1343 | + return action |
| 1344 | + |
| 1345 | + # python>3.13 |
| 1346 | + return super().add_argument(*args, **kwargs) |
| 1347 | + |
| 1348 | + |
1326 | 1349 | class FlexibleArgumentParser(ArgumentParser):
|
1327 | 1350 | """ArgumentParser that allows both underscore and dash in names."""
|
1328 | 1351 |
|
| 1352 | + _deprecated: set[Action] = set() |
| 1353 | + _seen: set[str] = set() |
| 1354 | + |
1329 | 1355 | def __init__(self, *args, **kwargs):
|
1330 | 1356 | # Set the default 'formatter_class' to SortedHelpFormatter
|
1331 | 1357 | if 'formatter_class' not in kwargs:
|
1332 | 1358 | kwargs['formatter_class'] = SortedHelpFormatter
|
1333 | 1359 | super().__init__(*args, **kwargs)
|
1334 | 1360 |
|
1335 |
| - def parse_args(self, args=None, namespace=None): |
| 1361 | + if sys.version_info < (3, 13): |
| 1362 | + |
| 1363 | + def parse_known_args( # type: ignore[override] |
| 1364 | + self, |
| 1365 | + args: Sequence[str] | None = None, |
| 1366 | + namespace: Namespace | None = None, |
| 1367 | + ) -> tuple[Namespace | None, list[str]]: |
| 1368 | + namespace, args = super().parse_known_args(args, namespace) |
| 1369 | + for action in FlexibleArgumentParser._deprecated: |
| 1370 | + if action.dest not in FlexibleArgumentParser._seen and getattr( |
| 1371 | + namespace, action.dest, |
| 1372 | + None) != action.default: # noqa: E501 |
| 1373 | + self._warning( |
| 1374 | + _gettext("argument '%(argument_name)s' is deprecated") |
| 1375 | + % {'argument_name': action.dest}) |
| 1376 | + FlexibleArgumentParser._seen.add(action.dest) |
| 1377 | + return namespace, args |
| 1378 | + |
| 1379 | + def add_argument(self, *args: Any, **kwargs: Any): |
| 1380 | + # add a deprecated=True compatibility |
| 1381 | + # for python < 3.13 |
| 1382 | + deprecated = kwargs.pop('deprecated', False) |
| 1383 | + action = super().add_argument(*args, **kwargs) |
| 1384 | + object.__setattr__(action, 'deprecated', deprecated) |
| 1385 | + if deprecated and \ |
| 1386 | + action not in FlexibleArgumentParser._deprecated: |
| 1387 | + self._deprecated.add(action) |
| 1388 | + |
| 1389 | + return action |
| 1390 | + |
| 1391 | + def _warning(self, message: str): |
| 1392 | + self._print_message( |
| 1393 | + _gettext('warning: %(message)s\n') % {'message': message}, |
| 1394 | + sys.stderr) |
| 1395 | + |
| 1396 | + def parse_args( # type: ignore[override] |
| 1397 | + self, |
| 1398 | + args: list[str] | None = None, |
| 1399 | + namespace: Namespace | None = None, |
| 1400 | + ): |
1336 | 1401 | if args is None:
|
1337 | 1402 | args = sys.argv[1:]
|
1338 | 1403 |
|
@@ -1503,6 +1568,15 @@ def _load_config_file(self, file_path: str) -> list[str]:
|
1503 | 1568 |
|
1504 | 1569 | return processed_args
|
1505 | 1570 |
|
| 1571 | + def add_argument_group( |
| 1572 | + self, |
| 1573 | + *args: Any, |
| 1574 | + **kwargs: Any, |
| 1575 | + ) -> _FlexibleArgumentGroup: |
| 1576 | + group = _FlexibleArgumentGroup(self, self, *args, **kwargs) |
| 1577 | + self._action_groups.append(group) |
| 1578 | + return group |
| 1579 | + |
1506 | 1580 |
|
1507 | 1581 | async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
|
1508 | 1582 | **kwargs):
|
|
0 commit comments