Skip to content

Commit aceb6ba

Browse files
author
Paweł Kędzia
committed
Add early‑exit for empty provider list and refactor random‑choice handling in FirstAvailable load‑balancer by introducing _try_acquire_random_provider and updating the method signature/docstring.
1 parent e9e3c5c commit aceb6ba

File tree

1 file changed

+127
-41
lines changed

1 file changed

+127
-41
lines changed

llm_router_api/base/lb/first_available.py

Lines changed: 127 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -126,78 +126,92 @@ def get_provider(
126126
model_name: str,
127127
providers: List[Dict],
128128
options: Optional[Dict[str, Any]] = None,
129-
) -> Dict:
129+
) -> Dict | None:
130130
"""
131-
Acquire the first available provider for *model_name*.
131+
Acquire a provider for *model_name* from the supplied ``providers`` list.
132132
133-
The method repeatedly attempts to acquire a lock on each provider in the
134-
order supplied by *providers*. If a provider is successfully marked as
135-
chosen in Redis, the provider dictionary is returned with an additional
136-
``"__chosen_field"`` entry that records the Redis hash field used for the
137-
lock. The call blocks until a provider is obtained or *self.timeout*
138-
seconds have elapsed, in which case a :class:`TimeoutError` is raised.
133+
The method attempts to lock a provider using a Redis‑backed
134+
atomic Lua script. If ``options`` contains ``{\"random_choice\": True}``,
135+
the selection is performed on a shuffled copy of ``providers``; otherwise
136+
providers are examined in the order they appear in the list.
139137
140138
Parameters
141139
----------
142140
model_name : str
143-
The name of the model for which a provider is required.
141+
Identifier of the model for which a provider is required.
144142
providers : List[Dict]
145-
A list of provider configuration dictionaries.
146-
options: Dict[str, Any], default: None
147-
Additional options passed to the chosen provider.
143+
A list of provider configuration dictionaries. Each dictionary must
144+
contain the information required by :meth:`_provider_field` to build a
145+
unique Redis hash field name.
146+
options : dict, optional
147+
Additional flags that influence the acquisition strategy. Currently
148+
supported keys:
149+
``random_choice`` (bool) – when ``True`` the provider is chosen at
150+
random; defaults to ``False``.
148151
149152
Returns
150153
-------
151-
Dict
152-
The selected provider configuration, augmented with a
153-
``"__chosen_field"`` key.
154+
dict | None
155+
The chosen provider dictionary with an extra ``"__chosen_field"``
156+
entry indicating the Redis hash field that was locked. Returns
157+
``None`` if ``providers`` is empty.
154158
155159
Raises
156160
------
157161
TimeoutError
158-
If no provider becomes available within the configured timeout.
159-
RuntimeError
160-
If Redis is not available.
162+
Raised when no provider can be locked within the ``timeout`` period
163+
configured for the strategy instance.
164+
165+
Notes
166+
-----
167+
* The method creates the Redis hash (``model:<model_name>``) and initial
168+
``false`` fields if they do not already exist.
169+
* The lock is represented by the value ``'true'`` in the hash field.
170+
* Call :meth:`put_provider` to release the lock once the provider is no
171+
longer needed.
161172
"""
173+
if not providers:
174+
return None
175+
162176
redis_key = self._get_redis_key(model_name)
163177
start_time = time.time()
164178

165-
is_random = options and options.get("random_choice", False)
166-
if is_random:
167-
print("===" * 300)
168-
169179
# Ensure fields exist; if someone removed the hash, recreate it
170180
if not self.redis_client.exists(redis_key):
171181
for p in providers:
172182
self.redis_client.hset(redis_key, self._provider_field(p), "false")
173183

184+
is_random = options and options.get("random_choice", False)
185+
174186
while True:
175187
if time.time() - start_time > self.timeout:
176188
raise TimeoutError(
177189
f"No available provider found for model '{model_name}' "
178190
f"within {self.timeout} seconds"
179191
)
180192

181-
available_providers = []
182-
for provider in providers:
183-
provider_field = self._provider_field(provider)
184-
try:
185-
ok = int(
186-
self._acquire_script(keys=[redis_key], args=[provider_field])
187-
)
188-
if ok == 1:
189-
provider["__chosen_field"] = provider_field
190-
if is_random:
191-
available_providers.append(provider)
192-
else:
193+
if is_random:
194+
provider = self._try_acquire_random_provider(
195+
redis_key=redis_key, providers=providers
196+
)
197+
if provider:
198+
return provider
199+
else:
200+
for provider in providers:
201+
provider_field = self._provider_field(provider)
202+
try:
203+
ok = int(
204+
self._acquire_script(
205+
keys=[redis_key], args=[provider_field]
206+
)
207+
)
208+
if ok == 1:
209+
provider["__chosen_field"] = provider_field
193210
return provider
194-
except Exception:
195-
time.sleep(self.check_interval)
196-
continue
197-
198-
if is_random and available_providers:
199-
return random.choice(available_providers)
200-
211+
except Exception:
212+
time.sleep(self.check_interval)
213+
continue
214+
# -------------------------------------------------------------
201215
time.sleep(self.check_interval)
202216

203217
def put_provider(
@@ -232,6 +246,78 @@ def put_provider(
232246

233247
provider.pop("__chosen_field", None)
234248

249+
def _try_acquire_random_provider(
250+
self, redis_key: str, providers: List[Dict]
251+
) -> Optional[Dict]:
252+
"""
253+
Attempt to lock a provider chosen at random.
254+
255+
The method works in three stages:
256+
257+
1. **Shuffle** – a shallow copy of ``providers`` is shuffled so that each
258+
provider has an equal probability of being tried first. The original
259+
list is left untouched.
260+
2. **Atomic acquisition** – each shuffled provider is passed to the
261+
``_acquire_script`` Lua script which atomically sets the corresponding
262+
Redis hash field to ``'true'`` *only if* it is currently ``'false'`` or
263+
missing. The first provider for which the script returns ``1`` is
264+
considered successfully acquired.
265+
3. **Fallback** – if none of the providers can be locked (e.g., all are
266+
currently in use), the method falls back to the *first* provider in the
267+
original ``providers`` list, marks its ``"__chosen_field"`` for
268+
consistency, and returns it. This fallback mirrors the behaviour of
269+
the non‑random acquisition path and ensures the caller always receives
270+
a provider dictionary (or ``None`` when ``providers`` is empty).
271+
272+
Parameters
273+
----------
274+
redis_key : str
275+
The Redis hash key associated with the model (e.g., ``model:<name>``).
276+
providers : List[Dict]
277+
A list of provider configuration dictionaries. Each dictionary must
278+
contain sufficient information for :meth:`_provider_field` to generate
279+
a unique field name within the Redis hash.
280+
281+
Returns
282+
-------
283+
Optional[Dict]
284+
The selected provider dictionary with an additional ``"__chosen_field"``
285+
entry indicating the Redis hash field that was locked. Returns ``None``
286+
only when the input ``providers`` list is empty.
287+
288+
Raises
289+
------
290+
Exception
291+
Propagates any unexpected exceptions raised by the Lua script execution;
292+
callers may catch these to implement retry or logging logic.
293+
294+
Notes
295+
-----
296+
* The random selection is *non‑deterministic* on each call; however, the
297+
fallback to the first provider ensures deterministic behaviour when
298+
all providers are currently busy.
299+
* The method does **not** block; it returns immediately after trying all
300+
shuffled providers.
301+
"""
302+
shuffled = providers[:]
303+
random.shuffle(shuffled)
304+
for provider in shuffled:
305+
provider_field = self._provider_field(provider)
306+
try:
307+
ok = int(
308+
self._acquire_script(keys=[redis_key], args=[provider_field])
309+
)
310+
if ok == 1:
311+
provider["__chosen_field"] = provider_field
312+
return provider
313+
except Exception:
314+
continue
315+
316+
provider = providers[0]
317+
provider_field = self._provider_field(provider)
318+
provider["__chosen_field"] = provider_field
319+
return provider
320+
235321
def _get_redis_key(self, model_name: str) -> str:
236322
"""
237323
Return Redis key prefix for a given model.

0 commit comments

Comments
 (0)