Skip to content

Commit 8aff227

Browse files
committed
feat: improve llm qlf context formatting, llm requests headers & api key handling
1 parent 9b5b23a commit 8aff227

File tree

3 files changed

+128
-28
lines changed

3 files changed

+128
-28
lines changed

edsnlp/pipes/llm/async_worker.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,16 @@ def instance(cls) -> "AsyncRequestWorker":
2525
cls._instance = AsyncRequestWorker()
2626
return cls._instance
2727

28-
def submit(self, coro: Coroutine[Any, Any, Any]) -> int:
28+
def submit(
29+
self, coro: Coroutine[Any, Any, Any], timeout: Optional[float] = None
30+
) -> int:
2931
with self._lock:
3032
task_id = self._next_id
3133
self._next_id += 1
3234

3335
async def _wrap():
3436
try:
35-
res = await coro
37+
res = await asyncio.wait_for(coro, timeout=timeout)
3638
exc = None
3739
except BaseException as e: # noqa: BLE001
3840
res = None

edsnlp/pipes/llm/llm_span_qualifier/llm_span_qualifier.py

Lines changed: 121 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@ class LlmSpanQualifier(BaseSpanAttributeClassifierComponent):
7070
model with an OpenAI-compatible API, such as
7171
[vLLM](https://github.com/vllm-project/vllm).
7272
73+
You can store your OpenAI API key in the `OPENAI_API_KEY` environment
74+
variable.
75+
```python { .no-check }
76+
import os
77+
os.environ["OPENAI_API_KEY"] = "your_api_key_here"
78+
```
79+
7380
Start a server with the model of your choice:
7481
7582
```bash { data-md-color-scheme="slate" }
@@ -86,16 +93,43 @@ class LlmSpanQualifier(BaseSpanAttributeClassifierComponent):
8693
=== "Yes/no bool classification"
8794
8895
```python { .no-check }
89-
from typing import Annotated
96+
from typing import Annotated, TypedDict
9097
from pydantic import BeforeValidator, PlainSerializer, WithJsonSchema
9198
import edsnlp, edsnlp.pipes as eds
9299
93-
BiopsySchema = Annotated[
100+
# Pydantic schema used to validate and parse the LLM response
101+
# The output will be a boolean field.
102+
# Example:
103+
# ent._.biopsy_procedure → False
104+
class BiopsySchema1(BaseModel):
105+
biopsy_procedure: bool = Field(
106+
..., description="Is the span a biopsy procedure or not"
107+
)
108+
109+
# Alternative schema using a TypedDict
110+
# The output will be a dict with a boolean value instead of a boolean field.
111+
# Example:
112+
# ent._.biopsy_procedure → {'biopsy_procedure': False}
113+
class BiopsySchema2(TypedDict):
114+
biopsy_procedure: bool
115+
116+
# Alternative annotated schema with custom (de)serializers.
117+
# This schema transforms the LLM’s output into a boolean before validation.
118+
# Any case-insensitive variant of "yes", "y", or "true" is interpreted as True;
119+
# all other values are treated as False.
120+
#
121+
# When serializing to JSON, the boolean is converted back into the strings
122+
# "yes" (for True) or "no" (for False).
123+
# The output will be a boolean field.
124+
# Example:
125+
# ent._.biopsy_procedure → False
126+
BiopsySchema3 = Annotated[
94127
bool,
95128
BeforeValidator(lambda v: str(v).lower() in {"yes", "y", "true"}),
96129
PlainSerializer(lambda v: "yes" if v else "no", when_used="json"),
97130
]
98131
132+
99133
PROMPT = """
100134
You are a span classifier. The user sends text where the target is
101135
marked with <ent>...</ent>. Answer ONLY with a JSON value: "yes" or
@@ -129,7 +163,7 @@ class LlmSpanQualifier(BaseSpanAttributeClassifierComponent):
129163
context_getter="sent",
130164
context_formatter=doc_to_xml,
131165
attributes=["biopsy_procedure"],
132-
output_schema=BiopsySchema,
166+
output_schema=BiopsySchema1, # or BiopsySchema2 or BiopsySchema3
133167
examples=examples,
134168
max_few_shot_examples=2,
135169
max_concurrent_requests=4,
@@ -219,6 +253,8 @@ class CovidMentionSchema(BaseModel):
219253
220254
<!-- blacken-docs:on -->
221255
256+
Advanced usage
257+
--------
222258
You can also control the prompt more finely by providing a callable instead of a
223259
string. For example, to put few-shot examples in the system message and keep the
224260
span context as the user payload:
@@ -240,6 +276,33 @@ def prompt(context_text, examples):
240276
return messages
241277
```
242278
279+
You can also control the context formatting by providing a custom callable
280+
to the `context_formatter` parameter. For example, to wrap the context with
281+
a custom prefix and suffix as follows:
282+
283+
```python { .no-check }
284+
from spacy.tokens import Doc
285+
286+
class ContextFormatter:
287+
def __init__(self, prefix: str, suffix: str):
288+
self.prefix = prefix
289+
self.suffix = suffix
290+
291+
def __call__(self, context: Doc) -> str:
292+
span = context.ents[0].text if context.ents else ""
293+
prefix = self.prefix.format(span=span)
294+
suffix = self.suffix.format(span=span)
295+
return f"{prefix}{context.text}{suffix}"
296+
297+
context_formatter = ContextFormatter(prefix="\n## Context\n\n<<<\n",
298+
suffix= "\n>>>\n\n## Instruction\nDoes '{span}' corresponds to a Biopsy date?")
299+
```
300+
301+
!!! note "`max_concurrent_requests` parameter"
302+
303+
We recommend setting the `max_concurrent_requests` parameter to a greater value
304+
to improve throughput when processing batches of documents.
305+
243306
Parameters
244307
----------
245308
nlp : PipelineProtocol
@@ -250,14 +313,14 @@ def prompt(context_text, examples):
250313
Base URL of the OpenAI-compatible API.
251314
model : str
252315
Model identifier exposed by the API.
253-
prompt : Union[str, Callable[[str, List[Tuple[str, str]]], List[Dict[str, str]]]]
316+
prompt : Union[str, Callable[[Union[str, Doc], List[Tuple[Union[str, Doc], str]]], List[Dict[str, str]]]]
254317
The prompt is the main way to control the model's behavior.
255318
It can be either:
256319
257320
- A string, which will be used as a system prompt.
258321
Few-shot examples (if any) will be provided as user/assistant
259322
messages before the actual user query.
260-
- A callable that takes two arguments and returns a list of messages in the
323+
- A callable that takes three arguments and returns a list of messages in the
261324
format expected by the OpenAI chat completions API.
262325
263326
* `context`: the context text with the target span marked up
@@ -270,7 +333,7 @@ def prompt(context_text, examples):
270333
If `None`, the whole document text is used.
271334
context_formatter : Optional[Callable[[Doc], str]]
272335
Callable used to render the context passed to the LLM. Defaults to
273-
`lambda doc: doc.text`.
336+
`lambda context_getter_output: context_getter_output.text`.
274337
attributes : Optional[AttributesArg]
275338
Attributes to predict. If omitted, the keys are inferred from the provided
276339
schema.
@@ -289,12 +352,16 @@ def prompt(context_text, examples):
289352
seed : Optional[int]
290353
Optional seed forwarded to the API.
291354
max_concurrent_requests : int
292-
Maximum number of concurrent span requests per document.
355+
Maximum number of concurrent span requests per batch of documents.
293356
api_kwargs : Dict[str, Any]
294357
Extra keyword arguments forwarded to `chat.completions.create`.
295358
on_error : Literal["raise", "warn"]
296359
Error handling strategy. If `"raise"`, exceptions are raised. If `"warn"`,
297360
exceptions are logged as warnings and processing continues.
361+
timeout : Optional[float]
362+
Optional timeout (in seconds) for each LLM request.
363+
default_headers : Optional[Dict[str, str]]
364+
Optional default headers for the API client.
298365
''' # noqa: E501
299366

300367
def __init__(
@@ -305,11 +372,15 @@ def __init__(
305372
api_url: str,
306373
model: str,
307374
prompt: Union[
308-
str, Callable[[str, List[Tuple[str, str]]], List[Dict[str, str]]]
375+
str,
376+
Callable[
377+
[Union[str, Doc], List[Tuple[Union[str, Doc], str]]],
378+
List[Dict[str, str]],
379+
],
309380
],
310381
span_getter: Optional[SpanGetterArg] = None,
311382
context_getter: Optional[ContextWindow] = None,
312-
context_formatter: Optional[Callable[[Doc], str]] = None,
383+
context_formatter: Optional[Callable[[Doc], Union[str, Doc]]] = None,
313384
attributes: Optional[AttributesArg] = None, # confit will auto cast to dict
314385
output_schema: Optional[
315386
Union[
@@ -325,6 +396,8 @@ def __init__(
325396
max_concurrent_requests: int = 1,
326397
api_kwargs: Optional[Dict[str, Any]] = None,
327398
on_error: Literal["raise", "warn"] = "raise",
399+
timeout: Optional[float] = None,
400+
default_headers: Optional[Dict[str, str]] = {"Connection": "close"},
328401
):
329402
import openai
330403

@@ -333,12 +406,15 @@ def __init__(
333406
self.api_url = api_url
334407
self.model = model
335408
self.prompt = prompt
409+
self.timeout = timeout
336410
self.context_window = (
337411
ContextWindow.validate(context_getter)
338412
if context_getter is not None
339413
else None
340414
)
341-
self.context_formatter = context_formatter or (lambda doc: doc.text)
415+
self.context_formatter = context_formatter or (
416+
lambda context_getter_output: context_getter_output.text
417+
)
342418
self.seed = seed
343419
self.api_kwargs = api_kwargs or {}
344420
self.max_concurrent_requests = max_concurrent_requests
@@ -405,11 +481,11 @@ class Output(RootModel):
405481
self.bindings.append((attr_path, labels, json_key, setter, getter))
406482
self.attributes = {path: labels for path, labels, *_ in self.bindings}
407483

408-
self.examples: List[Tuple[str, str]] = []
484+
self.examples: List[Tuple[Union[str, Doc], str]] = []
409485
for doc in examples or []:
410486
for span in get_spans(doc, span_getter):
411487
context_doc = self._build_context_doc(span)
412-
context_text = self.context_formatter(context_doc)
488+
formatted_context = self.context_formatter(context_doc)
413489
values: Dict[str, Any] = {}
414490
for _, labels, json_key, _, getter in self.bindings:
415491
if (
@@ -440,17 +516,23 @@ class Output(RootModel):
440516
continue
441517
else:
442518
answer = json.dumps(values)
443-
self.examples.append((context_text, answer))
519+
self.examples.append((formatted_context, answer))
444520

445521
self.max_few_shot_examples = max_few_shot_examples
446522
self.retriever = None
447523
self.retriever_stemmer = None
448524
if self.max_few_shot_examples > 0 and use_retriever is not False:
449525
self.build_few_shot_retriever_(self.examples)
450526

451-
api_key = os.getenv("OPENAI_API_KEY", "")
527+
api_key = os.getenv(
528+
"OPENAI_API_KEY", "EMPTY_API_KEY"
529+
) # API key should be non empty (even when exposing local models without auth)
452530
self.client = openai.Client(base_url=self.api_url, api_key=api_key)
453-
self._async_client = openai.AsyncOpenAI(base_url=self.api_url, api_key=api_key)
531+
self._async_client = openai.AsyncOpenAI(
532+
base_url=self.api_url,
533+
api_key=api_key,
534+
default_headers=default_headers,
535+
)
454536

455537
super().__init__(nlp=nlp, name=name, span_getter=span_getter)
456538

@@ -468,24 +550,37 @@ def set_extensions(self) -> None:
468550
if not Span.has_extension(ext_name):
469551
Span.set_extension(ext_name, default=None)
470552

471-
def build_few_shot_retriever_(self, samples: List[Tuple[str, str]]) -> None:
553+
def build_few_shot_retriever_(
554+
self, samples: List[Tuple[Union[str, Doc], str]]
555+
) -> None:
472556
# Same BM25 strategy as llm_markup_extractor
473557
import bm25s
474558
import Stemmer
475559

476560
lang = {"eds": "french"}.get(self.lang, self.lang)
477561
stemmer = Stemmer.Stemmer(lang)
478562
corpus = bm25s.tokenize(
479-
[text for text, _ in samples], stemmer=stemmer, stopwords=lang
563+
[
564+
sample.text if isinstance(sample, Doc) else sample
565+
for sample, _ in samples
566+
],
567+
stemmer=stemmer,
568+
stopwords=lang,
480569
)
481570
retriever = bm25s.BM25()
482571
retriever.index(corpus)
483572
self.retriever = retriever
484573
self.retriever_stemmer = stemmer
485574

486-
def build_prompt(self, context_text: str) -> List[Dict[str, str]]:
575+
def build_prompt(self, formatted_context: Union[str, Doc]) -> List[Dict[str, str]]:
576+
"""Build the prompt messages for the LLM request."""
487577
import bm25s
488578

579+
if isinstance(formatted_context, Doc):
580+
context_text = formatted_context.text
581+
else:
582+
context_text = formatted_context
583+
489584
few_shot_examples: List[Tuple[str, str]] = []
490585
if self.retriever is not None:
491586
closest, _ = self.retriever.retrieve(
@@ -510,7 +605,7 @@ def build_prompt(self, context_text: str) -> List[Dict[str, str]]:
510605
messages.append({"role": "assistant", "content": ans})
511606
messages.append({"role": "user", "content": context_text})
512607
return messages
513-
return self.prompt(context_text, few_shot_examples)
608+
return self.prompt(formatted_context, few_shot_examples)
514609

515610
def _llm_request_sync(self, messages: List[Dict[str, str]]) -> str:
516611
call_kwargs = dict(self.api_kwargs)
@@ -662,9 +757,11 @@ def schedule() -> None:
662757
span = state["spans"][state["next_span"]]
663758
state["next_span"] += 1
664759
context_doc = self._build_context_doc(span)
665-
context_text = self.context_formatter(context_doc)
666-
messages = self.build_prompt(context_text)
667-
task_id = worker.submit(self._llm_request_coro(messages))
760+
formatted_context = self.context_formatter(context_doc)
761+
messages = self.build_prompt(formatted_context)
762+
task_id = worker.submit(
763+
self._llm_request_coro(messages), timeout=self.timeout
764+
)
668765
pending[task_id] = (state, span)
669766
state["pending"] += 1
670767
if len(pending) >= self.max_concurrent_requests:
@@ -726,8 +823,8 @@ def process(self, doc: Doc) -> Doc:
726823

727824
for span in spans:
728825
context_doc = self._build_context_doc(span)
729-
context_text = self.context_formatter(context_doc)
730-
messages = self.build_prompt(context_text)
826+
formatted_context = self.context_formatter(context_doc)
827+
messages = self.build_prompt(formatted_context)
731828
data = None
732829
try:
733830
raw = self._llm_request_sync(messages)

tests/pipelines/llm/test_llm_span_qualifier.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,9 @@ def responder(*, response_format, **_):
220220
assert_response_schema(response_format)
221221
raise ValueError("Simulated error")
222222

223-
with mock_llm_service(responder=responder), pytest.warns(
224-
UserWarning, match="request failed"
223+
with (
224+
mock_llm_service(responder=responder),
225+
pytest.warns(UserWarning, match="request failed"),
225226
):
226227
doc = nlp(doc)
227228

0 commit comments

Comments
 (0)