Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions servers/fai-lambda/fai-chat/src/llm/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

from ..tools.models import Tool
from .base import LLMProvider
from .caching import (
CacheConfig,
ProviderCapabilities,
)
from .models import (
LLMMessage,
LLMMetrics,
Expand All @@ -25,10 +29,12 @@ def __init__(
api_key: str,
temperature: float = 0.0,
max_tokens: int = 4096,
cache_config: CacheConfig | None = None,
):
self._model_id = model_id
self._temperature = temperature
self._max_tokens = max_tokens
self._cache_config = cache_config
self._client = AsyncAnthropic(api_key=api_key)

@property
Expand All @@ -39,15 +45,25 @@ def model_id(self) -> str:
def provider_name(self) -> str:
return "anthropic"

def _extract_system_and_messages(self, messages: list[LLMMessage]) -> tuple[str | None, list[dict[str, Any]]]:
@property
def capabilities(self) -> ProviderCapabilities:
return ProviderCapabilities(supports_system_prompt_caching=True)

def _extract_system_and_messages(
self, messages: list[LLMMessage]
) -> tuple[str | list[dict[str, Any]] | None, list[dict[str, Any]]]:
system_messages = [msg for msg in messages if msg.role.value == "system"]
user_assistant_messages = [msg for msg in messages if msg.role.value != "system"]

system_prompt = None
system_prompt: str | list[dict[str, Any]] | None = None
if system_messages:
system_prompt = "\n\n".join(
combined_text = "\n\n".join(
msg.content if isinstance(msg.content, str) else str(msg.content) for msg in system_messages
)
if self._cache_config and self._cache_config.enabled:
system_prompt = [{"type": "text", "text": combined_text, "cache_control": {"type": "ephemeral"}}]
else:
system_prompt = combined_text

return system_prompt, [msg.to_dict() for msg in user_assistant_messages]

Expand Down
3 changes: 3 additions & 0 deletions servers/fai-lambda/fai-chat/src/llm/anthropic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os

from .anthropic import AnthropicProvider
from .caching import CacheConfig
from .models import ModelId
from .provider_factory import ProviderFactory

Expand All @@ -22,6 +23,7 @@ def create(
model: str,
temperature: float = 0.0,
max_tokens: int = 4096,
cache_config: CacheConfig | None = None,
) -> AnthropicProvider | None:
if model not in ANTHROPIC_MODEL_CONFIGS:
return None
Expand All @@ -35,6 +37,7 @@ def create(
api_key=self._api_key,
temperature=temperature,
max_tokens=max_tokens,
cache_config=cache_config,
)

def is_available(self) -> bool:
Expand Down
5 changes: 5 additions & 0 deletions servers/fai-lambda/fai-chat/src/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections.abc import AsyncGenerator

from ..tools.models import Tool
from .caching import ProviderCapabilities
from .models import (
LLMMessage,
LLMResponse,
Expand Down Expand Up @@ -42,3 +43,7 @@ def model_id(self) -> str:
def provider_name(self) -> str:
"""Return the provider name."""
pass

@property
def capabilities(self) -> ProviderCapabilities:
return ProviderCapabilities()
16 changes: 14 additions & 2 deletions servers/fai-lambda/fai-chat/src/llm/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

from ..tools.models import Tool
from .base import LLMProvider
from .caching import (
CacheConfig,
ProviderCapabilities,
)
from .models import (
LLMMessage,
LLMMetrics,
Expand All @@ -27,13 +31,15 @@ def __init__(
max_tokens: int = 4096,
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
cache_config: CacheConfig | None = None,
):
self._model_id = model_id
self._temperature = temperature
self._max_tokens = max_tokens
self._region = region
self._aws_access_key_id = aws_access_key_id
self._aws_secret_access_key = aws_secret_access_key
self._cache_config = cache_config
self._session = None

@property
Expand All @@ -44,6 +50,10 @@ def model_id(self) -> str:
def provider_name(self) -> str:
return "bedrock"

@property
def capabilities(self) -> ProviderCapabilities:
return ProviderCapabilities(supports_system_prompt_caching=True)

def _get_session(self) -> aioboto3.Session:
if self._session is None:
if self._aws_access_key_id and self._aws_secret_access_key:
Expand All @@ -58,16 +68,18 @@ def _get_session(self) -> aioboto3.Session:

def _extract_system_and_messages(
self, messages: list[LLMMessage]
) -> tuple[list[dict[str, str]] | None, list[dict[str, Any]]]:
) -> tuple[list[dict[str, Any]] | None, list[dict[str, Any]]]:
system_messages = [msg for msg in messages if msg.role.value == "system"]
user_assistant_messages = [msg for msg in messages if msg.role.value != "system"]

system_blocks = None
system_blocks: list[dict[str, Any]] | None = None
if system_messages:
system_blocks = []
for msg in system_messages:
text_content = msg.content if isinstance(msg.content, str) else str(msg.content)
system_blocks.append({"text": text_content})
if self._cache_config and self._cache_config.enabled:
system_blocks.append({"cachePoint": {"type": "default"}})

bedrock_messages: list[dict[str, Any]] = []
for msg in user_assistant_messages:
Expand Down
3 changes: 3 additions & 0 deletions servers/fai-lambda/fai-chat/src/llm/bedrock_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os

from .bedrock import BedrockProvider
from .caching import CacheConfig
from .models import ModelId
from .provider_factory import ProviderFactory

Expand Down Expand Up @@ -36,6 +37,7 @@ def create(
model: str,
temperature: float = 0.0,
max_tokens: int = 4096,
cache_config: CacheConfig | None = None,
) -> BedrockProvider | None:
if model not in BEDROCK_MODEL_CONFIGS:
return None
Expand All @@ -51,6 +53,7 @@ def create(
max_tokens=max_tokens,
aws_access_key_id=self._aws_access_key_id,
aws_secret_access_key=self._aws_secret_access_key,
cache_config=cache_config,
)

def is_available(self) -> bool:
Expand Down
11 changes: 11 additions & 0 deletions servers/fai-lambda/fai-chat/src/llm/caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from dataclasses import dataclass


@dataclass(frozen=True)
class CacheConfig:
enabled: bool = True


@dataclass(frozen=True)
class ProviderCapabilities:
supports_system_prompt_caching: bool = False
17 changes: 17 additions & 0 deletions servers/fai-lambda/fai-chat/src/llm/cohere.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Cohere provider implementation."""

import json
import logging
import time
from collections.abc import AsyncGenerator
from typing import Any
Expand All @@ -22,6 +23,10 @@

from ..tools.models import Tool
from .base import LLMProvider
from .caching import (
CacheConfig,
ProviderCapabilities,
)
from .models import (
LLMMessage,
LLMMetrics,
Expand All @@ -31,6 +36,8 @@
StreamEventType,
)

logger = logging.getLogger(__name__)


class CohereProvider(LLMProvider):
def __init__(
Expand All @@ -39,11 +46,17 @@ def __init__(
api_key: str,
temperature: float = 0.0,
max_tokens: int = 4096,
cache_config: CacheConfig | None = None,
):
self._model_id = model_id
self._temperature = temperature
self._max_tokens = max_tokens
self._client = cohere.AsyncClientV2(api_key=api_key)
if cache_config and cache_config.enabled:
logger.warning(
"System prompt caching requested but Cohere provider does not support caching. "
"Continuing without caching."
)

@property
def model_id(self) -> str:
Expand All @@ -53,6 +66,10 @@ def model_id(self) -> str:
def provider_name(self) -> str:
return "cohere"

@property
def capabilities(self) -> ProviderCapabilities:
return ProviderCapabilities(supports_system_prompt_caching=False)

def _format_messages(self, messages: list[LLMMessage]) -> list[Any]:
cohere_messages: list[Any] = []
for msg in messages:
Expand Down
3 changes: 3 additions & 0 deletions servers/fai-lambda/fai-chat/src/llm/cohere_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os

from .caching import CacheConfig
from .cohere import CohereProvider
from .models import ModelId
from .provider_factory import ProviderFactory
Expand All @@ -20,6 +21,7 @@ def create(
model: str,
temperature: float = 0.0,
max_tokens: int = 4096,
cache_config: CacheConfig | None = None,
) -> CohereProvider | None:
if model not in COHERE_MODEL_CONFIGS:
return None
Expand All @@ -33,6 +35,7 @@ def create(
api_key=self._api_key,
temperature=temperature,
max_tokens=max_tokens,
cache_config=cache_config,
)

def is_available(self) -> bool:
Expand Down
11 changes: 10 additions & 1 deletion servers/fai-lambda/fai-chat/src/llm/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .anthropic_factory import AnthropicProviderFactory
from .base import LLMProvider
from .bedrock_factory import BedrockProviderFactory
from .caching import CacheConfig
from .cohere_factory import CohereProviderFactory
from .fallback import FallbackProvider
from .models import ModelId
Expand All @@ -23,6 +24,7 @@ def _create_llm_provider(
temperature: float = 0.0,
max_tokens: int = 4096,
provider_preference: list[Literal["bedrock", "anthropic", "cohere"]] | None = None,
cache_config: CacheConfig | None = None,
) -> LLMProvider:
model_id = _resolve_model_id(model)
provider_preference = provider_preference or ["bedrock", "anthropic", "cohere"]
Expand All @@ -46,7 +48,12 @@ def _create_llm_provider(
continue

for alias in ordered_models:
provider = factory.create(model=alias, temperature=temperature, max_tokens=max_tokens)
provider = factory.create(
model=alias,
temperature=temperature,
max_tokens=max_tokens,
cache_config=cache_config,
)
if provider:
providers.append(provider)

Expand Down Expand Up @@ -81,11 +88,13 @@ def get_llm_provider(
"anthropic",
"cohere",
),
cache_config: CacheConfig | None = None,
) -> LLMProvider:
provider_list: list[Literal["bedrock", "anthropic", "cohere"]] = [*provider_preference_tuple]
return _create_llm_provider(
model=model,
temperature=temperature,
max_tokens=max_tokens,
provider_preference=provider_list,
cache_config=cache_config,
)
2 changes: 2 additions & 0 deletions servers/fai-lambda/fai-chat/src/llm/provider_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
)

from .base import LLMProvider
from .caching import CacheConfig


class ProviderFactory(ABC):
Expand All @@ -15,6 +16,7 @@ def create(
model: str,
temperature: float = 0.0,
max_tokens: int = 4096,
cache_config: CacheConfig | None = None,
) -> LLMProvider | None:
pass

Expand Down
Loading
Loading