Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions any_llm_client/clients/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@
import pydantic
import typing_extensions

from any_llm_client.core import LLMClient, LLMConfig, LLMError, Message, MessageRole, OutOfTokensOrSymbolsError
from any_llm_client.core import (
LLMClient,
LLMConfig,
LLMError,
Message,
MessageRole,
OutOfTokensOrSymbolsError,
UserMessage,
)
from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request
from any_llm_client.retry import RequestRetryConfig

Expand Down Expand Up @@ -44,7 +52,7 @@ class ChatCompletionsRequest(pydantic.BaseModel):


class OneStreamingChoiceDelta(pydantic.BaseModel):
role: typing.Literal["assistant"] | None = None
role: typing.Literal[MessageRole.assistant] | None = None
content: str | None = None


Expand All @@ -67,16 +75,16 @@ class ChatCompletionsNotStreamingResponse(pydantic.BaseModel):
def _make_user_assistant_alternate_messages(
messages: typing.Iterable[ChatCompletionsMessage],
) -> typing.Iterable[ChatCompletionsMessage]:
current_message_role: MessageRole = "user"
current_message_role = MessageRole.user
current_message_content_chunks = []

for one_message in messages:
if not one_message.content.strip():
continue

if (
one_message.role in {"system", "user"} and current_message_role == "user"
) or one_message.role == current_message_role == "assistant":
one_message.role in {MessageRole.system, MessageRole.user} and current_message_role == MessageRole.user
) or one_message.role == current_message_role == MessageRole.assistant:
current_message_content_chunks.append(one_message.content)
else:
if current_message_content_chunks:
Expand Down Expand Up @@ -122,7 +130,7 @@ def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request:
)

def _prepare_messages(self, messages: str | list[Message]) -> list[ChatCompletionsMessage]:
messages = [Message(role="user", text=messages)] if isinstance(messages, str) else messages
messages = [UserMessage(messages)] if isinstance(messages, str) else messages
initial_messages: typing.Final = (
ChatCompletionsMessage(role=one_message.role, content=one_message.text) for one_message in messages
)
Expand Down
4 changes: 2 additions & 2 deletions any_llm_client/clients/yandexgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pydantic
import typing_extensions

from any_llm_client.core import LLMClient, LLMConfig, LLMError, Message, OutOfTokensOrSymbolsError
from any_llm_client.core import LLMClient, LLMConfig, LLMError, Message, OutOfTokensOrSymbolsError, UserMessage
from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request
from any_llm_client.retry import RequestRetryConfig

Expand Down Expand Up @@ -98,7 +98,7 @@ def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request:
def _prepare_payload(
self, *, messages: str | list[Message], temperature: float = 0.2, stream: bool
) -> dict[str, typing.Any]:
messages = [Message(role="user", text=messages)] if isinstance(messages, str) else messages
messages = [UserMessage(messages)] if isinstance(messages, str) else messages
return YandexGPTRequest(
modelUri=f"gpt://{self.config.folder_id}/{self.config.model_name}/{self.config.model_version}",
completionOptions=YandexGPTCompletionOptions(
Expand Down
18 changes: 11 additions & 7 deletions any_llm_client/core.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import contextlib
import dataclasses
import enum
import types
import typing

import pydantic
import typing_extensions


MessageRole = typing.Literal["system", "user", "assistant"]
class MessageRole(str, enum.Enum):
system = "system"
user = "user"
assistant = "assistant"


@pydantic.dataclasses.dataclass(kw_only=True)
Expand All @@ -20,28 +24,28 @@ class Message:

@pydantic.dataclasses.dataclass
class SystemMessage(Message):
role: typing.Literal["system"] = pydantic.Field("system", init=False)
role: typing.Literal[MessageRole.system] = pydantic.Field(MessageRole.system, init=False)
text: str

@pydantic.dataclasses.dataclass
class UserMessage(Message):
role: typing.Literal["user"] = pydantic.Field("user", init=False)
role: typing.Literal[MessageRole.user] = pydantic.Field(MessageRole.user, init=False)
text: str

@pydantic.dataclasses.dataclass
class AssistantMessage(Message):
role: typing.Literal["assistant"] = pydantic.Field("assistant", init=False)
role: typing.Literal[MessageRole.assistant] = pydantic.Field(MessageRole.assistant, init=False)
text: str
else:

def SystemMessage(text: str) -> Message: # noqa: N802
return Message(role="system", text=text)
return Message(role=MessageRole.system, text=text)

def UserMessage(text: str) -> Message: # noqa: N802
return Message(role="user", text=text)
return Message(role=MessageRole.user, text=text)

def AssistantMessage(text: str) -> Message: # noqa: N802
return Message(role="assistant", text=text)
return Message(role=MessageRole.assistant, text=text)


@dataclasses.dataclass
Expand Down
45 changes: 29 additions & 16 deletions tests/test_openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ async def test_ok(self, faker: faker.Faker) -> None:
200,
json=ChatCompletionsNotStreamingResponse(
choices=[
OneNotStreamingChoice(message=ChatCompletionsMessage(role="assistant", content=expected_result))
OneNotStreamingChoice(
message=ChatCompletionsMessage(
role=any_llm_client.MessageRole.assistant, content=expected_result
)
)
]
).model_dump(mode="json"),
)
Expand All @@ -55,11 +59,11 @@ async def test_fails_without_alternatives(self) -> None:
class TestOpenAIRequestLLMPartialResponses:
async def test_ok(self, faker: faker.Faker) -> None:
generated_messages: typing.Final = [
OneStreamingChoiceDelta(role="assistant"),
OneStreamingChoiceDelta(role=any_llm_client.MessageRole.assistant),
OneStreamingChoiceDelta(content="H"),
OneStreamingChoiceDelta(content="i"),
OneStreamingChoiceDelta(content=" t"),
OneStreamingChoiceDelta(role="assistant", content="here"),
OneStreamingChoiceDelta(role=any_llm_client.MessageRole.assistant, content="here"),
OneStreamingChoiceDelta(),
OneStreamingChoiceDelta(content=". How is you"),
OneStreamingChoiceDelta(content="r day?"),
Expand Down Expand Up @@ -171,12 +175,17 @@ class TestOpenAIMessageAlternation:
],
[],
),
([any_llm_client.SystemMessage("Be nice")], [ChatCompletionsMessage(role="user", content="Be nice")]),
(
[any_llm_client.SystemMessage("Be nice")],
[ChatCompletionsMessage(role=any_llm_client.MessageRole.user, content="Be nice")],
),
(
[any_llm_client.UserMessage("Hi there"), any_llm_client.AssistantMessage("Hi! How can I help you?")],
[
ChatCompletionsMessage(role="user", content="Hi there"),
ChatCompletionsMessage(role="assistant", content="Hi! How can I help you?"),
ChatCompletionsMessage(role=any_llm_client.MessageRole.user, content="Hi there"),
ChatCompletionsMessage(
role=any_llm_client.MessageRole.assistant, content="Hi! How can I help you?"
),
],
),
(
Expand All @@ -186,13 +195,15 @@ class TestOpenAIMessageAlternation:
any_llm_client.AssistantMessage("Hi! How can I help you?"),
],
[
ChatCompletionsMessage(role="user", content="Hi there"),
ChatCompletionsMessage(role="assistant", content="Hi! How can I help you?"),
ChatCompletionsMessage(role=any_llm_client.MessageRole.user, content="Hi there"),
ChatCompletionsMessage(
role=any_llm_client.MessageRole.assistant, content="Hi! How can I help you?"
),
],
),
(
[any_llm_client.SystemMessage("Be nice"), any_llm_client.UserMessage("Hi there")],
[ChatCompletionsMessage(role="user", content="Be nice\n\nHi there")],
[ChatCompletionsMessage(role=any_llm_client.MessageRole.user, content="Be nice\n\nHi there")],
),
(
[
Expand All @@ -210,14 +221,16 @@ class TestOpenAIMessageAlternation:
any_llm_client.UserMessage("Hmmm..."),
],
[
ChatCompletionsMessage(role="user", content="Be nice"),
ChatCompletionsMessage(role=any_llm_client.MessageRole.user, content="Be nice"),
ChatCompletionsMessage(
role="assistant",
role=any_llm_client.MessageRole.assistant,
content="Hi!\n\nI'm your answer to everything.\n\nHow can I help you?",
),
ChatCompletionsMessage(role="user", content="Hi there\n\nWhy is the sky blue?"),
ChatCompletionsMessage(role="assistant", content="Well..."),
ChatCompletionsMessage(role="user", content="Hmmm..."),
ChatCompletionsMessage(
role=any_llm_client.MessageRole.user, content="Hi there\n\nWhy is the sky blue?"
),
ChatCompletionsMessage(role=any_llm_client.MessageRole.assistant, content="Well..."),
ChatCompletionsMessage(role=any_llm_client.MessageRole.user, content="Hmmm..."),
],
),
],
Expand All @@ -237,6 +250,6 @@ def test_without_alternation(self) -> None:
assert client._prepare_messages( # noqa: SLF001
[any_llm_client.SystemMessage("Be nice"), any_llm_client.UserMessage("Hi there")]
) == [
ChatCompletionsMessage(role="system", content="Be nice"),
ChatCompletionsMessage(role="user", content="Hi there"),
ChatCompletionsMessage(role=any_llm_client.MessageRole.system, content="Be nice"),
ChatCompletionsMessage(role=any_llm_client.MessageRole.user, content="Hi there"),
]