Skip to content

Commit 41aa99b

Browse files
Huggingface Hub integration (#3033)
Adds integration for Huggingface Hub. --------- Co-authored-by: Anton Pirker <[email protected]>
1 parent eac253a commit 41aa99b

File tree

11 files changed

+364
-1
lines changed

11 files changed

+364
-1
lines changed

.github/workflows/test-integrations-data-processing.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ jobs:
7070
run: |
7171
set -x # print commands that are executed
7272
./scripts/runtox.sh "py${{ matrix.python-version }}-openai-latest" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch
73+
- name: Test huggingface_hub latest
74+
run: |
75+
set -x # print commands that are executed
76+
./scripts/runtox.sh "py${{ matrix.python-version }}-huggingface_hub-latest" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch
7377
- name: Test rq latest
7478
run: |
7579
set -x # print commands that are executed
@@ -134,6 +138,10 @@ jobs:
134138
run: |
135139
set -x # print commands that are executed
136140
./scripts/runtox.sh --exclude-latest "py${{ matrix.python-version }}-openai" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch
141+
- name: Test huggingface_hub pinned
142+
run: |
143+
set -x # print commands that are executed
144+
./scripts/runtox.sh --exclude-latest "py${{ matrix.python-version }}-huggingface_hub" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch
137145
- name: Test rq pinned
138146
run: |
139147
set -x # print commands that are executed

mypy.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ ignore_missing_imports = True
7373
ignore_missing_imports = True
7474
[mypy-openai.*]
7575
ignore_missing_imports = True
76+
[mypy-huggingface_hub.*]
77+
ignore_missing_imports = True
7678
[mypy-arq.*]
7779
ignore_missing_imports = True
7880
[mypy-grpc.*]

scripts/split-tox-gh-actions/split-tox-gh-actions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
"huey",
7474
"langchain",
7575
"openai",
76+
"huggingface_hub",
7677
"rq",
7778
],
7879
"Databases": [

sentry_sdk/consts.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,9 @@ class OP:
325325
MIDDLEWARE_STARLITE_SEND = "middleware.starlite.send"
326326
OPENAI_CHAT_COMPLETIONS_CREATE = "ai.chat_completions.create.openai"
327327
OPENAI_EMBEDDINGS_CREATE = "ai.embeddings.create.openai"
328+
HUGGINGFACE_HUB_CHAT_COMPLETIONS_CREATE = (
329+
"ai.chat_completions.create.huggingface_hub"
330+
)
328331
LANGCHAIN_PIPELINE = "ai.pipeline.langchain"
329332
LANGCHAIN_RUN = "ai.run.langchain"
330333
LANGCHAIN_TOOL = "ai.tool.langchain"

sentry_sdk/integrations/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def iter_default_integrations(with_auto_enabling_integrations):
8585
"sentry_sdk.integrations.graphene.GrapheneIntegration",
8686
"sentry_sdk.integrations.httpx.HttpxIntegration",
8787
"sentry_sdk.integrations.huey.HueyIntegration",
88+
"sentry_sdk.integrations.huggingface_hub.HuggingfaceHubIntegration",
8889
"sentry_sdk.integrations.langchain.LangchainIntegration",
8990
"sentry_sdk.integrations.loguru.LoguruIntegration",
9091
"sentry_sdk.integrations.openai.OpenAIIntegration",
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
from functools import wraps
2+
3+
from sentry_sdk import consts
4+
from sentry_sdk.ai.monitoring import record_token_usage
5+
from sentry_sdk.ai.utils import set_data_normalized
6+
from sentry_sdk.consts import SPANDATA
7+
8+
from typing import Any, Iterable, Callable
9+
10+
import sentry_sdk
11+
from sentry_sdk.scope import should_send_default_pii
12+
from sentry_sdk.integrations import DidNotEnable, Integration
13+
from sentry_sdk.utils import (
14+
capture_internal_exceptions,
15+
event_from_exception,
16+
ensure_integration_enabled,
17+
)
18+
19+
try:
20+
import huggingface_hub.inference._client
21+
22+
from huggingface_hub import ChatCompletionStreamOutput, TextGenerationOutput
23+
except ImportError:
24+
raise DidNotEnable("Huggingface not installed")
25+
26+
27+
class HuggingfaceHubIntegration(Integration):
28+
identifier = "huggingface_hub"
29+
30+
def __init__(self, include_prompts=True):
31+
# type: (HuggingfaceHubIntegration, bool) -> None
32+
self.include_prompts = include_prompts
33+
34+
@staticmethod
35+
def setup_once():
36+
# type: () -> None
37+
huggingface_hub.inference._client.InferenceClient.text_generation = (
38+
_wrap_text_generation(
39+
huggingface_hub.inference._client.InferenceClient.text_generation
40+
)
41+
)
42+
43+
44+
def _capture_exception(exc):
45+
# type: (Any) -> None
46+
event, hint = event_from_exception(
47+
exc,
48+
client_options=sentry_sdk.get_client().options,
49+
mechanism={"type": "huggingface_hub", "handled": False},
50+
)
51+
sentry_sdk.capture_event(event, hint=hint)
52+
53+
54+
def _wrap_text_generation(f):
55+
# type: (Callable[..., Any]) -> Callable[..., Any]
56+
@wraps(f)
57+
@ensure_integration_enabled(HuggingfaceHubIntegration, f)
58+
def new_text_generation(*args, **kwargs):
59+
# type: (*Any, **Any) -> Any
60+
if "prompt" in kwargs:
61+
prompt = kwargs["prompt"]
62+
elif len(args) >= 2:
63+
kwargs["prompt"] = args[1]
64+
prompt = kwargs["prompt"]
65+
args = (args[0],) + args[2:]
66+
else:
67+
# invalid call, let it return error
68+
return f(*args, **kwargs)
69+
70+
model = kwargs.get("model")
71+
streaming = kwargs.get("stream")
72+
73+
span = sentry_sdk.start_span(
74+
op=consts.OP.HUGGINGFACE_HUB_CHAT_COMPLETIONS_CREATE,
75+
description="Text Generation",
76+
)
77+
span.__enter__()
78+
try:
79+
res = f(*args, **kwargs)
80+
except Exception as e:
81+
_capture_exception(e)
82+
span.__exit__(None, None, None)
83+
raise e from None
84+
85+
integration = sentry_sdk.get_client().get_integration(HuggingfaceHubIntegration)
86+
87+
with capture_internal_exceptions():
88+
if should_send_default_pii() and integration.include_prompts:
89+
set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, prompt)
90+
91+
set_data_normalized(span, SPANDATA.AI_MODEL_ID, model)
92+
set_data_normalized(span, SPANDATA.AI_STREAMING, streaming)
93+
94+
if isinstance(res, str):
95+
if should_send_default_pii() and integration.include_prompts:
96+
set_data_normalized(
97+
span,
98+
"ai.responses",
99+
[res],
100+
)
101+
span.__exit__(None, None, None)
102+
return res
103+
104+
if isinstance(res, TextGenerationOutput):
105+
if should_send_default_pii() and integration.include_prompts:
106+
set_data_normalized(
107+
span,
108+
"ai.responses",
109+
[res.generated_text],
110+
)
111+
if res.details is not None and res.details.generated_tokens > 0:
112+
record_token_usage(span, total_tokens=res.details.generated_tokens)
113+
span.__exit__(None, None, None)
114+
return res
115+
116+
if not isinstance(res, Iterable):
117+
# we only know how to deal with strings and iterables, ignore
118+
set_data_normalized(span, "unknown_response", True)
119+
span.__exit__(None, None, None)
120+
return res
121+
122+
if kwargs.get("details", False):
123+
# res is Iterable[TextGenerationStreamOutput]
124+
def new_details_iterator():
125+
# type: () -> Iterable[ChatCompletionStreamOutput]
126+
with capture_internal_exceptions():
127+
tokens_used = 0
128+
data_buf: list[str] = []
129+
for x in res:
130+
if hasattr(x, "token") and hasattr(x.token, "text"):
131+
data_buf.append(x.token.text)
132+
if hasattr(x, "details") and hasattr(
133+
x.details, "generated_tokens"
134+
):
135+
tokens_used = x.details.generated_tokens
136+
yield x
137+
if (
138+
len(data_buf) > 0
139+
and should_send_default_pii()
140+
and integration.include_prompts
141+
):
142+
set_data_normalized(
143+
span, SPANDATA.AI_RESPONSES, "".join(data_buf)
144+
)
145+
if tokens_used > 0:
146+
record_token_usage(span, total_tokens=tokens_used)
147+
span.__exit__(None, None, None)
148+
149+
return new_details_iterator()
150+
else:
151+
# res is Iterable[str]
152+
153+
def new_iterator():
154+
# type: () -> Iterable[str]
155+
data_buf: list[str] = []
156+
with capture_internal_exceptions():
157+
for s in res:
158+
if isinstance(s, str):
159+
data_buf.append(s)
160+
yield s
161+
if (
162+
len(data_buf) > 0
163+
and should_send_default_pii()
164+
and integration.include_prompts
165+
):
166+
set_data_normalized(
167+
span, SPANDATA.AI_RESPONSES, "".join(data_buf)
168+
)
169+
span.__exit__(None, None, None)
170+
171+
return new_iterator()
172+
173+
return new_text_generation

sentry_sdk/integrations/langchain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def count_tokens(s):
6363

6464
# To avoid double collecting tokens, we do *not* measure
6565
# token counts for models for which we have an explicit integration
66-
NO_COLLECT_TOKEN_MODELS = ["openai-chat"]
66+
NO_COLLECT_TOKEN_MODELS = ["openai-chat"] # TODO add huggingface and anthropic
6767

6868

6969
class LangchainIntegration(Integration):

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def get_file_text(file_name):
6060
"grpcio": ["grpcio>=1.21.1"],
6161
"httpx": ["httpx>=0.16.0"],
6262
"huey": ["huey>=2"],
63+
"huggingface_hub": ["huggingface_hub>=0.22"],
6364
"langchain": ["langchain>=0.0.210"],
6465
"loguru": ["loguru>=0.5"],
6566
"openai": ["openai>=1.0.0", "tiktoken>=0.3.0"],
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import pytest
2+
3+
pytest.importorskip("huggingface_hub")

0 commit comments

Comments
 (0)