-
Notifications
You must be signed in to change notification settings - Fork 22
fix: Make auth tokens resolved dynamically per request. #108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e5e72a5
0f83705
8492a9d
03736bd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| import os | ||
| from collections.abc import Generator | ||
| from typing import Any | ||
| from typing import Callable | ||
| from typing import Optional | ||
|
|
@@ -7,6 +8,7 @@ | |
| import httpx | ||
|
|
||
| from singlestoredb import manage_workspaces | ||
| from singlestoredb.management.inference_api import InferenceAPIInfo | ||
|
|
||
| try: | ||
| from langchain_openai import OpenAIEmbeddings | ||
|
|
@@ -31,43 +33,87 @@ | |
|
|
||
| def SingleStoreEmbeddingsFactory( | ||
| model_name: str, | ||
| api_key: Optional[str] = None, | ||
| api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None, | ||
| http_client: Optional[httpx.Client] = None, | ||
| obo_token: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None, | ||
| obo_token_getter: Optional[Callable[[], Optional[str]]] = None, | ||
| base_url: Optional[str] = None, | ||
| hosting_platform: Optional[str] = None, | ||
| timeout: Optional[float] = None, | ||
| **kwargs: Any, | ||
| ) -> Union[OpenAIEmbeddings, BedrockEmbeddings]: | ||
| """Return an embeddings model instance (OpenAIEmbeddings or BedrockEmbeddings). | ||
| """ | ||
| inference_api_manager = ( | ||
| manage_workspaces().organizations.current.inference_apis | ||
| ) | ||
| info = inference_api_manager.get(model_name=model_name) | ||
| token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN') | ||
| token = api_key if api_key is not None else token_env | ||
| # Handle api_key and obo_token as callable functions | ||
| if callable(api_key): | ||
| api_key_getter_fn = api_key | ||
| else: | ||
| def api_key_getter_fn() -> Optional[str]: | ||
| if api_key is None: | ||
| return os.environ.get('SINGLESTOREDB_USER_TOKEN') | ||
| return api_key | ||
|
|
||
| if obo_token_getter is not None: | ||
| obo_token_getter_fn = obo_token_getter | ||
| else: | ||
| if callable(obo_token): | ||
| obo_token_getter_fn = obo_token | ||
| else: | ||
| def obo_token_getter_fn() -> Optional[str]: | ||
| return obo_token | ||
|
|
||
| # handle model info | ||
| if base_url is None: | ||
| base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL') | ||
| if hosting_platform is None: | ||
| hosting_platform = os.environ.get('SINGLESTOREDB_INFERENCE_API_HOSTING_PLATFORM') | ||
| if base_url is None or hosting_platform is None: | ||
| inference_api_manager = ( | ||
| manage_workspaces().organizations.current.inference_apis | ||
| ) | ||
| info = inference_api_manager.get(model_name=model_name) | ||
| else: | ||
| info = InferenceAPIInfo( | ||
| service_id='', | ||
| model_name=model_name, | ||
| name='', | ||
| connection_url=base_url, | ||
| project_id='', | ||
| hosting_platform=hosting_platform, | ||
| ) | ||
| if base_url is not None: | ||
| info.connection_url = base_url | ||
| if hosting_platform is not None: | ||
| info.hosting_platform = hosting_platform | ||
|
|
||
| # Extract timeouts from http_client if provided | ||
| t = http_client.timeout if http_client is not None else None | ||
| connect_timeout = None | ||
| read_timeout = None | ||
| if t is not None: | ||
| if isinstance(t, httpx.Timeout): | ||
| if t.connect is not None: | ||
| connect_timeout = float(t.connect) | ||
| if t.read is not None: | ||
| read_timeout = float(t.read) | ||
| if connect_timeout is None and read_timeout is not None: | ||
| connect_timeout = read_timeout | ||
| if read_timeout is None and connect_timeout is not None: | ||
| read_timeout = connect_timeout | ||
| elif isinstance(t, (int, float)): | ||
| connect_timeout = float(t) | ||
| read_timeout = float(t) | ||
| if timeout is not None: | ||
| connect_timeout = timeout | ||
| read_timeout = timeout | ||
| t = httpx.Timeout(timeout) | ||
|
Comment on lines
+48
to
+109
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code looks nearly identical to the code in SingleStoreChatFactory. Is there a reason they can't be combined into a single function that is called to do these operations instead of duplicating code?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will do this refactoring as part of cleaning up the code along with the deprecated parameters. |
||
|
|
||
| if info.hosting_platform == 'Amazon': | ||
| # Instantiate Bedrock client | ||
| cfg_kwargs = { | ||
| 'signature_version': UNSIGNED, | ||
| 'retries': {'max_attempts': 1, 'mode': 'standard'}, | ||
| } | ||
| # Extract timeouts from http_client if provided | ||
| t = http_client.timeout if http_client is not None else None | ||
| connect_timeout = None | ||
| read_timeout = None | ||
| if t is not None: | ||
| if isinstance(t, httpx.Timeout): | ||
| if t.connect is not None: | ||
| connect_timeout = float(t.connect) | ||
| if t.read is not None: | ||
| read_timeout = float(t.read) | ||
| if connect_timeout is None and read_timeout is not None: | ||
| connect_timeout = read_timeout | ||
| if read_timeout is None and connect_timeout is not None: | ||
| read_timeout = connect_timeout | ||
| elif isinstance(t, (int, float)): | ||
| connect_timeout = float(t) | ||
| read_timeout = float(t) | ||
| if read_timeout is not None: | ||
| cfg_kwargs['read_timeout'] = read_timeout | ||
| if connect_timeout is not None: | ||
|
|
@@ -85,12 +131,14 @@ def SingleStoreEmbeddingsFactory( | |
|
|
||
| def _inject_headers(request: Any, **_ignored: Any) -> None: | ||
| """Inject dynamic auth/OBO headers prior to Bedrock sending.""" | ||
| if obo_token_getter is not None: | ||
| obo_val = obo_token_getter() | ||
| if api_key_getter_fn is not None: | ||
| token_val = api_key_getter_fn() | ||
| if token_val: | ||
| request.headers['Authorization'] = f'Bearer {token_val}' | ||
| if obo_token_getter_fn is not None: | ||
| obo_val = obo_token_getter_fn() | ||
| if obo_val: | ||
| request.headers['X-S2-OBO'] = obo_val | ||
| if token: | ||
| request.headers['Authorization'] = f'Bearer {token}' | ||
| request.headers.pop('X-Amz-Date', None) | ||
| request.headers.pop('X-Amz-Security-Token', None) | ||
|
|
||
|
|
@@ -114,14 +162,38 @@ def _inject_headers(request: Any, **_ignored: Any) -> None: | |
| **kwargs, | ||
| ) | ||
|
|
||
| class OpenAIAuth(httpx.Auth): | ||
| def auth_flow( | ||
| self, request: httpx.Request, | ||
| ) -> Generator[httpx.Request, None, None]: | ||
| if api_key_getter_fn is not None: | ||
| token_val = api_key_getter_fn() | ||
| if token_val: | ||
| request.headers['Authorization'] = f'Bearer {token_val}' | ||
| if obo_token_getter_fn is not None: | ||
| obo_val = obo_token_getter_fn() | ||
| if obo_val: | ||
| request.headers['X-S2-OBO'] = obo_val | ||
| yield request | ||
|
|
||
| if t is not None: | ||
| http_client = httpx.Client( | ||
| timeout=t, | ||
| auth=OpenAIAuth(), | ||
| ) | ||
| else: | ||
| http_client = httpx.Client( | ||
| timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout | ||
| auth=OpenAIAuth(), | ||
| ) | ||
|
|
||
| # OpenAI / Azure OpenAI path | ||
| openai_kwargs = dict( | ||
| base_url=info.connection_url, | ||
| api_key=token, | ||
| api_key='placeholder', | ||
| model=model_name, | ||
| ) | ||
| if http_client is not None: | ||
| openai_kwargs['http_client'] = http_client | ||
| openai_kwargs['http_client'] = http_client | ||
| return OpenAIEmbeddings( | ||
| **openai_kwargs, | ||
| **kwargs, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.