Skip to content

Commit 1be9f06

Browse files
author
Paweł Kędzia
committed
Merge branch 'features/refactoring'
2 parents dccb17d + 02530c2 commit 1be9f06

File tree

18 files changed

+582
-0
lines changed

18 files changed

+582
-0
lines changed

.version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
0.0.1

README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# llm_router_services
2+
3+
## Overview
4+
5+
`llm_router_services` provides **HTTP services** that implement the core functionality used by the LLM‑Router’s plugin
6+
system.
7+
The services expose guardrail and masking capabilities through Flask applications
8+
that can be called by the corresponding plugins in `llm_router_plugins`.
9+
10+
Key components:
11+
12+
| Sub‑package | Primary purpose |
13+
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
14+
| **guardrails/** | Hosts the NASK‑PIB guardrail service (`nask_pib_guard_app.py`). It receives a JSON payload, chunks the text, runs a Hugging‑Face classification pipeline, and returns a safety verdict (`safe` flag + detailed per‑chunk results). |
15+
| **maskers/** | Contains the **BANonymizer** (`banonymizer.py`) – a lightweight Flask service that performs token‑classification based anonymisation of input text. |
16+
| **run_*.sh** scripts | Convenience wrappers to start the services (Gunicorn for the guardrail, plain Flask for the anonymiser). |
17+
| **requirements‑gpu.txt** | Lists heavy dependencies (e.g., `transformers`) required for GPU‑accelerated inference. |
18+
19+
The services are **stateless**; they load their models once at start‑up and then serve requests over HTTP.
20+
21+
---
22+
23+
*Happy masking and safe routing!*

llm_router_services/__init__.py

Whitespace-only changes.

llm_router_services/guardrails/__init__.py

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
SERVICES_API_PREFIX = "/api/guardrails"

llm_router_services/guardrails/inference/__init__.py

Whitespace-only changes.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from typing import Any, Dict
2+
from abc import ABC, abstractmethod
3+
4+
5+
class GuardrailBase(ABC):
6+
"""Common interface for all guardrail models."""
7+
8+
@abstractmethod
9+
def classify_chunks(self, payload: Dict[Any, Any]) -> Dict[str, Any]:
10+
"""Classify the supplied payload and return a result dictionary."""
11+
pass
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
5+
6+
class GuardrailModelConfig(ABC):
7+
"""
8+
Abstract base class that defines the configuration interface required by a
9+
guardrail model. Concrete implementations must provide the three fields
10+
used by :class:`TextClassificationGuardrail`:
11+
12+
* ``pipeline_batch_size`` – size of batches sent to the HF pipeline.
13+
* ``min_score_for_safe`` – threshold below which a “SAFE” label is treated as unsafe.
14+
* ``min_score_for_not_safe`` – threshold above which a non‑safe label is treated as safe.
15+
"""
16+
17+
@property
18+
@abstractmethod
19+
def pipeline_batch_size(self) -> int: ...
20+
21+
@property
22+
@abstractmethod
23+
def min_score_for_safe(self) -> float: ...
24+
25+
@property
26+
@abstractmethod
27+
def min_score_for_not_safe(self) -> float: ...
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
from llm_router_services.guardrails.inference.base import GuardrailBase
6+
from llm_router_services.guardrails.inference.config import GuardrailModelConfig
7+
from llm_router_services.guardrails.inference.text_classification import (
8+
TextClassificationGuardrail,
9+
)
10+
11+
12+
def create(
13+
model_type: str,
14+
model_path: str,
15+
device: int = -1,
16+
*,
17+
config: GuardrailModelConfig | None = None,
18+
**kwargs: Any,
19+
) -> GuardrailBase:
20+
"""
21+
Factory that builds a concrete GuardrailBase implementation.
22+
23+
Parameters
24+
----------
25+
model_type:
26+
Identifier of the concrete implementation (e.g. ``"text_classification"``).
27+
model_path:
28+
Path or hub identifier of the model.
29+
device:
30+
``-1`` → CPU, otherwise the CUDA device index.
31+
config:
32+
Optional model‑specific configuration object that implements
33+
:class:`GuardrailModelConfig`. If omitted, a generic default config
34+
is used.
35+
kwargs:
36+
Additional arguments forwarded to the concrete class.
37+
"""
38+
if model_type == "text_classification":
39+
# ``config`` may be ``None`` – the guardrail class will fall back to a
40+
# generic config.
41+
return TextClassificationGuardrail(
42+
model_path=model_path,
43+
device=device,
44+
config=config,
45+
**kwargs,
46+
)
47+
raise ValueError(f"Unsupported guardrail model_type: {model_type}")
48+
49+
50+
# Public alias expected by the Flask app
51+
GuardrailModelFactory = create
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Any, Dict, List
5+
6+
from transformers import pipeline, AutoTokenizer, AutoConfig
7+
8+
from llm_router_services.guardrails.inference.base import GuardrailBase
9+
from llm_router_services.guardrails.inference.config import GuardrailModelConfig
10+
from llm_router_services.guardrails.payload_handler import GuardrailPayloadExtractor
11+
12+
13+
# -----------------------------------------------------------------------
14+
# Default (generic) configuration – can be used when a model does not have a
15+
# specialized config. It implements the GuardrailModelConfig interface.
16+
# -----------------------------------------------------------------------
17+
@dataclass(frozen=True)
18+
class GenericModelConfig(GuardrailModelConfig):
19+
pipeline_batch_size: int = 64
20+
min_score_for_safe: float = 0.5
21+
min_score_for_not_safe: float = 0.5
22+
23+
24+
class TextClassificationGuardrail(GuardrailBase):
25+
"""
26+
Generic text‑classification guardrail.
27+
28+
The caller supplies a concrete ``config`` object that implements
29+
:class:`GuardrailModelConfig`. This makes the guardrail reusable for any model.
30+
"""
31+
32+
def __init__(
33+
self,
34+
model_path: str,
35+
device: int = -1,
36+
max_tokens: int = 500,
37+
overlap: int = 200,
38+
*,
39+
config: GuardrailModelConfig | None = None,
40+
):
41+
# ---------------------------------------------------------------
42+
# Store model‑specific thresholds & batch size
43+
# ---------------------------------------------------------------
44+
self._config = config or GenericModelConfig()
45+
46+
self._overlap = overlap
47+
self._max_tokens = max_tokens
48+
49+
# ---------------------------------------------------------------
50+
# Tokeniser & pipeline preparation (unchanged)
51+
# ---------------------------------------------------------------
52+
self._tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
53+
self._model_max_length = AutoConfig.from_pretrained(
54+
model_path
55+
).max_position_embeddings
56+
57+
if self._max_tokens > self._model_max_length:
58+
self._max_tokens = self._model_max_length
59+
60+
self._pipeline = pipeline(
61+
"text-classification",
62+
model=model_path,
63+
tokenizer=self._tokenizer,
64+
device=device,
65+
truncation=True,
66+
max_length=self._max_tokens,
67+
)
68+
69+
# -------------------------------------------------------------------
70+
# Helper: convert payload → list of strings
71+
# -------------------------------------------------------------------
72+
@staticmethod
73+
def _payload_to_string_list(payload: Dict[Any, Any]) -> List[str]:
74+
try:
75+
return GuardrailPayloadExtractor.extract_texts(payload)
76+
except (TypeError, ValueError):
77+
parts = [f"{str(k)}={str(v)}" for k, v in payload.items()]
78+
return [", ".join(parts)]
79+
80+
# -------------------------------------------------------------------
81+
# Helper: split long texts into token‑aware chunks
82+
# -------------------------------------------------------------------
83+
def _chunk_text(self, texts: List[str]) -> List[str]:
84+
chunks: List[str] = []
85+
for text in texts:
86+
token_ids = self._tokenizer.encode(text, add_special_tokens=False)
87+
step = self._max_tokens - self._overlap
88+
for start in range(0, len(token_ids), step):
89+
end = min(start + self._max_tokens, len(token_ids))
90+
chunk_ids = token_ids[start:end]
91+
chunk_text = self._tokenizer.decode(
92+
chunk_ids,
93+
skip_special_tokens=True,
94+
clean_up_tokenization_spaces=True,
95+
)
96+
chunks.append(chunk_text.strip())
97+
if end == len(token_ids):
98+
break
99+
return chunks
100+
101+
# -------------------------------------------------------------------
102+
# Public API – called from the Flask endpoint
103+
# -------------------------------------------------------------------
104+
def classify_chunks(self, payload: Dict[Any, Any]) -> Dict[str, Any]:
105+
texts = self._payload_to_string_list(payload)
106+
chunks = self._chunk_text(texts=texts)
107+
108+
# Run inference in batches defined by the model config
109+
raw_results = self._pipeline(
110+
chunks, batch_size=self._config.pipeline_batch_size
111+
)
112+
113+
# Normalise pipeline output (it can be a list of dicts or a list containing a single list)
114+
flat_results = [r[0] if isinstance(r, list) else r for r in raw_results]
115+
116+
detailed: List[Dict[str, Any]] = []
117+
for idx, (chunk, classification) in enumerate(zip(chunks, flat_results)):
118+
label = classification.get("label", "")
119+
score = round(classification.get("score", 0.0), 4)
120+
is_safe = label.lower() == "safe"
121+
122+
detailed.append(
123+
{
124+
"chunk_index": idx,
125+
"chunk_text": chunk,
126+
"label": label,
127+
"score": score,
128+
"safe": is_safe,
129+
}
130+
)
131+
132+
# ---------------------------------------------------------------
133+
# Overall safety decision – uses the per‑model thresholds
134+
# ---------------------------------------------------------------
135+
overall_safe = True
136+
for item in detailed:
137+
if item["safe"] and item["score"] < self._config.min_score_for_safe:
138+
overall_safe = False
139+
break
140+
if (
141+
not item["safe"]
142+
and item["score"] > self._config.min_score_for_not_safe
143+
):
144+
overall_safe = False
145+
break
146+
147+
return {"safe": overall_safe, "detailed": detailed}

0 commit comments

Comments
 (0)