Skip to content

Commit 2ed790a

Browse files
liuzijing2014markmcnjhill
authored andcommitted
[V1][Metrics] Allow V1 AsyncLLM to use custom logger (vllm-project#14661)
Signed-off-by: Zijing Liu <[email protected]> Signed-off-by: Mark McLoughlin <[email protected]> Signed-off-by: Nick Hill <[email protected]> Co-authored-by: Mark McLoughlin <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent 34c40e4 commit 2ed790a

File tree

4 files changed

+118
-30
lines changed

4 files changed

+118
-30
lines changed

tests/v1/engine/test_async_llm.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@
33
import asyncio
44
from contextlib import ExitStack
55
from typing import Optional
6+
from unittest.mock import MagicMock
67

78
import pytest
89

910
from vllm import SamplingParams
1011
from vllm.assets.image import ImageAsset
12+
from vllm.config import VllmConfig
1113
from vllm.engine.arg_utils import AsyncEngineArgs
1214
from vllm.inputs import PromptType
1315
from vllm.platforms import current_platform
1416
from vllm.sampling_params import RequestOutputKind
1517
from vllm.v1.engine.async_llm import AsyncLLM
18+
from vllm.v1.metrics.loggers import LoggingStatLogger
1619

1720
if not current_platform.is_cuda():
1821
pytest.skip(reason="V1 currently only supported on CUDA.",
@@ -216,3 +219,33 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int,
216219
# Assert only the last output has the finished flag set
217220
assert all(not out.finished for out in outputs[:-1])
218221
assert outputs[-1].finished
222+
223+
224+
class MockLoggingStatLogger(LoggingStatLogger):
225+
226+
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
227+
super().__init__(vllm_config, engine_index)
228+
self.log = MagicMock()
229+
230+
231+
@pytest.mark.asyncio
232+
async def test_customize_loggers(monkeypatch):
233+
"""Test that we can customize the loggers.
234+
If a customized logger is provided at the init, it should
235+
be used directly.
236+
"""
237+
238+
with monkeypatch.context() as m, ExitStack() as after:
239+
m.setenv("VLLM_USE_V1", "1")
240+
241+
engine = AsyncLLM.from_engine_args(
242+
TEXT_ENGINE_ARGS,
243+
stat_loggers=[MockLoggingStatLogger],
244+
)
245+
after.callback(engine.shutdown)
246+
247+
await engine.do_log_stats()
248+
249+
assert len(engine.stat_loggers) == 1
250+
assert len(engine.stat_loggers[0]) == 1
251+
engine.stat_loggers[0][0].log.assert_called_once()

vllm/v1/engine/async_llm.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import asyncio
3-
import logging
43
from collections.abc import AsyncGenerator, Mapping
54
from copy import copy
65
from typing import Optional, Union
@@ -33,8 +32,8 @@
3332
from vllm.v1.engine.parallel_sampling import ParentRequest
3433
from vllm.v1.engine.processor import Processor
3534
from vllm.v1.executor.abstract import Executor
36-
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
37-
StatLoggerBase)
35+
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
36+
setup_default_loggers)
3837
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
3938

4039
logger = init_logger(__name__)
@@ -52,7 +51,28 @@ def __init__(
5251
use_cached_outputs: bool = False,
5352
log_requests: bool = True,
5453
start_engine_loop: bool = True,
54+
stat_loggers: Optional[list[StatLoggerFactory]] = None,
5555
) -> None:
56+
"""
57+
Create an AsyncLLM.
58+
59+
Args:
60+
vllm_config: global configuration.
61+
executor_class: an Executor impl, e.g. MultiprocExecutor.
62+
log_stats: Whether to log stats.
63+
usage_context: Usage context of the LLM.
64+
mm_registry: Multi-modal registry.
65+
use_cached_outputs: Whether to use cached outputs.
66+
log_requests: Whether to log requests.
67+
start_engine_loop: Whether to start the engine loop.
68+
stat_loggers: customized stat loggers for the engine.
69+
If not provided, default stat loggers will be used.
70+
PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
71+
IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.
72+
73+
Returns:
74+
None
75+
"""
5676
if not envs.VLLM_USE_V1:
5777
raise ValueError(
5878
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
@@ -66,15 +86,12 @@ def __init__(
6686
self.log_stats = log_stats
6787

6888
# Set up stat loggers; independent set for each DP rank.
69-
self.stat_loggers: list[list[StatLoggerBase]] = []
70-
if self.log_stats:
71-
for i in range(vllm_config.parallel_config.data_parallel_size):
72-
loggers: list[StatLoggerBase] = []
73-
if logger.isEnabledFor(logging.INFO):
74-
loggers.append(LoggingStatLogger(engine_index=i))
75-
loggers.append(
76-
PrometheusStatLogger(vllm_config, engine_index=i))
77-
self.stat_loggers.append(loggers)
89+
self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
90+
vllm_config=vllm_config,
91+
log_stats=self.log_stats,
92+
engine_num=vllm_config.parallel_config.data_parallel_size,
93+
custom_stat_loggers=stat_loggers,
94+
)
7895

7996
# Tokenizer (+ ensure liveness if running in another process).
8097
self.tokenizer = init_tokenizer_from_configs(
@@ -118,7 +135,7 @@ def from_vllm_config(
118135
vllm_config: VllmConfig,
119136
start_engine_loop: bool = True,
120137
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
121-
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
138+
stat_loggers: Optional[list[StatLoggerFactory]] = None,
122139
disable_log_requests: bool = False,
123140
disable_log_stats: bool = False,
124141
) -> "AsyncLLM":
@@ -129,17 +146,12 @@ def from_vllm_config(
129146
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
130147
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
131148

132-
# FIXME(rob): refactor VllmConfig to include the StatLoggers
133-
# include StatLogger in the Oracle decision.
134-
if stat_loggers is not None:
135-
raise ValueError("Custom StatLoggers are not yet supported on V1. "
136-
"Explicitly set VLLM_USE_V1=0 to disable V1.")
137-
138149
# Create the LLMEngine.
139150
return cls(
140151
vllm_config=vllm_config,
141152
executor_class=Executor.get_class(vllm_config),
142153
start_engine_loop=start_engine_loop,
154+
stat_loggers=stat_loggers,
143155
log_requests=not disable_log_requests,
144156
log_stats=not disable_log_stats,
145157
usage_context=usage_context,
@@ -151,6 +163,7 @@ def from_engine_args(
151163
engine_args: AsyncEngineArgs,
152164
start_engine_loop: bool = True,
153165
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
166+
stat_loggers: Optional[list[StatLoggerFactory]] = None,
154167
) -> "AsyncLLM":
155168
"""Create an AsyncLLM from the EngineArgs."""
156169

@@ -166,6 +179,7 @@ def from_engine_args(
166179
log_stats=not engine_args.disable_log_stats,
167180
start_engine_loop=start_engine_loop,
168181
usage_context=usage_context,
182+
stat_loggers=stat_loggers,
169183
)
170184

171185
def __del__(self):

vllm/v1/engine/llm_engine.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from vllm.config import ParallelConfig, VllmConfig
1111
from vllm.distributed import stateless_destroy_torch_distributed_process_group
1212
from vllm.engine.arg_utils import EngineArgs
13-
from vllm.engine.metrics_types import StatLoggerBase
1413
from vllm.inputs import PromptType
1514
from vllm.logger import init_logger
1615
from vllm.lora.request import LoRARequest
@@ -28,6 +27,7 @@
2827
from vllm.v1.engine.parallel_sampling import ParentRequest
2928
from vllm.v1.engine.processor import Processor
3029
from vllm.v1.executor.abstract import Executor
30+
from vllm.v1.metrics.loggers import StatLoggerFactory
3131

3232
logger = init_logger(__name__)
3333

@@ -43,7 +43,7 @@ def __init__(
4343
executor_class: type[Executor],
4444
log_stats: bool,
4545
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
46-
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
46+
stat_loggers: Optional[list[StatLoggerFactory]] = None,
4747
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
4848
use_cached_outputs: bool = False,
4949
multiprocess_mode: bool = False,
@@ -55,6 +55,11 @@ def __init__(
5555
"LLMEngine.from_vllm_config(...) or explicitly set "
5656
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
5757

58+
if stat_loggers is not None:
59+
raise NotImplementedError(
60+
"Passing StatLoggers to LLMEngine in V1 is not yet supported. "
61+
"Set VLLM_USE_V1=0 and file and issue on Github.")
62+
5863
self.vllm_config = vllm_config
5964
self.model_config = vllm_config.model_config
6065
self.cache_config = vllm_config.cache_config
@@ -101,14 +106,9 @@ def from_vllm_config(
101106
cls,
102107
vllm_config: VllmConfig,
103108
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
104-
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
109+
stat_loggers: Optional[list[StatLoggerFactory]] = None,
105110
disable_log_stats: bool = False,
106111
) -> "LLMEngine":
107-
if stat_loggers is not None:
108-
raise NotImplementedError(
109-
"Passing StatLoggers to V1 is not yet supported. "
110-
"Set VLLM_USE_V1=0 and file and issue on Github.")
111-
112112
return cls(vllm_config=vllm_config,
113113
executor_class=Executor.get_class(vllm_config),
114114
log_stats=(not disable_log_stats),
@@ -121,7 +121,7 @@ def from_engine_args(
121121
cls,
122122
engine_args: EngineArgs,
123123
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
124-
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
124+
stat_loggers: Optional[list[StatLoggerFactory]] = None,
125125
enable_multiprocessing: bool = False,
126126
) -> "LLMEngine":
127127
"""Creates an LLM engine from the engine arguments."""

vllm/v1/metrics/loggers.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import logging
34
import time
45
from abc import ABC, abstractmethod
5-
from typing import Optional
6+
from typing import Callable, Optional
67

78
import numpy as np
89
import prometheus_client
@@ -18,8 +19,20 @@
1819

1920
_LOCAL_LOGGING_INTERVAL_SEC = 5.0
2021

22+
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
23+
2124

2225
class StatLoggerBase(ABC):
26+
"""Interface for logging metrics.
27+
28+
API users may define custom loggers that implement this interface.
29+
However, note that the `SchedulerStats` and `IterationStats` classes
30+
are not considered stable interfaces and may change in future versions.
31+
"""
32+
33+
@abstractmethod
34+
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
35+
...
2336

2437
@abstractmethod
2538
def record(self, scheduler_stats: SchedulerStats,
@@ -32,7 +45,7 @@ def log(self): # noqa
3245

3346
class LoggingStatLogger(StatLoggerBase):
3447

35-
def __init__(self, engine_index: int = 0):
48+
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
3649
self.engine_index = engine_index
3750
self._reset(time.monotonic())
3851
self.last_scheduler_stats = SchedulerStats()
@@ -462,3 +475,31 @@ def build_cudagraph_buckets(vllm_config: VllmConfig) -> list[int]:
462475
return buckets
463476
else:
464477
return [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]
478+
479+
480+
def setup_default_loggers(
481+
vllm_config: VllmConfig,
482+
log_stats: bool,
483+
engine_num: int,
484+
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
485+
) -> list[list[StatLoggerBase]]:
486+
"""Setup logging and prometheus metrics."""
487+
if not log_stats:
488+
return []
489+
490+
factories: list[StatLoggerFactory]
491+
if custom_stat_loggers is not None:
492+
factories = custom_stat_loggers
493+
else:
494+
factories = [PrometheusStatLogger]
495+
if logger.isEnabledFor(logging.INFO):
496+
factories.append(LoggingStatLogger)
497+
498+
stat_loggers: list[list[StatLoggerBase]] = []
499+
for i in range(engine_num):
500+
per_engine_stat_loggers: list[StatLoggerBase] = []
501+
for logger_factory in factories:
502+
per_engine_stat_loggers.append(logger_factory(vllm_config, i))
503+
stat_loggers.append(per_engine_stat_loggers)
504+
505+
return stat_loggers

0 commit comments

Comments
 (0)