|
7 | 7 | import re
|
8 | 8 | import threading
|
9 | 9 | from dataclasses import MISSING, dataclass, fields
|
10 |
| -from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Type, |
| 10 | +from typing import (Any, Callable, Dict, List, Literal, Optional, Type, |
11 | 11 | TypeVar, Union, cast, get_args, get_origin)
|
12 | 12 |
|
13 | 13 | import torch
|
@@ -192,18 +192,23 @@ class EngineArgs:
|
192 | 192 | get_field(MultiModalConfig, "limit_per_prompt")
|
193 | 193 | mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
194 | 194 | disable_mm_preprocessor_cache: bool = False
|
| 195 | + # LoRA fields |
195 | 196 | enable_lora: bool = False
|
196 |
| - enable_lora_bias: bool = False |
197 |
| - max_loras: int = 1 |
198 |
| - max_lora_rank: int = 16 |
| 197 | + enable_lora_bias: bool = LoRAConfig.bias_enabled |
| 198 | + max_loras: int = LoRAConfig.max_loras |
| 199 | + max_lora_rank: int = LoRAConfig.max_lora_rank |
| 200 | + fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras |
| 201 | + max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras |
| 202 | + lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype |
| 203 | + lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size |
| 204 | + long_lora_scaling_factors: Optional[tuple[float, ...]] = \ |
| 205 | + LoRAConfig.long_lora_scaling_factors |
| 206 | + # PromptAdapter fields |
199 | 207 | enable_prompt_adapter: bool = False
|
200 |
| - max_prompt_adapters: int = 1 |
201 |
| - max_prompt_adapter_token: int = 0 |
202 |
| - fully_sharded_loras: bool = False |
203 |
| - lora_extra_vocab_size: int = 256 |
204 |
| - long_lora_scaling_factors: Optional[Tuple[float]] = None |
205 |
| - lora_dtype: Optional[Union[str, torch.dtype]] = 'auto' |
206 |
| - max_cpu_loras: Optional[int] = None |
| 208 | + max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters |
| 209 | + max_prompt_adapter_token: int = \ |
| 210 | + PromptAdapterConfig.max_prompt_adapter_token |
| 211 | + |
207 | 212 | device: Device = DeviceConfig.device
|
208 | 213 | num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
|
209 | 214 | multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
|
@@ -338,10 +343,21 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
338 | 343 | kwargs[name]["choices"] = choices
|
339 | 344 | choice_type = type(choices[0])
|
340 | 345 | assert all(type(c) is choice_type for c in choices), (
|
341 |
| - f"All choices must be of the same type. " |
| 346 | + "All choices must be of the same type. " |
342 | 347 | f"Got {choices} with types {[type(c) for c in choices]}"
|
343 | 348 | )
|
344 | 349 | kwargs[name]["type"] = choice_type
|
| 350 | + elif can_be_type(field_type, tuple): |
| 351 | + if is_type_in_union(field_type, tuple): |
| 352 | + field_type = get_type_from_union(field_type, tuple) |
| 353 | + dtypes = get_args(field_type) |
| 354 | + dtype = dtypes[0] |
| 355 | + assert all( |
| 356 | + d is dtype for d in dtypes if d is not Ellipsis |
| 357 | + ), ("All non-Ellipsis tuple elements must be of the same " |
| 358 | + f"type. Got {dtypes}.") |
| 359 | + kwargs[name]["type"] = dtype |
| 360 | + kwargs[name]["nargs"] = "+" |
345 | 361 | elif can_be_type(field_type, int):
|
346 | 362 | kwargs[name]["type"] = optional_int if optional else int
|
347 | 363 | elif can_be_type(field_type, float):
|
@@ -685,70 +701,49 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
685 | 701 | 'inputs.')
|
686 | 702 |
|
687 | 703 | # LoRA related configs
|
688 |
| - parser.add_argument('--enable-lora', |
689 |
| - action='store_true', |
690 |
| - help='If True, enable handling of LoRA adapters.') |
691 |
| - parser.add_argument('--enable-lora-bias', |
692 |
| - action='store_true', |
693 |
| - help='If True, enable bias for LoRA adapters.') |
694 |
| - parser.add_argument('--max-loras', |
695 |
| - type=int, |
696 |
| - default=EngineArgs.max_loras, |
697 |
| - help='Max number of LoRAs in a single batch.') |
698 |
| - parser.add_argument('--max-lora-rank', |
699 |
| - type=int, |
700 |
| - default=EngineArgs.max_lora_rank, |
701 |
| - help='Max LoRA rank.') |
702 |
| - parser.add_argument( |
703 |
| - '--lora-extra-vocab-size', |
704 |
| - type=int, |
705 |
| - default=EngineArgs.lora_extra_vocab_size, |
706 |
| - help=('Maximum size of extra vocabulary that can be ' |
707 |
| - 'present in a LoRA adapter (added to the base ' |
708 |
| - 'model vocabulary).')) |
709 |
| - parser.add_argument( |
| 704 | + lora_kwargs = get_kwargs(LoRAConfig) |
| 705 | + lora_group = parser.add_argument_group( |
| 706 | + title="LoRAConfig", |
| 707 | + description=LoRAConfig.__doc__, |
| 708 | + ) |
| 709 | + lora_group.add_argument( |
| 710 | + '--enable-lora', |
| 711 | + action=argparse.BooleanOptionalAction, |
| 712 | + help='If True, enable handling of LoRA adapters.') |
| 713 | + lora_group.add_argument('--enable-lora-bias', |
| 714 | + **lora_kwargs["bias_enabled"]) |
| 715 | + lora_group.add_argument('--max-loras', **lora_kwargs["max_loras"]) |
| 716 | + lora_group.add_argument('--max-lora-rank', |
| 717 | + **lora_kwargs["max_lora_rank"]) |
| 718 | + lora_group.add_argument('--lora-extra-vocab-size', |
| 719 | + **lora_kwargs["lora_extra_vocab_size"]) |
| 720 | + lora_group.add_argument( |
710 | 721 | '--lora-dtype',
|
711 |
| - type=str, |
712 |
| - default=EngineArgs.lora_dtype, |
713 |
| - choices=['auto', 'float16', 'bfloat16'], |
714 |
| - help=('Data type for LoRA. If auto, will default to ' |
715 |
| - 'base model dtype.')) |
716 |
| - parser.add_argument( |
717 |
| - '--long-lora-scaling-factors', |
718 |
| - type=optional_str, |
719 |
| - default=EngineArgs.long_lora_scaling_factors, |
720 |
| - help=('Specify multiple scaling factors (which can ' |
721 |
| - 'be different from base model scaling factor ' |
722 |
| - '- see eg. Long LoRA) to allow for multiple ' |
723 |
| - 'LoRA adapters trained with those scaling ' |
724 |
| - 'factors to be used at the same time. If not ' |
725 |
| - 'specified, only adapters trained with the ' |
726 |
| - 'base model scaling factor are allowed.')) |
727 |
| - parser.add_argument( |
728 |
| - '--max-cpu-loras', |
729 |
| - type=int, |
730 |
| - default=EngineArgs.max_cpu_loras, |
731 |
| - help=('Maximum number of LoRAs to store in CPU memory. ' |
732 |
| - 'Must be >= than max_loras.')) |
733 |
| - parser.add_argument( |
734 |
| - '--fully-sharded-loras', |
735 |
| - action='store_true', |
736 |
| - help=('By default, only half of the LoRA computation is ' |
737 |
| - 'sharded with tensor parallelism. ' |
738 |
| - 'Enabling this will use the fully sharded layers. ' |
739 |
| - 'At high sequence length, max rank or ' |
740 |
| - 'tensor parallel size, this is likely faster.')) |
741 |
| - parser.add_argument('--enable-prompt-adapter', |
742 |
| - action='store_true', |
743 |
| - help='If True, enable handling of PromptAdapters.') |
744 |
| - parser.add_argument('--max-prompt-adapters', |
745 |
| - type=int, |
746 |
| - default=EngineArgs.max_prompt_adapters, |
747 |
| - help='Max number of PromptAdapters in a batch.') |
748 |
| - parser.add_argument('--max-prompt-adapter-token', |
749 |
| - type=int, |
750 |
| - default=EngineArgs.max_prompt_adapter_token, |
751 |
| - help='Max number of PromptAdapters tokens') |
| 722 | + **lora_kwargs["lora_dtype"], |
| 723 | + ) |
| 724 | + lora_group.add_argument('--long-lora-scaling-factors', |
| 725 | + **lora_kwargs["long_lora_scaling_factors"]) |
| 726 | + lora_group.add_argument('--max-cpu-loras', |
| 727 | + **lora_kwargs["max_cpu_loras"]) |
| 728 | + lora_group.add_argument('--fully-sharded-loras', |
| 729 | + **lora_kwargs["fully_sharded_loras"]) |
| 730 | + |
| 731 | + # PromptAdapter related configs |
| 732 | + prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig) |
| 733 | + prompt_adapter_group = parser.add_argument_group( |
| 734 | + title="PromptAdapterConfig", |
| 735 | + description=PromptAdapterConfig.__doc__, |
| 736 | + ) |
| 737 | + prompt_adapter_group.add_argument( |
| 738 | + '--enable-prompt-adapter', |
| 739 | + action=argparse.BooleanOptionalAction, |
| 740 | + help='If True, enable handling of PromptAdapters.') |
| 741 | + prompt_adapter_group.add_argument( |
| 742 | + '--max-prompt-adapters', |
| 743 | + **prompt_adapter_kwargs["max_prompt_adapters"]) |
| 744 | + prompt_adapter_group.add_argument( |
| 745 | + '--max-prompt-adapter-token', |
| 746 | + **prompt_adapter_kwargs["max_prompt_adapter_token"]) |
752 | 747 |
|
753 | 748 | # Device arguments
|
754 | 749 | device_kwargs = get_kwargs(DeviceConfig)
|
|
0 commit comments