diff --git a/singlestoredb/ai/chat.py b/singlestoredb/ai/chat.py index f3419c28..cc94b699 100644 --- a/singlestoredb/ai/chat.py +++ b/singlestoredb/ai/chat.py @@ -1,4 +1,5 @@ import os +from collections.abc import Generator from typing import Any from typing import Callable from typing import Optional @@ -7,6 +8,7 @@ import httpx from singlestoredb import manage_workspaces +from singlestoredb.management.inference_api import InferenceAPIInfo try: from langchain_openai import ChatOpenAI @@ -31,20 +33,81 @@ def SingleStoreChatFactory( model_name: str, - api_key: Optional[str] = None, + api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None, streaming: bool = True, http_client: Optional[httpx.Client] = None, + obo_token: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None, obo_token_getter: Optional[Callable[[], Optional[str]]] = None, + base_url: Optional[str] = None, + hosting_platform: Optional[str] = None, + timeout: Optional[float] = None, **kwargs: Any, ) -> Union[ChatOpenAI, ChatBedrockConverse]: """Return a chat model instance (ChatOpenAI or ChatBedrockConverse). """ - inference_api_manager = ( - manage_workspaces().organizations.current.inference_apis - ) - info = inference_api_manager.get(model_name=model_name) - token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN') - token = api_key if api_key is not None else token_env + # Handle api_key and obo_token as callable functions + if callable(api_key): + api_key_getter_fn = api_key + else: + def api_key_getter_fn() -> Optional[str]: + if api_key is None: + return os.environ.get('SINGLESTOREDB_USER_TOKEN') + return api_key + + if obo_token_getter is not None: + obo_token_getter_fn = obo_token_getter + else: + if callable(obo_token): + obo_token_getter_fn = obo_token + else: + def obo_token_getter_fn() -> Optional[str]: + return obo_token + + # handle model info + if base_url is None: + base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL') + if hosting_platform is None: + hosting_platform = os.environ.get('SINGLESTOREDB_INFERENCE_API_HOSTING_PLATFORM') + if base_url is None or hosting_platform is None: + inference_api_manager = ( + manage_workspaces().organizations.current.inference_apis + ) + info = inference_api_manager.get(model_name=model_name) + else: + info = InferenceAPIInfo( + service_id='', + model_name=model_name, + name='', + connection_url=base_url, + project_id='', + hosting_platform=hosting_platform, + ) + if base_url is not None: + info.connection_url = base_url + if hosting_platform is not None: + info.hosting_platform = hosting_platform + + # Extract timeouts from http_client if provided + t = http_client.timeout if http_client is not None else None + connect_timeout = None + read_timeout = None + if t is not None: + if isinstance(t, httpx.Timeout): + if t.connect is not None: + connect_timeout = float(t.connect) + if t.read is not None: + read_timeout = float(t.read) + if connect_timeout is None and read_timeout is not None: + connect_timeout = read_timeout + if read_timeout is None and connect_timeout is not None: + read_timeout = connect_timeout + elif isinstance(t, (int, float)): + connect_timeout = float(t) + read_timeout = float(t) + if timeout is not None: + connect_timeout = timeout + read_timeout = timeout + t = httpx.Timeout(timeout) if info.hosting_platform == 'Amazon': # Instantiate Bedrock client @@ -52,23 +115,6 @@ def SingleStoreChatFactory( 'signature_version': UNSIGNED, 'retries': {'max_attempts': 1, 'mode': 'standard'}, } - # Extract timeouts from http_client if provided - t = http_client.timeout if http_client is not None else None - connect_timeout = None - read_timeout = None - if t is not None: - if isinstance(t, httpx.Timeout): - if t.connect is not None: - connect_timeout = float(t.connect) - if t.read is not None: - read_timeout = float(t.read) - if connect_timeout is None and read_timeout is not None: - connect_timeout = read_timeout - if read_timeout is None and connect_timeout is not None: - read_timeout = connect_timeout - elif isinstance(t, (int, float)): - connect_timeout = float(t) - read_timeout = float(t) if read_timeout is not None: cfg_kwargs['read_timeout'] = read_timeout if connect_timeout is not None: @@ -86,12 +132,14 @@ def SingleStoreChatFactory( def _inject_headers(request: Any, **_ignored: Any) -> None: """Inject dynamic auth/OBO headers prior to Bedrock sending.""" - if obo_token_getter is not None: - obo_val = obo_token_getter() + if api_key_getter_fn is not None: + token_val = api_key_getter_fn() + if token_val: + request.headers['Authorization'] = f'Bearer {token_val}' + if obo_token_getter_fn is not None: + obo_val = obo_token_getter_fn() if obo_val: request.headers['X-S2-OBO'] = obo_val - if token: - request.headers['Authorization'] = f'Bearer {token}' request.headers.pop('X-Amz-Date', None) request.headers.pop('X-Amz-Security-Token', None) @@ -124,15 +172,39 @@ def _inject_headers(request: Any, **_ignored: Any) -> None: **kwargs, ) + class OpenAIAuth(httpx.Auth): + def auth_flow( + self, request: httpx.Request, + ) -> Generator[httpx.Request, None, None]: + if api_key_getter_fn is not None: + token_val = api_key_getter_fn() + if token_val: + request.headers['Authorization'] = f'Bearer {token_val}' + if obo_token_getter_fn is not None: + obo_val = obo_token_getter_fn() + if obo_val: + request.headers['X-S2-OBO'] = obo_val + yield request + + if t is not None: + http_client = httpx.Client( + timeout=t, + auth=OpenAIAuth(), + ) + else: + http_client = httpx.Client( + timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout + auth=OpenAIAuth(), + ) + # OpenAI / Azure OpenAI path openai_kwargs = dict( base_url=info.connection_url, - api_key=token, + api_key='placeholder', model=model_name, streaming=streaming, ) - if http_client is not None: - openai_kwargs['http_client'] = http_client + openai_kwargs['http_client'] = http_client return ChatOpenAI( **openai_kwargs, **kwargs, diff --git a/singlestoredb/ai/embeddings.py b/singlestoredb/ai/embeddings.py index 2449d94f..e85f26a7 100644 --- a/singlestoredb/ai/embeddings.py +++ b/singlestoredb/ai/embeddings.py @@ -1,4 +1,5 @@ import os +from collections.abc import Generator from typing import Any from typing import Callable from typing import Optional @@ -7,6 +8,7 @@ import httpx from singlestoredb import manage_workspaces +from singlestoredb.management.inference_api import InferenceAPIInfo try: from langchain_openai import OpenAIEmbeddings @@ -31,19 +33,80 @@ def SingleStoreEmbeddingsFactory( model_name: str, - api_key: Optional[str] = None, + api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None, http_client: Optional[httpx.Client] = None, + obo_token: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None, obo_token_getter: Optional[Callable[[], Optional[str]]] = None, + base_url: Optional[str] = None, + hosting_platform: Optional[str] = None, + timeout: Optional[float] = None, **kwargs: Any, ) -> Union[OpenAIEmbeddings, BedrockEmbeddings]: """Return an embeddings model instance (OpenAIEmbeddings or BedrockEmbeddings). """ - inference_api_manager = ( - manage_workspaces().organizations.current.inference_apis - ) - info = inference_api_manager.get(model_name=model_name) - token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN') - token = api_key if api_key is not None else token_env + # Handle api_key and obo_token as callable functions + if callable(api_key): + api_key_getter_fn = api_key + else: + def api_key_getter_fn() -> Optional[str]: + if api_key is None: + return os.environ.get('SINGLESTOREDB_USER_TOKEN') + return api_key + + if obo_token_getter is not None: + obo_token_getter_fn = obo_token_getter + else: + if callable(obo_token): + obo_token_getter_fn = obo_token + else: + def obo_token_getter_fn() -> Optional[str]: + return obo_token + + # handle model info + if base_url is None: + base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL') + if hosting_platform is None: + hosting_platform = os.environ.get('SINGLESTOREDB_INFERENCE_API_HOSTING_PLATFORM') + if base_url is None or hosting_platform is None: + inference_api_manager = ( + manage_workspaces().organizations.current.inference_apis + ) + info = inference_api_manager.get(model_name=model_name) + else: + info = InferenceAPIInfo( + service_id='', + model_name=model_name, + name='', + connection_url=base_url, + project_id='', + hosting_platform=hosting_platform, + ) + if base_url is not None: + info.connection_url = base_url + if hosting_platform is not None: + info.hosting_platform = hosting_platform + + # Extract timeouts from http_client if provided + t = http_client.timeout if http_client is not None else None + connect_timeout = None + read_timeout = None + if t is not None: + if isinstance(t, httpx.Timeout): + if t.connect is not None: + connect_timeout = float(t.connect) + if t.read is not None: + read_timeout = float(t.read) + if connect_timeout is None and read_timeout is not None: + connect_timeout = read_timeout + if read_timeout is None and connect_timeout is not None: + read_timeout = connect_timeout + elif isinstance(t, (int, float)): + connect_timeout = float(t) + read_timeout = float(t) + if timeout is not None: + connect_timeout = timeout + read_timeout = timeout + t = httpx.Timeout(timeout) if info.hosting_platform == 'Amazon': # Instantiate Bedrock client @@ -51,23 +114,6 @@ def SingleStoreEmbeddingsFactory( 'signature_version': UNSIGNED, 'retries': {'max_attempts': 1, 'mode': 'standard'}, } - # Extract timeouts from http_client if provided - t = http_client.timeout if http_client is not None else None - connect_timeout = None - read_timeout = None - if t is not None: - if isinstance(t, httpx.Timeout): - if t.connect is not None: - connect_timeout = float(t.connect) - if t.read is not None: - read_timeout = float(t.read) - if connect_timeout is None and read_timeout is not None: - connect_timeout = read_timeout - if read_timeout is None and connect_timeout is not None: - read_timeout = connect_timeout - elif isinstance(t, (int, float)): - connect_timeout = float(t) - read_timeout = float(t) if read_timeout is not None: cfg_kwargs['read_timeout'] = read_timeout if connect_timeout is not None: @@ -85,12 +131,14 @@ def SingleStoreEmbeddingsFactory( def _inject_headers(request: Any, **_ignored: Any) -> None: """Inject dynamic auth/OBO headers prior to Bedrock sending.""" - if obo_token_getter is not None: - obo_val = obo_token_getter() + if api_key_getter_fn is not None: + token_val = api_key_getter_fn() + if token_val: + request.headers['Authorization'] = f'Bearer {token_val}' + if obo_token_getter_fn is not None: + obo_val = obo_token_getter_fn() if obo_val: request.headers['X-S2-OBO'] = obo_val - if token: - request.headers['Authorization'] = f'Bearer {token}' request.headers.pop('X-Amz-Date', None) request.headers.pop('X-Amz-Security-Token', None) @@ -114,14 +162,38 @@ def _inject_headers(request: Any, **_ignored: Any) -> None: **kwargs, ) + class OpenAIAuth(httpx.Auth): + def auth_flow( + self, request: httpx.Request, + ) -> Generator[httpx.Request, None, None]: + if api_key_getter_fn is not None: + token_val = api_key_getter_fn() + if token_val: + request.headers['Authorization'] = f'Bearer {token_val}' + if obo_token_getter_fn is not None: + obo_val = obo_token_getter_fn() + if obo_val: + request.headers['X-S2-OBO'] = obo_val + yield request + + if t is not None: + http_client = httpx.Client( + timeout=t, + auth=OpenAIAuth(), + ) + else: + http_client = httpx.Client( + timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout + auth=OpenAIAuth(), + ) + # OpenAI / Azure OpenAI path openai_kwargs = dict( base_url=info.connection_url, - api_key=token, + api_key='placeholder', model=model_name, ) - if http_client is not None: - openai_kwargs['http_client'] = http_client + openai_kwargs['http_client'] = http_client return OpenAIEmbeddings( **openai_kwargs, **kwargs,