Skip to content

Commit 2b3ceb9

Browse files
authored
Allow to pass extra data to LLM (#8)
1 parent c545e43 commit 2b3ceb9

File tree

7 files changed

+69
-15
lines changed

7 files changed

+69
-15
lines changed

README.md

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ config = any_llm_client.MockLLMConfig(
9999
response_message=...,
100100
stream_messages=["Hi!"],
101101
)
102-
client = any_llm_client.get_client(config, ...)
102+
103+
async with any_llm_client.get_client(config, ...) as client:
104+
...
103105
```
104106

105107
#### Configuration with environment variables
@@ -131,7 +133,9 @@ os.environ["LLM_MODEL"] = """{
131133
"model_name": "qwen2.5-coder:1.5b"
132134
}"""
133135
settings = Settings()
134-
client = any_llm_client.get_client(settings.llm_model, ...)
136+
137+
async with any_llm_client.get_client(settings.llm_model, ...) as client:
138+
...
135139
```
136140

137141
Combining with environment variables from previous section, you can keep LLM model configuration and secrets separate.
@@ -146,7 +150,9 @@ config = any_llm_client.OpenAIConfig(
146150
auth_token=os.environ["OPENAI_API_KEY"],
147151
model_name="gpt-4o-mini",
148152
)
149-
client = any_llm_client.OpenAIClient(config, ...)
153+
154+
async with any_llm_client.OpenAIClient(config, ...) as client:
155+
...
150156
```
151157

152158
#### Errors
@@ -179,5 +185,12 @@ Default timeout is `httpx.Timeout(None, connect=5.0)` (5 seconds on connect, unl
179185
By default, requests are retried 3 times on HTTP status errors. You can change the retry behaviour by supplying `request_retry` parameter:
180186

181187
```python
182-
client = any_llm_client.get_client(..., request_retry=any_llm_client.RequestRetryConfig(attempts=5, ...))
188+
async with any_llm_client.get_client(..., request_retry=any_llm_client.RequestRetryConfig(attempts=5, ...)) as client:
189+
...
190+
```
191+
192+
#### Passing extra data to LLM
193+
194+
```python
195+
await client.request_llm_message("Кек, чо как вообще на нарах?", extra={"best_of": 3})
183196
```

any_llm_client/clients/mock.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@ class MockLLMConfig(LLMConfig):
1919
class MockLLMClient(LLMClient):
2020
config: MockLLMConfig
2121

22-
async def request_llm_message(self, messages: str | list[Message], temperature: float = 0.2) -> str: # noqa: ARG002
22+
async def request_llm_message(
23+
self,
24+
messages: str | list[Message], # noqa: ARG002
25+
*,
26+
temperature: float = 0.2, # noqa: ARG002
27+
extra: dict[str, typing.Any] | None = None, # noqa: ARG002
28+
) -> str:
2329
return self.config.response_message
2430

2531
async def _iter_config_stream_messages(self) -> typing.AsyncIterable[str]:
@@ -30,7 +36,9 @@ async def _iter_config_stream_messages(self) -> typing.AsyncIterable[str]:
3036
async def stream_llm_partial_messages(
3137
self,
3238
messages: str | list[Message], # noqa: ARG002
39+
*,
3340
temperature: float = 0.2, # noqa: ARG002
41+
extra: dict[str, typing.Any] | None = None, # noqa: ARG002
3442
) -> typing.AsyncIterator[typing.AsyncIterable[str]]:
3543
yield self._iter_config_stream_messages()
3644

any_llm_client/clients/openai.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class ChatCompletionsMessage(pydantic.BaseModel):
4545

4646

4747
class ChatCompletionsRequest(pydantic.BaseModel):
48+
model_config = pydantic.ConfigDict(extra="allow")
4849
stream: bool
4950
model: str
5051
messages: list[ChatCompletionsMessage]
@@ -140,12 +141,15 @@ def _prepare_messages(self, messages: str | list[Message]) -> list[ChatCompletio
140141
else list(initial_messages)
141142
)
142143

143-
async def request_llm_message(self, messages: str | list[Message], temperature: float = 0.2) -> str:
144+
async def request_llm_message(
145+
self, messages: str | list[Message], *, temperature: float = 0.2, extra: dict[str, typing.Any] | None = None
146+
) -> str:
144147
payload: typing.Final = ChatCompletionsRequest(
145148
stream=False,
146149
model=self.config.model_name,
147150
messages=self._prepare_messages(messages),
148151
temperature=temperature,
152+
**extra or {},
149153
).model_dump(mode="json")
150154
try:
151155
response: typing.Final = await make_http_request(
@@ -173,13 +177,14 @@ async def _iter_partial_responses(self, response: httpx.Response) -> typing.Asyn
173177

174178
@contextlib.asynccontextmanager
175179
async def stream_llm_partial_messages(
176-
self, messages: str | list[Message], temperature: float = 0.2
180+
self, messages: str | list[Message], *, temperature: float = 0.2, extra: dict[str, typing.Any] | None = None
177181
) -> typing.AsyncIterator[typing.AsyncIterable[str]]:
178182
payload: typing.Final = ChatCompletionsRequest(
179183
stream=True,
180184
model=self.config.model_name,
181185
messages=self._prepare_messages(messages),
182186
temperature=temperature,
187+
**extra or {},
183188
).model_dump(mode="json")
184189
try:
185190
async with make_streaming_http_request(

any_llm_client/clients/yandexgpt.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class YandexGPTCompletionOptions(pydantic.BaseModel):
4343

4444

4545
class YandexGPTRequest(pydantic.BaseModel):
46-
model_config = pydantic.ConfigDict(protected_namespaces=())
46+
model_config = pydantic.ConfigDict(protected_namespaces=(), extra="allow")
4747
model_uri: str = pydantic.Field(alias="modelUri")
4848
completion_options: YandexGPTCompletionOptions = pydantic.Field(alias="completionOptions")
4949
messages: list[Message]
@@ -96,7 +96,12 @@ def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request:
9696
)
9797

9898
def _prepare_payload(
99-
self, *, messages: str | list[Message], temperature: float = 0.2, stream: bool
99+
self,
100+
*,
101+
messages: str | list[Message],
102+
temperature: float = 0.2,
103+
stream: bool,
104+
extra: dict[str, typing.Any] | None,
100105
) -> dict[str, typing.Any]:
101106
messages = [UserMessage(messages)] if isinstance(messages, str) else messages
102107
return YandexGPTRequest(
@@ -105,10 +110,15 @@ def _prepare_payload(
105110
stream=stream, temperature=temperature, maxTokens=self.config.max_tokens
106111
),
107112
messages=messages,
113+
**extra or {},
108114
).model_dump(mode="json", by_alias=True)
109115

110-
async def request_llm_message(self, messages: str | list[Message], temperature: float = 0.2) -> str:
111-
payload: typing.Final = self._prepare_payload(messages=messages, temperature=temperature, stream=False)
116+
async def request_llm_message(
117+
self, messages: str | list[Message], *, temperature: float = 0.2, extra: dict[str, typing.Any] | None = None
118+
) -> str:
119+
payload: typing.Final = self._prepare_payload(
120+
messages=messages, temperature=temperature, stream=False, extra=extra
121+
)
112122

113123
try:
114124
response: typing.Final = await make_http_request(
@@ -128,9 +138,11 @@ async def _iter_completion_messages(self, response: httpx.Response) -> typing.As
128138

129139
@contextlib.asynccontextmanager
130140
async def stream_llm_partial_messages(
131-
self, messages: str | list[Message], temperature: float = 0.2
141+
self, messages: str | list[Message], *, temperature: float = 0.2, extra: dict[str, typing.Any] | None = None
132142
) -> typing.AsyncIterator[typing.AsyncIterable[str]]:
133-
payload: typing.Final = self._prepare_payload(messages=messages, temperature=temperature, stream=True)
143+
payload: typing.Final = self._prepare_payload(
144+
messages=messages, temperature=temperature, stream=True, extra=extra
145+
)
134146

135147
try:
136148
async with make_streaming_http_request(

any_llm_client/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,12 @@ class LLMConfig(pydantic.BaseModel):
6868
@dataclasses.dataclass(slots=True, init=False)
6969
class LLMClient(typing.Protocol):
7070
async def request_llm_message(
71-
self, messages: str | list[Message], *, temperature: float = 0.2
71+
self, messages: str | list[Message], *, temperature: float = 0.2, extra: dict[str, typing.Any] | None = None
7272
) -> str: ... # raises LLMError
7373

7474
@contextlib.asynccontextmanager
7575
def stream_llm_partial_messages(
76-
self, messages: str | list[Message], temperature: float = 0.2
76+
self, messages: str | list[Message], *, temperature: float = 0.2, extra: dict[str, typing.Any] | None = None
7777
) -> typing.AsyncIterator[typing.AsyncIterable[str]]: ... # raises LLMError
7878

7979
async def __aenter__(self) -> typing_extensions.Self: ...

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def _deactivate_retries() -> None:
2121
class LLMFuncRequest(typing.TypedDict):
2222
messages: str | list[any_llm_client.Message]
2323
temperature: float
24+
extra: dict[str, typing.Any] | None
2425

2526

2627
class LLMFuncRequestFactory(TypedDictFactory[LLMFuncRequest]): ...

tests/test_static.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,14 @@
22
import typing
33

44
import faker
5+
import pydantic
6+
import pytest
57
import stamina
8+
from polyfactory.factories.pydantic_factory import ModelFactory
69

710
import any_llm_client
11+
from any_llm_client.clients.openai import ChatCompletionsRequest
12+
from any_llm_client.clients.yandexgpt import YandexGPTRequest
813
from tests.conftest import LLMFuncRequest
914

1015

@@ -40,3 +45,13 @@ def test_llm_func_request_has_same_annotations_as_llm_client_methods() -> None:
4045
annotations.pop(one_ignored_prop)
4146

4247
assert all(annotations == all_annotations[0] for annotations in all_annotations)
48+
49+
50+
@pytest.mark.parametrize("model_type", [YandexGPTRequest, ChatCompletionsRequest])
51+
def test_dumped_llm_request_payload_dump_has_extra_data(model_type: type[pydantic.BaseModel]) -> None:
52+
extra: typing.Final = {"hi": "there", "hi-hi": "there-there"}
53+
generated_data: typing.Final = ModelFactory.create_factory(model_type).build(**extra).model_dump(by_alias=True) # type: ignore[arg-type]
54+
dumped_model: typing.Final = model_type(**{**generated_data, **extra}).model_dump(mode="json", by_alias=True)
55+
56+
assert dumped_model["hi"] == "there"
57+
assert dumped_model["hi-hi"] == "there-there"

0 commit comments

Comments
 (0)