diff --git a/singlestoredb/ai/chat.py b/singlestoredb/ai/chat.py index cc94b699..4878c96c 100644 --- a/singlestoredb/ai/chat.py +++ b/singlestoredb/ai/chat.py @@ -1,5 +1,4 @@ import os -from collections.abc import Generator from typing import Any from typing import Callable from typing import Optional @@ -33,36 +32,16 @@ def SingleStoreChatFactory( model_name: str, - api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None, + api_key: Optional[str] = None, streaming: bool = True, http_client: Optional[httpx.Client] = None, - obo_token: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None, obo_token_getter: Optional[Callable[[], Optional[str]]] = None, base_url: Optional[str] = None, hosting_platform: Optional[str] = None, - timeout: Optional[float] = None, **kwargs: Any, ) -> Union[ChatOpenAI, ChatBedrockConverse]: """Return a chat model instance (ChatOpenAI or ChatBedrockConverse). """ - # Handle api_key and obo_token as callable functions - if callable(api_key): - api_key_getter_fn = api_key - else: - def api_key_getter_fn() -> Optional[str]: - if api_key is None: - return os.environ.get('SINGLESTOREDB_USER_TOKEN') - return api_key - - if obo_token_getter is not None: - obo_token_getter_fn = obo_token_getter - else: - if callable(obo_token): - obo_token_getter_fn = obo_token - else: - def obo_token_getter_fn() -> Optional[str]: - return obo_token - # handle model info if base_url is None: base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL') @@ -104,10 +83,6 @@ def obo_token_getter_fn() -> Optional[str]: elif isinstance(t, (int, float)): connect_timeout = float(t) read_timeout = float(t) - if timeout is not None: - connect_timeout = timeout - read_timeout = timeout - t = httpx.Timeout(timeout) if info.hosting_platform == 'Amazon': # Instantiate Bedrock client @@ -132,12 +107,12 @@ def obo_token_getter_fn() -> Optional[str]: def _inject_headers(request: Any, **_ignored: Any) -> None: """Inject dynamic auth/OBO headers prior to Bedrock sending.""" - if api_key_getter_fn is not None: - token_val = api_key_getter_fn() - if token_val: - request.headers['Authorization'] = f'Bearer {token_val}' - if obo_token_getter_fn is not None: - obo_val = obo_token_getter_fn() + token_env_val = os.environ.get('SINGLESTOREDB_USER_TOKEN') + token_val = api_key if api_key is not None else token_env_val + if token_val: + request.headers['Authorization'] = f'Bearer {token_val}' + if obo_token_getter is not None: + obo_val = obo_token_getter() if obo_val: request.headers['X-S2-OBO'] = obo_val request.headers.pop('X-Amz-Date', None) @@ -172,39 +147,18 @@ def _inject_headers(request: Any, **_ignored: Any) -> None: **kwargs, ) - class OpenAIAuth(httpx.Auth): - def auth_flow( - self, request: httpx.Request, - ) -> Generator[httpx.Request, None, None]: - if api_key_getter_fn is not None: - token_val = api_key_getter_fn() - if token_val: - request.headers['Authorization'] = f'Bearer {token_val}' - if obo_token_getter_fn is not None: - obo_val = obo_token_getter_fn() - if obo_val: - request.headers['X-S2-OBO'] = obo_val - yield request - - if t is not None: - http_client = httpx.Client( - timeout=t, - auth=OpenAIAuth(), - ) - else: - http_client = httpx.Client( - timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout - auth=OpenAIAuth(), - ) - # OpenAI / Azure OpenAI path + token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN') + token = api_key if api_key is not None else token_env + openai_kwargs = dict( base_url=info.connection_url, - api_key='placeholder', + api_key=token, model=model_name, streaming=streaming, ) - openai_kwargs['http_client'] = http_client + if http_client is not None: + openai_kwargs['http_client'] = http_client return ChatOpenAI( **openai_kwargs, **kwargs, diff --git a/singlestoredb/ai/embeddings.py b/singlestoredb/ai/embeddings.py index e85f26a7..fe23331c 100644 --- a/singlestoredb/ai/embeddings.py +++ b/singlestoredb/ai/embeddings.py @@ -1,5 +1,4 @@ import os -from collections.abc import Generator from typing import Any from typing import Callable from typing import Optional @@ -33,35 +32,15 @@ def SingleStoreEmbeddingsFactory( model_name: str, - api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None, + api_key: Optional[str] = None, http_client: Optional[httpx.Client] = None, - obo_token: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None, obo_token_getter: Optional[Callable[[], Optional[str]]] = None, base_url: Optional[str] = None, hosting_platform: Optional[str] = None, - timeout: Optional[float] = None, **kwargs: Any, ) -> Union[OpenAIEmbeddings, BedrockEmbeddings]: """Return an embeddings model instance (OpenAIEmbeddings or BedrockEmbeddings). """ - # Handle api_key and obo_token as callable functions - if callable(api_key): - api_key_getter_fn = api_key - else: - def api_key_getter_fn() -> Optional[str]: - if api_key is None: - return os.environ.get('SINGLESTOREDB_USER_TOKEN') - return api_key - - if obo_token_getter is not None: - obo_token_getter_fn = obo_token_getter - else: - if callable(obo_token): - obo_token_getter_fn = obo_token - else: - def obo_token_getter_fn() -> Optional[str]: - return obo_token - # handle model info if base_url is None: base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL') @@ -103,10 +82,6 @@ def obo_token_getter_fn() -> Optional[str]: elif isinstance(t, (int, float)): connect_timeout = float(t) read_timeout = float(t) - if timeout is not None: - connect_timeout = timeout - read_timeout = timeout - t = httpx.Timeout(timeout) if info.hosting_platform == 'Amazon': # Instantiate Bedrock client @@ -131,12 +106,12 @@ def obo_token_getter_fn() -> Optional[str]: def _inject_headers(request: Any, **_ignored: Any) -> None: """Inject dynamic auth/OBO headers prior to Bedrock sending.""" - if api_key_getter_fn is not None: - token_val = api_key_getter_fn() - if token_val: - request.headers['Authorization'] = f'Bearer {token_val}' - if obo_token_getter_fn is not None: - obo_val = obo_token_getter_fn() + token_env_val = os.environ.get('SINGLESTOREDB_USER_TOKEN') + token_val = api_key if api_key is not None else token_env_val + if token_val: + request.headers['Authorization'] = f'Bearer {token_val}' + if obo_token_getter is not None: + obo_val = obo_token_getter() if obo_val: request.headers['X-S2-OBO'] = obo_val request.headers.pop('X-Amz-Date', None) @@ -162,38 +137,17 @@ def _inject_headers(request: Any, **_ignored: Any) -> None: **kwargs, ) - class OpenAIAuth(httpx.Auth): - def auth_flow( - self, request: httpx.Request, - ) -> Generator[httpx.Request, None, None]: - if api_key_getter_fn is not None: - token_val = api_key_getter_fn() - if token_val: - request.headers['Authorization'] = f'Bearer {token_val}' - if obo_token_getter_fn is not None: - obo_val = obo_token_getter_fn() - if obo_val: - request.headers['X-S2-OBO'] = obo_val - yield request - - if t is not None: - http_client = httpx.Client( - timeout=t, - auth=OpenAIAuth(), - ) - else: - http_client = httpx.Client( - timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout - auth=OpenAIAuth(), - ) - # OpenAI / Azure OpenAI path + token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN') + token = api_key if api_key is not None else token_env + openai_kwargs = dict( base_url=info.connection_url, - api_key='placeholder', + api_key=token, model=model_name, ) - openai_kwargs['http_client'] = http_client + if http_client is not None: + openai_kwargs['http_client'] = http_client return OpenAIEmbeddings( **openai_kwargs, **kwargs,