@@ -70,6 +70,13 @@ class LlmSpanQualifier(BaseSpanAttributeClassifierComponent):
7070 model with an OpenAI-compatible API, such as
7171 [vLLM](https://github.com/vllm-project/vllm).
7272
73+ You can store your OpenAI API key in the `OPENAI_API_KEY` environment
74+ variable.
75+ ```python { .no-check }
76+ import os
77+ os.environ["OPENAI_API_KEY"] = "your_api_key_here"
78+ ```
79+
7380 Start a server with the model of your choice:
7481
7582 ```bash { data-md-color-scheme="slate" }
@@ -86,16 +93,43 @@ class LlmSpanQualifier(BaseSpanAttributeClassifierComponent):
8693 === "Yes/no bool classification"
8794
8895 ```python { .no-check }
89- from typing import Annotated
96+ from typing import Annotated, TypedDict
9097 from pydantic import BeforeValidator, PlainSerializer, WithJsonSchema
9198 import edsnlp, edsnlp.pipes as eds
9299
93- BiopsySchema = Annotated[
100+ # Pydantic schema used to validate and parse the LLM response
101+ # The output will be a boolean field.
102+ # Example:
103+ # ent._.biopsy_procedure → False
104+ class BiopsySchema1(BaseModel):
105+ biopsy_procedure: bool = Field(
106+ ..., description="Is the span a biopsy procedure or not"
107+ )
108+
109+ # Alternative schema using a TypedDict
110+ # The output will be a dict with a boolean value instead of a boolean field.
111+ # Example:
112+ # ent._.biopsy_procedure → {'biopsy_procedure': False}
113+ class BiopsySchema2(TypedDict):
114+ biopsy_procedure: bool
115+
116+ # Alternative annotated schema with custom (de)serializers.
117+ # This schema transforms the LLM’s output into a boolean before validation.
118+ # Any case-insensitive variant of "yes", "y", or "true" is interpreted as True;
119+ # all other values are treated as False.
120+ #
121+ # When serializing to JSON, the boolean is converted back into the strings
122+ # "yes" (for True) or "no" (for False).
123+ # The output will be a boolean field.
124+ # Example:
125+ # ent._.biopsy_procedure → False
126+ BiopsySchema3 = Annotated[
94127 bool,
95128 BeforeValidator(lambda v: str(v).lower() in {"yes", "y", "true"}),
96129 PlainSerializer(lambda v: "yes" if v else "no", when_used="json"),
97130 ]
98131
132+
99133 PROMPT = """
100134 You are a span classifier. The user sends text where the target is
101135 marked with <ent>...</ent>. Answer ONLY with a JSON value: "yes" or
@@ -129,7 +163,7 @@ class LlmSpanQualifier(BaseSpanAttributeClassifierComponent):
129163 context_getter="sent",
130164 context_formatter=doc_to_xml,
131165 attributes=["biopsy_procedure"],
132- output_schema=BiopsySchema,
166+ output_schema=BiopsySchema1, # or BiopsySchema2 or BiopsySchema3
133167 examples=examples,
134168 max_few_shot_examples=2,
135169 max_concurrent_requests=4,
@@ -219,6 +253,8 @@ class CovidMentionSchema(BaseModel):
219253
220254 <!-- blacken-docs:on -->
221255
256+ Advanced usage
257+ --------
222258 You can also control the prompt more finely by providing a callable instead of a
223259 string. For example, to put few-shot examples in the system message and keep the
224260 span context as the user payload:
@@ -240,6 +276,33 @@ def prompt(context_text, examples):
240276 return messages
241277 ```
242278
279+ You can also control the context formatting by providing a custom callable
280+ to the `context_formatter` parameter. For example, to wrap the context with
281+ a custom prefix and suffix as follows:
282+
283+ ```python { .no-check }
284+ from spacy.tokens import Doc
285+
286+ class ContextFormatter:
287+ def __init__(self, prefix: str, suffix: str):
288+ self.prefix = prefix
289+ self.suffix = suffix
290+
291+ def __call__(self, context: Doc) -> str:
292+ span = context.ents[0].text if context.ents else ""
293+ prefix = self.prefix.format(span=span)
294+ suffix = self.suffix.format(span=span)
295+ return f"{prefix}{context.text}{suffix}"
296+
297+ context_formatter = ContextFormatter(prefix="\n## Context\n\n<<<\n",
298+ suffix= "\n>>>\n\n## Instruction\nDoes '{span}' corresponds to a Biopsy date?")
299+ ```
300+
301+ !!! note "`max_concurrent_requests` parameter"
302+
303+ We recommend setting the `max_concurrent_requests` parameter to a greater value
304+ to improve throughput when processing batches of documents.
305+
243306 Parameters
244307 ----------
245308 nlp : PipelineProtocol
@@ -250,14 +313,14 @@ def prompt(context_text, examples):
250313 Base URL of the OpenAI-compatible API.
251314 model : str
252315 Model identifier exposed by the API.
253- prompt : Union[str, Callable[[str, List[Tuple[str, str]]], List[Dict[str, str]]]]
316+ prompt : Union[str, Callable[[Union[ str, Doc], List[Tuple[Union[ str, Doc] , str]]], List[Dict[str, str]]]]
254317 The prompt is the main way to control the model's behavior.
255318 It can be either:
256319
257320 - A string, which will be used as a system prompt.
258321 Few-shot examples (if any) will be provided as user/assistant
259322 messages before the actual user query.
260- - A callable that takes two arguments and returns a list of messages in the
323+ - A callable that takes three arguments and returns a list of messages in the
261324 format expected by the OpenAI chat completions API.
262325
263326 * `context`: the context text with the target span marked up
@@ -270,7 +333,7 @@ def prompt(context_text, examples):
270333 If `None`, the whole document text is used.
271334 context_formatter : Optional[Callable[[Doc], str]]
272335 Callable used to render the context passed to the LLM. Defaults to
273- `lambda doc: doc .text`.
336+ `lambda context_getter_output: context_getter_output .text`.
274337 attributes : Optional[AttributesArg]
275338 Attributes to predict. If omitted, the keys are inferred from the provided
276339 schema.
@@ -289,12 +352,16 @@ def prompt(context_text, examples):
289352 seed : Optional[int]
290353 Optional seed forwarded to the API.
291354 max_concurrent_requests : int
292- Maximum number of concurrent span requests per document .
355+ Maximum number of concurrent span requests per batch of documents .
293356 api_kwargs : Dict[str, Any]
294357 Extra keyword arguments forwarded to `chat.completions.create`.
295358 on_error : Literal["raise", "warn"]
296359 Error handling strategy. If `"raise"`, exceptions are raised. If `"warn"`,
297360 exceptions are logged as warnings and processing continues.
361+ timeout : Optional[float]
362+ Optional timeout (in seconds) for each LLM request.
363+ default_headers : Optional[Dict[str, str]]
364+ Optional default headers for the API client.
298365 ''' # noqa: E501
299366
300367 def __init__ (
@@ -305,11 +372,15 @@ def __init__(
305372 api_url : str ,
306373 model : str ,
307374 prompt : Union [
308- str , Callable [[str , List [Tuple [str , str ]]], List [Dict [str , str ]]]
375+ str ,
376+ Callable [
377+ [Union [str , Doc ], List [Tuple [Union [str , Doc ], str ]]],
378+ List [Dict [str , str ]],
379+ ],
309380 ],
310381 span_getter : Optional [SpanGetterArg ] = None ,
311382 context_getter : Optional [ContextWindow ] = None ,
312- context_formatter : Optional [Callable [[Doc ], str ]] = None ,
383+ context_formatter : Optional [Callable [[Doc ], Union [ str , Doc ] ]] = None ,
313384 attributes : Optional [AttributesArg ] = None , # confit will auto cast to dict
314385 output_schema : Optional [
315386 Union [
@@ -325,6 +396,8 @@ def __init__(
325396 max_concurrent_requests : int = 1 ,
326397 api_kwargs : Optional [Dict [str , Any ]] = None ,
327398 on_error : Literal ["raise" , "warn" ] = "raise" ,
399+ timeout : Optional [float ] = None ,
400+ default_headers : Optional [Dict [str , str ]] = {"Connection" : "close" },
328401 ):
329402 import openai
330403
@@ -333,12 +406,15 @@ def __init__(
333406 self .api_url = api_url
334407 self .model = model
335408 self .prompt = prompt
409+ self .timeout = timeout
336410 self .context_window = (
337411 ContextWindow .validate (context_getter )
338412 if context_getter is not None
339413 else None
340414 )
341- self .context_formatter = context_formatter or (lambda doc : doc .text )
415+ self .context_formatter = context_formatter or (
416+ lambda context_getter_output : context_getter_output .text
417+ )
342418 self .seed = seed
343419 self .api_kwargs = api_kwargs or {}
344420 self .max_concurrent_requests = max_concurrent_requests
@@ -405,11 +481,11 @@ class Output(RootModel):
405481 self .bindings .append ((attr_path , labels , json_key , setter , getter ))
406482 self .attributes = {path : labels for path , labels , * _ in self .bindings }
407483
408- self .examples : List [Tuple [str , str ]] = []
484+ self .examples : List [Tuple [Union [ str , Doc ] , str ]] = []
409485 for doc in examples or []:
410486 for span in get_spans (doc , span_getter ):
411487 context_doc = self ._build_context_doc (span )
412- context_text = self .context_formatter (context_doc )
488+ formatted_context = self .context_formatter (context_doc )
413489 values : Dict [str , Any ] = {}
414490 for _ , labels , json_key , _ , getter in self .bindings :
415491 if (
@@ -440,17 +516,23 @@ class Output(RootModel):
440516 continue
441517 else :
442518 answer = json .dumps (values )
443- self .examples .append ((context_text , answer ))
519+ self .examples .append ((formatted_context , answer ))
444520
445521 self .max_few_shot_examples = max_few_shot_examples
446522 self .retriever = None
447523 self .retriever_stemmer = None
448524 if self .max_few_shot_examples > 0 and use_retriever is not False :
449525 self .build_few_shot_retriever_ (self .examples )
450526
451- api_key = os .getenv ("OPENAI_API_KEY" , "" )
527+ api_key = os .getenv (
528+ "OPENAI_API_KEY" , "EMPTY_API_KEY"
529+ ) # API key should be non empty (even when exposing local models without auth)
452530 self .client = openai .Client (base_url = self .api_url , api_key = api_key )
453- self ._async_client = openai .AsyncOpenAI (base_url = self .api_url , api_key = api_key )
531+ self ._async_client = openai .AsyncOpenAI (
532+ base_url = self .api_url ,
533+ api_key = api_key ,
534+ default_headers = default_headers ,
535+ )
454536
455537 super ().__init__ (nlp = nlp , name = name , span_getter = span_getter )
456538
@@ -468,24 +550,37 @@ def set_extensions(self) -> None:
468550 if not Span .has_extension (ext_name ):
469551 Span .set_extension (ext_name , default = None )
470552
471- def build_few_shot_retriever_ (self , samples : List [Tuple [str , str ]]) -> None :
553+ def build_few_shot_retriever_ (
554+ self , samples : List [Tuple [Union [str , Doc ], str ]]
555+ ) -> None :
472556 # Same BM25 strategy as llm_markup_extractor
473557 import bm25s
474558 import Stemmer
475559
476560 lang = {"eds" : "french" }.get (self .lang , self .lang )
477561 stemmer = Stemmer .Stemmer (lang )
478562 corpus = bm25s .tokenize (
479- [text for text , _ in samples ], stemmer = stemmer , stopwords = lang
563+ [
564+ sample .text if isinstance (sample , Doc ) else sample
565+ for sample , _ in samples
566+ ],
567+ stemmer = stemmer ,
568+ stopwords = lang ,
480569 )
481570 retriever = bm25s .BM25 ()
482571 retriever .index (corpus )
483572 self .retriever = retriever
484573 self .retriever_stemmer = stemmer
485574
486- def build_prompt (self , context_text : str ) -> List [Dict [str , str ]]:
575+ def build_prompt (self , formatted_context : Union [str , Doc ]) -> List [Dict [str , str ]]:
576+ """Build the prompt messages for the LLM request."""
487577 import bm25s
488578
579+ if isinstance (formatted_context , Doc ):
580+ context_text = formatted_context .text
581+ else :
582+ context_text = formatted_context
583+
489584 few_shot_examples : List [Tuple [str , str ]] = []
490585 if self .retriever is not None :
491586 closest , _ = self .retriever .retrieve (
@@ -510,7 +605,7 @@ def build_prompt(self, context_text: str) -> List[Dict[str, str]]:
510605 messages .append ({"role" : "assistant" , "content" : ans })
511606 messages .append ({"role" : "user" , "content" : context_text })
512607 return messages
513- return self .prompt (context_text , few_shot_examples )
608+ return self .prompt (formatted_context , few_shot_examples )
514609
515610 def _llm_request_sync (self , messages : List [Dict [str , str ]]) -> str :
516611 call_kwargs = dict (self .api_kwargs )
@@ -662,9 +757,11 @@ def schedule() -> None:
662757 span = state ["spans" ][state ["next_span" ]]
663758 state ["next_span" ] += 1
664759 context_doc = self ._build_context_doc (span )
665- context_text = self .context_formatter (context_doc )
666- messages = self .build_prompt (context_text )
667- task_id = worker .submit (self ._llm_request_coro (messages ))
760+ formatted_context = self .context_formatter (context_doc )
761+ messages = self .build_prompt (formatted_context )
762+ task_id = worker .submit (
763+ self ._llm_request_coro (messages ), timeout = self .timeout
764+ )
668765 pending [task_id ] = (state , span )
669766 state ["pending" ] += 1
670767 if len (pending ) >= self .max_concurrent_requests :
@@ -726,8 +823,8 @@ def process(self, doc: Doc) -> Doc:
726823
727824 for span in spans :
728825 context_doc = self ._build_context_doc (span )
729- context_text = self .context_formatter (context_doc )
730- messages = self .build_prompt (context_text )
826+ formatted_context = self .context_formatter (context_doc )
827+ messages = self .build_prompt (formatted_context )
731828 data = None
732829 try :
733830 raw = self ._llm_request_sync (messages )
0 commit comments