|
1 | 1 | import contextlib |
2 | 2 | import dataclasses |
| 3 | +import os |
3 | 4 | import types |
4 | 5 | import typing |
5 | 6 | from http import HTTPStatus |
|
14 | 15 | from any_llm_client.retry import RequestRetryConfig |
15 | 16 |
|
16 | 17 |
|
| 18 | +YANDEXGPT_AUTH_HEADER_ENV_NAME: typing.Final = "ANY_LLM_CLIENT_YANDEXGPT_AUTH_HEADER" |
| 19 | +YANDEXGPT_FOLDER_ID_ENV_NAME: typing.Final = "ANY_LLM_CLIENT_YANDEXGPT_FOLDER_ID" |
| 20 | + |
| 21 | + |
17 | 22 | class YandexGPTConfig(LLMConfig): |
18 | 23 | if typing.TYPE_CHECKING: |
19 | 24 | url: str = "https://llm.api.cloud.yandex.net/foundationModels/v1/completion" # pragma: no cover |
20 | 25 | else: |
21 | 26 | url: pydantic.HttpUrl = "https://llm.api.cloud.yandex.net/foundationModels/v1/completion" |
22 | | - auth_header: str | None = None |
23 | | - folder_id: str | None = None |
| 27 | + auth_header: str = pydantic.Field( # type: ignore[assignment] |
| 28 | + default_factory=lambda: os.environ.get(YANDEXGPT_AUTH_HEADER_ENV_NAME), validate_default=True |
| 29 | + ) |
| 30 | + folder_id: str = pydantic.Field( # type: ignore[assignment] |
| 31 | + default_factory=lambda: os.environ.get(YANDEXGPT_FOLDER_ID_ENV_NAME), validate_default=True |
| 32 | + ) |
24 | 33 | model_name: str |
25 | 34 | model_version: str = "latest" |
26 | 35 | max_tokens: int = 7400 |
@@ -79,10 +88,12 @@ def __init__( |
79 | 88 | self.httpx_client = get_http_client_from_kwargs(httpx_kwargs) |
80 | 89 |
|
81 | 90 | def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request: |
82 | | - headers: typing.Final = {"x-data-logging-enabled": "false"} |
83 | | - if self.config.auth_header: |
84 | | - headers["Authorization"] = self.config.auth_header |
85 | | - return self.httpx_client.build_request(method="POST", url=str(self.config.url), json=payload, headers=headers) |
| 91 | + return self.httpx_client.build_request( |
| 92 | + method="POST", |
| 93 | + url=str(self.config.url), |
| 94 | + json=payload, |
| 95 | + headers={"Authorization": self.config.auth_header, "x-data-logging-enabled": "false"}, |
| 96 | + ) |
86 | 97 |
|
87 | 98 | def _prepare_payload( |
88 | 99 | self, *, messages: str | list[Message], temperature: float = 0.2, stream: bool |
|
0 commit comments