Skip to content

Improve configs - LoRAConfig + PromptAdapterConfig #16980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
] if current_platform.is_cuda_alike() else ["cpu"])

DEFAULT_DTYPE = torch.get_default_dtype()


@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch: pytest.MonkeyPatch):
Expand Down Expand Up @@ -125,8 +127,10 @@ def test_replace_submodules(dist_init, dummy_model):
model = dummy_model
manager = LoRAModelManager(
model, 1, 1, 1,
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
torch.device(DEVICES[0]))
LoRAConfig(max_lora_rank=8,
max_cpu_loras=8,
max_loras=8,
lora_dtype=DEFAULT_DTYPE), torch.device(DEVICES[0]))
model = manager.model
assert isinstance(model.get_submodule("dense1"),
ColumnParallelLinearWithLoRA)
Expand Down Expand Up @@ -155,7 +159,8 @@ def test_lora_model_manager(dist_init, dummy_model, device):
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=3,
max_loras=2),
max_loras=2,
lora_dtype=DEFAULT_DTYPE),
device=device)
assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_adapter(model_lora1)
Expand Down Expand Up @@ -221,7 +226,8 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=3,
max_loras=2),
max_loras=2,
lora_dtype=DEFAULT_DTYPE),
device=device)
assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_adapter(model_lora1)
Expand Down Expand Up @@ -316,7 +322,8 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=2,
max_loras=2),
max_loras=2,
lora_dtype=DEFAULT_DTYPE),
device=device)

assert all(x is None for x in manager.lora_index_to_id)
Expand Down Expand Up @@ -424,7 +431,10 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
@pytest.mark.parametrize("device", DEVICES)
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files, device):
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
lora_config = LoRAConfig(max_lora_rank=8,
max_cpu_loras=4,
max_loras=4,
lora_dtype=DEFAULT_DTYPE)
worker_adapter_manager = LRUCacheWorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, device,
Expand Down Expand Up @@ -504,7 +514,10 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files, device):
# Should remove every LoRA not specified in the request.
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
lora_config = LoRAConfig(max_lora_rank=8,
max_cpu_loras=4,
max_loras=4,
lora_dtype=DEFAULT_DTYPE)
worker_adapter_manager = WorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, device,
Expand Down Expand Up @@ -600,7 +613,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=2,
max_loras=2),
max_loras=2,
lora_dtype=DEFAULT_DTYPE),
device=device)
model = manager.model

Expand Down
46 changes: 38 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2565,18 +2565,41 @@ def __repr__(self) -> str:
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"


LoRADType = Literal["auto", "float16", "bfloat16"]


@config
@dataclass
class LoRAConfig:
max_lora_rank: int
max_loras: int
"""Configuration for LoRA."""

max_lora_rank: int = 16
"""Max LoRA rank."""
max_loras: int = 1
"""Max number of LoRAs in a single batch."""
fully_sharded_loras: bool = False
"""By default, only half of the LoRA computation is sharded with tensor
parallelism. Enabling this will use the fully sharded layers. At high
sequence length, max rank or tensor parallel size, this is likely faster.
"""
max_cpu_loras: Optional[int] = None
lora_dtype: Optional[Union[torch.dtype, str]] = None
"""Maximum number of LoRAs to store in CPU memory. Must be >= than
`max_loras`."""
lora_dtype: Union[torch.dtype, LoRADType] = "auto"
"""Data type for LoRA. If auto, will default to base model dtype."""
lora_extra_vocab_size: int = 256
"""Maximum size of extra vocabulary that can be present in a LoRA adapter
(added to the base model vocabulary)."""
# This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256
long_lora_scaling_factors: Optional[tuple[float]] = None
long_lora_scaling_factors: Optional[tuple[float, ...]] = None
"""Specify multiple scaling factors (which can be different from base model
scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters
trained with those scaling factors to be used at the same time. If not
specified, only adapters trained with the base model scaling factor are
allowed."""
bias_enabled: bool = False
"""Enable bias for LoRA adapters."""

def compute_hash(self) -> str:
"""
Expand Down Expand Up @@ -2641,12 +2664,19 @@ def verify_lora_support(self):
"V1 LoRA does not support long LoRA, please use V0.")


@config
@dataclass
class PromptAdapterConfig:
max_prompt_adapters: int
max_prompt_adapter_token: int
max_prompt_adapters: int = 1
"""Max number of PromptAdapters in a batch."""
max_prompt_adapter_token: int = 0
"""Max number of PromptAdapters tokens."""
max_cpu_prompt_adapters: Optional[int] = None
prompt_adapter_dtype: Optional[torch.dtype] = None
"""Maximum number of PromptAdapters to store in CPU memory. Must be >= than
`max_prompt_adapters`."""
prompt_adapter_dtype: Union[torch.dtype, str] = "auto"
"""Data type for PromptAdapter. If auto, will default to base model dtype.
"""

def compute_hash(self) -> str:
"""
Expand Down Expand Up @@ -2678,7 +2708,7 @@ def __post_init__(self):
self.max_cpu_prompt_adapters = self.max_prompt_adapters

def verify_with_model_config(self, model_config: ModelConfig):
if self.prompt_adapter_dtype in (None, "auto"):
if self.prompt_adapter_dtype == "auto":
self.prompt_adapter_dtype = model_config.dtype
elif isinstance(self.prompt_adapter_dtype, str):
self.prompt_adapter_dtype = getattr(torch,
Expand Down
145 changes: 70 additions & 75 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
import threading
from dataclasses import MISSING, dataclass, fields
from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Type,
from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
TypeVar, Union, cast, get_args, get_origin)

import torch
Expand Down Expand Up @@ -192,18 +192,23 @@ class EngineArgs:
get_field(MultiModalConfig, "limit_per_prompt")
mm_processor_kwargs: Optional[Dict[str, Any]] = None
disable_mm_preprocessor_cache: bool = False
# LoRA fields
enable_lora: bool = False
enable_lora_bias: bool = False
max_loras: int = 1
max_lora_rank: int = 16
enable_lora_bias: bool = LoRAConfig.bias_enabled
max_loras: int = LoRAConfig.max_loras
max_lora_rank: int = LoRAConfig.max_lora_rank
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
long_lora_scaling_factors: Optional[tuple[float, ...]] = \
LoRAConfig.long_lora_scaling_factors
# PromptAdapter fields
enable_prompt_adapter: bool = False
max_prompt_adapters: int = 1
max_prompt_adapter_token: int = 0
fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
max_cpu_loras: Optional[int] = None
max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
max_prompt_adapter_token: int = \
PromptAdapterConfig.max_prompt_adapter_token

device: Device = DeviceConfig.device
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
Expand Down Expand Up @@ -338,10 +343,21 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
kwargs[name]["choices"] = choices
choice_type = type(choices[0])
assert all(type(c) is choice_type for c in choices), (
f"All choices must be of the same type. "
"All choices must be of the same type. "
f"Got {choices} with types {[type(c) for c in choices]}"
)
kwargs[name]["type"] = choice_type
elif can_be_type(field_type, tuple):
if is_type_in_union(field_type, tuple):
field_type = get_type_from_union(field_type, tuple)
dtypes = get_args(field_type)
dtype = dtypes[0]
assert all(
d is dtype for d in dtypes if d is not Ellipsis
), ("All non-Ellipsis tuple elements must be of the same "
f"type. Got {dtypes}.")
kwargs[name]["type"] = dtype
kwargs[name]["nargs"] = "+"
elif can_be_type(field_type, int):
kwargs[name]["type"] = optional_int if optional else int
elif can_be_type(field_type, float):
Expand Down Expand Up @@ -685,70 +701,49 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
'inputs.')

# LoRA related configs
parser.add_argument('--enable-lora',
action='store_true',
help='If True, enable handling of LoRA adapters.')
parser.add_argument('--enable-lora-bias',
action='store_true',
help='If True, enable bias for LoRA adapters.')
parser.add_argument('--max-loras',
type=int,
default=EngineArgs.max_loras,
help='Max number of LoRAs in a single batch.')
parser.add_argument('--max-lora-rank',
type=int,
default=EngineArgs.max_lora_rank,
help='Max LoRA rank.')
parser.add_argument(
'--lora-extra-vocab-size',
type=int,
default=EngineArgs.lora_extra_vocab_size,
help=('Maximum size of extra vocabulary that can be '
'present in a LoRA adapter (added to the base '
'model vocabulary).'))
parser.add_argument(
lora_kwargs = get_kwargs(LoRAConfig)
lora_group = parser.add_argument_group(
title="LoRAConfig",
description=LoRAConfig.__doc__,
)
lora_group.add_argument(
'--enable-lora',
action=argparse.BooleanOptionalAction,
help='If True, enable handling of LoRA adapters.')
lora_group.add_argument('--enable-lora-bias',
**lora_kwargs["bias_enabled"])
lora_group.add_argument('--max-loras', **lora_kwargs["max_loras"])
lora_group.add_argument('--max-lora-rank',
**lora_kwargs["max_lora_rank"])
lora_group.add_argument('--lora-extra-vocab-size',
**lora_kwargs["lora_extra_vocab_size"])
lora_group.add_argument(
'--lora-dtype',
type=str,
default=EngineArgs.lora_dtype,
choices=['auto', 'float16', 'bfloat16'],
help=('Data type for LoRA. If auto, will default to '
'base model dtype.'))
parser.add_argument(
'--long-lora-scaling-factors',
type=optional_str,
default=EngineArgs.long_lora_scaling_factors,
help=('Specify multiple scaling factors (which can '
'be different from base model scaling factor '
'- see eg. Long LoRA) to allow for multiple '
'LoRA adapters trained with those scaling '
'factors to be used at the same time. If not '
'specified, only adapters trained with the '
'base model scaling factor are allowed.'))
parser.add_argument(
'--max-cpu-loras',
type=int,
default=EngineArgs.max_cpu_loras,
help=('Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_loras.'))
parser.add_argument(
'--fully-sharded-loras',
action='store_true',
help=('By default, only half of the LoRA computation is '
'sharded with tensor parallelism. '
'Enabling this will use the fully sharded layers. '
'At high sequence length, max rank or '
'tensor parallel size, this is likely faster.'))
parser.add_argument('--enable-prompt-adapter',
action='store_true',
help='If True, enable handling of PromptAdapters.')
parser.add_argument('--max-prompt-adapters',
type=int,
default=EngineArgs.max_prompt_adapters,
help='Max number of PromptAdapters in a batch.')
parser.add_argument('--max-prompt-adapter-token',
type=int,
default=EngineArgs.max_prompt_adapter_token,
help='Max number of PromptAdapters tokens')
**lora_kwargs["lora_dtype"],
)
lora_group.add_argument('--long-lora-scaling-factors',
**lora_kwargs["long_lora_scaling_factors"])
lora_group.add_argument('--max-cpu-loras',
**lora_kwargs["max_cpu_loras"])
lora_group.add_argument('--fully-sharded-loras',
**lora_kwargs["fully_sharded_loras"])

# PromptAdapter related configs
prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
prompt_adapter_group = parser.add_argument_group(
title="PromptAdapterConfig",
description=PromptAdapterConfig.__doc__,
)
prompt_adapter_group.add_argument(
'--enable-prompt-adapter',
action=argparse.BooleanOptionalAction,
help='If True, enable handling of PromptAdapters.')
prompt_adapter_group.add_argument(
'--max-prompt-adapters',
**prompt_adapter_kwargs["max_prompt_adapters"])
prompt_adapter_group.add_argument(
'--max-prompt-adapter-token',
**prompt_adapter_kwargs["max_prompt_adapter_token"])

# Device arguments
device_kwargs = get_kwargs(DeviceConfig)
Expand Down