Skip to content

Commit 39228c9

Browse files
authored
Add mesage shortcuts SystemMessage, UserMessage, AssistantMessage (#3)
1 parent 55eb707 commit 39228c9

File tree

10 files changed

+90
-71
lines changed

10 files changed

+90
-71
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,15 @@ Note that this will yield partial growing message, not message chunks, for examp
7272

7373
### Passing chat history and temperature
7474

75-
You can pass `list[any_llm_client.Message]` instead of `str` as the first argument, and set `temperature`:
75+
You can pass list of messages instead of `str` as the first argument, and set `temperature`:
7676

7777
```python
7878
async with (
7979
any_llm_client.get_client(config) as client,
8080
client.stream_llm_partial_messages(
8181
messages=[
82-
any_llm_client.Message(role="system", text="Ты — опытный ассистент"),
83-
any_llm_client.Message(role="user", text="Кек, чо как вообще на нарах?"),
82+
any_llm_client.SystemMessage("Ты — опытный ассистент"),
83+
any_llm_client.UserMessage("Кек, чо как вообще на нарах?"),
8484
],
8585
temperature=1.0,
8686
) as partial_messages,

any_llm_client/__init__.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,38 @@
11
from any_llm_client.clients.mock import MockLLMClient, MockLLMConfig
22
from any_llm_client.clients.openai import OpenAIClient, OpenAIConfig
33
from any_llm_client.clients.yandexgpt import YandexGPTClient, YandexGPTConfig
4-
from any_llm_client.core import LLMClient, LLMConfig, LLMError, Message, MessageRole, OutOfTokensOrSymbolsError
4+
from any_llm_client.core import (
5+
AssistantMessage,
6+
LLMClient,
7+
LLMConfig,
8+
LLMError,
9+
Message,
10+
MessageRole,
11+
OutOfTokensOrSymbolsError,
12+
SystemMessage,
13+
UserMessage,
14+
)
515
from any_llm_client.main import AnyLLMConfig, get_client
616
from any_llm_client.retry import RequestRetryConfig
717

818

919
__all__ = [
20+
"AnyLLMConfig",
21+
"AssistantMessage",
1022
"LLMClient",
1123
"LLMConfig",
1224
"LLMError",
1325
"Message",
1426
"MessageRole",
15-
"OutOfTokensOrSymbolsError",
1627
"MockLLMClient",
1728
"MockLLMConfig",
1829
"OpenAIClient",
1930
"OpenAIConfig",
31+
"OutOfTokensOrSymbolsError",
32+
"RequestRetryConfig",
33+
"SystemMessage",
34+
"UserMessage",
2035
"YandexGPTClient",
2136
"YandexGPTConfig",
2237
"get_client",
23-
"AnyLLMConfig",
24-
"RequestRetryConfig",
2538
]

any_llm_client/clients/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
class OpenAIConfig(LLMConfig):
2323
if typing.TYPE_CHECKING:
24-
url: str # pragma: no cover
24+
url: str
2525
else:
2626
url: pydantic.HttpUrl
2727
auth_token: str | None = pydantic.Field(default_factory=lambda: os.environ.get(OPENAI_AUTH_TOKEN_ENV_NAME))

any_llm_client/clients/yandexgpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
class YandexGPTConfig(LLMConfig):
2323
if typing.TYPE_CHECKING:
24-
url: str = "https://llm.api.cloud.yandex.net/foundationModels/v1/completion" # pragma: no cover
24+
url: str = "https://llm.api.cloud.yandex.net/foundationModels/v1/completion"
2525
else:
2626
url: pydantic.HttpUrl = "https://llm.api.cloud.yandex.net/foundationModels/v1/completion"
2727
auth_header: str = pydantic.Field( # type: ignore[assignment]

any_llm_client/core.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,40 @@
1010
MessageRole = typing.Literal["system", "user", "assistant"]
1111

1212

13-
class Message(pydantic.BaseModel):
13+
@pydantic.dataclasses.dataclass(kw_only=True)
14+
class Message:
1415
role: MessageRole
1516
text: str
1617

1718

19+
if typing.TYPE_CHECKING:
20+
21+
@pydantic.dataclasses.dataclass
22+
class SystemMessage(Message):
23+
role: typing.Literal["system"] = pydantic.Field("system", init=False)
24+
text: str
25+
26+
@pydantic.dataclasses.dataclass
27+
class UserMessage(Message):
28+
role: typing.Literal["user"] = pydantic.Field("user", init=False)
29+
text: str
30+
31+
@pydantic.dataclasses.dataclass
32+
class AssistantMessage(Message):
33+
role: typing.Literal["assistant"] = pydantic.Field("assistant", init=False)
34+
text: str
35+
else:
36+
37+
def SystemMessage(text: str) -> Message: # noqa: N802
38+
return Message(role="system", text=text)
39+
40+
def UserMessage(text: str) -> Message: # noqa: N802
41+
return Message(role="user", text=text)
42+
43+
def AssistantMessage(text: str) -> Message: # noqa: N802
44+
return Message(role="assistant", text=text)
45+
46+
1847
@dataclasses.dataclass
1948
class LLMError(Exception):
2049
response_content: bytes

any_llm_client/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def get_client(
1818
*,
1919
request_retry: RequestRetryConfig | None = None,
2020
**httpx_kwargs: typing.Any, # noqa: ANN401
21-
) -> LLMClient: ... # pragma: no cover
21+
) -> LLMClient: ...
2222
else:
2323

2424
@functools.singledispatch

examples/openai_stream_advanced.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ async def main() -> None:
1111
any_llm_client.get_client(config) as client,
1212
client.stream_llm_partial_messages(
1313
messages=[
14-
any_llm_client.Message(role="system", text="Ты — опытный ассистент"),
15-
any_llm_client.Message(role="user", text="Кек, чо как вообще на нарах?"),
14+
any_llm_client.SystemMessage("Ты — опытный ассистент"),
15+
any_llm_client.UserMessage("Кек, чо как вообще на нарах?"),
1616
],
1717
temperature=1.0,
1818
) as partial_messages,

pyproject.toml

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,7 @@ dev = [
3232
"pytest-cov",
3333
"pytest",
3434
]
35-
lint = [
36-
{include="dev"},
37-
"auto-typing-final",
38-
"mypy",
39-
"ruff",
40-
]
35+
lint = [{ include-group = "dev" }, "auto-typing-final", "mypy", "ruff"]
4136

4237
[build-system]
4338
requires = ["hatchling", "hatch-vcs"]
@@ -73,8 +68,6 @@ ignore = [
7368
"D213",
7469
"G004",
7570
"FA",
76-
"ANN101",
77-
"ANN102",
7871
"COM812",
7972
"ISC001",
8073
]
@@ -94,4 +87,4 @@ addopts = "--cov"
9487
[tool.coverage.report]
9588
skip_covered = true
9689
show_missing = true
97-
exclude_also = ["if TYPE_CHECKING:"]
90+
exclude_also = ["if typing.TYPE_CHECKING:"]

tests/test_openai_client.py

Lines changed: 31 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -154,69 +154,60 @@ class TestOpenAIMessageAlternation:
154154
("messages", "expected_result"),
155155
[
156156
([], []),
157-
([any_llm_client.Message(role="system", text="")], []),
158-
([any_llm_client.Message(role="system", text=" ")], []),
159-
([any_llm_client.Message(role="user", text="")], []),
160-
([any_llm_client.Message(role="assistant", text="")], []),
161-
([any_llm_client.Message(role="system", text=""), any_llm_client.Message(role="user", text="")], []),
162-
([any_llm_client.Message(role="system", text=""), any_llm_client.Message(role="assistant", text="")], []),
157+
([any_llm_client.SystemMessage("")], []),
158+
([any_llm_client.SystemMessage(" ")], []),
159+
([any_llm_client.UserMessage("")], []),
160+
([any_llm_client.AssistantMessage("")], []),
161+
([any_llm_client.SystemMessage(""), any_llm_client.UserMessage("")], []),
162+
([any_llm_client.SystemMessage(""), any_llm_client.AssistantMessage("")], []),
163163
(
164164
[
165-
any_llm_client.Message(role="system", text=""),
166-
any_llm_client.Message(role="user", text=""),
167-
any_llm_client.Message(role="assistant", text=""),
168-
any_llm_client.Message(role="assistant", text=""),
169-
any_llm_client.Message(role="user", text=""),
170-
any_llm_client.Message(role="assistant", text=""),
165+
any_llm_client.SystemMessage(""),
166+
any_llm_client.UserMessage(""),
167+
any_llm_client.AssistantMessage(""),
168+
any_llm_client.AssistantMessage(""),
169+
any_llm_client.UserMessage(""),
170+
any_llm_client.AssistantMessage(""),
171171
],
172172
[],
173173
),
174+
([any_llm_client.SystemMessage("Be nice")], [ChatCompletionsMessage(role="user", content="Be nice")]),
174175
(
175-
[any_llm_client.Message(role="system", text="Be nice")],
176-
[ChatCompletionsMessage(role="user", content="Be nice")],
177-
),
178-
(
179-
[
180-
any_llm_client.Message(role="user", text="Hi there"),
181-
any_llm_client.Message(role="assistant", text="Hi! How can I help you?"),
182-
],
176+
[any_llm_client.UserMessage("Hi there"), any_llm_client.AssistantMessage("Hi! How can I help you?")],
183177
[
184178
ChatCompletionsMessage(role="user", content="Hi there"),
185179
ChatCompletionsMessage(role="assistant", content="Hi! How can I help you?"),
186180
],
187181
),
188182
(
189183
[
190-
any_llm_client.Message(role="system", text=""),
191-
any_llm_client.Message(role="user", text="Hi there"),
192-
any_llm_client.Message(role="assistant", text="Hi! How can I help you?"),
184+
any_llm_client.SystemMessage(""),
185+
any_llm_client.UserMessage("Hi there"),
186+
any_llm_client.AssistantMessage("Hi! How can I help you?"),
193187
],
194188
[
195189
ChatCompletionsMessage(role="user", content="Hi there"),
196190
ChatCompletionsMessage(role="assistant", content="Hi! How can I help you?"),
197191
],
198192
),
199193
(
200-
[
201-
any_llm_client.Message(role="system", text="Be nice"),
202-
any_llm_client.Message(role="user", text="Hi there"),
203-
],
194+
[any_llm_client.SystemMessage("Be nice"), any_llm_client.UserMessage("Hi there")],
204195
[ChatCompletionsMessage(role="user", content="Be nice\n\nHi there")],
205196
),
206197
(
207198
[
208-
any_llm_client.Message(role="system", text="Be nice"),
209-
any_llm_client.Message(role="assistant", text="Hi!"),
210-
any_llm_client.Message(role="assistant", text="I'm your answer to everything."),
211-
any_llm_client.Message(role="assistant", text="How can I help you?"),
212-
any_llm_client.Message(role="user", text="Hi there"),
213-
any_llm_client.Message(role="user", text=""),
214-
any_llm_client.Message(role="user", text="Why is the sky blue?"),
215-
any_llm_client.Message(role="assistant", text=" "),
216-
any_llm_client.Message(role="assistant", text="Well..."),
217-
any_llm_client.Message(role="assistant", text=""),
218-
any_llm_client.Message(role="assistant", text=" \n "),
219-
any_llm_client.Message(role="user", text="Hmmm..."),
199+
any_llm_client.SystemMessage("Be nice"),
200+
any_llm_client.AssistantMessage("Hi!"),
201+
any_llm_client.AssistantMessage("I'm your answer to everything."),
202+
any_llm_client.AssistantMessage("How can I help you?"),
203+
any_llm_client.UserMessage("Hi there"),
204+
any_llm_client.UserMessage(""),
205+
any_llm_client.UserMessage("Why is the sky blue?"),
206+
any_llm_client.AssistantMessage(" "),
207+
any_llm_client.AssistantMessage("Well..."),
208+
any_llm_client.AssistantMessage(""),
209+
any_llm_client.AssistantMessage(" \n "),
210+
any_llm_client.UserMessage("Hmmm..."),
220211
],
221212
[
222213
ChatCompletionsMessage(role="user", content="Be nice"),
@@ -244,10 +235,7 @@ def test_without_alternation(self) -> None:
244235
OpenAIConfigFactory.build(force_user_assistant_message_alternation=False)
245236
)
246237
assert client._prepare_messages( # noqa: SLF001
247-
[
248-
any_llm_client.Message(role="system", text="Be nice"),
249-
any_llm_client.Message(role="user", text="Hi there"),
250-
]
238+
[any_llm_client.SystemMessage("Be nice"), any_llm_client.UserMessage("Hi there")]
251239
) == [
252240
ChatCompletionsMessage(role="system", content="Be nice"),
253241
ChatCompletionsMessage(role="user", content="Hi there"),

tests/test_yandexgpt_client.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ async def test_ok(self, faker: faker.Faker) -> None:
2121
200,
2222
json=YandexGPTResponse(
2323
result=YandexGPTResult(
24-
alternatives=[
25-
YandexGPTAlternative(message=any_llm_client.Message(role="assistant", text=expected_result))
26-
]
24+
alternatives=[YandexGPTAlternative(message=any_llm_client.AssistantMessage(expected_result))]
2725
)
2826
).model_dump(mode="json"),
2927
)
@@ -55,9 +53,7 @@ async def test_ok(self, faker: faker.Faker) -> None:
5553
"\n".join(
5654
YandexGPTResponse(
5755
result=YandexGPTResult(
58-
alternatives=[
59-
YandexGPTAlternative(message=any_llm_client.Message(role="assistant", text=one_text))
60-
]
56+
alternatives=[YandexGPTAlternative(message=any_llm_client.AssistantMessage(one_text))]
6157
)
6258
).model_dump_json()
6359
for one_text in expected_result

0 commit comments

Comments
 (0)