Skip to content

Commit 629b8a8

Browse files
authored
Make MessageRole an enum (#4)
1 parent 39228c9 commit 629b8a8

File tree

4 files changed

+56
-31
lines changed

4 files changed

+56
-31
lines changed

any_llm_client/clients/openai.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,15 @@
1111
import pydantic
1212
import typing_extensions
1313

14-
from any_llm_client.core import LLMClient, LLMConfig, LLMError, Message, MessageRole, OutOfTokensOrSymbolsError
14+
from any_llm_client.core import (
15+
LLMClient,
16+
LLMConfig,
17+
LLMError,
18+
Message,
19+
MessageRole,
20+
OutOfTokensOrSymbolsError,
21+
UserMessage,
22+
)
1523
from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request
1624
from any_llm_client.retry import RequestRetryConfig
1725

@@ -44,7 +52,7 @@ class ChatCompletionsRequest(pydantic.BaseModel):
4452

4553

4654
class OneStreamingChoiceDelta(pydantic.BaseModel):
47-
role: typing.Literal["assistant"] | None = None
55+
role: typing.Literal[MessageRole.assistant] | None = None
4856
content: str | None = None
4957

5058

@@ -67,16 +75,16 @@ class ChatCompletionsNotStreamingResponse(pydantic.BaseModel):
6775
def _make_user_assistant_alternate_messages(
6876
messages: typing.Iterable[ChatCompletionsMessage],
6977
) -> typing.Iterable[ChatCompletionsMessage]:
70-
current_message_role: MessageRole = "user"
78+
current_message_role = MessageRole.user
7179
current_message_content_chunks = []
7280

7381
for one_message in messages:
7482
if not one_message.content.strip():
7583
continue
7684

7785
if (
78-
one_message.role in {"system", "user"} and current_message_role == "user"
79-
) or one_message.role == current_message_role == "assistant":
86+
one_message.role in {MessageRole.system, MessageRole.user} and current_message_role == MessageRole.user
87+
) or one_message.role == current_message_role == MessageRole.assistant:
8088
current_message_content_chunks.append(one_message.content)
8189
else:
8290
if current_message_content_chunks:
@@ -122,7 +130,7 @@ def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request:
122130
)
123131

124132
def _prepare_messages(self, messages: str | list[Message]) -> list[ChatCompletionsMessage]:
125-
messages = [Message(role="user", text=messages)] if isinstance(messages, str) else messages
133+
messages = [UserMessage(messages)] if isinstance(messages, str) else messages
126134
initial_messages: typing.Final = (
127135
ChatCompletionsMessage(role=one_message.role, content=one_message.text) for one_message in messages
128136
)

any_llm_client/clients/yandexgpt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import pydantic
1111
import typing_extensions
1212

13-
from any_llm_client.core import LLMClient, LLMConfig, LLMError, Message, OutOfTokensOrSymbolsError
13+
from any_llm_client.core import LLMClient, LLMConfig, LLMError, Message, OutOfTokensOrSymbolsError, UserMessage
1414
from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request
1515
from any_llm_client.retry import RequestRetryConfig
1616

@@ -98,7 +98,7 @@ def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request:
9898
def _prepare_payload(
9999
self, *, messages: str | list[Message], temperature: float = 0.2, stream: bool
100100
) -> dict[str, typing.Any]:
101-
messages = [Message(role="user", text=messages)] if isinstance(messages, str) else messages
101+
messages = [UserMessage(messages)] if isinstance(messages, str) else messages
102102
return YandexGPTRequest(
103103
modelUri=f"gpt://{self.config.folder_id}/{self.config.model_name}/{self.config.model_version}",
104104
completionOptions=YandexGPTCompletionOptions(

any_llm_client/core.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
import contextlib
22
import dataclasses
3+
import enum
34
import types
45
import typing
56

67
import pydantic
78
import typing_extensions
89

910

10-
MessageRole = typing.Literal["system", "user", "assistant"]
11+
class MessageRole(str, enum.Enum):
12+
system = "system"
13+
user = "user"
14+
assistant = "assistant"
1115

1216

1317
@pydantic.dataclasses.dataclass(kw_only=True)
@@ -20,28 +24,28 @@ class Message:
2024

2125
@pydantic.dataclasses.dataclass
2226
class SystemMessage(Message):
23-
role: typing.Literal["system"] = pydantic.Field("system", init=False)
27+
role: typing.Literal[MessageRole.system] = pydantic.Field(MessageRole.system, init=False)
2428
text: str
2529

2630
@pydantic.dataclasses.dataclass
2731
class UserMessage(Message):
28-
role: typing.Literal["user"] = pydantic.Field("user", init=False)
32+
role: typing.Literal[MessageRole.user] = pydantic.Field(MessageRole.user, init=False)
2933
text: str
3034

3135
@pydantic.dataclasses.dataclass
3236
class AssistantMessage(Message):
33-
role: typing.Literal["assistant"] = pydantic.Field("assistant", init=False)
37+
role: typing.Literal[MessageRole.assistant] = pydantic.Field(MessageRole.assistant, init=False)
3438
text: str
3539
else:
3640

3741
def SystemMessage(text: str) -> Message: # noqa: N802
38-
return Message(role="system", text=text)
42+
return Message(role=MessageRole.system, text=text)
3943

4044
def UserMessage(text: str) -> Message: # noqa: N802
41-
return Message(role="user", text=text)
45+
return Message(role=MessageRole.user, text=text)
4246

4347
def AssistantMessage(text: str) -> Message: # noqa: N802
44-
return Message(role="assistant", text=text)
48+
return Message(role=MessageRole.assistant, text=text)
4549

4650

4751
@dataclasses.dataclass

tests/test_openai_client.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@ async def test_ok(self, faker: faker.Faker) -> None:
2828
200,
2929
json=ChatCompletionsNotStreamingResponse(
3030
choices=[
31-
OneNotStreamingChoice(message=ChatCompletionsMessage(role="assistant", content=expected_result))
31+
OneNotStreamingChoice(
32+
message=ChatCompletionsMessage(
33+
role=any_llm_client.MessageRole.assistant, content=expected_result
34+
)
35+
)
3236
]
3337
).model_dump(mode="json"),
3438
)
@@ -55,11 +59,11 @@ async def test_fails_without_alternatives(self) -> None:
5559
class TestOpenAIRequestLLMPartialResponses:
5660
async def test_ok(self, faker: faker.Faker) -> None:
5761
generated_messages: typing.Final = [
58-
OneStreamingChoiceDelta(role="assistant"),
62+
OneStreamingChoiceDelta(role=any_llm_client.MessageRole.assistant),
5963
OneStreamingChoiceDelta(content="H"),
6064
OneStreamingChoiceDelta(content="i"),
6165
OneStreamingChoiceDelta(content=" t"),
62-
OneStreamingChoiceDelta(role="assistant", content="here"),
66+
OneStreamingChoiceDelta(role=any_llm_client.MessageRole.assistant, content="here"),
6367
OneStreamingChoiceDelta(),
6468
OneStreamingChoiceDelta(content=". How is you"),
6569
OneStreamingChoiceDelta(content="r day?"),
@@ -171,12 +175,17 @@ class TestOpenAIMessageAlternation:
171175
],
172176
[],
173177
),
174-
([any_llm_client.SystemMessage("Be nice")], [ChatCompletionsMessage(role="user", content="Be nice")]),
178+
(
179+
[any_llm_client.SystemMessage("Be nice")],
180+
[ChatCompletionsMessage(role=any_llm_client.MessageRole.user, content="Be nice")],
181+
),
175182
(
176183
[any_llm_client.UserMessage("Hi there"), any_llm_client.AssistantMessage("Hi! How can I help you?")],
177184
[
178-
ChatCompletionsMessage(role="user", content="Hi there"),
179-
ChatCompletionsMessage(role="assistant", content="Hi! How can I help you?"),
185+
ChatCompletionsMessage(role=any_llm_client.MessageRole.user, content="Hi there"),
186+
ChatCompletionsMessage(
187+
role=any_llm_client.MessageRole.assistant, content="Hi! How can I help you?"
188+
),
180189
],
181190
),
182191
(
@@ -186,13 +195,15 @@ class TestOpenAIMessageAlternation:
186195
any_llm_client.AssistantMessage("Hi! How can I help you?"),
187196
],
188197
[
189-
ChatCompletionsMessage(role="user", content="Hi there"),
190-
ChatCompletionsMessage(role="assistant", content="Hi! How can I help you?"),
198+
ChatCompletionsMessage(role=any_llm_client.MessageRole.user, content="Hi there"),
199+
ChatCompletionsMessage(
200+
role=any_llm_client.MessageRole.assistant, content="Hi! How can I help you?"
201+
),
191202
],
192203
),
193204
(
194205
[any_llm_client.SystemMessage("Be nice"), any_llm_client.UserMessage("Hi there")],
195-
[ChatCompletionsMessage(role="user", content="Be nice\n\nHi there")],
206+
[ChatCompletionsMessage(role=any_llm_client.MessageRole.user, content="Be nice\n\nHi there")],
196207
),
197208
(
198209
[
@@ -210,14 +221,16 @@ class TestOpenAIMessageAlternation:
210221
any_llm_client.UserMessage("Hmmm..."),
211222
],
212223
[
213-
ChatCompletionsMessage(role="user", content="Be nice"),
224+
ChatCompletionsMessage(role=any_llm_client.MessageRole.user, content="Be nice"),
214225
ChatCompletionsMessage(
215-
role="assistant",
226+
role=any_llm_client.MessageRole.assistant,
216227
content="Hi!\n\nI'm your answer to everything.\n\nHow can I help you?",
217228
),
218-
ChatCompletionsMessage(role="user", content="Hi there\n\nWhy is the sky blue?"),
219-
ChatCompletionsMessage(role="assistant", content="Well..."),
220-
ChatCompletionsMessage(role="user", content="Hmmm..."),
229+
ChatCompletionsMessage(
230+
role=any_llm_client.MessageRole.user, content="Hi there\n\nWhy is the sky blue?"
231+
),
232+
ChatCompletionsMessage(role=any_llm_client.MessageRole.assistant, content="Well..."),
233+
ChatCompletionsMessage(role=any_llm_client.MessageRole.user, content="Hmmm..."),
221234
],
222235
),
223236
],
@@ -237,6 +250,6 @@ def test_without_alternation(self) -> None:
237250
assert client._prepare_messages( # noqa: SLF001
238251
[any_llm_client.SystemMessage("Be nice"), any_llm_client.UserMessage("Hi there")]
239252
) == [
240-
ChatCompletionsMessage(role="system", content="Be nice"),
241-
ChatCompletionsMessage(role="user", content="Hi there"),
253+
ChatCompletionsMessage(role=any_llm_client.MessageRole.system, content="Be nice"),
254+
ChatCompletionsMessage(role=any_llm_client.MessageRole.user, content="Hi there"),
242255
]

0 commit comments

Comments
 (0)