Skip to content

Commit a1b7bb4

Browse files
authored
Ask for **httpx_kwargs, not httpx_client (#1)
1 parent 15340b1 commit a1b7bb4

File tree

10 files changed

+74
-57
lines changed

10 files changed

+74
-57
lines changed

README.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ client = any_llm_client.OpenAIClient(config, ...)
144144

145145
#### Timeouts, proxy & other HTTP settings
146146

147-
Pass custom [HTTPX](https://www.python-httpx.org) client:
147+
148+
Pass custom [HTTPX](https://www.python-httpx.org) kwargs to `any_llm_client.get_client()`:
148149

149150
```python
150151
import httpx
@@ -154,14 +155,14 @@ import any_llm_client
154155

155156
async with any_llm_client.get_client(
156157
...,
157-
httpx_client=httpx.AsyncClient(
158-
mounts={"https://api.openai.com": httpx.AsyncHTTPTransport(proxy="http://localhost:8030")},
159-
timeout=httpx.Timeout(None, connect=5.0),
160-
),
158+
mounts={"https://api.openai.com": httpx.AsyncHTTPTransport(proxy="http://localhost:8030")},
159+
timeout=httpx.Timeout(None, connect=5.0),
161160
) as client:
162161
...
163162
```
164163

164+
Default timeout is `httpx.Timeout(None, connect=5.0)` (5 seconds on connect, unlimited on read, write or pool).
165+
165166
#### Retries
166167

167168
By default, requests are retried 3 times on HTTP status errors. You can change the retry behaviour by supplying `request_retry` parameter:

any_llm_client/clients/openai.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import typing_extensions
1212

1313
from any_llm_client.core import LLMClient, LLMConfig, LLMError, Message, MessageRole, OutOfTokensOrSymbolsError
14-
from any_llm_client.http import make_http_request, make_streaming_http_request
14+
from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request
1515
from any_llm_client.retry import RequestRetryConfig
1616

1717

@@ -101,12 +101,13 @@ class OpenAIClient(LLMClient):
101101
def __init__(
102102
self,
103103
config: OpenAIConfig,
104-
httpx_client: httpx.AsyncClient | None = None,
104+
*,
105105
request_retry: RequestRetryConfig | None = None,
106+
**httpx_kwargs: typing.Any, # noqa: ANN401
106107
) -> None:
107108
self.config = config
108-
self.httpx_client = httpx_client or httpx.AsyncClient()
109109
self.request_retry = request_retry or RequestRetryConfig()
110+
self.httpx_client = get_http_client_from_kwargs(httpx_kwargs)
110111

111112
def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request:
112113
return self.httpx_client.build_request(

any_llm_client/clients/yandexgpt.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import typing_extensions
1111

1212
from any_llm_client.core import LLMClient, LLMConfig, LLMError, Message, OutOfTokensOrSymbolsError
13-
from any_llm_client.http import make_http_request, make_streaming_http_request
13+
from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request
1414
from any_llm_client.retry import RequestRetryConfig
1515

1616

@@ -70,12 +70,13 @@ class YandexGPTClient(LLMClient):
7070
def __init__(
7171
self,
7272
config: YandexGPTConfig,
73-
httpx_client: httpx.AsyncClient | None = None,
73+
*,
7474
request_retry: RequestRetryConfig | None = None,
75+
**httpx_kwargs: typing.Any, # noqa: ANN401
7576
) -> None:
7677
self.config = config
77-
self.httpx_client = httpx_client or httpx.AsyncClient()
7878
self.request_retry = request_retry or RequestRetryConfig()
79+
self.httpx_client = get_http_client_from_kwargs(httpx_kwargs)
7980

8081
def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request:
8182
headers: typing.Final = {"x-data-logging-enabled": "false"}

any_llm_client/http.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@
88
from any_llm_client.retry import RequestRetryConfig
99

1010

11+
DEFAULT_HTTP_TIMEOUT: typing.Final = httpx.Timeout(None, connect=5.0)
12+
13+
14+
def get_http_client_from_kwargs(kwargs: dict[str, typing.Any]) -> httpx.AsyncClient:
15+
kwargs_with_defaults: typing.Final = kwargs.copy()
16+
kwargs_with_defaults.setdefault("timeout", DEFAULT_HTTP_TIMEOUT)
17+
return httpx.AsyncClient(**kwargs_with_defaults)
18+
19+
1120
async def make_http_request(
1221
*,
1322
httpx_client: httpx.AsyncClient,

any_llm_client/main.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import functools
22
import typing
33

4-
import httpx
5-
64
from any_llm_client.clients.mock import MockLLMClient, MockLLMConfig
75
from any_llm_client.clients.openai import OpenAIClient, OpenAIConfig
86
from any_llm_client.clients.yandexgpt import YandexGPTClient, YandexGPTConfig
@@ -18,43 +16,43 @@
1816
def get_client(
1917
config: AnyLLMConfig,
2018
*,
21-
httpx_client: httpx.AsyncClient | None = None,
2219
request_retry: RequestRetryConfig | None = None,
20+
**httpx_kwargs: typing.Any, # noqa: ANN401
2321
) -> LLMClient: ... # pragma: no cover
2422
else:
2523

2624
@functools.singledispatch
2725
def get_client(
2826
config: typing.Any, # noqa: ANN401, ARG001
2927
*,
30-
httpx_client: httpx.AsyncClient | None = None, # noqa: ARG001
3128
request_retry: RequestRetryConfig | None = None, # noqa: ARG001
29+
**httpx_kwargs: typing.Any, # noqa: ANN401, ARG001
3230
) -> LLMClient:
3331
raise AssertionError("unknown LLM config type")
3432

3533
@get_client.register
3634
def _(
3735
config: YandexGPTConfig,
3836
*,
39-
httpx_client: httpx.AsyncClient | None = None,
4037
request_retry: RequestRetryConfig | None = None,
38+
**httpx_kwargs: typing.Any, # noqa: ANN401
4139
) -> LLMClient:
42-
return YandexGPTClient(config=config, httpx_client=httpx_client, request_retry=request_retry)
40+
return YandexGPTClient(config=config, request_retry=request_retry, **httpx_kwargs)
4341

4442
@get_client.register
4543
def _(
4644
config: OpenAIConfig,
4745
*,
48-
httpx_client: httpx.AsyncClient | None = None,
4946
request_retry: RequestRetryConfig | None = None,
47+
**httpx_kwargs: typing.Any, # noqa: ANN401
5048
) -> LLMClient:
51-
return OpenAIClient(config=config, httpx_client=httpx_client, request_retry=request_retry)
49+
return OpenAIClient(config=config, request_retry=request_retry, **httpx_kwargs)
5250

5351
@get_client.register
5452
def _(
5553
config: MockLLMConfig,
5654
*,
57-
httpx_client: httpx.AsyncClient | None = None, # noqa: ARG001
5855
request_retry: RequestRetryConfig | None = None, # noqa: ARG001
56+
**httpx_kwargs: typing.Any, # noqa: ANN401, ARG001
5957
) -> LLMClient:
6058
return MockLLMClient(config=config)

tests/test_http.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import copy
2+
import typing
3+
4+
import httpx
5+
6+
from any_llm_client.http import DEFAULT_HTTP_TIMEOUT, get_http_client_from_kwargs
7+
8+
9+
class TestGetHttpClientFromKwargs:
10+
def test_http_timeout_is_added(self) -> None:
11+
original_kwargs: typing.Final = {"mounts": {"http://": None}}
12+
passed_kwargs: typing.Final = copy.deepcopy(original_kwargs)
13+
14+
result: typing.Final = get_http_client_from_kwargs(passed_kwargs)
15+
16+
assert result.timeout == DEFAULT_HTTP_TIMEOUT
17+
assert original_kwargs == passed_kwargs
18+
19+
def test_http_timeout_is_not_modified_if_set(self) -> None:
20+
timeout: typing.Final = httpx.Timeout(7, connect=5, read=3)
21+
original_kwargs: typing.Final = {"mounts": {"http://": None}, "timeout": timeout}
22+
passed_kwargs: typing.Final = copy.deepcopy(original_kwargs)
23+
24+
result: typing.Final = get_http_client_from_kwargs(passed_kwargs)
25+
26+
assert result.timeout == timeout
27+
assert original_kwargs == passed_kwargs

tests/test_mock_client.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import typing
2-
from unittest import mock
32

43
from polyfactory.factories.pydantic_factory import ModelFactory
54

@@ -12,7 +11,7 @@ class MockLLMConfigFactory(ModelFactory[any_llm_client.MockLLMConfig]): ...
1211

1312
async def test_mock_client_request_llm_message_returns_config_value() -> None:
1413
config: typing.Final = MockLLMConfigFactory.build()
15-
response: typing.Final = await any_llm_client.get_client(config, httpx_client=mock.Mock()).request_llm_message(
14+
response: typing.Final = await any_llm_client.get_client(config).request_llm_message(
1615
**LLMFuncRequestFactory.build()
1716
)
1817
assert response == config.response_message
@@ -21,8 +20,6 @@ async def test_mock_client_request_llm_message_returns_config_value() -> None:
2120
async def test_mock_client_request_llm_partial_responses_returns_config_value() -> None:
2221
config: typing.Final = MockLLMConfigFactory.build()
2322
response: typing.Final = await consume_llm_partial_responses(
24-
any_llm_client.get_client(config, httpx_client=mock.Mock()).stream_llm_partial_messages(
25-
**LLMFuncRequestFactory.build()
26-
)
23+
any_llm_client.get_client(config).stream_llm_partial_messages(**LLMFuncRequestFactory.build())
2724
)
2825
assert response == config.stream_messages

tests/test_openai_client.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import typing
2-
from unittest import mock
32

43
import faker
54
import httpx
@@ -35,8 +34,7 @@ async def test_ok(self, faker: faker.Faker) -> None:
3534
)
3635

3736
result: typing.Final = await any_llm_client.get_client(
38-
OpenAIConfigFactory.build(),
39-
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
37+
OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: response)
4038
).request_llm_message(**LLMFuncRequestFactory.build())
4139

4240
assert result == expected_result
@@ -47,8 +45,7 @@ async def test_fails_without_alternatives(self) -> None:
4745
json=ChatCompletionsNotStreamingResponse.model_construct(choices=[]).model_dump(mode="json"),
4846
)
4947
client: typing.Final = any_llm_client.get_client(
50-
OpenAIConfigFactory.build(),
51-
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
48+
OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: response)
5249
)
5350

5451
with pytest.raises(pydantic.ValidationError):
@@ -89,10 +86,7 @@ async def test_ok(self, faker: faker.Faker) -> None:
8986
response: typing.Final = httpx.Response(
9087
200, headers={"Content-Type": "text/event-stream"}, content=response_content
9188
)
92-
client: typing.Final = any_llm_client.get_client(
93-
config,
94-
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
95-
)
89+
client: typing.Final = any_llm_client.get_client(config, transport=httpx.MockTransport(lambda _: response))
9690

9791
result: typing.Final = await consume_llm_partial_responses(client.stream_llm_partial_messages(**func_request))
9892

@@ -106,8 +100,7 @@ async def test_fails_without_alternatives(self) -> None:
106100
200, headers={"Content-Type": "text/event-stream"}, content=response_content
107101
)
108102
client: typing.Final = any_llm_client.get_client(
109-
OpenAIConfigFactory.build(),
110-
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
103+
OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: response)
111104
)
112105

113106
with pytest.raises(pydantic.ValidationError):
@@ -119,8 +112,7 @@ class TestOpenAILLMErrors:
119112
@pytest.mark.parametrize("status_code", [400, 500])
120113
async def test_fails_with_unknown_error(self, stream: bool, status_code: int) -> None:
121114
client: typing.Final = any_llm_client.get_client(
122-
OpenAIConfigFactory.build(),
123-
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(status_code))),
115+
OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: httpx.Response(status_code))
124116
)
125117

126118
coroutine: typing.Final = (
@@ -144,8 +136,7 @@ async def test_fails_with_unknown_error(self, stream: bool, status_code: int) ->
144136
async def test_fails_with_out_of_tokens_error(self, stream: bool, content: bytes | None) -> None:
145137
response: typing.Final = httpx.Response(400, content=content)
146138
client: typing.Final = any_llm_client.get_client(
147-
OpenAIConfigFactory.build(),
148-
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
139+
OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: response)
149140
)
150141

151142
coroutine: typing.Final = (
@@ -244,14 +235,13 @@ def test_with_alternation(
244235
self, messages: list[any_llm_client.Message], expected_result: list[ChatCompletionsMessage]
245236
) -> None:
246237
client: typing.Final = any_llm_client.OpenAIClient(
247-
config=OpenAIConfigFactory.build(force_user_assistant_message_alternation=True), httpx_client=mock.Mock()
238+
OpenAIConfigFactory.build(force_user_assistant_message_alternation=True)
248239
)
249240
assert client._prepare_messages(messages) == expected_result # noqa: SLF001
250241

251242
def test_without_alternation(self) -> None:
252243
client: typing.Final = any_llm_client.OpenAIClient(
253-
config=OpenAIConfigFactory.build(force_user_assistant_message_alternation=False),
254-
httpx_client=mock.Mock(),
244+
OpenAIConfigFactory.build(force_user_assistant_message_alternation=False)
255245
)
256246
assert client._prepare_messages( # noqa: SLF001
257247
[

tests/test_unknown_client.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from unittest import mock
2-
31
import faker
42
import pytest
53

@@ -8,4 +6,4 @@
86

97
def test_unknown_client_raises_assertion_error(faker: faker.Faker) -> None:
108
with pytest.raises(AssertionError):
11-
any_llm_client.get_client(faker.pyobject(), httpx_client=mock.Mock()) # type: ignore[arg-type]
9+
any_llm_client.get_client(faker.pyobject()) # type: ignore[arg-type]

tests/test_yandexgpt_client.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ async def test_ok(self, faker: faker.Faker) -> None:
2929
)
3030

3131
result: typing.Final = await any_llm_client.get_client(
32-
YandexGPTConfigFactory.build(),
33-
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
32+
YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: response)
3433
).request_llm_message(**LLMFuncRequestFactory.build())
3534

3635
assert result == expected_result
@@ -40,8 +39,7 @@ async def test_fails_without_alternatives(self) -> None:
4039
200, json=YandexGPTResponse(result=YandexGPTResult.model_construct(alternatives=[])).model_dump(mode="json")
4140
)
4241
client: typing.Final = any_llm_client.get_client(
43-
YandexGPTConfigFactory.build(),
44-
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
42+
YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: response)
4543
)
4644

4745
with pytest.raises(pydantic.ValidationError):
@@ -70,7 +68,7 @@ async def test_ok(self, faker: faker.Faker) -> None:
7068

7169
result: typing.Final = await consume_llm_partial_responses(
7270
any_llm_client.get_client(
73-
config, httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response))
71+
config, transport=httpx.MockTransport(lambda _: response)
7472
).stream_llm_partial_messages(**func_request)
7573
)
7674

@@ -83,8 +81,7 @@ async def test_fails_without_alternatives(self) -> None:
8381
response: typing.Final = httpx.Response(200, content=response_content)
8482

8583
client: typing.Final = any_llm_client.get_client(
86-
YandexGPTConfigFactory.build(),
87-
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
84+
YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: response)
8885
)
8986

9087
with pytest.raises(pydantic.ValidationError):
@@ -96,8 +93,7 @@ class TestYandexGPTLLMErrors:
9693
@pytest.mark.parametrize("status_code", [400, 500])
9794
async def test_fails_with_unknown_error(self, stream: bool, status_code: int) -> None:
9895
client: typing.Final = any_llm_client.get_client(
99-
YandexGPTConfigFactory.build(),
100-
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(status_code))),
96+
YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: httpx.Response(status_code))
10197
)
10298

10399
coroutine: typing.Final = (
@@ -121,8 +117,7 @@ async def test_fails_with_unknown_error(self, stream: bool, status_code: int) ->
121117
async def test_fails_with_out_of_tokens_error(self, stream: bool, response_content: bytes | None) -> None:
122118
response: typing.Final = httpx.Response(400, content=response_content)
123119
client: typing.Final = any_llm_client.get_client(
124-
YandexGPTConfigFactory.build(),
125-
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
120+
YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: response)
126121
)
127122

128123
coroutine: typing.Final = (

0 commit comments

Comments
 (0)