|
35 | 35 | UserPromptPart, |
36 | 36 | VideoUrl, |
37 | 37 | ) |
38 | | -from ..profiles import ModelProfile |
| 38 | +from ..profiles import ModelProfile, ModelProfileSpec |
39 | 39 | from ..providers import Provider, infer_provider |
40 | 40 | from ..settings import ModelSettings |
41 | 41 | from ..tools import ToolDefinition |
@@ -121,20 +121,26 @@ def __init__( |
121 | 121 | model_name: str, |
122 | 122 | *, |
123 | 123 | provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface', |
| 124 | + profile: ModelProfileSpec | None = None, |
| 125 | + settings: ModelSettings | None = None, |
124 | 126 | ): |
125 | 127 | """Initialize a Hugging Face model. |
126 | 128 |
|
127 | 129 | Args: |
128 | 130 | model_name: The name of the Model to use. You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending). |
129 | 131 | provider: The provider to use for Hugging Face Inference Providers. Can be either the string 'huggingface' or an |
130 | 132 | instance of `Provider[AsyncInferenceClient]`. If not provided, the other parameters will be used. |
| 133 | + profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. |
| 134 | + settings: Model-specific settings that will be used as defaults for this model. |
131 | 135 | """ |
132 | 136 | self._model_name = model_name |
133 | 137 | self._provider = provider |
134 | 138 | if isinstance(provider, str): |
135 | 139 | provider = infer_provider(provider) |
136 | 140 | self.client = provider.client |
137 | 141 |
|
| 142 | + super().__init__(settings=settings, profile=profile or provider.model_profile) |
| 143 | + |
138 | 144 | async def request( |
139 | 145 | self, |
140 | 146 | messages: list[ModelMessage], |
@@ -444,11 +450,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: |
444 | 450 |
|
445 | 451 | # Handle the text part of the response |
446 | 452 | content = choice.delta.content |
447 | | - if content: |
| 453 | + if content is not None: |
448 | 454 | maybe_event = self._parts_manager.handle_text_delta( |
449 | 455 | vendor_part_id='content', |
450 | 456 | content=content, |
451 | 457 | thinking_tags=self._model_profile.thinking_tags, |
| 458 | + ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace, |
452 | 459 | ) |
453 | 460 | if maybe_event is not None: # pragma: no branch |
454 | 461 | yield maybe_event |
|
0 commit comments