Skip to content

Commit 26abd17

Browse files
Pass the functions as internal class properties.
1 parent ad2ea69 commit 26abd17

File tree

2 files changed

+48
-24
lines changed

2 files changed

+48
-24
lines changed

singlestoredb/ai/chat.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from collections.abc import Generator
23
from typing import Any
34
from typing import Callable
45
from typing import Optional
@@ -171,26 +172,37 @@ def _inject_headers(request: Any, **_ignored: Any) -> None:
171172
**kwargs,
172173
)
173174

174-
def inject_auth_headers(request: httpx.Request) -> None:
175-
"""Inject dynamic auth/OBO headers before request is sent."""
176-
if api_key_getter_fn is not None:
177-
token_val = api_key_getter_fn()
178-
if token_val:
179-
request.headers['Authorization'] = f'Bearer {token_val}'
180-
if obo_token_getter_fn is not None:
181-
obo_val = obo_token_getter_fn()
182-
if obo_val:
183-
request.headers['X-S2-OBO'] = obo_val
175+
class OpenAIAuth(httpx.Auth):
176+
def __init__(
177+
self,
178+
api_key_getter: Optional[Callable[[], Optional[str]]] = None,
179+
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
180+
) -> None:
181+
self.api_key_getter = api_key_getter
182+
self.obo_token_getter = obo_token_getter
183+
184+
def auth_flow(
185+
self, request: httpx.Request,
186+
) -> Generator[httpx.Request, None, None]:
187+
if self.api_key_getter is not None:
188+
token_val = self.api_key_getter()
189+
if token_val:
190+
request.headers['Authorization'] = f'Bearer {token_val}'
191+
if self.obo_token_getter is not None:
192+
obo_val = self.obo_token_getter()
193+
if obo_val:
194+
request.headers['X-S2-OBO'] = obo_val
195+
yield request
184196

185197
if t is not None:
186198
http_client = httpx.Client(
187199
timeout=t,
188-
event_hooks={'request': [inject_auth_headers]},
200+
auth=OpenAIAuth(api_key_getter_fn, obo_token_getter_fn),
189201
)
190202
else:
191203
http_client = httpx.Client(
192204
timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout
193-
event_hooks={'request': [inject_auth_headers]},
205+
auth=OpenAIAuth(api_key_getter_fn, obo_token_getter_fn),
194206
)
195207

196208
# OpenAI / Azure OpenAI path

singlestoredb/ai/debug.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from collections.abc import Generator
23
from typing import Any
34
from typing import Callable
45
from typing import Optional
@@ -171,26 +172,37 @@ def _inject_headers(request: Any, **_ignored: Any) -> None:
171172
**kwargs,
172173
)
173174

174-
def inject_auth_headers(request: httpx.Request) -> None:
175-
"""Inject dynamic auth/OBO headers before request is sent."""
176-
if api_key_getter_fn is not None:
177-
token_val = api_key_getter_fn()
178-
if token_val:
179-
request.headers['Authorization'] = f'Bearer {token_val}'
180-
if obo_token_getter_fn is not None:
181-
obo_val = obo_token_getter_fn()
182-
if obo_val:
183-
request.headers['X-S2-OBO'] = obo_val
175+
class OpenAIAuth(httpx.Auth):
176+
def __init__(
177+
self,
178+
api_key_getter: Optional[Callable[[], Optional[str]]] = None,
179+
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
180+
) -> None:
181+
self.api_key_getter = api_key_getter
182+
self.obo_token_getter = obo_token_getter
183+
184+
def auth_flow(
185+
self, request: httpx.Request,
186+
) -> Generator[httpx.Request, None, None]:
187+
if self.api_key_getter is not None:
188+
token_val = self.api_key_getter()
189+
if token_val:
190+
request.headers['Authorization'] = f'Bearer {token_val}'
191+
if self.obo_token_getter is not None:
192+
obo_val = self.obo_token_getter()
193+
if obo_val:
194+
request.headers['X-S2-OBO'] = obo_val
195+
yield request
184196

185197
if t is not None:
186198
http_client = httpx.Client(
187199
timeout=t,
188-
event_hooks={'request': [inject_auth_headers]},
200+
auth=OpenAIAuth(api_key_getter_fn, obo_token_getter_fn),
189201
)
190202
else:
191203
http_client = httpx.Client(
192204
timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout
193-
event_hooks={'request': [inject_auth_headers]},
205+
auth=OpenAIAuth(api_key_getter_fn, obo_token_getter_fn),
194206
)
195207

196208
# OpenAI / Azure OpenAI path

0 commit comments

Comments
 (0)