Skip to content

Commit 8fd5a41

Browse files
fix: Make auth tokens resolved dynamically per request. (#108)
* fix: Make auth tokens resolved dynamically per request. * Apply same changes to ChatFactory as well. * Handle timeout properly without causing regressions until we remove accepting http_client parameter. * Fixes; prepare for deprecation of http_client and obo_token_getter params.
1 parent 4adf1c5 commit 8fd5a41

File tree

2 files changed

+206
-62
lines changed

2 files changed

+206
-62
lines changed

singlestoredb/ai/chat.py

Lines changed: 103 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from collections.abc import Generator
23
from typing import Any
34
from typing import Callable
45
from typing import Optional
@@ -7,6 +8,7 @@
78
import httpx
89

910
from singlestoredb import manage_workspaces
11+
from singlestoredb.management.inference_api import InferenceAPIInfo
1012

1113
try:
1214
from langchain_openai import ChatOpenAI
@@ -31,44 +33,88 @@
3133

3234
def SingleStoreChatFactory(
3335
model_name: str,
34-
api_key: Optional[str] = None,
36+
api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
3537
streaming: bool = True,
3638
http_client: Optional[httpx.Client] = None,
39+
obo_token: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
3740
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
41+
base_url: Optional[str] = None,
42+
hosting_platform: Optional[str] = None,
43+
timeout: Optional[float] = None,
3844
**kwargs: Any,
3945
) -> Union[ChatOpenAI, ChatBedrockConverse]:
4046
"""Return a chat model instance (ChatOpenAI or ChatBedrockConverse).
4147
"""
42-
inference_api_manager = (
43-
manage_workspaces().organizations.current.inference_apis
44-
)
45-
info = inference_api_manager.get(model_name=model_name)
46-
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
47-
token = api_key if api_key is not None else token_env
48+
# Handle api_key and obo_token as callable functions
49+
if callable(api_key):
50+
api_key_getter_fn = api_key
51+
else:
52+
def api_key_getter_fn() -> Optional[str]:
53+
if api_key is None:
54+
return os.environ.get('SINGLESTOREDB_USER_TOKEN')
55+
return api_key
56+
57+
if obo_token_getter is not None:
58+
obo_token_getter_fn = obo_token_getter
59+
else:
60+
if callable(obo_token):
61+
obo_token_getter_fn = obo_token
62+
else:
63+
def obo_token_getter_fn() -> Optional[str]:
64+
return obo_token
65+
66+
# handle model info
67+
if base_url is None:
68+
base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL')
69+
if hosting_platform is None:
70+
hosting_platform = os.environ.get('SINGLESTOREDB_INFERENCE_API_HOSTING_PLATFORM')
71+
if base_url is None or hosting_platform is None:
72+
inference_api_manager = (
73+
manage_workspaces().organizations.current.inference_apis
74+
)
75+
info = inference_api_manager.get(model_name=model_name)
76+
else:
77+
info = InferenceAPIInfo(
78+
service_id='',
79+
model_name=model_name,
80+
name='',
81+
connection_url=base_url,
82+
project_id='',
83+
hosting_platform=hosting_platform,
84+
)
85+
if base_url is not None:
86+
info.connection_url = base_url
87+
if hosting_platform is not None:
88+
info.hosting_platform = hosting_platform
89+
90+
# Extract timeouts from http_client if provided
91+
t = http_client.timeout if http_client is not None else None
92+
connect_timeout = None
93+
read_timeout = None
94+
if t is not None:
95+
if isinstance(t, httpx.Timeout):
96+
if t.connect is not None:
97+
connect_timeout = float(t.connect)
98+
if t.read is not None:
99+
read_timeout = float(t.read)
100+
if connect_timeout is None and read_timeout is not None:
101+
connect_timeout = read_timeout
102+
if read_timeout is None and connect_timeout is not None:
103+
read_timeout = connect_timeout
104+
elif isinstance(t, (int, float)):
105+
connect_timeout = float(t)
106+
read_timeout = float(t)
107+
if timeout is not None:
108+
connect_timeout = timeout
109+
read_timeout = timeout
110+
t = httpx.Timeout(timeout)
48111

49112
if info.hosting_platform == 'Amazon':
50113
# Instantiate Bedrock client
51114
cfg_kwargs = {
52115
'signature_version': UNSIGNED,
53116
'retries': {'max_attempts': 1, 'mode': 'standard'},
54117
}
55-
# Extract timeouts from http_client if provided
56-
t = http_client.timeout if http_client is not None else None
57-
connect_timeout = None
58-
read_timeout = None
59-
if t is not None:
60-
if isinstance(t, httpx.Timeout):
61-
if t.connect is not None:
62-
connect_timeout = float(t.connect)
63-
if t.read is not None:
64-
read_timeout = float(t.read)
65-
if connect_timeout is None and read_timeout is not None:
66-
connect_timeout = read_timeout
67-
if read_timeout is None and connect_timeout is not None:
68-
read_timeout = connect_timeout
69-
elif isinstance(t, (int, float)):
70-
connect_timeout = float(t)
71-
read_timeout = float(t)
72118
if read_timeout is not None:
73119
cfg_kwargs['read_timeout'] = read_timeout
74120
if connect_timeout is not None:
@@ -86,12 +132,14 @@ def SingleStoreChatFactory(
86132

87133
def _inject_headers(request: Any, **_ignored: Any) -> None:
88134
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
89-
if obo_token_getter is not None:
90-
obo_val = obo_token_getter()
135+
if api_key_getter_fn is not None:
136+
token_val = api_key_getter_fn()
137+
if token_val:
138+
request.headers['Authorization'] = f'Bearer {token_val}'
139+
if obo_token_getter_fn is not None:
140+
obo_val = obo_token_getter_fn()
91141
if obo_val:
92142
request.headers['X-S2-OBO'] = obo_val
93-
if token:
94-
request.headers['Authorization'] = f'Bearer {token}'
95143
request.headers.pop('X-Amz-Date', None)
96144
request.headers.pop('X-Amz-Security-Token', None)
97145

@@ -124,15 +172,39 @@ def _inject_headers(request: Any, **_ignored: Any) -> None:
124172
**kwargs,
125173
)
126174

175+
class OpenAIAuth(httpx.Auth):
176+
def auth_flow(
177+
self, request: httpx.Request,
178+
) -> Generator[httpx.Request, None, None]:
179+
if api_key_getter_fn is not None:
180+
token_val = api_key_getter_fn()
181+
if token_val:
182+
request.headers['Authorization'] = f'Bearer {token_val}'
183+
if obo_token_getter_fn is not None:
184+
obo_val = obo_token_getter_fn()
185+
if obo_val:
186+
request.headers['X-S2-OBO'] = obo_val
187+
yield request
188+
189+
if t is not None:
190+
http_client = httpx.Client(
191+
timeout=t,
192+
auth=OpenAIAuth(),
193+
)
194+
else:
195+
http_client = httpx.Client(
196+
timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout
197+
auth=OpenAIAuth(),
198+
)
199+
127200
# OpenAI / Azure OpenAI path
128201
openai_kwargs = dict(
129202
base_url=info.connection_url,
130-
api_key=token,
203+
api_key='placeholder',
131204
model=model_name,
132205
streaming=streaming,
133206
)
134-
if http_client is not None:
135-
openai_kwargs['http_client'] = http_client
207+
openai_kwargs['http_client'] = http_client
136208
return ChatOpenAI(
137209
**openai_kwargs,
138210
**kwargs,

singlestoredb/ai/embeddings.py

Lines changed: 103 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from collections.abc import Generator
23
from typing import Any
34
from typing import Callable
45
from typing import Optional
@@ -7,6 +8,7 @@
78
import httpx
89

910
from singlestoredb import manage_workspaces
11+
from singlestoredb.management.inference_api import InferenceAPIInfo
1012

1113
try:
1214
from langchain_openai import OpenAIEmbeddings
@@ -31,43 +33,87 @@
3133

3234
def SingleStoreEmbeddingsFactory(
3335
model_name: str,
34-
api_key: Optional[str] = None,
36+
api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
3537
http_client: Optional[httpx.Client] = None,
38+
obo_token: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
3639
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
40+
base_url: Optional[str] = None,
41+
hosting_platform: Optional[str] = None,
42+
timeout: Optional[float] = None,
3743
**kwargs: Any,
3844
) -> Union[OpenAIEmbeddings, BedrockEmbeddings]:
3945
"""Return an embeddings model instance (OpenAIEmbeddings or BedrockEmbeddings).
4046
"""
41-
inference_api_manager = (
42-
manage_workspaces().organizations.current.inference_apis
43-
)
44-
info = inference_api_manager.get(model_name=model_name)
45-
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
46-
token = api_key if api_key is not None else token_env
47+
# Handle api_key and obo_token as callable functions
48+
if callable(api_key):
49+
api_key_getter_fn = api_key
50+
else:
51+
def api_key_getter_fn() -> Optional[str]:
52+
if api_key is None:
53+
return os.environ.get('SINGLESTOREDB_USER_TOKEN')
54+
return api_key
55+
56+
if obo_token_getter is not None:
57+
obo_token_getter_fn = obo_token_getter
58+
else:
59+
if callable(obo_token):
60+
obo_token_getter_fn = obo_token
61+
else:
62+
def obo_token_getter_fn() -> Optional[str]:
63+
return obo_token
64+
65+
# handle model info
66+
if base_url is None:
67+
base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL')
68+
if hosting_platform is None:
69+
hosting_platform = os.environ.get('SINGLESTOREDB_INFERENCE_API_HOSTING_PLATFORM')
70+
if base_url is None or hosting_platform is None:
71+
inference_api_manager = (
72+
manage_workspaces().organizations.current.inference_apis
73+
)
74+
info = inference_api_manager.get(model_name=model_name)
75+
else:
76+
info = InferenceAPIInfo(
77+
service_id='',
78+
model_name=model_name,
79+
name='',
80+
connection_url=base_url,
81+
project_id='',
82+
hosting_platform=hosting_platform,
83+
)
84+
if base_url is not None:
85+
info.connection_url = base_url
86+
if hosting_platform is not None:
87+
info.hosting_platform = hosting_platform
88+
89+
# Extract timeouts from http_client if provided
90+
t = http_client.timeout if http_client is not None else None
91+
connect_timeout = None
92+
read_timeout = None
93+
if t is not None:
94+
if isinstance(t, httpx.Timeout):
95+
if t.connect is not None:
96+
connect_timeout = float(t.connect)
97+
if t.read is not None:
98+
read_timeout = float(t.read)
99+
if connect_timeout is None and read_timeout is not None:
100+
connect_timeout = read_timeout
101+
if read_timeout is None and connect_timeout is not None:
102+
read_timeout = connect_timeout
103+
elif isinstance(t, (int, float)):
104+
connect_timeout = float(t)
105+
read_timeout = float(t)
106+
if timeout is not None:
107+
connect_timeout = timeout
108+
read_timeout = timeout
109+
t = httpx.Timeout(timeout)
47110

48111
if info.hosting_platform == 'Amazon':
49112
# Instantiate Bedrock client
50113
cfg_kwargs = {
51114
'signature_version': UNSIGNED,
52115
'retries': {'max_attempts': 1, 'mode': 'standard'},
53116
}
54-
# Extract timeouts from http_client if provided
55-
t = http_client.timeout if http_client is not None else None
56-
connect_timeout = None
57-
read_timeout = None
58-
if t is not None:
59-
if isinstance(t, httpx.Timeout):
60-
if t.connect is not None:
61-
connect_timeout = float(t.connect)
62-
if t.read is not None:
63-
read_timeout = float(t.read)
64-
if connect_timeout is None and read_timeout is not None:
65-
connect_timeout = read_timeout
66-
if read_timeout is None and connect_timeout is not None:
67-
read_timeout = connect_timeout
68-
elif isinstance(t, (int, float)):
69-
connect_timeout = float(t)
70-
read_timeout = float(t)
71117
if read_timeout is not None:
72118
cfg_kwargs['read_timeout'] = read_timeout
73119
if connect_timeout is not None:
@@ -85,12 +131,14 @@ def SingleStoreEmbeddingsFactory(
85131

86132
def _inject_headers(request: Any, **_ignored: Any) -> None:
87133
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
88-
if obo_token_getter is not None:
89-
obo_val = obo_token_getter()
134+
if api_key_getter_fn is not None:
135+
token_val = api_key_getter_fn()
136+
if token_val:
137+
request.headers['Authorization'] = f'Bearer {token_val}'
138+
if obo_token_getter_fn is not None:
139+
obo_val = obo_token_getter_fn()
90140
if obo_val:
91141
request.headers['X-S2-OBO'] = obo_val
92-
if token:
93-
request.headers['Authorization'] = f'Bearer {token}'
94142
request.headers.pop('X-Amz-Date', None)
95143
request.headers.pop('X-Amz-Security-Token', None)
96144

@@ -114,14 +162,38 @@ def _inject_headers(request: Any, **_ignored: Any) -> None:
114162
**kwargs,
115163
)
116164

165+
class OpenAIAuth(httpx.Auth):
166+
def auth_flow(
167+
self, request: httpx.Request,
168+
) -> Generator[httpx.Request, None, None]:
169+
if api_key_getter_fn is not None:
170+
token_val = api_key_getter_fn()
171+
if token_val:
172+
request.headers['Authorization'] = f'Bearer {token_val}'
173+
if obo_token_getter_fn is not None:
174+
obo_val = obo_token_getter_fn()
175+
if obo_val:
176+
request.headers['X-S2-OBO'] = obo_val
177+
yield request
178+
179+
if t is not None:
180+
http_client = httpx.Client(
181+
timeout=t,
182+
auth=OpenAIAuth(),
183+
)
184+
else:
185+
http_client = httpx.Client(
186+
timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout
187+
auth=OpenAIAuth(),
188+
)
189+
117190
# OpenAI / Azure OpenAI path
118191
openai_kwargs = dict(
119192
base_url=info.connection_url,
120-
api_key=token,
193+
api_key='placeholder',
121194
model=model_name,
122195
)
123-
if http_client is not None:
124-
openai_kwargs['http_client'] = http_client
196+
openai_kwargs['http_client'] = http_client
125197
return OpenAIEmbeddings(
126198
**openai_kwargs,
127199
**kwargs,

0 commit comments

Comments
 (0)