-
Notifications
You must be signed in to change notification settings - Fork 513
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds integration for Huggingface Hub. --------- Co-authored-by: Anton Pirker <[email protected]>
- Loading branch information
1 parent
eac253a
commit 41aa99b
Showing
11 changed files
with
364 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -73,6 +73,7 @@ | |
"huey", | ||
"langchain", | ||
"openai", | ||
"huggingface_hub", | ||
"rq", | ||
], | ||
"Databases": [ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
from functools import wraps | ||
|
||
from sentry_sdk import consts | ||
from sentry_sdk.ai.monitoring import record_token_usage | ||
from sentry_sdk.ai.utils import set_data_normalized | ||
from sentry_sdk.consts import SPANDATA | ||
|
||
from typing import Any, Iterable, Callable | ||
|
||
import sentry_sdk | ||
from sentry_sdk.scope import should_send_default_pii | ||
from sentry_sdk.integrations import DidNotEnable, Integration | ||
from sentry_sdk.utils import ( | ||
capture_internal_exceptions, | ||
event_from_exception, | ||
ensure_integration_enabled, | ||
) | ||
|
||
try: | ||
import huggingface_hub.inference._client | ||
|
||
from huggingface_hub import ChatCompletionStreamOutput, TextGenerationOutput | ||
except ImportError: | ||
raise DidNotEnable("Huggingface not installed") | ||
|
||
|
||
class HuggingfaceHubIntegration(Integration): | ||
identifier = "huggingface_hub" | ||
|
||
def __init__(self, include_prompts=True): | ||
# type: (HuggingfaceHubIntegration, bool) -> None | ||
self.include_prompts = include_prompts | ||
|
||
@staticmethod | ||
def setup_once(): | ||
# type: () -> None | ||
huggingface_hub.inference._client.InferenceClient.text_generation = ( | ||
_wrap_text_generation( | ||
huggingface_hub.inference._client.InferenceClient.text_generation | ||
) | ||
) | ||
|
||
|
||
def _capture_exception(exc): | ||
# type: (Any) -> None | ||
event, hint = event_from_exception( | ||
exc, | ||
client_options=sentry_sdk.get_client().options, | ||
mechanism={"type": "huggingface_hub", "handled": False}, | ||
) | ||
sentry_sdk.capture_event(event, hint=hint) | ||
|
||
|
||
def _wrap_text_generation(f): | ||
# type: (Callable[..., Any]) -> Callable[..., Any] | ||
@wraps(f) | ||
@ensure_integration_enabled(HuggingfaceHubIntegration, f) | ||
def new_text_generation(*args, **kwargs): | ||
# type: (*Any, **Any) -> Any | ||
if "prompt" in kwargs: | ||
prompt = kwargs["prompt"] | ||
elif len(args) >= 2: | ||
kwargs["prompt"] = args[1] | ||
prompt = kwargs["prompt"] | ||
args = (args[0],) + args[2:] | ||
else: | ||
# invalid call, let it return error | ||
return f(*args, **kwargs) | ||
|
||
model = kwargs.get("model") | ||
streaming = kwargs.get("stream") | ||
|
||
span = sentry_sdk.start_span( | ||
op=consts.OP.HUGGINGFACE_HUB_CHAT_COMPLETIONS_CREATE, | ||
description="Text Generation", | ||
) | ||
span.__enter__() | ||
try: | ||
res = f(*args, **kwargs) | ||
except Exception as e: | ||
_capture_exception(e) | ||
span.__exit__(None, None, None) | ||
raise e from None | ||
|
||
integration = sentry_sdk.get_client().get_integration(HuggingfaceHubIntegration) | ||
|
||
with capture_internal_exceptions(): | ||
if should_send_default_pii() and integration.include_prompts: | ||
set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, prompt) | ||
|
||
set_data_normalized(span, SPANDATA.AI_MODEL_ID, model) | ||
set_data_normalized(span, SPANDATA.AI_STREAMING, streaming) | ||
|
||
if isinstance(res, str): | ||
if should_send_default_pii() and integration.include_prompts: | ||
set_data_normalized( | ||
span, | ||
"ai.responses", | ||
[res], | ||
) | ||
span.__exit__(None, None, None) | ||
return res | ||
|
||
if isinstance(res, TextGenerationOutput): | ||
if should_send_default_pii() and integration.include_prompts: | ||
set_data_normalized( | ||
span, | ||
"ai.responses", | ||
[res.generated_text], | ||
) | ||
if res.details is not None and res.details.generated_tokens > 0: | ||
record_token_usage(span, total_tokens=res.details.generated_tokens) | ||
span.__exit__(None, None, None) | ||
return res | ||
|
||
if not isinstance(res, Iterable): | ||
# we only know how to deal with strings and iterables, ignore | ||
set_data_normalized(span, "unknown_response", True) | ||
span.__exit__(None, None, None) | ||
return res | ||
|
||
if kwargs.get("details", False): | ||
# res is Iterable[TextGenerationStreamOutput] | ||
def new_details_iterator(): | ||
# type: () -> Iterable[ChatCompletionStreamOutput] | ||
with capture_internal_exceptions(): | ||
tokens_used = 0 | ||
data_buf: list[str] = [] | ||
for x in res: | ||
if hasattr(x, "token") and hasattr(x.token, "text"): | ||
data_buf.append(x.token.text) | ||
if hasattr(x, "details") and hasattr( | ||
x.details, "generated_tokens" | ||
): | ||
tokens_used = x.details.generated_tokens | ||
yield x | ||
if ( | ||
len(data_buf) > 0 | ||
and should_send_default_pii() | ||
and integration.include_prompts | ||
): | ||
set_data_normalized( | ||
span, SPANDATA.AI_RESPONSES, "".join(data_buf) | ||
) | ||
if tokens_used > 0: | ||
record_token_usage(span, total_tokens=tokens_used) | ||
span.__exit__(None, None, None) | ||
|
||
return new_details_iterator() | ||
else: | ||
# res is Iterable[str] | ||
|
||
def new_iterator(): | ||
# type: () -> Iterable[str] | ||
data_buf: list[str] = [] | ||
with capture_internal_exceptions(): | ||
for s in res: | ||
if isinstance(s, str): | ||
data_buf.append(s) | ||
yield s | ||
if ( | ||
len(data_buf) > 0 | ||
and should_send_default_pii() | ||
and integration.include_prompts | ||
): | ||
set_data_normalized( | ||
span, SPANDATA.AI_RESPONSES, "".join(data_buf) | ||
) | ||
span.__exit__(None, None, None) | ||
|
||
return new_iterator() | ||
|
||
return new_text_generation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
import pytest | ||
|
||
pytest.importorskip("huggingface_hub") |
Oops, something went wrong.