Skip to content

Commit 353850c

Browse files
committed
add detailed token usage info for LangServe UI
1 parent d9be320 commit 353850c

File tree

1 file changed

+62
-42
lines changed

1 file changed

+62
-42
lines changed

jupyter_ai_jupyternaut/jupyternaut/chat_models.py

Lines changed: 62 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Sequence,
1818
Tuple,
1919
Type,
20+
TYPE_CHECKING,
2021
Union,
2122
)
2223

@@ -64,6 +65,8 @@
6465

6566
logger = logging.getLogger(__name__)
6667

68+
if TYPE_CHECKING:
69+
from litellm import ModelResponseStream, Usage
6770

6871
class ChatLiteLLMException(Exception):
6972
"""Error with the `LiteLLM I/O` library"""
@@ -466,30 +469,10 @@ def _stream(
466469
run_manager: Optional[CallbackManagerForLLMRun] = None,
467470
**kwargs: Any,
468471
) -> Iterator[ChatGenerationChunk]:
469-
message_dicts, params = self._create_message_dicts(messages, stop)
470-
params = {**params, **kwargs, "stream": True}
471-
params["stream_options"] = self.stream_options
472-
default_chunk_class = AIMessageChunk
473-
for chunk in self.completion_with_retry(
474-
messages=message_dicts, run_manager=run_manager, **params
475-
):
476-
usage_metadata = None
477-
if not isinstance(chunk, dict):
478-
chunk = chunk.model_dump()
479-
if "usage" in chunk and chunk["usage"]:
480-
usage_metadata = _create_usage_metadata(chunk["usage"])
481-
if len(chunk["choices"]) == 0:
482-
continue
483-
delta = chunk["choices"][0]["delta"]
484-
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
485-
if usage_metadata and isinstance(chunk, AIMessageChunk):
486-
chunk.usage_metadata = usage_metadata
487-
488-
default_chunk_class = chunk.__class__
489-
cg_chunk = ChatGenerationChunk(message=chunk)
490-
if run_manager:
491-
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
492-
yield cg_chunk
472+
# deleting this method minimizes code duplication.
473+
# we can run `_astream()` in a `ThreadPoolExecutor` if we need to
474+
# implement this method in the future.
475+
raise NotImplementedError()
493476

494477
async def _astream(
495478
self,
@@ -501,25 +484,41 @@ async def _astream(
501484
message_dicts, params = self._create_message_dicts(messages, stop)
502485
params = {**params, **kwargs, "stream": True}
503486
params["stream_options"] = self.stream_options
487+
488+
# This local variable hints the type of successive chunks when a
489+
# new chunk differs from the previous one in type.
490+
# (unsure if this is required)
504491
default_chunk_class = AIMessageChunk
505-
async for chunk in await self.acompletion_with_retry(
492+
493+
async for _untyped_chunk in await self.acompletion_with_retry(
506494
messages=message_dicts, run_manager=run_manager, **params
507495
):
508-
usage_metadata = None
509-
if not isinstance(chunk, dict):
510-
chunk = chunk.model_dump()
511-
if "usage" in chunk and chunk["usage"]:
512-
usage_metadata = _create_usage_metadata(chunk["usage"])
513-
if len(chunk["choices"]) == 0:
496+
# LiteLLM chunk
497+
litellm_chunk: ModelResponseStream = _untyped_chunk
498+
# LiteLLM usage metadata
499+
litellm_usage: Usage | None = getattr(litellm_chunk, 'usage', None)
500+
501+
# Continue (do nothing) if the chunk is empty
502+
if len(litellm_chunk.choices) == 0:
514503
continue
515-
delta = chunk["choices"][0]["delta"]
516-
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
517-
if usage_metadata and isinstance(chunk, AIMessageChunk):
518-
chunk.usage_metadata = usage_metadata
519-
default_chunk_class = chunk.__class__
520-
cg_chunk = ChatGenerationChunk(message=chunk)
504+
505+
# Extract delta from chunk
506+
delta = litellm_chunk.choices[0].delta
507+
508+
# Convert LiteLLM delta (litellm.Delta) to LangChain
509+
# chunk (BaseMessageChunk)
510+
message_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
511+
512+
# Append usage metadata if it exists
513+
if litellm_usage and isinstance(message_chunk, AIMessageChunk):
514+
message_chunk.usage_metadata = _create_usage_metadata(litellm_usage)
515+
516+
# Set type of successive chunks until a new chunk changes type
517+
default_chunk_class = message_chunk.__class__
518+
519+
cg_chunk = ChatGenerationChunk(message=message_chunk)
521520
if run_manager:
522-
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
521+
await run_manager.on_llm_new_token(message_chunk.content, chunk=cg_chunk)
523522
yield cg_chunk
524523

525524
async def _agenerate(
@@ -622,11 +621,32 @@ def _llm_type(self) -> str:
622621
return "litellm-chat"
623622

624623

625-
def _create_usage_metadata(token_usage: Mapping[str, Any]) -> UsageMetadata:
626-
input_tokens = token_usage.get("prompt_tokens", 0)
627-
output_tokens = token_usage.get("completion_tokens", 0)
624+
def _create_usage_metadata(usage: Usage) -> UsageMetadata:
625+
"""
626+
Converts LiteLLM usage metadata object (`litellm.Usage`) into LangChain usage
627+
metadata object (`langchain_core.messages.ai.UsageMetadata`).
628+
"""
629+
input_tokens = usage.prompt_tokens or 0
630+
input_audio_tokens = usage.prompt_tokens_details.audio_tokens or 0
631+
output_tokens = usage.completion_tokens or 0
632+
output_audio_tokens = usage.completion_tokens_details.audio_tokens or 0
633+
output_reasoning_tokens = usage.completion_tokens_details.reasoning_tokens or 0
634+
total_tokens = input_tokens + output_tokens
635+
636+
cache_creation_tokens = usage.prompt_tokens_details.cache_creation_tokens or 0
637+
cache_read_tokens = usage.prompt_tokens_details.cached_tokens or 0
638+
628639
return UsageMetadata(
629640
input_tokens=input_tokens,
630641
output_tokens=output_tokens,
631-
total_tokens=input_tokens + output_tokens,
642+
total_tokens=total_tokens,
643+
input_token_details={
644+
"cache_creation": cache_creation_tokens,
645+
"cache_read": cache_read_tokens,
646+
"audio": input_audio_tokens,
647+
},
648+
output_token_details={
649+
"audio": output_audio_tokens,
650+
"reasoning": output_reasoning_tokens,
651+
}
632652
)

0 commit comments

Comments
 (0)