11import os
2+ from collections .abc import Generator
23from typing import Any
34from typing import Callable
45from typing import Optional
78import httpx
89
910from singlestoredb import manage_workspaces
11+ from singlestoredb .management .inference_api import InferenceAPIInfo
1012
1113try :
1214 from langchain_openai import OpenAIEmbeddings
3133
3234def SingleStoreEmbeddingsFactory (
3335 model_name : str ,
34- api_key : Optional [str ] = None ,
36+ api_key : Optional [Union [ Optional [ str ], Callable [[], Optional [ str ]]] ] = None ,
3537 http_client : Optional [httpx .Client ] = None ,
38+ obo_token : Optional [Union [Optional [str ], Callable [[], Optional [str ]]]] = None ,
3639 obo_token_getter : Optional [Callable [[], Optional [str ]]] = None ,
40+ base_url : Optional [str ] = None ,
41+ hosting_platform : Optional [str ] = None ,
42+ timeout : Optional [float ] = None ,
3743 ** kwargs : Any ,
3844) -> Union [OpenAIEmbeddings , BedrockEmbeddings ]:
3945 """Return an embeddings model instance (OpenAIEmbeddings or BedrockEmbeddings).
4046 """
41- inference_api_manager = (
42- manage_workspaces ().organizations .current .inference_apis
43- )
44- info = inference_api_manager .get (model_name = model_name )
45- token_env = os .environ .get ('SINGLESTOREDB_USER_TOKEN' )
46- token = api_key if api_key is not None else token_env
47+ # Handle api_key and obo_token as callable functions
48+ if callable (api_key ):
49+ api_key_getter_fn = api_key
50+ else :
51+ def api_key_getter_fn () -> Optional [str ]:
52+ if api_key is None :
53+ return os .environ .get ('SINGLESTOREDB_USER_TOKEN' )
54+ return api_key
55+
56+ if obo_token_getter is not None :
57+ obo_token_getter_fn = obo_token_getter
58+ else :
59+ if callable (obo_token ):
60+ obo_token_getter_fn = obo_token
61+ else :
62+ def obo_token_getter_fn () -> Optional [str ]:
63+ return obo_token
64+
65+ # handle model info
66+ if base_url is None :
67+ base_url = os .environ .get ('SINGLESTOREDB_INFERENCE_API_BASE_URL' )
68+ if hosting_platform is None :
69+ hosting_platform = os .environ .get ('SINGLESTOREDB_INFERENCE_API_HOSTING_PLATFORM' )
70+ if base_url is None or hosting_platform is None :
71+ inference_api_manager = (
72+ manage_workspaces ().organizations .current .inference_apis
73+ )
74+ info = inference_api_manager .get (model_name = model_name )
75+ else :
76+ info = InferenceAPIInfo (
77+ service_id = '' ,
78+ model_name = model_name ,
79+ name = '' ,
80+ connection_url = base_url ,
81+ project_id = '' ,
82+ hosting_platform = hosting_platform ,
83+ )
84+ if base_url is not None :
85+ info .connection_url = base_url
86+ if hosting_platform is not None :
87+ info .hosting_platform = hosting_platform
88+
89+ # Extract timeouts from http_client if provided
90+ t = http_client .timeout if http_client is not None else None
91+ connect_timeout = None
92+ read_timeout = None
93+ if t is not None :
94+ if isinstance (t , httpx .Timeout ):
95+ if t .connect is not None :
96+ connect_timeout = float (t .connect )
97+ if t .read is not None :
98+ read_timeout = float (t .read )
99+ if connect_timeout is None and read_timeout is not None :
100+ connect_timeout = read_timeout
101+ if read_timeout is None and connect_timeout is not None :
102+ read_timeout = connect_timeout
103+ elif isinstance (t , (int , float )):
104+ connect_timeout = float (t )
105+ read_timeout = float (t )
106+ if timeout is not None :
107+ connect_timeout = timeout
108+ read_timeout = timeout
109+ t = httpx .Timeout (timeout )
47110
48111 if info .hosting_platform == 'Amazon' :
49112 # Instantiate Bedrock client
50113 cfg_kwargs = {
51114 'signature_version' : UNSIGNED ,
52115 'retries' : {'max_attempts' : 1 , 'mode' : 'standard' },
53116 }
54- # Extract timeouts from http_client if provided
55- t = http_client .timeout if http_client is not None else None
56- connect_timeout = None
57- read_timeout = None
58- if t is not None :
59- if isinstance (t , httpx .Timeout ):
60- if t .connect is not None :
61- connect_timeout = float (t .connect )
62- if t .read is not None :
63- read_timeout = float (t .read )
64- if connect_timeout is None and read_timeout is not None :
65- connect_timeout = read_timeout
66- if read_timeout is None and connect_timeout is not None :
67- read_timeout = connect_timeout
68- elif isinstance (t , (int , float )):
69- connect_timeout = float (t )
70- read_timeout = float (t )
71117 if read_timeout is not None :
72118 cfg_kwargs ['read_timeout' ] = read_timeout
73119 if connect_timeout is not None :
@@ -85,12 +131,14 @@ def SingleStoreEmbeddingsFactory(
85131
86132 def _inject_headers (request : Any , ** _ignored : Any ) -> None :
87133 """Inject dynamic auth/OBO headers prior to Bedrock sending."""
88- if obo_token_getter is not None :
89- obo_val = obo_token_getter ()
134+ if api_key_getter_fn is not None :
135+ token_val = api_key_getter_fn ()
136+ if token_val :
137+ request .headers ['Authorization' ] = f'Bearer { token_val } '
138+ if obo_token_getter_fn is not None :
139+ obo_val = obo_token_getter_fn ()
90140 if obo_val :
91141 request .headers ['X-S2-OBO' ] = obo_val
92- if token :
93- request .headers ['Authorization' ] = f'Bearer { token } '
94142 request .headers .pop ('X-Amz-Date' , None )
95143 request .headers .pop ('X-Amz-Security-Token' , None )
96144
@@ -114,14 +162,38 @@ def _inject_headers(request: Any, **_ignored: Any) -> None:
114162 ** kwargs ,
115163 )
116164
165+ class OpenAIAuth (httpx .Auth ):
166+ def auth_flow (
167+ self , request : httpx .Request ,
168+ ) -> Generator [httpx .Request , None , None ]:
169+ if api_key_getter_fn is not None :
170+ token_val = api_key_getter_fn ()
171+ if token_val :
172+ request .headers ['Authorization' ] = f'Bearer { token_val } '
173+ if obo_token_getter_fn is not None :
174+ obo_val = obo_token_getter_fn ()
175+ if obo_val :
176+ request .headers ['X-S2-OBO' ] = obo_val
177+ yield request
178+
179+ if t is not None :
180+ http_client = httpx .Client (
181+ timeout = t ,
182+ auth = OpenAIAuth (),
183+ )
184+ else :
185+ http_client = httpx .Client (
186+ timeout = httpx .Timeout (timeout = 600 , connect = 5.0 ), # default OpenAI timeout
187+ auth = OpenAIAuth (),
188+ )
189+
117190 # OpenAI / Azure OpenAI path
118191 openai_kwargs = dict (
119192 base_url = info .connection_url ,
120- api_key = token ,
193+ api_key = 'placeholder' ,
121194 model = model_name ,
122195 )
123- if http_client is not None :
124- openai_kwargs ['http_client' ] = http_client
196+ openai_kwargs ['http_client' ] = http_client
125197 return OpenAIEmbeddings (
126198 ** openai_kwargs ,
127199 ** kwargs ,
0 commit comments