Skip to content

Commit e496d96

Browse files
fix: Impersonating JWT (obo) header not propagated correctly for OpenAI requests. (#109)
* fix: Headers not propagated correctly for OpenAI requests. * Expose the experimental SingleStoreChatFactoryDebug factory. * Change package name. * Pass the functions as internal class properties. * Introduce debug messages. * Do not pass explicitly the api_key. * Rename file and factory method. * Create separate class. * Rename from debugv2 to debugv3. * Fix imports. * Fix 'auth_flow' call. * Remove explicit types. * Populate OBO token ONLY dynamically. * Switch to old implementation for OpenAI client. * Remove 'timeout', go to the previous way of initializing the OpenAI clients; 'http_client' should be configured outside of the factory method. * Final fixes.
1 parent 711d823 commit e496d96

File tree

2 files changed

+26
-118
lines changed

2 files changed

+26
-118
lines changed

singlestoredb/ai/chat.py

Lines changed: 13 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
from collections.abc import Generator
32
from typing import Any
43
from typing import Callable
54
from typing import Optional
@@ -33,36 +32,16 @@
3332

3433
def SingleStoreChatFactory(
3534
model_name: str,
36-
api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
35+
api_key: Optional[str] = None,
3736
streaming: bool = True,
3837
http_client: Optional[httpx.Client] = None,
39-
obo_token: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
4038
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
4139
base_url: Optional[str] = None,
4240
hosting_platform: Optional[str] = None,
43-
timeout: Optional[float] = None,
4441
**kwargs: Any,
4542
) -> Union[ChatOpenAI, ChatBedrockConverse]:
4643
"""Return a chat model instance (ChatOpenAI or ChatBedrockConverse).
4744
"""
48-
# Handle api_key and obo_token as callable functions
49-
if callable(api_key):
50-
api_key_getter_fn = api_key
51-
else:
52-
def api_key_getter_fn() -> Optional[str]:
53-
if api_key is None:
54-
return os.environ.get('SINGLESTOREDB_USER_TOKEN')
55-
return api_key
56-
57-
if obo_token_getter is not None:
58-
obo_token_getter_fn = obo_token_getter
59-
else:
60-
if callable(obo_token):
61-
obo_token_getter_fn = obo_token
62-
else:
63-
def obo_token_getter_fn() -> Optional[str]:
64-
return obo_token
65-
6645
# handle model info
6746
if base_url is None:
6847
base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL')
@@ -104,10 +83,6 @@ def obo_token_getter_fn() -> Optional[str]:
10483
elif isinstance(t, (int, float)):
10584
connect_timeout = float(t)
10685
read_timeout = float(t)
107-
if timeout is not None:
108-
connect_timeout = timeout
109-
read_timeout = timeout
110-
t = httpx.Timeout(timeout)
11186

11287
if info.hosting_platform == 'Amazon':
11388
# Instantiate Bedrock client
@@ -132,12 +107,12 @@ def obo_token_getter_fn() -> Optional[str]:
132107

133108
def _inject_headers(request: Any, **_ignored: Any) -> None:
134109
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
135-
if api_key_getter_fn is not None:
136-
token_val = api_key_getter_fn()
137-
if token_val:
138-
request.headers['Authorization'] = f'Bearer {token_val}'
139-
if obo_token_getter_fn is not None:
140-
obo_val = obo_token_getter_fn()
110+
token_env_val = os.environ.get('SINGLESTOREDB_USER_TOKEN')
111+
token_val = api_key if api_key is not None else token_env_val
112+
if token_val:
113+
request.headers['Authorization'] = f'Bearer {token_val}'
114+
if obo_token_getter is not None:
115+
obo_val = obo_token_getter()
141116
if obo_val:
142117
request.headers['X-S2-OBO'] = obo_val
143118
request.headers.pop('X-Amz-Date', None)
@@ -172,39 +147,18 @@ def _inject_headers(request: Any, **_ignored: Any) -> None:
172147
**kwargs,
173148
)
174149

175-
class OpenAIAuth(httpx.Auth):
176-
def auth_flow(
177-
self, request: httpx.Request,
178-
) -> Generator[httpx.Request, None, None]:
179-
if api_key_getter_fn is not None:
180-
token_val = api_key_getter_fn()
181-
if token_val:
182-
request.headers['Authorization'] = f'Bearer {token_val}'
183-
if obo_token_getter_fn is not None:
184-
obo_val = obo_token_getter_fn()
185-
if obo_val:
186-
request.headers['X-S2-OBO'] = obo_val
187-
yield request
188-
189-
if t is not None:
190-
http_client = httpx.Client(
191-
timeout=t,
192-
auth=OpenAIAuth(),
193-
)
194-
else:
195-
http_client = httpx.Client(
196-
timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout
197-
auth=OpenAIAuth(),
198-
)
199-
200150
# OpenAI / Azure OpenAI path
151+
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
152+
token = api_key if api_key is not None else token_env
153+
201154
openai_kwargs = dict(
202155
base_url=info.connection_url,
203-
api_key='placeholder',
156+
api_key=token,
204157
model=model_name,
205158
streaming=streaming,
206159
)
207-
openai_kwargs['http_client'] = http_client
160+
if http_client is not None:
161+
openai_kwargs['http_client'] = http_client
208162
return ChatOpenAI(
209163
**openai_kwargs,
210164
**kwargs,

singlestoredb/ai/embeddings.py

Lines changed: 13 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
from collections.abc import Generator
32
from typing import Any
43
from typing import Callable
54
from typing import Optional
@@ -33,35 +32,15 @@
3332

3433
def SingleStoreEmbeddingsFactory(
3534
model_name: str,
36-
api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
35+
api_key: Optional[str] = None,
3736
http_client: Optional[httpx.Client] = None,
38-
obo_token: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
3937
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
4038
base_url: Optional[str] = None,
4139
hosting_platform: Optional[str] = None,
42-
timeout: Optional[float] = None,
4340
**kwargs: Any,
4441
) -> Union[OpenAIEmbeddings, BedrockEmbeddings]:
4542
"""Return an embeddings model instance (OpenAIEmbeddings or BedrockEmbeddings).
4643
"""
47-
# Handle api_key and obo_token as callable functions
48-
if callable(api_key):
49-
api_key_getter_fn = api_key
50-
else:
51-
def api_key_getter_fn() -> Optional[str]:
52-
if api_key is None:
53-
return os.environ.get('SINGLESTOREDB_USER_TOKEN')
54-
return api_key
55-
56-
if obo_token_getter is not None:
57-
obo_token_getter_fn = obo_token_getter
58-
else:
59-
if callable(obo_token):
60-
obo_token_getter_fn = obo_token
61-
else:
62-
def obo_token_getter_fn() -> Optional[str]:
63-
return obo_token
64-
6544
# handle model info
6645
if base_url is None:
6746
base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL')
@@ -103,10 +82,6 @@ def obo_token_getter_fn() -> Optional[str]:
10382
elif isinstance(t, (int, float)):
10483
connect_timeout = float(t)
10584
read_timeout = float(t)
106-
if timeout is not None:
107-
connect_timeout = timeout
108-
read_timeout = timeout
109-
t = httpx.Timeout(timeout)
11085

11186
if info.hosting_platform == 'Amazon':
11287
# Instantiate Bedrock client
@@ -131,12 +106,12 @@ def obo_token_getter_fn() -> Optional[str]:
131106

132107
def _inject_headers(request: Any, **_ignored: Any) -> None:
133108
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
134-
if api_key_getter_fn is not None:
135-
token_val = api_key_getter_fn()
136-
if token_val:
137-
request.headers['Authorization'] = f'Bearer {token_val}'
138-
if obo_token_getter_fn is not None:
139-
obo_val = obo_token_getter_fn()
109+
token_env_val = os.environ.get('SINGLESTOREDB_USER_TOKEN')
110+
token_val = api_key if api_key is not None else token_env_val
111+
if token_val:
112+
request.headers['Authorization'] = f'Bearer {token_val}'
113+
if obo_token_getter is not None:
114+
obo_val = obo_token_getter()
140115
if obo_val:
141116
request.headers['X-S2-OBO'] = obo_val
142117
request.headers.pop('X-Amz-Date', None)
@@ -162,38 +137,17 @@ def _inject_headers(request: Any, **_ignored: Any) -> None:
162137
**kwargs,
163138
)
164139

165-
class OpenAIAuth(httpx.Auth):
166-
def auth_flow(
167-
self, request: httpx.Request,
168-
) -> Generator[httpx.Request, None, None]:
169-
if api_key_getter_fn is not None:
170-
token_val = api_key_getter_fn()
171-
if token_val:
172-
request.headers['Authorization'] = f'Bearer {token_val}'
173-
if obo_token_getter_fn is not None:
174-
obo_val = obo_token_getter_fn()
175-
if obo_val:
176-
request.headers['X-S2-OBO'] = obo_val
177-
yield request
178-
179-
if t is not None:
180-
http_client = httpx.Client(
181-
timeout=t,
182-
auth=OpenAIAuth(),
183-
)
184-
else:
185-
http_client = httpx.Client(
186-
timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout
187-
auth=OpenAIAuth(),
188-
)
189-
190140
# OpenAI / Azure OpenAI path
141+
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
142+
token = api_key if api_key is not None else token_env
143+
191144
openai_kwargs = dict(
192145
base_url=info.connection_url,
193-
api_key='placeholder',
146+
api_key=token,
194147
model=model_name,
195148
)
196-
openai_kwargs['http_client'] = http_client
149+
if http_client is not None:
150+
openai_kwargs['http_client'] = http_client
197151
return OpenAIEmbeddings(
198152
**openai_kwargs,
199153
**kwargs,

0 commit comments

Comments
 (0)