diff --git a/README.md b/README.md index f7325e5..f75da8b 100644 --- a/README.md +++ b/README.md @@ -144,7 +144,8 @@ client = any_llm_client.OpenAIClient(config, ...) #### Timeouts, proxy & other HTTP settings -Pass custom [HTTPX](https://www.python-httpx.org) client: + +Pass custom [HTTPX](https://www.python-httpx.org) kwargs to `any_llm_client.get_client()`: ```python import httpx @@ -154,14 +155,14 @@ import any_llm_client async with any_llm_client.get_client( ..., - httpx_client=httpx.AsyncClient( - mounts={"https://api.openai.com": httpx.AsyncHTTPTransport(proxy="http://localhost:8030")}, - timeout=httpx.Timeout(None, connect=5.0), - ), + mounts={"https://api.openai.com": httpx.AsyncHTTPTransport(proxy="http://localhost:8030")}, + timeout=httpx.Timeout(None, connect=5.0), ) as client: ... ``` +Default timeout is `httpx.Timeout(None, connect=5.0)` (5 seconds on connect, unlimited on read, write or pool). + #### Retries By default, requests are retried 3 times on HTTP status errors. You can change the retry behaviour by supplying `request_retry` parameter: diff --git a/any_llm_client/clients/openai.py b/any_llm_client/clients/openai.py index 1a1d9d6..62f009c 100644 --- a/any_llm_client/clients/openai.py +++ b/any_llm_client/clients/openai.py @@ -11,7 +11,7 @@ import typing_extensions from any_llm_client.core import LLMClient, LLMConfig, LLMError, Message, MessageRole, OutOfTokensOrSymbolsError -from any_llm_client.http import make_http_request, make_streaming_http_request +from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request from any_llm_client.retry import RequestRetryConfig @@ -101,12 +101,13 @@ class OpenAIClient(LLMClient): def __init__( self, config: OpenAIConfig, - httpx_client: httpx.AsyncClient | None = None, + *, request_retry: RequestRetryConfig | None = None, + **httpx_kwargs: typing.Any, # noqa: ANN401 ) -> None: self.config = config - self.httpx_client = httpx_client or httpx.AsyncClient() self.request_retry = request_retry or RequestRetryConfig() + self.httpx_client = get_http_client_from_kwargs(httpx_kwargs) def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request: return self.httpx_client.build_request( diff --git a/any_llm_client/clients/yandexgpt.py b/any_llm_client/clients/yandexgpt.py index dcc27a3..42bfd2e 100644 --- a/any_llm_client/clients/yandexgpt.py +++ b/any_llm_client/clients/yandexgpt.py @@ -10,7 +10,7 @@ import typing_extensions from any_llm_client.core import LLMClient, LLMConfig, LLMError, Message, OutOfTokensOrSymbolsError -from any_llm_client.http import make_http_request, make_streaming_http_request +from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request from any_llm_client.retry import RequestRetryConfig @@ -70,12 +70,13 @@ class YandexGPTClient(LLMClient): def __init__( self, config: YandexGPTConfig, - httpx_client: httpx.AsyncClient | None = None, + *, request_retry: RequestRetryConfig | None = None, + **httpx_kwargs: typing.Any, # noqa: ANN401 ) -> None: self.config = config - self.httpx_client = httpx_client or httpx.AsyncClient() self.request_retry = request_retry or RequestRetryConfig() + self.httpx_client = get_http_client_from_kwargs(httpx_kwargs) def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request: headers: typing.Final = {"x-data-logging-enabled": "false"} diff --git a/any_llm_client/http.py b/any_llm_client/http.py index 1bc5d34..7d05030 100644 --- a/any_llm_client/http.py +++ b/any_llm_client/http.py @@ -8,6 +8,15 @@ from any_llm_client.retry import RequestRetryConfig +DEFAULT_HTTP_TIMEOUT: typing.Final = httpx.Timeout(None, connect=5.0) + + +def get_http_client_from_kwargs(kwargs: dict[str, typing.Any]) -> httpx.AsyncClient: + kwargs_with_defaults: typing.Final = kwargs.copy() + kwargs_with_defaults.setdefault("timeout", DEFAULT_HTTP_TIMEOUT) + return httpx.AsyncClient(**kwargs_with_defaults) + + async def make_http_request( *, httpx_client: httpx.AsyncClient, diff --git a/any_llm_client/main.py b/any_llm_client/main.py index ce8c3b5..8163fd3 100644 --- a/any_llm_client/main.py +++ b/any_llm_client/main.py @@ -1,8 +1,6 @@ import functools import typing -import httpx - from any_llm_client.clients.mock import MockLLMClient, MockLLMConfig from any_llm_client.clients.openai import OpenAIClient, OpenAIConfig from any_llm_client.clients.yandexgpt import YandexGPTClient, YandexGPTConfig @@ -18,8 +16,8 @@ def get_client( config: AnyLLMConfig, *, - httpx_client: httpx.AsyncClient | None = None, request_retry: RequestRetryConfig | None = None, + **httpx_kwargs: typing.Any, # noqa: ANN401 ) -> LLMClient: ... # pragma: no cover else: @@ -27,8 +25,8 @@ def get_client( def get_client( config: typing.Any, # noqa: ANN401, ARG001 *, - httpx_client: httpx.AsyncClient | None = None, # noqa: ARG001 request_retry: RequestRetryConfig | None = None, # noqa: ARG001 + **httpx_kwargs: typing.Any, # noqa: ANN401, ARG001 ) -> LLMClient: raise AssertionError("unknown LLM config type") @@ -36,25 +34,25 @@ def get_client( def _( config: YandexGPTConfig, *, - httpx_client: httpx.AsyncClient | None = None, request_retry: RequestRetryConfig | None = None, + **httpx_kwargs: typing.Any, # noqa: ANN401 ) -> LLMClient: - return YandexGPTClient(config=config, httpx_client=httpx_client, request_retry=request_retry) + return YandexGPTClient(config=config, request_retry=request_retry, **httpx_kwargs) @get_client.register def _( config: OpenAIConfig, *, - httpx_client: httpx.AsyncClient | None = None, request_retry: RequestRetryConfig | None = None, + **httpx_kwargs: typing.Any, # noqa: ANN401 ) -> LLMClient: - return OpenAIClient(config=config, httpx_client=httpx_client, request_retry=request_retry) + return OpenAIClient(config=config, request_retry=request_retry, **httpx_kwargs) @get_client.register def _( config: MockLLMConfig, *, - httpx_client: httpx.AsyncClient | None = None, # noqa: ARG001 request_retry: RequestRetryConfig | None = None, # noqa: ARG001 + **httpx_kwargs: typing.Any, # noqa: ANN401, ARG001 ) -> LLMClient: return MockLLMClient(config=config) diff --git a/tests/test_http.py b/tests/test_http.py new file mode 100644 index 0000000..0269270 --- /dev/null +++ b/tests/test_http.py @@ -0,0 +1,27 @@ +import copy +import typing + +import httpx + +from any_llm_client.http import DEFAULT_HTTP_TIMEOUT, get_http_client_from_kwargs + + +class TestGetHttpClientFromKwargs: + def test_http_timeout_is_added(self) -> None: + original_kwargs: typing.Final = {"mounts": {"http://": None}} + passed_kwargs: typing.Final = copy.deepcopy(original_kwargs) + + result: typing.Final = get_http_client_from_kwargs(passed_kwargs) + + assert result.timeout == DEFAULT_HTTP_TIMEOUT + assert original_kwargs == passed_kwargs + + def test_http_timeout_is_not_modified_if_set(self) -> None: + timeout: typing.Final = httpx.Timeout(7, connect=5, read=3) + original_kwargs: typing.Final = {"mounts": {"http://": None}, "timeout": timeout} + passed_kwargs: typing.Final = copy.deepcopy(original_kwargs) + + result: typing.Final = get_http_client_from_kwargs(passed_kwargs) + + assert result.timeout == timeout + assert original_kwargs == passed_kwargs diff --git a/tests/test_mock_client.py b/tests/test_mock_client.py index 23717b9..959a80e 100644 --- a/tests/test_mock_client.py +++ b/tests/test_mock_client.py @@ -1,5 +1,4 @@ import typing -from unittest import mock from polyfactory.factories.pydantic_factory import ModelFactory @@ -12,7 +11,7 @@ class MockLLMConfigFactory(ModelFactory[any_llm_client.MockLLMConfig]): ... async def test_mock_client_request_llm_message_returns_config_value() -> None: config: typing.Final = MockLLMConfigFactory.build() - response: typing.Final = await any_llm_client.get_client(config, httpx_client=mock.Mock()).request_llm_message( + response: typing.Final = await any_llm_client.get_client(config).request_llm_message( **LLMFuncRequestFactory.build() ) assert response == config.response_message @@ -21,8 +20,6 @@ async def test_mock_client_request_llm_message_returns_config_value() -> None: async def test_mock_client_request_llm_partial_responses_returns_config_value() -> None: config: typing.Final = MockLLMConfigFactory.build() response: typing.Final = await consume_llm_partial_responses( - any_llm_client.get_client(config, httpx_client=mock.Mock()).stream_llm_partial_messages( - **LLMFuncRequestFactory.build() - ) + any_llm_client.get_client(config).stream_llm_partial_messages(**LLMFuncRequestFactory.build()) ) assert response == config.stream_messages diff --git a/tests/test_openai_client.py b/tests/test_openai_client.py index b23c851..efa7d1e 100644 --- a/tests/test_openai_client.py +++ b/tests/test_openai_client.py @@ -1,5 +1,4 @@ import typing -from unittest import mock import faker import httpx @@ -35,8 +34,7 @@ async def test_ok(self, faker: faker.Faker) -> None: ) result: typing.Final = await any_llm_client.get_client( - OpenAIConfigFactory.build(), - httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)), + OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) ).request_llm_message(**LLMFuncRequestFactory.build()) assert result == expected_result @@ -47,8 +45,7 @@ async def test_fails_without_alternatives(self) -> None: json=ChatCompletionsNotStreamingResponse.model_construct(choices=[]).model_dump(mode="json"), ) client: typing.Final = any_llm_client.get_client( - OpenAIConfigFactory.build(), - httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)), + OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) ) with pytest.raises(pydantic.ValidationError): @@ -89,10 +86,7 @@ async def test_ok(self, faker: faker.Faker) -> None: response: typing.Final = httpx.Response( 200, headers={"Content-Type": "text/event-stream"}, content=response_content ) - client: typing.Final = any_llm_client.get_client( - config, - httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)), - ) + client: typing.Final = any_llm_client.get_client(config, transport=httpx.MockTransport(lambda _: response)) result: typing.Final = await consume_llm_partial_responses(client.stream_llm_partial_messages(**func_request)) @@ -106,8 +100,7 @@ async def test_fails_without_alternatives(self) -> None: 200, headers={"Content-Type": "text/event-stream"}, content=response_content ) client: typing.Final = any_llm_client.get_client( - OpenAIConfigFactory.build(), - httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)), + OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) ) with pytest.raises(pydantic.ValidationError): @@ -119,8 +112,7 @@ class TestOpenAILLMErrors: @pytest.mark.parametrize("status_code", [400, 500]) async def test_fails_with_unknown_error(self, stream: bool, status_code: int) -> None: client: typing.Final = any_llm_client.get_client( - OpenAIConfigFactory.build(), - httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(status_code))), + OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: httpx.Response(status_code)) ) coroutine: typing.Final = ( @@ -144,8 +136,7 @@ async def test_fails_with_unknown_error(self, stream: bool, status_code: int) -> async def test_fails_with_out_of_tokens_error(self, stream: bool, content: bytes | None) -> None: response: typing.Final = httpx.Response(400, content=content) client: typing.Final = any_llm_client.get_client( - OpenAIConfigFactory.build(), - httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)), + OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) ) coroutine: typing.Final = ( @@ -244,14 +235,13 @@ def test_with_alternation( self, messages: list[any_llm_client.Message], expected_result: list[ChatCompletionsMessage] ) -> None: client: typing.Final = any_llm_client.OpenAIClient( - config=OpenAIConfigFactory.build(force_user_assistant_message_alternation=True), httpx_client=mock.Mock() + OpenAIConfigFactory.build(force_user_assistant_message_alternation=True) ) assert client._prepare_messages(messages) == expected_result # noqa: SLF001 def test_without_alternation(self) -> None: client: typing.Final = any_llm_client.OpenAIClient( - config=OpenAIConfigFactory.build(force_user_assistant_message_alternation=False), - httpx_client=mock.Mock(), + OpenAIConfigFactory.build(force_user_assistant_message_alternation=False) ) assert client._prepare_messages( # noqa: SLF001 [ diff --git a/tests/test_unknown_client.py b/tests/test_unknown_client.py index 6320593..f2abb57 100644 --- a/tests/test_unknown_client.py +++ b/tests/test_unknown_client.py @@ -1,5 +1,3 @@ -from unittest import mock - import faker import pytest @@ -8,4 +6,4 @@ def test_unknown_client_raises_assertion_error(faker: faker.Faker) -> None: with pytest.raises(AssertionError): - any_llm_client.get_client(faker.pyobject(), httpx_client=mock.Mock()) # type: ignore[arg-type] + any_llm_client.get_client(faker.pyobject()) # type: ignore[arg-type] diff --git a/tests/test_yandexgpt_client.py b/tests/test_yandexgpt_client.py index a4c23d4..3f45b73 100644 --- a/tests/test_yandexgpt_client.py +++ b/tests/test_yandexgpt_client.py @@ -29,8 +29,7 @@ async def test_ok(self, faker: faker.Faker) -> None: ) result: typing.Final = await any_llm_client.get_client( - YandexGPTConfigFactory.build(), - httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)), + YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) ).request_llm_message(**LLMFuncRequestFactory.build()) assert result == expected_result @@ -40,8 +39,7 @@ async def test_fails_without_alternatives(self) -> None: 200, json=YandexGPTResponse(result=YandexGPTResult.model_construct(alternatives=[])).model_dump(mode="json") ) client: typing.Final = any_llm_client.get_client( - YandexGPTConfigFactory.build(), - httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)), + YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) ) with pytest.raises(pydantic.ValidationError): @@ -70,7 +68,7 @@ async def test_ok(self, faker: faker.Faker) -> None: result: typing.Final = await consume_llm_partial_responses( any_llm_client.get_client( - config, httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)) + config, transport=httpx.MockTransport(lambda _: response) ).stream_llm_partial_messages(**func_request) ) @@ -83,8 +81,7 @@ async def test_fails_without_alternatives(self) -> None: response: typing.Final = httpx.Response(200, content=response_content) client: typing.Final = any_llm_client.get_client( - YandexGPTConfigFactory.build(), - httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)), + YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) ) with pytest.raises(pydantic.ValidationError): @@ -96,8 +93,7 @@ class TestYandexGPTLLMErrors: @pytest.mark.parametrize("status_code", [400, 500]) async def test_fails_with_unknown_error(self, stream: bool, status_code: int) -> None: client: typing.Final = any_llm_client.get_client( - YandexGPTConfigFactory.build(), - httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(status_code))), + YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: httpx.Response(status_code)) ) coroutine: typing.Final = ( @@ -121,8 +117,7 @@ async def test_fails_with_unknown_error(self, stream: bool, status_code: int) -> async def test_fails_with_out_of_tokens_error(self, stream: bool, response_content: bytes | None) -> None: response: typing.Final = httpx.Response(400, content=response_content) client: typing.Final = any_llm_client.get_client( - YandexGPTConfigFactory.build(), - httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)), + YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) ) coroutine: typing.Final = (