Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 84 additions & 43 deletions jupyter_ai_jupyternaut/jupyternaut/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Sequence,
Tuple,
Type,
TYPE_CHECKING,
Union,
)

Expand Down Expand Up @@ -64,6 +65,8 @@

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from litellm import ModelResponseStream, Usage

class ChatLiteLLMException(Exception):
"""Error with the `LiteLLM I/O` library"""
Expand Down Expand Up @@ -341,9 +344,30 @@ async def acompletion_with_retry(
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)

# Enables ephemeral prompt caching of the last system message by
# default when passed to `litellm.acompletion()`.
#
# See: https://docs.litellm.ai/docs/tutorials/prompt_caching
cache_control_kwargs = {
"cache_control_injection_points": [
{ "location": "message", "role": "system" }
]
}

# Disable ephemeral prompt caching on Amazon Bedrock when the
# InvokeModel API is used instead of Converse API. This is motivated by
# an upstream bug in LiteLLM that has yet to be patched.
#
# See: github.com/BerriAI/litellm/issues/17479
if self.model.startswith("bedrock/") and not self.model.startswith("bedrock/converse/"):
cache_control_kwargs = {}

@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
return await self.client.acompletion(**kwargs)
return await self.client.acompletion(
**kwargs,
**cache_control_kwargs,
)

return await _completion_with_retry(**kwargs)

Expand Down Expand Up @@ -456,30 +480,10 @@ def _stream(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
params["stream_options"] = self.stream_options
default_chunk_class = AIMessageChunk
for chunk in self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
usage_metadata = None
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if "usage" in chunk and chunk["usage"]:
usage_metadata = _create_usage_metadata(chunk["usage"])
if len(chunk["choices"]) == 0:
continue
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
if usage_metadata and isinstance(chunk, AIMessageChunk):
chunk.usage_metadata = usage_metadata

default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
# deleting this method minimizes code duplication.
# we can run `_astream()` in a `ThreadPoolExecutor` if we need to
# implement this method in the future.
raise NotImplementedError()

async def _astream(
self,
Expand All @@ -491,25 +495,41 @@ async def _astream(
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
params["stream_options"] = self.stream_options

# This local variable hints the type of successive chunks when a
# new chunk differs from the previous one in type.
# (unsure if this is required)
default_chunk_class = AIMessageChunk
async for chunk in await self.acompletion_with_retry(

async for _untyped_chunk in await self.acompletion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
usage_metadata = None
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if "usage" in chunk and chunk["usage"]:
usage_metadata = _create_usage_metadata(chunk["usage"])
if len(chunk["choices"]) == 0:
# LiteLLM chunk
litellm_chunk: ModelResponseStream = _untyped_chunk
# LiteLLM usage metadata
litellm_usage: Usage | None = getattr(litellm_chunk, 'usage', None)

# Continue (do nothing) if the chunk is empty
if len(litellm_chunk.choices) == 0:
continue
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
if usage_metadata and isinstance(chunk, AIMessageChunk):
chunk.usage_metadata = usage_metadata
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)

# Extract delta from chunk
delta = litellm_chunk.choices[0].delta

# Convert LiteLLM delta (litellm.Delta) to LangChain
# chunk (BaseMessageChunk)
message_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)

# Append usage metadata if it exists
if litellm_usage and isinstance(message_chunk, AIMessageChunk):
message_chunk.usage_metadata = _create_usage_metadata(litellm_usage)

# Set type of successive chunks until a new chunk changes type
default_chunk_class = message_chunk.__class__

cg_chunk = ChatGenerationChunk(message=message_chunk)
if run_manager:
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
await run_manager.on_llm_new_token(message_chunk.content, chunk=cg_chunk)
yield cg_chunk

async def _agenerate(
Expand Down Expand Up @@ -612,11 +632,32 @@ def _llm_type(self) -> str:
return "litellm-chat"


def _create_usage_metadata(token_usage: Mapping[str, Any]) -> UsageMetadata:
input_tokens = token_usage.get("prompt_tokens", 0)
output_tokens = token_usage.get("completion_tokens", 0)
def _create_usage_metadata(usage: Usage) -> UsageMetadata:
"""
Converts LiteLLM usage metadata object (`litellm.Usage`) into LangChain usage
metadata object (`langchain_core.messages.ai.UsageMetadata`).
"""
input_tokens = usage.prompt_tokens or 0
input_audio_tokens = usage.prompt_tokens_details.audio_tokens or 0
output_tokens = usage.completion_tokens or 0
output_audio_tokens = usage.completion_tokens_details.audio_tokens or 0
output_reasoning_tokens = usage.completion_tokens_details.reasoning_tokens or 0
total_tokens = input_tokens + output_tokens

cache_creation_tokens = usage.prompt_tokens_details.cache_creation_tokens or 0
cache_read_tokens = usage.prompt_tokens_details.cached_tokens or 0

return UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
total_tokens=total_tokens,
input_token_details={
"cache_creation": cache_creation_tokens,
"cache_read": cache_read_tokens,
"audio": input_audio_tokens,
},
output_token_details={
"audio": output_audio_tokens,
"reasoning": output_reasoning_tokens,
}
)