2222
2323"""
2424
25+ import random
2526import time
2627
2728try :
3132except ImportError :
3233 REDIS_IS_AVAILABLE = False
3334
34- from typing import List , Dict
35+ from typing import List , Dict , Optional , Any
3536
3637from llm_router_api .base .constants import REDIS_PORT , REDIS_HOST
3738from llm_router_api .base .lb .strategy import ChooseProviderStrategyI
@@ -120,37 +121,58 @@ def __init__(
120121
121122 self ._clear_buffers ()
122123
123- def get_provider (self , model_name : str , providers : List [Dict ]) -> Dict :
124+ def get_provider (
125+ self ,
126+ model_name : str ,
127+ providers : List [Dict ],
128+ options : Optional [Dict [str , Any ]] = None ,
129+ ) -> Dict | None :
124130 """
125- Acquire the first available provider for *model_name*.
131+ Acquire a provider for *model_name* from the supplied ``providers`` list .
126132
127- The method repeatedly attempts to acquire a lock on each provider in the
128- order supplied by *providers*. If a provider is successfully marked as
129- chosen in Redis, the provider dictionary is returned with an additional
130- ``"__chosen_field"`` entry that records the Redis hash field used for the
131- lock. The call blocks until a provider is obtained or *self.timeout*
132- seconds have elapsed, in which case a :class:`TimeoutError` is raised.
133+ The method attempts to lock a provider using a Redis‑backed
134+ atomic Lua script. If ``options`` contains ``{\" random_choice\" : True}``,
135+ the selection is performed on a shuffled copy of ``providers``; otherwise
136+ providers are examined in the order they appear in the list.
133137
134138 Parameters
135139 ----------
136140 model_name : str
137- The name of the model for which a provider is required.
141+ Identifier of the model for which a provider is required.
138142 providers : List[Dict]
139- A list of provider configuration dictionaries.
143+ A list of provider configuration dictionaries. Each dictionary must
144+ contain the information required by :meth:`_provider_field` to build a
145+ unique Redis hash field name.
146+ options : dict, optional
147+ Additional flags that influence the acquisition strategy. Currently
148+ supported keys:
149+ ``random_choice`` (bool) – when ``True`` the provider is chosen at
150+ random; defaults to ``False``.
140151
141152 Returns
142153 -------
143- Dict
144- The selected provider configuration, augmented with a
145- ``"__chosen_field"`` key.
154+ dict | None
155+ The chosen provider dictionary with an extra ``"__chosen_field"``
156+ entry indicating the Redis hash field that was locked. Returns
157+ ``None`` if ``providers`` is empty.
146158
147159 Raises
148160 ------
149161 TimeoutError
150- If no provider becomes available within the configured timeout.
151- RuntimeError
152- If Redis is not available.
162+ Raised when no provider can be locked within the ``timeout`` period
163+ configured for the strategy instance.
164+
165+ Notes
166+ -----
167+ * The method creates the Redis hash (``model:<model_name>``) and initial
168+ ``false`` fields if they do not already exist.
169+ * The lock is represented by the value ``'true'`` in the hash field.
170+ * Call :meth:`put_provider` to release the lock once the provider is no
171+ longer needed.
153172 """
173+ if not providers :
174+ return None
175+
154176 redis_key = self ._get_redis_key (model_name )
155177 start_time = time .time ()
156178
@@ -159,29 +181,47 @@ def get_provider(self, model_name: str, providers: List[Dict]) -> Dict:
159181 for p in providers :
160182 self .redis_client .hset (redis_key , self ._provider_field (p ), "false" )
161183
184+ # self._print_provider_status(redis_key, providers)
185+
186+ is_random = options and options .get ("random_choice" , False )
187+
162188 while True :
163189 if time .time () - start_time > self .timeout :
164190 raise TimeoutError (
165191 f"No available provider found for model '{ model_name } ' "
166192 f"within { self .timeout } seconds"
167193 )
168194
169- for provider in providers :
170- provider_field = self ._provider_field (provider )
171- try :
172- ok = int (
173- self ._acquire_script (keys = [redis_key ], args = [provider_field ])
174- )
175- if ok == 1 :
176- provider ["__chosen_field" ] = provider_field
177- return provider
178- except Exception :
179- time .sleep (self .check_interval )
180- continue
181-
195+ if is_random :
196+ provider = self ._try_acquire_random_provider (
197+ redis_key = redis_key , providers = providers
198+ )
199+ if provider :
200+ return provider
201+ else :
202+ for provider in providers :
203+ provider_field = self ._provider_field (provider )
204+ try :
205+ ok = int (
206+ self ._acquire_script (
207+ keys = [redis_key ], args = [provider_field ]
208+ )
209+ )
210+ if ok == 1 :
211+ provider ["__chosen_field" ] = provider_field
212+ return provider
213+ except Exception :
214+ time .sleep (self .check_interval )
215+ continue
216+ # -------------------------------------------------------------
182217 time .sleep (self .check_interval )
183218
184- def put_provider (self , model_name : str , provider : Dict ) -> None :
219+ def put_provider (
220+ self ,
221+ model_name : str ,
222+ provider : Dict ,
223+ options : Optional [Dict [str , Any ]] = None ,
224+ ) -> None :
185225 """
186226 Release a previously acquired provider back to the pool.
187227
@@ -196,6 +236,8 @@ def put_provider(self, model_name: str, provider: Dict) -> None:
196236 The model name associated with the provider.
197237 provider : Dict
198238 The provider dictionary that was returned by :meth:`get_provider`.
239+ options: Dict[str, Any], default: None
240+ Additional options passed to the chosen provider.
199241 """
200242 redis_key = self ._get_redis_key (model_name )
201243 provider_field = self ._provider_field (provider )
@@ -206,6 +248,78 @@ def put_provider(self, model_name: str, provider: Dict) -> None:
206248
207249 provider .pop ("__chosen_field" , None )
208250
251+ def _try_acquire_random_provider (
252+ self , redis_key : str , providers : List [Dict ]
253+ ) -> Optional [Dict ]:
254+ """
255+ Attempt to lock a provider chosen at random.
256+
257+ The method works in three stages:
258+
259+ 1. **Shuffle** – a shallow copy of ``providers`` is shuffled so that each
260+ provider has an equal probability of being tried first. The original
261+ list is left untouched.
262+ 2. **Atomic acquisition** – each shuffled provider is passed to the
263+ ``_acquire_script`` Lua script which atomically sets the corresponding
264+ Redis hash field to ``'true'`` *only if* it is currently ``'false'`` or
265+ missing. The first provider for which the script returns ``1`` is
266+ considered successfully acquired.
267+ 3. **Fallback** – if none of the providers can be locked (e.g., all are
268+ currently in use), the method falls back to the *first* provider in the
269+ original ``providers`` list, marks its ``"__chosen_field"`` for
270+ consistency, and returns it. This fallback mirrors the behaviour of
271+ the non‑random acquisition path and ensures the caller always receives
272+ a provider dictionary (or ``None`` when ``providers`` is empty).
273+
274+ Parameters
275+ ----------
276+ redis_key : str
277+ The Redis hash key associated with the model (e.g., ``model:<name>``).
278+ providers : List[Dict]
279+ A list of provider configuration dictionaries. Each dictionary must
280+ contain sufficient information for :meth:`_provider_field` to generate
281+ a unique field name within the Redis hash.
282+
283+ Returns
284+ -------
285+ Optional[Dict]
286+ The selected provider dictionary with an additional ``"__chosen_field"``
287+ entry indicating the Redis hash field that was locked. Returns ``None``
288+ only when the input ``providers`` list is empty.
289+
290+ Raises
291+ ------
292+ Exception
293+ Propagates any unexpected exceptions raised by the Lua script execution;
294+ callers may catch these to implement retry or logging logic.
295+
296+ Notes
297+ -----
298+ * The random selection is *non‑deterministic* on each call; however, the
299+ fallback to the first provider ensures deterministic behaviour when
300+ all providers are currently busy.
301+ * The method does **not** block; it returns immediately after trying all
302+ shuffled providers.
303+ """
304+ shuffled = providers [:]
305+ random .shuffle (shuffled )
306+ for provider in shuffled :
307+ provider_field = self ._provider_field (provider )
308+ try :
309+ ok = int (
310+ self ._acquire_script (keys = [redis_key ], args = [provider_field ])
311+ )
312+ if ok == 1 :
313+ provider ["__chosen_field" ] = provider_field
314+ return provider
315+ except Exception :
316+ continue
317+
318+ provider = providers [0 ]
319+ provider_field = self ._provider_field (provider )
320+ provider ["__chosen_field" ] = provider_field
321+ return provider
322+
209323 def _get_redis_key (self , model_name : str ) -> str :
210324 """
211325 Return Redis key prefix for a given model.
@@ -319,3 +433,31 @@ def _clear_buffers(self) -> None:
319433 self ._initialize_providers (
320434 model_name = model_name , providers = providers
321435 )
436+
437+ def _print_provider_status (self , redis_key : str , providers : List [Dict ]) -> None :
438+ """
439+ Print the lock status of each provider stored in the Redis hash
440+ ``redis_key``. Uses emojis for a quick visual cue:
441+
442+ * 🟢 – provider is free (`'false'` or missing)
443+ * 🔴 – provider is currently taken (`'true'`)
444+
445+ The output is formatted in a table‑like layout for readability.
446+ """
447+ try :
448+ # Retrieve the entire hash; missing fields default to None
449+ hash_data = self .redis_client .hgetall (redis_key )
450+ except Exception as exc :
451+ print (f"[⚠️] Could not read Redis key '{ redis_key } ': { exc } " )
452+ return
453+
454+ print ("\n Provider lock status:" )
455+ print ("-" * 40 )
456+ for provider in providers :
457+ field = self ._provider_field (provider )
458+ status = hash_data .get (field , "false" )
459+ icon = "🔴" if status == "true" else "🟢"
460+ # Show a short identifier for the provider (fallback to field)
461+ provider_id = provider .get ("id" ) or provider .get ("name" ) or field
462+ print (f"{ icon } { provider_id :<30} [{ field } ]" )
463+ print ("-" * 40 )
0 commit comments