Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ client = any_llm_client.OpenAIClient(config, ...)

#### Timeouts, proxy & other HTTP settings

Pass custom [HTTPX](https://www.python-httpx.org) client:

Pass custom [HTTPX](https://www.python-httpx.org) kwargs to `any_llm_client.get_client()`:

```python
import httpx
Expand All @@ -154,14 +155,14 @@ import any_llm_client

async with any_llm_client.get_client(
...,
httpx_client=httpx.AsyncClient(
mounts={"https://api.openai.com": httpx.AsyncHTTPTransport(proxy="http://localhost:8030")},
timeout=httpx.Timeout(None, connect=5.0),
),
mounts={"https://api.openai.com": httpx.AsyncHTTPTransport(proxy="http://localhost:8030")},
timeout=httpx.Timeout(None, connect=5.0),
) as client:
...
```

Default timeout is `httpx.Timeout(None, connect=5.0)` (5 seconds on connect, unlimited on read, write or pool).

#### Retries

By default, requests are retried 3 times on HTTP status errors. You can change the retry behaviour by supplying `request_retry` parameter:
Expand Down
7 changes: 4 additions & 3 deletions any_llm_client/clients/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import typing_extensions

from any_llm_client.core import LLMClient, LLMConfig, LLMError, Message, MessageRole, OutOfTokensOrSymbolsError
from any_llm_client.http import make_http_request, make_streaming_http_request
from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request
from any_llm_client.retry import RequestRetryConfig


Expand Down Expand Up @@ -101,12 +101,13 @@ class OpenAIClient(LLMClient):
def __init__(
self,
config: OpenAIConfig,
httpx_client: httpx.AsyncClient | None = None,
*,
request_retry: RequestRetryConfig | None = None,
**httpx_kwargs: typing.Any, # noqa: ANN401
) -> None:
self.config = config
self.httpx_client = httpx_client or httpx.AsyncClient()
self.request_retry = request_retry or RequestRetryConfig()
self.httpx_client = get_http_client_from_kwargs(httpx_kwargs)

def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request:
return self.httpx_client.build_request(
Expand Down
7 changes: 4 additions & 3 deletions any_llm_client/clients/yandexgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import typing_extensions

from any_llm_client.core import LLMClient, LLMConfig, LLMError, Message, OutOfTokensOrSymbolsError
from any_llm_client.http import make_http_request, make_streaming_http_request
from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request
from any_llm_client.retry import RequestRetryConfig


Expand Down Expand Up @@ -70,12 +70,13 @@ class YandexGPTClient(LLMClient):
def __init__(
self,
config: YandexGPTConfig,
httpx_client: httpx.AsyncClient | None = None,
*,
request_retry: RequestRetryConfig | None = None,
**httpx_kwargs: typing.Any, # noqa: ANN401
) -> None:
self.config = config
self.httpx_client = httpx_client or httpx.AsyncClient()
self.request_retry = request_retry or RequestRetryConfig()
self.httpx_client = get_http_client_from_kwargs(httpx_kwargs)

def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request:
headers: typing.Final = {"x-data-logging-enabled": "false"}
Expand Down
9 changes: 9 additions & 0 deletions any_llm_client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@
from any_llm_client.retry import RequestRetryConfig


DEFAULT_HTTP_TIMEOUT: typing.Final = httpx.Timeout(None, connect=5.0)


def get_http_client_from_kwargs(kwargs: dict[str, typing.Any]) -> httpx.AsyncClient:
kwargs_with_defaults: typing.Final = kwargs.copy()
kwargs_with_defaults.setdefault("timeout", DEFAULT_HTTP_TIMEOUT)
return httpx.AsyncClient(**kwargs_with_defaults)


async def make_http_request(
*,
httpx_client: httpx.AsyncClient,
Expand Down
16 changes: 7 additions & 9 deletions any_llm_client/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import functools
import typing

import httpx

from any_llm_client.clients.mock import MockLLMClient, MockLLMConfig
from any_llm_client.clients.openai import OpenAIClient, OpenAIConfig
from any_llm_client.clients.yandexgpt import YandexGPTClient, YandexGPTConfig
Expand All @@ -18,43 +16,43 @@
def get_client(
config: AnyLLMConfig,
*,
httpx_client: httpx.AsyncClient | None = None,
request_retry: RequestRetryConfig | None = None,
**httpx_kwargs: typing.Any, # noqa: ANN401
) -> LLMClient: ... # pragma: no cover
else:

@functools.singledispatch
def get_client(
config: typing.Any, # noqa: ANN401, ARG001
*,
httpx_client: httpx.AsyncClient | None = None, # noqa: ARG001
request_retry: RequestRetryConfig | None = None, # noqa: ARG001
**httpx_kwargs: typing.Any, # noqa: ANN401, ARG001
) -> LLMClient:
raise AssertionError("unknown LLM config type")

@get_client.register
def _(
config: YandexGPTConfig,
*,
httpx_client: httpx.AsyncClient | None = None,
request_retry: RequestRetryConfig | None = None,
**httpx_kwargs: typing.Any, # noqa: ANN401
) -> LLMClient:
return YandexGPTClient(config=config, httpx_client=httpx_client, request_retry=request_retry)
return YandexGPTClient(config=config, request_retry=request_retry, **httpx_kwargs)

@get_client.register
def _(
config: OpenAIConfig,
*,
httpx_client: httpx.AsyncClient | None = None,
request_retry: RequestRetryConfig | None = None,
**httpx_kwargs: typing.Any, # noqa: ANN401
) -> LLMClient:
return OpenAIClient(config=config, httpx_client=httpx_client, request_retry=request_retry)
return OpenAIClient(config=config, request_retry=request_retry, **httpx_kwargs)

@get_client.register
def _(
config: MockLLMConfig,
*,
httpx_client: httpx.AsyncClient | None = None, # noqa: ARG001
request_retry: RequestRetryConfig | None = None, # noqa: ARG001
**httpx_kwargs: typing.Any, # noqa: ANN401, ARG001
) -> LLMClient:
return MockLLMClient(config=config)
27 changes: 27 additions & 0 deletions tests/test_http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import copy
import typing

import httpx

from any_llm_client.http import DEFAULT_HTTP_TIMEOUT, get_http_client_from_kwargs


class TestGetHttpClientFromKwargs:
def test_http_timeout_is_added(self) -> None:
original_kwargs: typing.Final = {"mounts": {"http://": None}}
passed_kwargs: typing.Final = copy.deepcopy(original_kwargs)

result: typing.Final = get_http_client_from_kwargs(passed_kwargs)

assert result.timeout == DEFAULT_HTTP_TIMEOUT
assert original_kwargs == passed_kwargs

def test_http_timeout_is_not_modified_if_set(self) -> None:
timeout: typing.Final = httpx.Timeout(7, connect=5, read=3)
original_kwargs: typing.Final = {"mounts": {"http://": None}, "timeout": timeout}
passed_kwargs: typing.Final = copy.deepcopy(original_kwargs)

result: typing.Final = get_http_client_from_kwargs(passed_kwargs)

assert result.timeout == timeout
assert original_kwargs == passed_kwargs
7 changes: 2 additions & 5 deletions tests/test_mock_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import typing
from unittest import mock

from polyfactory.factories.pydantic_factory import ModelFactory

Expand All @@ -12,7 +11,7 @@ class MockLLMConfigFactory(ModelFactory[any_llm_client.MockLLMConfig]): ...

async def test_mock_client_request_llm_message_returns_config_value() -> None:
config: typing.Final = MockLLMConfigFactory.build()
response: typing.Final = await any_llm_client.get_client(config, httpx_client=mock.Mock()).request_llm_message(
response: typing.Final = await any_llm_client.get_client(config).request_llm_message(
**LLMFuncRequestFactory.build()
)
assert response == config.response_message
Expand All @@ -21,8 +20,6 @@ async def test_mock_client_request_llm_message_returns_config_value() -> None:
async def test_mock_client_request_llm_partial_responses_returns_config_value() -> None:
config: typing.Final = MockLLMConfigFactory.build()
response: typing.Final = await consume_llm_partial_responses(
any_llm_client.get_client(config, httpx_client=mock.Mock()).stream_llm_partial_messages(
**LLMFuncRequestFactory.build()
)
any_llm_client.get_client(config).stream_llm_partial_messages(**LLMFuncRequestFactory.build())
)
assert response == config.stream_messages
26 changes: 8 additions & 18 deletions tests/test_openai_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import typing
from unittest import mock

import faker
import httpx
Expand Down Expand Up @@ -35,8 +34,7 @@ async def test_ok(self, faker: faker.Faker) -> None:
)

result: typing.Final = await any_llm_client.get_client(
OpenAIConfigFactory.build(),
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: response)
).request_llm_message(**LLMFuncRequestFactory.build())

assert result == expected_result
Expand All @@ -47,8 +45,7 @@ async def test_fails_without_alternatives(self) -> None:
json=ChatCompletionsNotStreamingResponse.model_construct(choices=[]).model_dump(mode="json"),
)
client: typing.Final = any_llm_client.get_client(
OpenAIConfigFactory.build(),
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: response)
)

with pytest.raises(pydantic.ValidationError):
Expand Down Expand Up @@ -89,10 +86,7 @@ async def test_ok(self, faker: faker.Faker) -> None:
response: typing.Final = httpx.Response(
200, headers={"Content-Type": "text/event-stream"}, content=response_content
)
client: typing.Final = any_llm_client.get_client(
config,
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
)
client: typing.Final = any_llm_client.get_client(config, transport=httpx.MockTransport(lambda _: response))

result: typing.Final = await consume_llm_partial_responses(client.stream_llm_partial_messages(**func_request))

Expand All @@ -106,8 +100,7 @@ async def test_fails_without_alternatives(self) -> None:
200, headers={"Content-Type": "text/event-stream"}, content=response_content
)
client: typing.Final = any_llm_client.get_client(
OpenAIConfigFactory.build(),
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: response)
)

with pytest.raises(pydantic.ValidationError):
Expand All @@ -119,8 +112,7 @@ class TestOpenAILLMErrors:
@pytest.mark.parametrize("status_code", [400, 500])
async def test_fails_with_unknown_error(self, stream: bool, status_code: int) -> None:
client: typing.Final = any_llm_client.get_client(
OpenAIConfigFactory.build(),
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(status_code))),
OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: httpx.Response(status_code))
)

coroutine: typing.Final = (
Expand All @@ -144,8 +136,7 @@ async def test_fails_with_unknown_error(self, stream: bool, status_code: int) ->
async def test_fails_with_out_of_tokens_error(self, stream: bool, content: bytes | None) -> None:
response: typing.Final = httpx.Response(400, content=content)
client: typing.Final = any_llm_client.get_client(
OpenAIConfigFactory.build(),
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: response)
)

coroutine: typing.Final = (
Expand Down Expand Up @@ -244,14 +235,13 @@ def test_with_alternation(
self, messages: list[any_llm_client.Message], expected_result: list[ChatCompletionsMessage]
) -> None:
client: typing.Final = any_llm_client.OpenAIClient(
config=OpenAIConfigFactory.build(force_user_assistant_message_alternation=True), httpx_client=mock.Mock()
OpenAIConfigFactory.build(force_user_assistant_message_alternation=True)
)
assert client._prepare_messages(messages) == expected_result # noqa: SLF001

def test_without_alternation(self) -> None:
client: typing.Final = any_llm_client.OpenAIClient(
config=OpenAIConfigFactory.build(force_user_assistant_message_alternation=False),
httpx_client=mock.Mock(),
OpenAIConfigFactory.build(force_user_assistant_message_alternation=False)
)
assert client._prepare_messages( # noqa: SLF001
[
Expand Down
4 changes: 1 addition & 3 deletions tests/test_unknown_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from unittest import mock

import faker
import pytest

Expand All @@ -8,4 +6,4 @@

def test_unknown_client_raises_assertion_error(faker: faker.Faker) -> None:
with pytest.raises(AssertionError):
any_llm_client.get_client(faker.pyobject(), httpx_client=mock.Mock()) # type: ignore[arg-type]
any_llm_client.get_client(faker.pyobject()) # type: ignore[arg-type]
17 changes: 6 additions & 11 deletions tests/test_yandexgpt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ async def test_ok(self, faker: faker.Faker) -> None:
)

result: typing.Final = await any_llm_client.get_client(
YandexGPTConfigFactory.build(),
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: response)
).request_llm_message(**LLMFuncRequestFactory.build())

assert result == expected_result
Expand All @@ -40,8 +39,7 @@ async def test_fails_without_alternatives(self) -> None:
200, json=YandexGPTResponse(result=YandexGPTResult.model_construct(alternatives=[])).model_dump(mode="json")
)
client: typing.Final = any_llm_client.get_client(
YandexGPTConfigFactory.build(),
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: response)
)

with pytest.raises(pydantic.ValidationError):
Expand Down Expand Up @@ -70,7 +68,7 @@ async def test_ok(self, faker: faker.Faker) -> None:

result: typing.Final = await consume_llm_partial_responses(
any_llm_client.get_client(
config, httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response))
config, transport=httpx.MockTransport(lambda _: response)
).stream_llm_partial_messages(**func_request)
)

Expand All @@ -83,8 +81,7 @@ async def test_fails_without_alternatives(self) -> None:
response: typing.Final = httpx.Response(200, content=response_content)

client: typing.Final = any_llm_client.get_client(
YandexGPTConfigFactory.build(),
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: response)
)

with pytest.raises(pydantic.ValidationError):
Expand All @@ -96,8 +93,7 @@ class TestYandexGPTLLMErrors:
@pytest.mark.parametrize("status_code", [400, 500])
async def test_fails_with_unknown_error(self, stream: bool, status_code: int) -> None:
client: typing.Final = any_llm_client.get_client(
YandexGPTConfigFactory.build(),
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(status_code))),
YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: httpx.Response(status_code))
)

coroutine: typing.Final = (
Expand All @@ -121,8 +117,7 @@ async def test_fails_with_unknown_error(self, stream: bool, status_code: int) ->
async def test_fails_with_out_of_tokens_error(self, stream: bool, response_content: bytes | None) -> None:
response: typing.Final = httpx.Response(400, content=response_content)
client: typing.Final = any_llm_client.get_client(
YandexGPTConfigFactory.build(),
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: response)
)

coroutine: typing.Final = (
Expand Down