1
1
import itertools
2
- import json
3
2
4
3
import pytest
5
4
from huggingface_hub import (
6
5
InferenceClient ,
7
- TextGenerationOutput ,
8
- TextGenerationOutputDetails ,
9
- TextGenerationStreamOutput ,
10
- TextGenerationOutputToken ,
11
- TextGenerationStreamDetails ,
12
6
)
13
7
from huggingface_hub .errors import OverloadedError
14
8
@@ -35,19 +29,15 @@ def test_nonstreaming_chat_completion(
35
29
client = InferenceClient ("some-model" )
36
30
if details_arg :
37
31
client .post = mock .Mock (
38
- return_value = json .dumps (
39
- [
40
- TextGenerationOutput (
41
- generated_text = "the model response" ,
42
- details = TextGenerationOutputDetails (
43
- finish_reason = "TextGenerationFinishReason" ,
44
- generated_tokens = 10 ,
45
- prefill = [],
46
- tokens = [], # not needed for integration
47
- ),
48
- )
49
- ]
50
- ).encode ("utf-8" )
32
+ return_value = b"""[{
33
+ "generated_text": "the model response",
34
+ "details": {
35
+ "finish_reason": "length",
36
+ "generated_tokens": 10,
37
+ "prefill": [],
38
+ "tokens": []
39
+ }
40
+ }]"""
51
41
)
52
42
else :
53
43
client .post = mock .Mock (
@@ -96,27 +86,13 @@ def test_streaming_chat_completion(
96
86
client = InferenceClient ("some-model" )
97
87
client .post = mock .Mock (
98
88
return_value = [
99
- b"data:"
100
- + json .dumps (
101
- TextGenerationStreamOutput (
102
- token = TextGenerationOutputToken (
103
- id = 1 , special = False , text = "the model "
104
- ),
105
- ),
106
- ).encode ("utf-8" ),
107
- b"data:"
108
- + json .dumps (
109
- TextGenerationStreamOutput (
110
- token = TextGenerationOutputToken (
111
- id = 2 , special = False , text = "response"
112
- ),
113
- details = TextGenerationStreamDetails (
114
- finish_reason = "length" ,
115
- generated_tokens = 10 ,
116
- seed = 0 ,
117
- ),
118
- )
119
- ).encode ("utf-8" ),
89
+ b"""data:{
90
+ "token":{"id":1, "special": false, "text": "the model "}
91
+ }""" ,
92
+ b"""data:{
93
+ "token":{"id":2, "special": false, "text": "response"},
94
+ "details":{"finish_reason": "length", "generated_tokens": 10, "seed": 0}
95
+ }""" ,
120
96
]
121
97
)
122
98
with start_transaction (name = "huggingface_hub tx" ):
0 commit comments