From d9be32066cf5d655b57123034cb13a1e6517e699 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Wed, 3 Dec 2025 17:35:35 -0800 Subject: [PATCH 1/3] enable ephemeral prompt caching by default --- jupyter_ai_jupyternaut/jupyternaut/chat_models.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/jupyter_ai_jupyternaut/jupyternaut/chat_models.py b/jupyter_ai_jupyternaut/jupyternaut/chat_models.py index 4f8248a..644bf13 100644 --- a/jupyter_ai_jupyternaut/jupyternaut/chat_models.py +++ b/jupyter_ai_jupyternaut/jupyternaut/chat_models.py @@ -343,7 +343,17 @@ async def acompletion_with_retry( @retry_decorator async def _completion_with_retry(**kwargs: Any) -> Any: - return await self.client.acompletion(**kwargs) + return await self.client.acompletion( + **kwargs, + # Enables ephemeral prompt caching of the last system message by + # default. + cache_control_injection_points=[ + { + "location": "message", + "role": "system", + } + ], + ) return await _completion_with_retry(**kwargs) From 353850c4ce2f1753ced6fb5be8861f26f01ac8f1 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Wed, 3 Dec 2025 17:55:38 -0800 Subject: [PATCH 2/3] add detailed token usage info for LangServe UI --- .../jupyternaut/chat_models.py | 104 +++++++++++------- 1 file changed, 62 insertions(+), 42 deletions(-) diff --git a/jupyter_ai_jupyternaut/jupyternaut/chat_models.py b/jupyter_ai_jupyternaut/jupyternaut/chat_models.py index 644bf13..f929e46 100644 --- a/jupyter_ai_jupyternaut/jupyternaut/chat_models.py +++ b/jupyter_ai_jupyternaut/jupyternaut/chat_models.py @@ -17,6 +17,7 @@ Sequence, Tuple, Type, + TYPE_CHECKING, Union, ) @@ -64,6 +65,8 @@ logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from litellm import ModelResponseStream, Usage class ChatLiteLLMException(Exception): """Error with the `LiteLLM I/O` library""" @@ -466,30 +469,10 @@ def _stream( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: - message_dicts, params = self._create_message_dicts(messages, stop) - params = {**params, **kwargs, "stream": True} - params["stream_options"] = self.stream_options - default_chunk_class = AIMessageChunk - for chunk in self.completion_with_retry( - messages=message_dicts, run_manager=run_manager, **params - ): - usage_metadata = None - if not isinstance(chunk, dict): - chunk = chunk.model_dump() - if "usage" in chunk and chunk["usage"]: - usage_metadata = _create_usage_metadata(chunk["usage"]) - if len(chunk["choices"]) == 0: - continue - delta = chunk["choices"][0]["delta"] - chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) - if usage_metadata and isinstance(chunk, AIMessageChunk): - chunk.usage_metadata = usage_metadata - - default_chunk_class = chunk.__class__ - cg_chunk = ChatGenerationChunk(message=chunk) - if run_manager: - run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) - yield cg_chunk + # deleting this method minimizes code duplication. + # we can run `_astream()` in a `ThreadPoolExecutor` if we need to + # implement this method in the future. + raise NotImplementedError() async def _astream( self, @@ -501,25 +484,41 @@ async def _astream( message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": True} params["stream_options"] = self.stream_options + + # This local variable hints the type of successive chunks when a + # new chunk differs from the previous one in type. + # (unsure if this is required) default_chunk_class = AIMessageChunk - async for chunk in await self.acompletion_with_retry( + + async for _untyped_chunk in await self.acompletion_with_retry( messages=message_dicts, run_manager=run_manager, **params ): - usage_metadata = None - if not isinstance(chunk, dict): - chunk = chunk.model_dump() - if "usage" in chunk and chunk["usage"]: - usage_metadata = _create_usage_metadata(chunk["usage"]) - if len(chunk["choices"]) == 0: + # LiteLLM chunk + litellm_chunk: ModelResponseStream = _untyped_chunk + # LiteLLM usage metadata + litellm_usage: Usage | None = getattr(litellm_chunk, 'usage', None) + + # Continue (do nothing) if the chunk is empty + if len(litellm_chunk.choices) == 0: continue - delta = chunk["choices"][0]["delta"] - chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) - if usage_metadata and isinstance(chunk, AIMessageChunk): - chunk.usage_metadata = usage_metadata - default_chunk_class = chunk.__class__ - cg_chunk = ChatGenerationChunk(message=chunk) + + # Extract delta from chunk + delta = litellm_chunk.choices[0].delta + + # Convert LiteLLM delta (litellm.Delta) to LangChain + # chunk (BaseMessageChunk) + message_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + + # Append usage metadata if it exists + if litellm_usage and isinstance(message_chunk, AIMessageChunk): + message_chunk.usage_metadata = _create_usage_metadata(litellm_usage) + + # Set type of successive chunks until a new chunk changes type + default_chunk_class = message_chunk.__class__ + + cg_chunk = ChatGenerationChunk(message=message_chunk) if run_manager: - await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) + await run_manager.on_llm_new_token(message_chunk.content, chunk=cg_chunk) yield cg_chunk async def _agenerate( @@ -622,11 +621,32 @@ def _llm_type(self) -> str: return "litellm-chat" -def _create_usage_metadata(token_usage: Mapping[str, Any]) -> UsageMetadata: - input_tokens = token_usage.get("prompt_tokens", 0) - output_tokens = token_usage.get("completion_tokens", 0) +def _create_usage_metadata(usage: Usage) -> UsageMetadata: + """ + Converts LiteLLM usage metadata object (`litellm.Usage`) into LangChain usage + metadata object (`langchain_core.messages.ai.UsageMetadata`). + """ + input_tokens = usage.prompt_tokens or 0 + input_audio_tokens = usage.prompt_tokens_details.audio_tokens or 0 + output_tokens = usage.completion_tokens or 0 + output_audio_tokens = usage.completion_tokens_details.audio_tokens or 0 + output_reasoning_tokens = usage.completion_tokens_details.reasoning_tokens or 0 + total_tokens = input_tokens + output_tokens + + cache_creation_tokens = usage.prompt_tokens_details.cache_creation_tokens or 0 + cache_read_tokens = usage.prompt_tokens_details.cached_tokens or 0 + return UsageMetadata( input_tokens=input_tokens, output_tokens=output_tokens, - total_tokens=input_tokens + output_tokens, + total_tokens=total_tokens, + input_token_details={ + "cache_creation": cache_creation_tokens, + "cache_read": cache_read_tokens, + "audio": input_audio_tokens, + }, + output_token_details={ + "audio": output_audio_tokens, + "reasoning": output_reasoning_tokens, + } ) From 4daa42f15e6fcd8dd84d921e0b8b576e4a13e57b Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Fri, 5 Dec 2025 15:49:17 -0800 Subject: [PATCH 3/3] disable prompt caching when using Bedrock Invoke API --- .../jupyternaut/chat_models.py | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/jupyter_ai_jupyternaut/jupyternaut/chat_models.py b/jupyter_ai_jupyternaut/jupyternaut/chat_models.py index f929e46..61c0b98 100644 --- a/jupyter_ai_jupyternaut/jupyternaut/chat_models.py +++ b/jupyter_ai_jupyternaut/jupyternaut/chat_models.py @@ -344,18 +344,29 @@ async def acompletion_with_retry( """Use tenacity to retry the async completion call.""" retry_decorator = _create_retry_decorator(self, run_manager=run_manager) + # Enables ephemeral prompt caching of the last system message by + # default when passed to `litellm.acompletion()`. + # + # See: https://docs.litellm.ai/docs/tutorials/prompt_caching + cache_control_kwargs = { + "cache_control_injection_points": [ + { "location": "message", "role": "system" } + ] + } + + # Disable ephemeral prompt caching on Amazon Bedrock when the + # InvokeModel API is used instead of Converse API. This is motivated by + # an upstream bug in LiteLLM that has yet to be patched. + # + # See: github.com/BerriAI/litellm/issues/17479 + if self.model.startswith("bedrock/") and not self.model.startswith("bedrock/converse/"): + cache_control_kwargs = {} + @retry_decorator async def _completion_with_retry(**kwargs: Any) -> Any: return await self.client.acompletion( **kwargs, - # Enables ephemeral prompt caching of the last system message by - # default. - cache_control_injection_points=[ - { - "location": "message", - "role": "system", - } - ], + **cache_control_kwargs, ) return await _completion_with_retry(**kwargs)