Skip to content

Commit 0fa939e

Browse files
authored
Improve configs - LoRAConfig + PromptAdapterConfig (#16980)
Signed-off-by: Harry Mellor <[email protected]>
1 parent 0422ce1 commit 0fa939e

File tree

3 files changed

+130
-91
lines changed

3 files changed

+130
-91
lines changed

tests/lora/test_lora_manager.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
3232
] if current_platform.is_cuda_alike() else ["cpu"])
3333

34+
DEFAULT_DTYPE = torch.get_default_dtype()
35+
3436

3537
@pytest.fixture(scope="function", autouse=True)
3638
def use_v0_only(monkeypatch: pytest.MonkeyPatch):
@@ -125,8 +127,10 @@ def test_replace_submodules(dist_init, dummy_model):
125127
model = dummy_model
126128
manager = LoRAModelManager(
127129
model, 1, 1, 1,
128-
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
129-
torch.device(DEVICES[0]))
130+
LoRAConfig(max_lora_rank=8,
131+
max_cpu_loras=8,
132+
max_loras=8,
133+
lora_dtype=DEFAULT_DTYPE), torch.device(DEVICES[0]))
130134
model = manager.model
131135
assert isinstance(model.get_submodule("dense1"),
132136
ColumnParallelLinearWithLoRA)
@@ -155,7 +159,8 @@ def test_lora_model_manager(dist_init, dummy_model, device):
155159
2,
156160
LoRAConfig(max_lora_rank=8,
157161
max_cpu_loras=3,
158-
max_loras=2),
162+
max_loras=2,
163+
lora_dtype=DEFAULT_DTYPE),
159164
device=device)
160165
assert all(x is None for x in manager.lora_index_to_id)
161166
assert manager.add_adapter(model_lora1)
@@ -221,7 +226,8 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
221226
2,
222227
LoRAConfig(max_lora_rank=8,
223228
max_cpu_loras=3,
224-
max_loras=2),
229+
max_loras=2,
230+
lora_dtype=DEFAULT_DTYPE),
225231
device=device)
226232
assert all(x is None for x in manager.lora_index_to_id)
227233
assert manager.add_adapter(model_lora1)
@@ -316,7 +322,8 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
316322
2,
317323
LoRAConfig(max_lora_rank=8,
318324
max_cpu_loras=2,
319-
max_loras=2),
325+
max_loras=2,
326+
lora_dtype=DEFAULT_DTYPE),
320327
device=device)
321328

322329
assert all(x is None for x in manager.lora_index_to_id)
@@ -424,7 +431,10 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
424431
@pytest.mark.parametrize("device", DEVICES)
425432
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
426433
sql_lora_files, device):
427-
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
434+
lora_config = LoRAConfig(max_lora_rank=8,
435+
max_cpu_loras=4,
436+
max_loras=4,
437+
lora_dtype=DEFAULT_DTYPE)
428438
worker_adapter_manager = LRUCacheWorkerLoRAManager(
429439
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
430440
lora_config.lora_extra_vocab_size, lora_config, device,
@@ -504,7 +514,10 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
504514
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
505515
sql_lora_files, device):
506516
# Should remove every LoRA not specified in the request.
507-
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
517+
lora_config = LoRAConfig(max_lora_rank=8,
518+
max_cpu_loras=4,
519+
max_loras=4,
520+
lora_dtype=DEFAULT_DTYPE)
508521
worker_adapter_manager = WorkerLoRAManager(
509522
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
510523
lora_config.lora_extra_vocab_size, lora_config, device,
@@ -600,7 +613,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
600613
2,
601614
LoRAConfig(max_lora_rank=8,
602615
max_cpu_loras=2,
603-
max_loras=2),
616+
max_loras=2,
617+
lora_dtype=DEFAULT_DTYPE),
604618
device=device)
605619
model = manager.model
606620

vllm/config.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2565,18 +2565,41 @@ def __repr__(self) -> str:
25652565
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
25662566

25672567

2568+
LoRADType = Literal["auto", "float16", "bfloat16"]
2569+
2570+
2571+
@config
25682572
@dataclass
25692573
class LoRAConfig:
2570-
max_lora_rank: int
2571-
max_loras: int
2574+
"""Configuration for LoRA."""
2575+
2576+
max_lora_rank: int = 16
2577+
"""Max LoRA rank."""
2578+
max_loras: int = 1
2579+
"""Max number of LoRAs in a single batch."""
25722580
fully_sharded_loras: bool = False
2581+
"""By default, only half of the LoRA computation is sharded with tensor
2582+
parallelism. Enabling this will use the fully sharded layers. At high
2583+
sequence length, max rank or tensor parallel size, this is likely faster.
2584+
"""
25732585
max_cpu_loras: Optional[int] = None
2574-
lora_dtype: Optional[Union[torch.dtype, str]] = None
2586+
"""Maximum number of LoRAs to store in CPU memory. Must be >= than
2587+
`max_loras`."""
2588+
lora_dtype: Union[torch.dtype, LoRADType] = "auto"
2589+
"""Data type for LoRA. If auto, will default to base model dtype."""
25752590
lora_extra_vocab_size: int = 256
2591+
"""Maximum size of extra vocabulary that can be present in a LoRA adapter
2592+
(added to the base model vocabulary)."""
25762593
# This is a constant.
25772594
lora_vocab_padding_size: ClassVar[int] = 256
2578-
long_lora_scaling_factors: Optional[tuple[float]] = None
2595+
long_lora_scaling_factors: Optional[tuple[float, ...]] = None
2596+
"""Specify multiple scaling factors (which can be different from base model
2597+
scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters
2598+
trained with those scaling factors to be used at the same time. If not
2599+
specified, only adapters trained with the base model scaling factor are
2600+
allowed."""
25792601
bias_enabled: bool = False
2602+
"""Enable bias for LoRA adapters."""
25802603

25812604
def compute_hash(self) -> str:
25822605
"""
@@ -2641,12 +2664,19 @@ def verify_lora_support(self):
26412664
"V1 LoRA does not support long LoRA, please use V0.")
26422665

26432666

2667+
@config
26442668
@dataclass
26452669
class PromptAdapterConfig:
2646-
max_prompt_adapters: int
2647-
max_prompt_adapter_token: int
2670+
max_prompt_adapters: int = 1
2671+
"""Max number of PromptAdapters in a batch."""
2672+
max_prompt_adapter_token: int = 0
2673+
"""Max number of PromptAdapters tokens."""
26482674
max_cpu_prompt_adapters: Optional[int] = None
2649-
prompt_adapter_dtype: Optional[torch.dtype] = None
2675+
"""Maximum number of PromptAdapters to store in CPU memory. Must be >= than
2676+
`max_prompt_adapters`."""
2677+
prompt_adapter_dtype: Union[torch.dtype, str] = "auto"
2678+
"""Data type for PromptAdapter. If auto, will default to base model dtype.
2679+
"""
26502680

26512681
def compute_hash(self) -> str:
26522682
"""
@@ -2678,7 +2708,7 @@ def __post_init__(self):
26782708
self.max_cpu_prompt_adapters = self.max_prompt_adapters
26792709

26802710
def verify_with_model_config(self, model_config: ModelConfig):
2681-
if self.prompt_adapter_dtype in (None, "auto"):
2711+
if self.prompt_adapter_dtype == "auto":
26822712
self.prompt_adapter_dtype = model_config.dtype
26832713
elif isinstance(self.prompt_adapter_dtype, str):
26842714
self.prompt_adapter_dtype = getattr(torch,

vllm/engine/arg_utils.py

Lines changed: 70 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import re
88
import threading
99
from dataclasses import MISSING, dataclass, fields
10-
from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Type,
10+
from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
1111
TypeVar, Union, cast, get_args, get_origin)
1212

1313
import torch
@@ -192,18 +192,23 @@ class EngineArgs:
192192
get_field(MultiModalConfig, "limit_per_prompt")
193193
mm_processor_kwargs: Optional[Dict[str, Any]] = None
194194
disable_mm_preprocessor_cache: bool = False
195+
# LoRA fields
195196
enable_lora: bool = False
196-
enable_lora_bias: bool = False
197-
max_loras: int = 1
198-
max_lora_rank: int = 16
197+
enable_lora_bias: bool = LoRAConfig.bias_enabled
198+
max_loras: int = LoRAConfig.max_loras
199+
max_lora_rank: int = LoRAConfig.max_lora_rank
200+
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
201+
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
202+
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
203+
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
204+
long_lora_scaling_factors: Optional[tuple[float, ...]] = \
205+
LoRAConfig.long_lora_scaling_factors
206+
# PromptAdapter fields
199207
enable_prompt_adapter: bool = False
200-
max_prompt_adapters: int = 1
201-
max_prompt_adapter_token: int = 0
202-
fully_sharded_loras: bool = False
203-
lora_extra_vocab_size: int = 256
204-
long_lora_scaling_factors: Optional[Tuple[float]] = None
205-
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
206-
max_cpu_loras: Optional[int] = None
208+
max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
209+
max_prompt_adapter_token: int = \
210+
PromptAdapterConfig.max_prompt_adapter_token
211+
207212
device: Device = DeviceConfig.device
208213
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
209214
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
@@ -338,10 +343,21 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
338343
kwargs[name]["choices"] = choices
339344
choice_type = type(choices[0])
340345
assert all(type(c) is choice_type for c in choices), (
341-
f"All choices must be of the same type. "
346+
"All choices must be of the same type. "
342347
f"Got {choices} with types {[type(c) for c in choices]}"
343348
)
344349
kwargs[name]["type"] = choice_type
350+
elif can_be_type(field_type, tuple):
351+
if is_type_in_union(field_type, tuple):
352+
field_type = get_type_from_union(field_type, tuple)
353+
dtypes = get_args(field_type)
354+
dtype = dtypes[0]
355+
assert all(
356+
d is dtype for d in dtypes if d is not Ellipsis
357+
), ("All non-Ellipsis tuple elements must be of the same "
358+
f"type. Got {dtypes}.")
359+
kwargs[name]["type"] = dtype
360+
kwargs[name]["nargs"] = "+"
345361
elif can_be_type(field_type, int):
346362
kwargs[name]["type"] = optional_int if optional else int
347363
elif can_be_type(field_type, float):
@@ -685,70 +701,49 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
685701
'inputs.')
686702

687703
# LoRA related configs
688-
parser.add_argument('--enable-lora',
689-
action='store_true',
690-
help='If True, enable handling of LoRA adapters.')
691-
parser.add_argument('--enable-lora-bias',
692-
action='store_true',
693-
help='If True, enable bias for LoRA adapters.')
694-
parser.add_argument('--max-loras',
695-
type=int,
696-
default=EngineArgs.max_loras,
697-
help='Max number of LoRAs in a single batch.')
698-
parser.add_argument('--max-lora-rank',
699-
type=int,
700-
default=EngineArgs.max_lora_rank,
701-
help='Max LoRA rank.')
702-
parser.add_argument(
703-
'--lora-extra-vocab-size',
704-
type=int,
705-
default=EngineArgs.lora_extra_vocab_size,
706-
help=('Maximum size of extra vocabulary that can be '
707-
'present in a LoRA adapter (added to the base '
708-
'model vocabulary).'))
709-
parser.add_argument(
704+
lora_kwargs = get_kwargs(LoRAConfig)
705+
lora_group = parser.add_argument_group(
706+
title="LoRAConfig",
707+
description=LoRAConfig.__doc__,
708+
)
709+
lora_group.add_argument(
710+
'--enable-lora',
711+
action=argparse.BooleanOptionalAction,
712+
help='If True, enable handling of LoRA adapters.')
713+
lora_group.add_argument('--enable-lora-bias',
714+
**lora_kwargs["bias_enabled"])
715+
lora_group.add_argument('--max-loras', **lora_kwargs["max_loras"])
716+
lora_group.add_argument('--max-lora-rank',
717+
**lora_kwargs["max_lora_rank"])
718+
lora_group.add_argument('--lora-extra-vocab-size',
719+
**lora_kwargs["lora_extra_vocab_size"])
720+
lora_group.add_argument(
710721
'--lora-dtype',
711-
type=str,
712-
default=EngineArgs.lora_dtype,
713-
choices=['auto', 'float16', 'bfloat16'],
714-
help=('Data type for LoRA. If auto, will default to '
715-
'base model dtype.'))
716-
parser.add_argument(
717-
'--long-lora-scaling-factors',
718-
type=optional_str,
719-
default=EngineArgs.long_lora_scaling_factors,
720-
help=('Specify multiple scaling factors (which can '
721-
'be different from base model scaling factor '
722-
'- see eg. Long LoRA) to allow for multiple '
723-
'LoRA adapters trained with those scaling '
724-
'factors to be used at the same time. If not '
725-
'specified, only adapters trained with the '
726-
'base model scaling factor are allowed.'))
727-
parser.add_argument(
728-
'--max-cpu-loras',
729-
type=int,
730-
default=EngineArgs.max_cpu_loras,
731-
help=('Maximum number of LoRAs to store in CPU memory. '
732-
'Must be >= than max_loras.'))
733-
parser.add_argument(
734-
'--fully-sharded-loras',
735-
action='store_true',
736-
help=('By default, only half of the LoRA computation is '
737-
'sharded with tensor parallelism. '
738-
'Enabling this will use the fully sharded layers. '
739-
'At high sequence length, max rank or '
740-
'tensor parallel size, this is likely faster.'))
741-
parser.add_argument('--enable-prompt-adapter',
742-
action='store_true',
743-
help='If True, enable handling of PromptAdapters.')
744-
parser.add_argument('--max-prompt-adapters',
745-
type=int,
746-
default=EngineArgs.max_prompt_adapters,
747-
help='Max number of PromptAdapters in a batch.')
748-
parser.add_argument('--max-prompt-adapter-token',
749-
type=int,
750-
default=EngineArgs.max_prompt_adapter_token,
751-
help='Max number of PromptAdapters tokens')
722+
**lora_kwargs["lora_dtype"],
723+
)
724+
lora_group.add_argument('--long-lora-scaling-factors',
725+
**lora_kwargs["long_lora_scaling_factors"])
726+
lora_group.add_argument('--max-cpu-loras',
727+
**lora_kwargs["max_cpu_loras"])
728+
lora_group.add_argument('--fully-sharded-loras',
729+
**lora_kwargs["fully_sharded_loras"])
730+
731+
# PromptAdapter related configs
732+
prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
733+
prompt_adapter_group = parser.add_argument_group(
734+
title="PromptAdapterConfig",
735+
description=PromptAdapterConfig.__doc__,
736+
)
737+
prompt_adapter_group.add_argument(
738+
'--enable-prompt-adapter',
739+
action=argparse.BooleanOptionalAction,
740+
help='If True, enable handling of PromptAdapters.')
741+
prompt_adapter_group.add_argument(
742+
'--max-prompt-adapters',
743+
**prompt_adapter_kwargs["max_prompt_adapters"])
744+
prompt_adapter_group.add_argument(
745+
'--max-prompt-adapter-token',
746+
**prompt_adapter_kwargs["max_prompt_adapter_token"])
752747

753748
# Device arguments
754749
device_kwargs = get_kwargs(DeviceConfig)

0 commit comments

Comments
 (0)