Skip to content

Commit 5d41d27

Browse files
gmarinho2maxdebayser
authored and
Mu Huai
committed
Truncation control for embedding models (vllm-project#14776)
Signed-off-by: Gabriel Marinho <[email protected]> Signed-off-by: Max de Bayser <[email protected]> Co-authored-by: Max de Bayser <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent 3f748f9 commit 5d41d27

21 files changed

+332
-70
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from typing import Any
3+
4+
import openai
5+
import pytest
6+
import pytest_asyncio
7+
8+
from tests.utils import RemoteOpenAIServer
9+
10+
MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2"
11+
max_model_len = 128
12+
13+
input = """Immerse yourself in the enchanting chronicle of calculus, a
14+
mathematical domain that has radically transformed our comprehension of
15+
change and motion. Despite its roots in ancient civilizations, the
16+
formal birth of calculus predominantly occurred in the 17th century,
17+
primarily under the influential guidance of Sir Isaac Newton and Gottfried
18+
Wilhelm Leibniz. The earliest traces of calculus concepts are found in
19+
ancient Greek mathematics,most notably in the works of Eudoxus and
20+
Archimedes, around 300 BCE. They utilized the 'method of exhaustion'—a
21+
technique for computing areas and volumes through the use of finite sums.
22+
This methodology laid crucial foundational work for integral calculus.
23+
In the 17th century, both Newton and Leibniz independently pioneered
24+
calculus, each contributing unique perspectives that would shape this new
25+
field."""
26+
27+
28+
@pytest.fixture(scope="module")
29+
def server():
30+
args = [
31+
"--task",
32+
"embed",
33+
"--dtype",
34+
"bfloat16",
35+
"--enforce-eager",
36+
"--max-model-len",
37+
str(max_model_len),
38+
]
39+
40+
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
41+
yield remote_server
42+
43+
44+
@pytest_asyncio.fixture
45+
async def client(server):
46+
async with server.get_async_client() as async_client:
47+
yield async_client
48+
49+
50+
@pytest.mark.asyncio
51+
async def test_smaller_truncation_size(client: openai.AsyncOpenAI):
52+
truncation_size = 10
53+
kwargs: dict[str, Any] = {
54+
"model": MODEL_NAME,
55+
"input": input,
56+
"truncate_prompt_tokens": truncation_size
57+
}
58+
59+
response = await client.post(path="embeddings",
60+
cast_to=object,
61+
body={**kwargs})
62+
63+
assert response["usage"]["prompt_tokens"] == truncation_size
64+
65+
66+
@pytest.mark.asyncio
67+
async def test_bigger_truncation_size(client: openai.AsyncOpenAI):
68+
truncation_size = max_model_len + 1
69+
kwargs: dict[str, Any] = {
70+
"model": MODEL_NAME,
71+
"input": input,
72+
"truncate_prompt_tokens": truncation_size
73+
}
74+
75+
with pytest.raises(openai.BadRequestError) as err:
76+
err = await client.post(path="embeddings",
77+
cast_to=object,
78+
body={**kwargs})
79+
80+
assert str(err) == f"""openai.BadRequestError:
81+
Error code: 400 - {{'object': 'error',
82+
'message': 'truncate_prompt_tokens value
83+
({truncation_size})
84+
is greater than max_model_len ({max_model_len}).
85+
Please, select a smaller truncation size.',
86+
'type': 'BadRequestError',
87+
'param': None, 'code': 400}}"""
88+
89+
90+
@pytest.mark.asyncio
91+
async def test_max_truncation_size(client: openai.AsyncOpenAI):
92+
truncation_size = -1
93+
kwargs: dict[str, Any] = {
94+
"model": MODEL_NAME,
95+
"input": input,
96+
"truncate_prompt_tokens": truncation_size
97+
}
98+
99+
response = await client.post(path="embeddings",
100+
cast_to=object,
101+
body={**kwargs})
102+
103+
assert response["usage"]["prompt_tokens"] == max_model_len
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import pytest
3+
4+
MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2"
5+
max_model_len = 128
6+
7+
input_str = """Immerse yourself in the enchanting chronicle of calculus, a
8+
mathematical domain that has radically transformed our comprehension of
9+
change and motion. Despite its roots in ancient civilizations, the
10+
formal birth of calculus predominantly occurred in the 17th century,
11+
primarily under the influential guidance of Sir Isaac Newton and Gottfried
12+
Wilhelm Leibniz. The earliest traces of calculus concepts are found in
13+
ancient Greek mathematics,most notably in the works of Eudoxus and
14+
Archimedes, around 300 BCE. They utilized the 'method of exhaustion'—a
15+
technique for computing areas and volumes through the use of finite sums.
16+
This methodology laid crucial foundational work for integral calculus.
17+
In the 17th century, both Newton and Leibniz independently pioneered
18+
calculus, each contributing unique perspectives that would shape this new
19+
field."""
20+
21+
22+
def test_smaller_truncation_size(vllm_runner,
23+
model_name=MODEL_NAME,
24+
input_str=input_str):
25+
26+
truncate_prompt_tokens = 10
27+
28+
with vllm_runner(model_name, task="embed",
29+
max_model_len=max_model_len) as vllm_model:
30+
vllm_output = vllm_model.model.encode(
31+
input_str, truncate_prompt_tokens=truncate_prompt_tokens)
32+
33+
prompt_tokens = vllm_output[0].prompt_token_ids
34+
35+
assert len(prompt_tokens) == truncate_prompt_tokens
36+
37+
38+
def test_max_truncation_size(vllm_runner,
39+
model_name=MODEL_NAME,
40+
input_str=input_str):
41+
truncate_prompt_tokens = -1
42+
43+
with vllm_runner(model_name, task="embed",
44+
max_model_len=max_model_len) as vllm_model:
45+
vllm_output = vllm_model.model.encode(
46+
input_str, truncate_prompt_tokens=truncate_prompt_tokens)
47+
48+
prompt_tokens = vllm_output[0].prompt_token_ids
49+
50+
assert len(prompt_tokens) == max_model_len
51+
52+
53+
def test_bigger_truncation_size(vllm_runner,
54+
model_name=MODEL_NAME,
55+
input_str=input_str):
56+
57+
truncate_prompt_tokens = max_model_len + 1
58+
59+
with pytest.raises(ValueError), vllm_runner(
60+
model_name, task="embed",
61+
max_model_len=max_model_len) as vllm_model:
62+
63+
llm_output = vllm_model.model.encode(
64+
input_str, truncate_prompt_tokens=truncate_prompt_tokens)
65+
66+
assert llm_output == f"""truncate_prompt_tokens value
67+
({truncate_prompt_tokens}) is greater than
68+
max_model_len ({max_model_len}). Please, select
69+
a smaller truncation size."""

vllm/engine/llm_engine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,7 @@ def add_request(
645645
params: Union[SamplingParams, PoolingParams],
646646
arrival_time: Optional[float] = None,
647647
lora_request: Optional[LoRARequest] = None,
648+
tokenization_kwargs: Optional[dict[str, Any]] = None,
648649
trace_headers: Optional[Mapping[str, str]] = None,
649650
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
650651
priority: int = 0,
@@ -678,6 +679,7 @@ def add_request(
678679
params: Optional[Union[SamplingParams, PoolingParams]] = None,
679680
arrival_time: Optional[float] = None,
680681
lora_request: Optional[LoRARequest] = None,
682+
tokenization_kwargs: Optional[dict[str, Any]] = None,
681683
trace_headers: Optional[Mapping[str, str]] = None,
682684
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
683685
priority: int = 0,
@@ -758,6 +760,7 @@ def add_request(
758760

759761
processed_inputs = self.input_preprocessor.preprocess(
760762
prompt,
763+
tokenization_kwargs=tokenization_kwargs,
761764
lora_request=lora_request,
762765
prompt_adapter_request=prompt_adapter_request,
763766
)

vllm/engine/protocol.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import asyncio
44
from abc import ABC, abstractmethod
5-
from typing import AsyncGenerator, List, Mapping, Optional
5+
from typing import AsyncGenerator, Mapping, Optional
66

77
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
88
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
@@ -256,7 +256,7 @@ async def is_tracing_enabled(self) -> bool:
256256
async def do_log_stats(
257257
self,
258258
scheduler_outputs: Optional[SchedulerOutputs] = None,
259-
model_output: Optional[List[SamplerOutput]] = None,
259+
model_output: Optional[list[SamplerOutput]] = None,
260260
) -> None:
261261
...
262262

vllm/entrypoints/llm.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
resolve_chat_template_content_format)
2626
from vllm.entrypoints.score_utils import (_cosine_similarity,
2727
_validate_score_input_lens)
28+
from vllm.entrypoints.utils import _validate_truncation_size
2829
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
2930
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
3031
from vllm.logger import init_logger
@@ -793,6 +794,7 @@ def encode(
793794
pooling_params: Optional[Union[PoolingParams,
794795
Sequence[PoolingParams]]] = None,
795796
*,
797+
truncate_prompt_tokens: Optional[int] = None,
796798
use_tqdm: bool = True,
797799
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
798800
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -807,6 +809,7 @@ def encode(
807809
pooling_params: Optional[Union[PoolingParams,
808810
Sequence[PoolingParams]]] = None,
809811
prompt_token_ids: Optional[list[int]] = None,
812+
truncate_prompt_tokens: Optional[int] = None,
810813
use_tqdm: bool = True,
811814
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
812815
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -821,6 +824,7 @@ def encode(
821824
pooling_params: Optional[Union[PoolingParams,
822825
Sequence[PoolingParams]]] = None,
823826
prompt_token_ids: Optional[list[list[int]]] = None,
827+
truncate_prompt_tokens: Optional[int] = None,
824828
use_tqdm: bool = True,
825829
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
826830
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -836,6 +840,7 @@ def encode(
836840
Sequence[PoolingParams]]] = None,
837841
*,
838842
prompt_token_ids: list[int],
843+
truncate_prompt_tokens: Optional[int] = None,
839844
use_tqdm: bool = True,
840845
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
841846
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -851,6 +856,7 @@ def encode(
851856
Sequence[PoolingParams]]] = None,
852857
*,
853858
prompt_token_ids: list[list[int]],
859+
truncate_prompt_tokens: Optional[int] = None,
854860
use_tqdm: bool = True,
855861
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
856862
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -864,6 +870,7 @@ def encode(
864870
prompts: None,
865871
pooling_params: None,
866872
prompt_token_ids: Union[list[int], list[list[int]]],
873+
truncate_prompt_tokens: Optional[int] = None,
867874
use_tqdm: bool = True,
868875
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
869876
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -882,6 +889,7 @@ def encode(
882889
pooling_params: Optional[Union[PoolingParams,
883890
Sequence[PoolingParams]]] = None,
884891
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
892+
truncate_prompt_tokens: Optional[int] = None,
885893
use_tqdm: bool = True,
886894
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
887895
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -946,10 +954,15 @@ def encode(
946954
for pooling_param in pooling_params:
947955
pooling_param.verify(self.llm_engine.model_config)
948956

957+
tokenization_kwargs: dict[str, Any] = {}
958+
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
959+
truncate_prompt_tokens, tokenization_kwargs)
960+
949961
self._validate_and_add_requests(
950962
prompts=parsed_prompts,
951963
params=pooling_params,
952964
lora_request=lora_request,
965+
tokenization_kwargs=tokenization_kwargs,
953966
prompt_adapter_request=prompt_adapter_request,
954967
)
955968

@@ -962,6 +975,7 @@ def embed(
962975
prompts: Union[PromptType, Sequence[PromptType]],
963976
/,
964977
*,
978+
truncate_prompt_tokens: Optional[int] = None,
965979
use_tqdm: bool = True,
966980
pooling_params: Optional[Union[PoolingParams,
967981
Sequence[PoolingParams]]] = None,
@@ -995,6 +1009,7 @@ def embed(
9951009
"Embedding API is only enabled for `--task embed`")
9961010

9971011
items = self.encode(prompts,
1012+
truncate_prompt_tokens=truncate_prompt_tokens,
9981013
use_tqdm=use_tqdm,
9991014
pooling_params=pooling_params,
10001015
lora_request=lora_request,
@@ -1055,6 +1070,7 @@ def _embedding_score(
10551070

10561071
encoded_output: list[PoolingRequestOutput] = self.encode(
10571072
text_1 + text_2,
1073+
truncate_prompt_tokens=truncate_prompt_tokens,
10581074
use_tqdm=use_tqdm,
10591075
lora_request=lora_request,
10601076
prompt_adapter_request=prompt_adapter_request)
@@ -1098,9 +1114,8 @@ def _cross_encoding_score(
10981114
pooling_params = PoolingParams()
10991115

11001116
tokenization_kwargs: dict[str, Any] = {}
1101-
if truncate_prompt_tokens is not None:
1102-
tokenization_kwargs["truncation"] = True
1103-
tokenization_kwargs["max_length"] = truncate_prompt_tokens
1117+
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
1118+
truncate_prompt_tokens, tokenization_kwargs)
11041119

11051120
parsed_prompts = []
11061121

@@ -1323,6 +1338,7 @@ def _validate_and_add_requests(
13231338
Sequence[PoolingParams]],
13241339
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
13251340
prompt_adapter_request: Optional[PromptAdapterRequest],
1341+
tokenization_kwargs: Optional[dict[str, Any]] = None,
13261342
guided_options: Optional[GuidedDecodingRequest] = None,
13271343
priority: Optional[list[int]] = None,
13281344
) -> None:
@@ -1359,6 +1375,7 @@ def _validate_and_add_requests(
13591375
self._add_request(
13601376
prompt,
13611377
params[i] if isinstance(params, Sequence) else params,
1378+
tokenization_kwargs=tokenization_kwargs,
13621379
lora_request=lora_request[i] if isinstance(
13631380
lora_request, Sequence) else lora_request,
13641381
prompt_adapter_request=prompt_adapter_request,
@@ -1369,6 +1386,7 @@ def _add_request(
13691386
self,
13701387
prompt: PromptType,
13711388
params: Union[SamplingParams, PoolingParams],
1389+
tokenization_kwargs: Optional[dict[str, Any]] = None,
13721390
lora_request: Optional[LoRARequest] = None,
13731391
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
13741392
priority: int = 0,
@@ -1379,6 +1397,7 @@ def _add_request(
13791397
prompt,
13801398
params,
13811399
lora_request=lora_request,
1400+
tokenization_kwargs=tokenization_kwargs,
13821401
prompt_adapter_request=prompt_adapter_request,
13831402
priority=priority,
13841403
)

vllm/entrypoints/openai/protocol.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,7 +1014,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
10141014
encoding_format: Literal["float", "base64"] = "float"
10151015
dimensions: Optional[int] = None
10161016
user: Optional[str] = None
1017-
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
1017+
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
10181018

10191019
# doc: begin-embedding-pooling-params
10201020
additional_data: Optional[Any] = None
@@ -1049,7 +1049,7 @@ class EmbeddingChatRequest(OpenAIBaseModel):
10491049
encoding_format: Literal["float", "base64"] = "float"
10501050
dimensions: Optional[int] = None
10511051
user: Optional[str] = None
1052-
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
1052+
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
10531053

10541054
# doc: begin-chat-embedding-pooling-params
10551055
additional_data: Optional[Any] = None
@@ -1116,7 +1116,7 @@ class ScoreRequest(OpenAIBaseModel):
11161116
model: Optional[str] = None
11171117
text_1: Union[list[str], str]
11181118
text_2: Union[list[str], str]
1119-
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
1119+
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
11201120

11211121
# doc: begin-score-pooling-params
11221122
additional_data: Optional[Any] = None
@@ -1142,7 +1142,7 @@ class RerankRequest(OpenAIBaseModel):
11421142
query: str
11431143
documents: list[str]
11441144
top_n: int = Field(default_factory=lambda: 0)
1145-
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
1145+
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
11461146

11471147
# doc: begin-rerank-pooling-params
11481148
additional_data: Optional[Any] = None

0 commit comments

Comments
 (0)