Skip to content

Commit 82e37ac

Browse files
Create separate class.
1 parent a8bae85 commit 82e37ac

File tree

3 files changed

+38
-58
lines changed

3 files changed

+38
-58
lines changed

singlestoredb/ai/chat.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import os
2-
from collections.abc import Generator
32
from typing import Any
43
from typing import Callable
54
from typing import Optional
65
from typing import Union
76

87
import httpx
8+
from utils import SingleStoreOpenAIAuth
99

1010
from singlestoredb import manage_workspaces
1111
from singlestoredb.management.inference_api import InferenceAPIInfo
@@ -172,41 +172,15 @@ def _inject_headers(request: Any, **_ignored: Any) -> None:
172172
**kwargs,
173173
)
174174

175-
class OpenAIAuth(httpx.Auth):
176-
def __init__(
177-
self,
178-
api_key_getter: Optional[Callable[[], Optional[str]]] = None,
179-
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
180-
) -> None:
181-
self.api_key_getter = api_key_getter
182-
self.obo_token_getter = obo_token_getter
183-
184-
def auth_flow(
185-
self, request: httpx.Request,
186-
) -> Generator[httpx.Request, None, None]:
187-
print(f'[DEBUG] auth_flow called for {request.method} {request.url}')
188-
if self.api_key_getter is not None:
189-
token_val = self.api_key_getter()
190-
print(f"[DEBUG] api_key_getter: {token_val if token_val else 'None'}...")
191-
if token_val:
192-
request.headers['Authorization'] = f'Bearer {token_val}'
193-
if self.obo_token_getter is not None:
194-
obo_val = self.obo_token_getter()
195-
print(f"[DEBUG] obo_token_getter: {obo_val if obo_val else 'None'}...")
196-
if obo_val:
197-
request.headers['X-S2-OBO'] = obo_val
198-
print(f'[DEBUG] Final headers: {dict(request.headers)}')
199-
yield request
200-
201175
if t is not None:
202176
http_client = httpx.Client(
203177
timeout=t,
204-
auth=OpenAIAuth(api_key_getter_fn, obo_token_getter_fn),
178+
auth=SingleStoreOpenAIAuth(api_key_getter_fn, obo_token_getter_fn),
205179
)
206180
else:
207181
http_client = httpx.Client(
208182
timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout
209-
auth=OpenAIAuth(api_key_getter_fn, obo_token_getter_fn),
183+
auth=SingleStoreOpenAIAuth(api_key_getter_fn, obo_token_getter_fn),
210184
)
211185

212186
# OpenAI / Azure OpenAI path

singlestoredb/ai/debugv2.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import os
2-
from collections.abc import Generator
32
from typing import Any
43
from typing import Callable
54
from typing import Optional
65
from typing import Union
76

87
import httpx
8+
from utils import SingleStoreOpenAIAuth
99

1010
from singlestoredb import manage_workspaces
1111
from singlestoredb.management.inference_api import InferenceAPIInfo
@@ -172,41 +172,15 @@ def _inject_headers(request: Any, **_ignored: Any) -> None:
172172
**kwargs,
173173
)
174174

175-
class OpenAIAuth(httpx.Auth):
176-
def __init__(
177-
self,
178-
api_key_getter: Optional[Callable[[], Optional[str]]] = None,
179-
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
180-
) -> None:
181-
self.api_key_getter = api_key_getter
182-
self.obo_token_getter = obo_token_getter
183-
184-
def auth_flow(
185-
self, request: httpx.Request,
186-
) -> Generator[httpx.Request, None, None]:
187-
print(f'[DEBUG] auth_flow called for {request.method} {request.url}')
188-
if self.api_key_getter is not None:
189-
token_val = self.api_key_getter()
190-
print(f"[DEBUG] api_key_getter: {token_val if token_val else 'None'}...")
191-
if token_val:
192-
request.headers['Authorization'] = f'Bearer {token_val}'
193-
if self.obo_token_getter is not None:
194-
obo_val = self.obo_token_getter()
195-
print(f"[DEBUG] obo_token_getter: {obo_val if obo_val else 'None'}...")
196-
if obo_val:
197-
request.headers['X-S2-OBO'] = obo_val
198-
print(f'[DEBUG] Final headers: {dict(request.headers)}')
199-
yield request
200-
201175
if t is not None:
202176
http_client = httpx.Client(
203177
timeout=t,
204-
auth=OpenAIAuth(api_key_getter_fn, obo_token_getter_fn),
178+
auth=SingleStoreOpenAIAuth(api_key_getter_fn, obo_token_getter_fn),
205179
)
206180
else:
207181
http_client = httpx.Client(
208182
timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout
209-
auth=OpenAIAuth(api_key_getter_fn, obo_token_getter_fn),
183+
auth=SingleStoreOpenAIAuth(api_key_getter_fn, obo_token_getter_fn),
210184
)
211185

212186
# OpenAI / Azure OpenAI path

singlestoredb/ai/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from collections.abc import Generator
2+
from typing import Callable
3+
from typing import Optional
4+
5+
import httpx
6+
7+
8+
class SingleStoreOpenAIAuth(httpx.Auth):
9+
def __init__(
10+
self,
11+
api_key_getter: Optional[Callable[[], Optional[str]]] = None,
12+
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
13+
) -> None:
14+
self.api_key_getter = api_key_getter
15+
self.obo_token_getter = obo_token_getter
16+
17+
def auth_flow(
18+
self, request: httpx.Request,
19+
) -> Generator[httpx.Request, None, None]:
20+
print(f'[DEBUG] auth_flow called for {request.method} {request.url}')
21+
if self.api_key_getter is not None:
22+
token_val = self.api_key_getter()
23+
print(f"[DEBUG] api_key_getter: {token_val if token_val else 'None'}...")
24+
if token_val:
25+
request.headers['Authorization'] = f'Bearer {token_val}'
26+
if self.obo_token_getter is not None:
27+
obo_val = self.obo_token_getter()
28+
print(f"[DEBUG] obo_token_getter: {obo_val if obo_val else 'None'}...")
29+
if obo_val:
30+
request.headers['X-S2-OBO'] = obo_val
31+
print(f'[DEBUG] Final headers: {dict(request.headers)}')
32+
yield request

0 commit comments

Comments
 (0)