From 8e74c99663622419bbc79cb9fe4d575ff137359d Mon Sep 17 00:00:00 2001 From: Lev Vereshchagin Date: Fri, 22 Nov 2024 14:02:21 +0300 Subject: [PATCH] Make MessageRole an enum --- any_llm_client/clients/openai.py | 20 +++++++++---- any_llm_client/clients/yandexgpt.py | 4 +-- any_llm_client/core.py | 18 +++++++----- tests/test_openai_client.py | 45 +++++++++++++++++++---------- 4 files changed, 56 insertions(+), 31 deletions(-) diff --git a/any_llm_client/clients/openai.py b/any_llm_client/clients/openai.py index 7dfbc07..332b46c 100644 --- a/any_llm_client/clients/openai.py +++ b/any_llm_client/clients/openai.py @@ -11,7 +11,15 @@ import pydantic import typing_extensions -from any_llm_client.core import LLMClient, LLMConfig, LLMError, Message, MessageRole, OutOfTokensOrSymbolsError +from any_llm_client.core import ( + LLMClient, + LLMConfig, + LLMError, + Message, + MessageRole, + OutOfTokensOrSymbolsError, + UserMessage, +) from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request from any_llm_client.retry import RequestRetryConfig @@ -44,7 +52,7 @@ class ChatCompletionsRequest(pydantic.BaseModel): class OneStreamingChoiceDelta(pydantic.BaseModel): - role: typing.Literal["assistant"] | None = None + role: typing.Literal[MessageRole.assistant] | None = None content: str | None = None @@ -67,7 +75,7 @@ class ChatCompletionsNotStreamingResponse(pydantic.BaseModel): def _make_user_assistant_alternate_messages( messages: typing.Iterable[ChatCompletionsMessage], ) -> typing.Iterable[ChatCompletionsMessage]: - current_message_role: MessageRole = "user" + current_message_role = MessageRole.user current_message_content_chunks = [] for one_message in messages: @@ -75,8 +83,8 @@ def _make_user_assistant_alternate_messages( continue if ( - one_message.role in {"system", "user"} and current_message_role == "user" - ) or one_message.role == current_message_role == "assistant": + one_message.role in {MessageRole.system, MessageRole.user} and current_message_role == MessageRole.user + ) or one_message.role == current_message_role == MessageRole.assistant: current_message_content_chunks.append(one_message.content) else: if current_message_content_chunks: @@ -122,7 +130,7 @@ def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request: ) def _prepare_messages(self, messages: str | list[Message]) -> list[ChatCompletionsMessage]: - messages = [Message(role="user", text=messages)] if isinstance(messages, str) else messages + messages = [UserMessage(messages)] if isinstance(messages, str) else messages initial_messages: typing.Final = ( ChatCompletionsMessage(role=one_message.role, content=one_message.text) for one_message in messages ) diff --git a/any_llm_client/clients/yandexgpt.py b/any_llm_client/clients/yandexgpt.py index 7472d95..b684800 100644 --- a/any_llm_client/clients/yandexgpt.py +++ b/any_llm_client/clients/yandexgpt.py @@ -10,7 +10,7 @@ import pydantic import typing_extensions -from any_llm_client.core import LLMClient, LLMConfig, LLMError, Message, OutOfTokensOrSymbolsError +from any_llm_client.core import LLMClient, LLMConfig, LLMError, Message, OutOfTokensOrSymbolsError, UserMessage from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request from any_llm_client.retry import RequestRetryConfig @@ -98,7 +98,7 @@ def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request: def _prepare_payload( self, *, messages: str | list[Message], temperature: float = 0.2, stream: bool ) -> dict[str, typing.Any]: - messages = [Message(role="user", text=messages)] if isinstance(messages, str) else messages + messages = [UserMessage(messages)] if isinstance(messages, str) else messages return YandexGPTRequest( modelUri=f"gpt://{self.config.folder_id}/{self.config.model_name}/{self.config.model_version}", completionOptions=YandexGPTCompletionOptions( diff --git a/any_llm_client/core.py b/any_llm_client/core.py index 97dbe8a..b071ec9 100644 --- a/any_llm_client/core.py +++ b/any_llm_client/core.py @@ -1,5 +1,6 @@ import contextlib import dataclasses +import enum import types import typing @@ -7,7 +8,10 @@ import typing_extensions -MessageRole = typing.Literal["system", "user", "assistant"] +class MessageRole(str, enum.Enum): + system = "system" + user = "user" + assistant = "assistant" @pydantic.dataclasses.dataclass(kw_only=True) @@ -20,28 +24,28 @@ class Message: @pydantic.dataclasses.dataclass class SystemMessage(Message): - role: typing.Literal["system"] = pydantic.Field("system", init=False) + role: typing.Literal[MessageRole.system] = pydantic.Field(MessageRole.system, init=False) text: str @pydantic.dataclasses.dataclass class UserMessage(Message): - role: typing.Literal["user"] = pydantic.Field("user", init=False) + role: typing.Literal[MessageRole.user] = pydantic.Field(MessageRole.user, init=False) text: str @pydantic.dataclasses.dataclass class AssistantMessage(Message): - role: typing.Literal["assistant"] = pydantic.Field("assistant", init=False) + role: typing.Literal[MessageRole.assistant] = pydantic.Field(MessageRole.assistant, init=False) text: str else: def SystemMessage(text: str) -> Message: # noqa: N802 - return Message(role="system", text=text) + return Message(role=MessageRole.system, text=text) def UserMessage(text: str) -> Message: # noqa: N802 - return Message(role="user", text=text) + return Message(role=MessageRole.user, text=text) def AssistantMessage(text: str) -> Message: # noqa: N802 - return Message(role="assistant", text=text) + return Message(role=MessageRole.assistant, text=text) @dataclasses.dataclass diff --git a/tests/test_openai_client.py b/tests/test_openai_client.py index 7c9f419..c563623 100644 --- a/tests/test_openai_client.py +++ b/tests/test_openai_client.py @@ -28,7 +28,11 @@ async def test_ok(self, faker: faker.Faker) -> None: 200, json=ChatCompletionsNotStreamingResponse( choices=[ - OneNotStreamingChoice(message=ChatCompletionsMessage(role="assistant", content=expected_result)) + OneNotStreamingChoice( + message=ChatCompletionsMessage( + role=any_llm_client.MessageRole.assistant, content=expected_result + ) + ) ] ).model_dump(mode="json"), ) @@ -55,11 +59,11 @@ async def test_fails_without_alternatives(self) -> None: class TestOpenAIRequestLLMPartialResponses: async def test_ok(self, faker: faker.Faker) -> None: generated_messages: typing.Final = [ - OneStreamingChoiceDelta(role="assistant"), + OneStreamingChoiceDelta(role=any_llm_client.MessageRole.assistant), OneStreamingChoiceDelta(content="H"), OneStreamingChoiceDelta(content="i"), OneStreamingChoiceDelta(content=" t"), - OneStreamingChoiceDelta(role="assistant", content="here"), + OneStreamingChoiceDelta(role=any_llm_client.MessageRole.assistant, content="here"), OneStreamingChoiceDelta(), OneStreamingChoiceDelta(content=". How is you"), OneStreamingChoiceDelta(content="r day?"), @@ -171,12 +175,17 @@ class TestOpenAIMessageAlternation: ], [], ), - ([any_llm_client.SystemMessage("Be nice")], [ChatCompletionsMessage(role="user", content="Be nice")]), + ( + [any_llm_client.SystemMessage("Be nice")], + [ChatCompletionsMessage(role=any_llm_client.MessageRole.user, content="Be nice")], + ), ( [any_llm_client.UserMessage("Hi there"), any_llm_client.AssistantMessage("Hi! How can I help you?")], [ - ChatCompletionsMessage(role="user", content="Hi there"), - ChatCompletionsMessage(role="assistant", content="Hi! How can I help you?"), + ChatCompletionsMessage(role=any_llm_client.MessageRole.user, content="Hi there"), + ChatCompletionsMessage( + role=any_llm_client.MessageRole.assistant, content="Hi! How can I help you?" + ), ], ), ( @@ -186,13 +195,15 @@ class TestOpenAIMessageAlternation: any_llm_client.AssistantMessage("Hi! How can I help you?"), ], [ - ChatCompletionsMessage(role="user", content="Hi there"), - ChatCompletionsMessage(role="assistant", content="Hi! How can I help you?"), + ChatCompletionsMessage(role=any_llm_client.MessageRole.user, content="Hi there"), + ChatCompletionsMessage( + role=any_llm_client.MessageRole.assistant, content="Hi! How can I help you?" + ), ], ), ( [any_llm_client.SystemMessage("Be nice"), any_llm_client.UserMessage("Hi there")], - [ChatCompletionsMessage(role="user", content="Be nice\n\nHi there")], + [ChatCompletionsMessage(role=any_llm_client.MessageRole.user, content="Be nice\n\nHi there")], ), ( [ @@ -210,14 +221,16 @@ class TestOpenAIMessageAlternation: any_llm_client.UserMessage("Hmmm..."), ], [ - ChatCompletionsMessage(role="user", content="Be nice"), + ChatCompletionsMessage(role=any_llm_client.MessageRole.user, content="Be nice"), ChatCompletionsMessage( - role="assistant", + role=any_llm_client.MessageRole.assistant, content="Hi!\n\nI'm your answer to everything.\n\nHow can I help you?", ), - ChatCompletionsMessage(role="user", content="Hi there\n\nWhy is the sky blue?"), - ChatCompletionsMessage(role="assistant", content="Well..."), - ChatCompletionsMessage(role="user", content="Hmmm..."), + ChatCompletionsMessage( + role=any_llm_client.MessageRole.user, content="Hi there\n\nWhy is the sky blue?" + ), + ChatCompletionsMessage(role=any_llm_client.MessageRole.assistant, content="Well..."), + ChatCompletionsMessage(role=any_llm_client.MessageRole.user, content="Hmmm..."), ], ), ], @@ -237,6 +250,6 @@ def test_without_alternation(self) -> None: assert client._prepare_messages( # noqa: SLF001 [any_llm_client.SystemMessage("Be nice"), any_llm_client.UserMessage("Hi there")] ) == [ - ChatCompletionsMessage(role="system", content="Be nice"), - ChatCompletionsMessage(role="user", content="Hi there"), + ChatCompletionsMessage(role=any_llm_client.MessageRole.system, content="Be nice"), + ChatCompletionsMessage(role=any_llm_client.MessageRole.user, content="Hi there"), ]