Skip to content

Commit c93ff7d

Browse files
fix: One more attempt resolving dynamically the tokens.
1 parent 711d823 commit c93ff7d

File tree

3 files changed

+223
-0
lines changed

3 files changed

+223
-0
lines changed

singlestoredb/ai/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .chat import SingleStoreChatFactory # noqa: F401
2+
from .debugv1 import SingleStoreChatFactoryDebugV1 # noqa: F401
23
from .embeddings import SingleStoreEmbeddingsFactory # noqa: F401

singlestoredb/ai/debugv1.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import os
2+
from typing import Any
3+
from typing import Callable
4+
from typing import Optional
5+
from typing import Union
6+
7+
import httpx
8+
9+
from singlestoredb import manage_workspaces
10+
from singlestoredb.ai.utils import SingleStoreOpenAIAuth
11+
from singlestoredb.management.inference_api import InferenceAPIInfo
12+
13+
try:
14+
from langchain_openai import ChatOpenAI
15+
except ImportError:
16+
raise ImportError(
17+
'Could not import langchain_openai python package. '
18+
'Please install it with `pip install langchain_openai`.',
19+
)
20+
21+
try:
22+
from langchain_aws import ChatBedrockConverse
23+
except ImportError:
24+
raise ImportError(
25+
'Could not import langchain-aws python package. '
26+
'Please install it with `pip install langchain-aws`.',
27+
)
28+
29+
import boto3
30+
from botocore import UNSIGNED
31+
from botocore.config import Config
32+
33+
34+
def SingleStoreChatFactoryDebugV1(
35+
model_name: str,
36+
api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
37+
streaming: bool = True,
38+
http_client: Optional[httpx.Client] = None,
39+
obo_token: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
40+
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
41+
base_url: Optional[str] = None,
42+
hosting_platform: Optional[str] = None,
43+
timeout: Optional[float] = None,
44+
**kwargs: Any,
45+
) -> Union[ChatOpenAI, ChatBedrockConverse]:
46+
"""Return a chat model instance (ChatOpenAI or ChatBedrockConverse).
47+
"""
48+
# Handle api_key and obo_token as callable functions
49+
if callable(api_key):
50+
api_key_getter_fn = api_key
51+
else:
52+
def api_key_getter_fn() -> Optional[str]:
53+
if api_key is None:
54+
return os.environ.get('SINGLESTOREDB_USER_TOKEN')
55+
return api_key
56+
57+
if obo_token_getter is not None:
58+
obo_token_getter_fn = obo_token_getter
59+
else:
60+
if callable(obo_token):
61+
obo_token_getter_fn = obo_token
62+
else:
63+
def obo_token_getter_fn() -> Optional[str]:
64+
return obo_token
65+
66+
# handle model info
67+
if base_url is None:
68+
base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL')
69+
if hosting_platform is None:
70+
hosting_platform = os.environ.get('SINGLESTOREDB_INFERENCE_API_HOSTING_PLATFORM')
71+
if base_url is None or hosting_platform is None:
72+
inference_api_manager = (
73+
manage_workspaces().organizations.current.inference_apis
74+
)
75+
info = inference_api_manager.get(model_name=model_name)
76+
else:
77+
info = InferenceAPIInfo(
78+
service_id='',
79+
model_name=model_name,
80+
name='',
81+
connection_url=base_url,
82+
project_id='',
83+
hosting_platform=hosting_platform,
84+
)
85+
if base_url is not None:
86+
info.connection_url = base_url
87+
if hosting_platform is not None:
88+
info.hosting_platform = hosting_platform
89+
90+
# Extract timeouts from http_client if provided
91+
t = http_client.timeout if http_client is not None else None
92+
connect_timeout = None
93+
read_timeout = None
94+
if t is not None:
95+
if isinstance(t, httpx.Timeout):
96+
if t.connect is not None:
97+
connect_timeout = float(t.connect)
98+
if t.read is not None:
99+
read_timeout = float(t.read)
100+
if connect_timeout is None and read_timeout is not None:
101+
connect_timeout = read_timeout
102+
if read_timeout is None and connect_timeout is not None:
103+
read_timeout = connect_timeout
104+
elif isinstance(t, (int, float)):
105+
connect_timeout = float(t)
106+
read_timeout = float(t)
107+
if timeout is not None:
108+
connect_timeout = timeout
109+
read_timeout = timeout
110+
t = httpx.Timeout(timeout)
111+
112+
if info.hosting_platform == 'Amazon':
113+
# Instantiate Bedrock client
114+
cfg_kwargs = {
115+
'signature_version': UNSIGNED,
116+
'retries': {'max_attempts': 1, 'mode': 'standard'},
117+
}
118+
if read_timeout is not None:
119+
cfg_kwargs['read_timeout'] = read_timeout
120+
if connect_timeout is not None:
121+
cfg_kwargs['connect_timeout'] = connect_timeout
122+
123+
cfg = Config(**cfg_kwargs)
124+
client = boto3.client(
125+
'bedrock-runtime',
126+
endpoint_url=info.connection_url,
127+
region_name='us-east-1',
128+
aws_access_key_id='placeholder',
129+
aws_secret_access_key='placeholder',
130+
config=cfg,
131+
)
132+
133+
def _inject_headers(request: Any, **_ignored: Any) -> None:
134+
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
135+
if api_key_getter_fn is not None:
136+
token_val = api_key_getter_fn()
137+
if token_val:
138+
request.headers['Authorization'] = f'Bearer {token_val}'
139+
if obo_token_getter_fn is not None:
140+
obo_val = obo_token_getter_fn()
141+
if obo_val:
142+
request.headers['X-S2-OBO'] = obo_val
143+
request.headers.pop('X-Amz-Date', None)
144+
request.headers.pop('X-Amz-Security-Token', None)
145+
146+
emitter = client._endpoint._event_emitter
147+
emitter.register_first(
148+
'before-send.bedrock-runtime.Converse',
149+
_inject_headers,
150+
)
151+
emitter.register_first(
152+
'before-send.bedrock-runtime.ConverseStream',
153+
_inject_headers,
154+
)
155+
emitter.register_first(
156+
'before-send.bedrock-runtime.InvokeModel',
157+
_inject_headers,
158+
)
159+
emitter.register_first(
160+
'before-send.bedrock-runtime.InvokeModelWithResponseStream',
161+
_inject_headers,
162+
)
163+
164+
return ChatBedrockConverse(
165+
model_id=model_name,
166+
endpoint_url=info.connection_url,
167+
region_name='us-east-1',
168+
aws_access_key_id='placeholder',
169+
aws_secret_access_key='placeholder',
170+
disable_streaming=not streaming,
171+
client=client,
172+
**kwargs,
173+
)
174+
175+
# OpenAI / Azure OpenAI path
176+
if t is not None:
177+
http_client_internal = httpx.Client(
178+
timeout=t,
179+
auth=SingleStoreOpenAIAuth(obo_token_getter_fn),
180+
)
181+
else:
182+
http_client_internal = httpx.Client(
183+
timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout
184+
auth=SingleStoreOpenAIAuth(obo_token_getter_fn),
185+
)
186+
187+
token = api_key_getter_fn()
188+
openai_kwargs = dict(
189+
base_url=info.connection_url,
190+
api_key=token,
191+
model=model_name,
192+
streaming=streaming,
193+
)
194+
openai_kwargs['http_client'] = http_client_internal
195+
return ChatOpenAI(
196+
**openai_kwargs,
197+
**kwargs,
198+
)

singlestoredb/ai/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import logging
2+
3+
import httpx
4+
5+
logger = logging.getLogger(__name__)
6+
7+
8+
class SingleStoreOpenAIAuth(httpx.Auth):
9+
def __init__(self, obo_token_getter): # type: ignore[no-untyped-def]
10+
self.obo_token_getter_fn = obo_token_getter
11+
12+
def auth_flow(self, request: httpx.Request): # type: ignore[no-untyped-def]
13+
logger.info(f'auth_flow called for request to {request.url}')
14+
if self.obo_token_getter_fn is not None:
15+
logger.debug('obo_token_getter_fn is set, attempting to get token')
16+
obo_val = self.obo_token_getter_fn()
17+
if obo_val:
18+
logger.info('OBO token retrieved successfully, adding X-S2-OBO header')
19+
request.headers['X-S2-OBO'] = obo_val
20+
else:
21+
logger.warning('obo_token_getter_fn returned empty/None value')
22+
else:
23+
logger.debug('obo_token_getter_fn is None, skipping OBO token')
24+
yield request

0 commit comments

Comments
 (0)