Skip to content

Commit 9148d57

Browse files
author
Paweł Kędzia
committed
Merge branch 'features/lb'
2 parents 7795dae + d2cb7c5 commit 9148d57

File tree

10 files changed

+423
-112
lines changed

10 files changed

+423
-112
lines changed

llm_router_api/base/lb/balanced.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Dict
1+
from typing import List, Dict, Optional, Any
22
from collections import defaultdict
33

44
from llm_router_api.base.lb.strategy import ChooseProviderStrategyI
@@ -13,7 +13,12 @@ def __init__(self, models_config_path: str) -> None:
1313
lambda: defaultdict(int)
1414
)
1515

16-
def get_provider(self, model_name: str, providers: List[Dict]) -> Dict:
16+
def get_provider(
17+
self,
18+
model_name: str,
19+
providers: List[Dict],
20+
options: Optional[Dict[str, Any]] = None,
21+
) -> Dict:
1722
if not providers:
1823
raise ValueError(f"No providers configured for model '{model_name}'")
1924

llm_router_api/base/lb/chooser.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ def __strategy_from_name(
128128

129129
return _cls(models_config_path=models_config_path)
130130

131-
def get_provider(self, model_name: str, providers: List[Dict]) -> Dict:
131+
def get_provider(
132+
self, model_name: str, providers: List[Dict], options: Optional[Dict] = None
133+
) -> Dict:
132134
"""
133135
Choose a provider for *model_name* from *providers* using the configured strategy.
134136
@@ -141,6 +143,8 @@ def get_provider(self, model_name: str, providers: List[Dict]) -> Dict:
141143
The name of the model for which a provider is required.
142144
providers : List[Dict]
143145
A list of provider configuration dictionaries.
146+
options: Optional[Dict], Default is ``None``.
147+
Additional options to pass to ``self.strategy.choose``.
144148
145149
Returns
146150
-------
@@ -154,7 +158,13 @@ def get_provider(self, model_name: str, providers: List[Dict]) -> Dict:
154158
"""
155159
if not providers:
156160
raise RuntimeError(f"{model_name} does not have any providers!")
157-
return self.strategy.get_provider(model_name, providers)
161+
return self.strategy.get_provider(
162+
model_name=model_name, providers=providers, options=options
163+
)
158164

159-
def put_provider(self, model_name: str, provider: Dict) -> None:
160-
self.strategy.put_provider(model_name, provider)
165+
def put_provider(
166+
self, model_name: str, provider: Dict, options: Optional[Dict] = None
167+
) -> None:
168+
self.strategy.put_provider(
169+
model_name=model_name, provider=provider, options=options
170+
)

llm_router_api/base/lb/first_available.py

Lines changed: 173 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
2323
"""
2424

25+
import random
2526
import time
2627

2728
try:
@@ -31,7 +32,7 @@
3132
except ImportError:
3233
REDIS_IS_AVAILABLE = False
3334

34-
from typing import List, Dict
35+
from typing import List, Dict, Optional, Any
3536

3637
from llm_router_api.base.constants import REDIS_PORT, REDIS_HOST
3738
from llm_router_api.base.lb.strategy import ChooseProviderStrategyI
@@ -120,37 +121,58 @@ def __init__(
120121

121122
self._clear_buffers()
122123

123-
def get_provider(self, model_name: str, providers: List[Dict]) -> Dict:
124+
def get_provider(
125+
self,
126+
model_name: str,
127+
providers: List[Dict],
128+
options: Optional[Dict[str, Any]] = None,
129+
) -> Dict | None:
124130
"""
125-
Acquire the first available provider for *model_name*.
131+
Acquire a provider for *model_name* from the supplied ``providers`` list.
126132
127-
The method repeatedly attempts to acquire a lock on each provider in the
128-
order supplied by *providers*. If a provider is successfully marked as
129-
chosen in Redis, the provider dictionary is returned with an additional
130-
``"__chosen_field"`` entry that records the Redis hash field used for the
131-
lock. The call blocks until a provider is obtained or *self.timeout*
132-
seconds have elapsed, in which case a :class:`TimeoutError` is raised.
133+
The method attempts to lock a provider using a Redis‑backed
134+
atomic Lua script. If ``options`` contains ``{\"random_choice\": True}``,
135+
the selection is performed on a shuffled copy of ``providers``; otherwise
136+
providers are examined in the order they appear in the list.
133137
134138
Parameters
135139
----------
136140
model_name : str
137-
The name of the model for which a provider is required.
141+
Identifier of the model for which a provider is required.
138142
providers : List[Dict]
139-
A list of provider configuration dictionaries.
143+
A list of provider configuration dictionaries. Each dictionary must
144+
contain the information required by :meth:`_provider_field` to build a
145+
unique Redis hash field name.
146+
options : dict, optional
147+
Additional flags that influence the acquisition strategy. Currently
148+
supported keys:
149+
``random_choice`` (bool) – when ``True`` the provider is chosen at
150+
random; defaults to ``False``.
140151
141152
Returns
142153
-------
143-
Dict
144-
The selected provider configuration, augmented with a
145-
``"__chosen_field"`` key.
154+
dict | None
155+
The chosen provider dictionary with an extra ``"__chosen_field"``
156+
entry indicating the Redis hash field that was locked. Returns
157+
``None`` if ``providers`` is empty.
146158
147159
Raises
148160
------
149161
TimeoutError
150-
If no provider becomes available within the configured timeout.
151-
RuntimeError
152-
If Redis is not available.
162+
Raised when no provider can be locked within the ``timeout`` period
163+
configured for the strategy instance.
164+
165+
Notes
166+
-----
167+
* The method creates the Redis hash (``model:<model_name>``) and initial
168+
``false`` fields if they do not already exist.
169+
* The lock is represented by the value ``'true'`` in the hash field.
170+
* Call :meth:`put_provider` to release the lock once the provider is no
171+
longer needed.
153172
"""
173+
if not providers:
174+
return None
175+
154176
redis_key = self._get_redis_key(model_name)
155177
start_time = time.time()
156178

@@ -159,29 +181,47 @@ def get_provider(self, model_name: str, providers: List[Dict]) -> Dict:
159181
for p in providers:
160182
self.redis_client.hset(redis_key, self._provider_field(p), "false")
161183

184+
# self._print_provider_status(redis_key, providers)
185+
186+
is_random = options and options.get("random_choice", False)
187+
162188
while True:
163189
if time.time() - start_time > self.timeout:
164190
raise TimeoutError(
165191
f"No available provider found for model '{model_name}' "
166192
f"within {self.timeout} seconds"
167193
)
168194

169-
for provider in providers:
170-
provider_field = self._provider_field(provider)
171-
try:
172-
ok = int(
173-
self._acquire_script(keys=[redis_key], args=[provider_field])
174-
)
175-
if ok == 1:
176-
provider["__chosen_field"] = provider_field
177-
return provider
178-
except Exception:
179-
time.sleep(self.check_interval)
180-
continue
181-
195+
if is_random:
196+
provider = self._try_acquire_random_provider(
197+
redis_key=redis_key, providers=providers
198+
)
199+
if provider:
200+
return provider
201+
else:
202+
for provider in providers:
203+
provider_field = self._provider_field(provider)
204+
try:
205+
ok = int(
206+
self._acquire_script(
207+
keys=[redis_key], args=[provider_field]
208+
)
209+
)
210+
if ok == 1:
211+
provider["__chosen_field"] = provider_field
212+
return provider
213+
except Exception:
214+
time.sleep(self.check_interval)
215+
continue
216+
# -------------------------------------------------------------
182217
time.sleep(self.check_interval)
183218

184-
def put_provider(self, model_name: str, provider: Dict) -> None:
219+
def put_provider(
220+
self,
221+
model_name: str,
222+
provider: Dict,
223+
options: Optional[Dict[str, Any]] = None,
224+
) -> None:
185225
"""
186226
Release a previously acquired provider back to the pool.
187227
@@ -196,6 +236,8 @@ def put_provider(self, model_name: str, provider: Dict) -> None:
196236
The model name associated with the provider.
197237
provider : Dict
198238
The provider dictionary that was returned by :meth:`get_provider`.
239+
options: Dict[str, Any], default: None
240+
Additional options passed to the chosen provider.
199241
"""
200242
redis_key = self._get_redis_key(model_name)
201243
provider_field = self._provider_field(provider)
@@ -206,6 +248,78 @@ def put_provider(self, model_name: str, provider: Dict) -> None:
206248

207249
provider.pop("__chosen_field", None)
208250

251+
def _try_acquire_random_provider(
252+
self, redis_key: str, providers: List[Dict]
253+
) -> Optional[Dict]:
254+
"""
255+
Attempt to lock a provider chosen at random.
256+
257+
The method works in three stages:
258+
259+
1. **Shuffle** – a shallow copy of ``providers`` is shuffled so that each
260+
provider has an equal probability of being tried first. The original
261+
list is left untouched.
262+
2. **Atomic acquisition** – each shuffled provider is passed to the
263+
``_acquire_script`` Lua script which atomically sets the corresponding
264+
Redis hash field to ``'true'`` *only if* it is currently ``'false'`` or
265+
missing. The first provider for which the script returns ``1`` is
266+
considered successfully acquired.
267+
3. **Fallback** – if none of the providers can be locked (e.g., all are
268+
currently in use), the method falls back to the *first* provider in the
269+
original ``providers`` list, marks its ``"__chosen_field"`` for
270+
consistency, and returns it. This fallback mirrors the behaviour of
271+
the non‑random acquisition path and ensures the caller always receives
272+
a provider dictionary (or ``None`` when ``providers`` is empty).
273+
274+
Parameters
275+
----------
276+
redis_key : str
277+
The Redis hash key associated with the model (e.g., ``model:<name>``).
278+
providers : List[Dict]
279+
A list of provider configuration dictionaries. Each dictionary must
280+
contain sufficient information for :meth:`_provider_field` to generate
281+
a unique field name within the Redis hash.
282+
283+
Returns
284+
-------
285+
Optional[Dict]
286+
The selected provider dictionary with an additional ``"__chosen_field"``
287+
entry indicating the Redis hash field that was locked. Returns ``None``
288+
only when the input ``providers`` list is empty.
289+
290+
Raises
291+
------
292+
Exception
293+
Propagates any unexpected exceptions raised by the Lua script execution;
294+
callers may catch these to implement retry or logging logic.
295+
296+
Notes
297+
-----
298+
* The random selection is *non‑deterministic* on each call; however, the
299+
fallback to the first provider ensures deterministic behaviour when
300+
all providers are currently busy.
301+
* The method does **not** block; it returns immediately after trying all
302+
shuffled providers.
303+
"""
304+
shuffled = providers[:]
305+
random.shuffle(shuffled)
306+
for provider in shuffled:
307+
provider_field = self._provider_field(provider)
308+
try:
309+
ok = int(
310+
self._acquire_script(keys=[redis_key], args=[provider_field])
311+
)
312+
if ok == 1:
313+
provider["__chosen_field"] = provider_field
314+
return provider
315+
except Exception:
316+
continue
317+
318+
provider = providers[0]
319+
provider_field = self._provider_field(provider)
320+
provider["__chosen_field"] = provider_field
321+
return provider
322+
209323
def _get_redis_key(self, model_name: str) -> str:
210324
"""
211325
Return Redis key prefix for a given model.
@@ -319,3 +433,31 @@ def _clear_buffers(self) -> None:
319433
self._initialize_providers(
320434
model_name=model_name, providers=providers
321435
)
436+
437+
def _print_provider_status(self, redis_key: str, providers: List[Dict]) -> None:
438+
"""
439+
Print the lock status of each provider stored in the Redis hash
440+
``redis_key``. Uses emojis for a quick visual cue:
441+
442+
* 🟢 – provider is free (`'false'` or missing)
443+
* 🔴 – provider is currently taken (`'true'`)
444+
445+
The output is formatted in a table‑like layout for readability.
446+
"""
447+
try:
448+
# Retrieve the entire hash; missing fields default to None
449+
hash_data = self.redis_client.hgetall(redis_key)
450+
except Exception as exc:
451+
print(f"[⚠️] Could not read Redis key '{redis_key}': {exc}")
452+
return
453+
454+
print("\nProvider lock status:")
455+
print("-" * 40)
456+
for provider in providers:
457+
field = self._provider_field(provider)
458+
status = hash_data.get(field, "false")
459+
icon = "🔴" if status == "true" else "🟢"
460+
# Show a short identifier for the provider (fallback to field)
461+
provider_id = provider.get("id") or provider.get("name") or field
462+
print(f"{icon} {provider_id:<30} [{field}]")
463+
print("-" * 40)

llm_router_api/base/lb/strategy.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Dict
1+
from typing import List, Dict, Any, Optional
22
from abc import ABC, abstractmethod
33

44
from llm_router_api.base.model_config import ApiModelConfig
@@ -34,7 +34,12 @@ def _provider_key(self, provider_cfg: Dict) -> str:
3434
return _pk
3535

3636
@abstractmethod
37-
def get_provider(self, model_name: str, providers: List[Dict]) -> Dict:
37+
def get_provider(
38+
self,
39+
model_name: str,
40+
providers: List[Dict],
41+
options: Optional[Dict[str, Any]] = None,
42+
) -> Dict:
3843
"""
3944
Choose a provider for the given model.
4045
@@ -44,6 +49,8 @@ def get_provider(self, model_name: str, providers: List[Dict]) -> Dict:
4449
Name of the model for which a provider is required.
4550
providers: List[Dict]
4651
List of provider configuration dictionaries.
52+
options: Dict[str, Any], default: None
53+
Options passed to the chosen provider.
4754
4855
Returns
4956
-------
@@ -52,7 +59,12 @@ def get_provider(self, model_name: str, providers: List[Dict]) -> Dict:
5259
"""
5360
raise NotImplementedError
5461

55-
def put_provider(self, model_name: str, provider: Dict) -> None:
62+
def put_provider(
63+
self,
64+
model_name: str,
65+
provider: Dict,
66+
options: Optional[Dict[str, Any]] = None,
67+
) -> None:
5668
"""
5769
Notify the strategy that a provider has been used.
5870
@@ -70,5 +82,7 @@ def put_provider(self, model_name: str, provider: Dict) -> None:
7082
Name of the model for which the provider was used.
7183
provider : Dict
7284
The provider configuration dictionary that was used.
85+
options: Dict[str, Any], default: None
86+
Options passed to the chosen provider.
7387
"""
7488
pass

0 commit comments

Comments
 (0)