Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 13 additions & 59 deletions singlestoredb/ai/chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from collections.abc import Generator
from typing import Any
from typing import Callable
from typing import Optional
Expand Down Expand Up @@ -33,36 +32,16 @@

def SingleStoreChatFactory(
model_name: str,
api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
api_key: Optional[str] = None,
streaming: bool = True,
http_client: Optional[httpx.Client] = None,
obo_token: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
base_url: Optional[str] = None,
hosting_platform: Optional[str] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> Union[ChatOpenAI, ChatBedrockConverse]:
"""Return a chat model instance (ChatOpenAI or ChatBedrockConverse).
"""
# Handle api_key and obo_token as callable functions
if callable(api_key):
api_key_getter_fn = api_key
else:
def api_key_getter_fn() -> Optional[str]:
if api_key is None:
return os.environ.get('SINGLESTOREDB_USER_TOKEN')
return api_key

if obo_token_getter is not None:
obo_token_getter_fn = obo_token_getter
else:
if callable(obo_token):
obo_token_getter_fn = obo_token
else:
def obo_token_getter_fn() -> Optional[str]:
return obo_token

# handle model info
if base_url is None:
base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL')
Expand Down Expand Up @@ -104,10 +83,6 @@ def obo_token_getter_fn() -> Optional[str]:
elif isinstance(t, (int, float)):
connect_timeout = float(t)
read_timeout = float(t)
if timeout is not None:
connect_timeout = timeout
read_timeout = timeout
t = httpx.Timeout(timeout)

if info.hosting_platform == 'Amazon':
# Instantiate Bedrock client
Expand All @@ -132,12 +107,12 @@ def obo_token_getter_fn() -> Optional[str]:

def _inject_headers(request: Any, **_ignored: Any) -> None:
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
if api_key_getter_fn is not None:
token_val = api_key_getter_fn()
if token_val:
request.headers['Authorization'] = f'Bearer {token_val}'
if obo_token_getter_fn is not None:
obo_val = obo_token_getter_fn()
token_env_val = os.environ.get('SINGLESTOREDB_USER_TOKEN')
token_val = api_key if api_key is not None else token_env_val
if token_val:
request.headers['Authorization'] = f'Bearer {token_val}'
if obo_token_getter is not None:
obo_val = obo_token_getter()
if obo_val:
request.headers['X-S2-OBO'] = obo_val
request.headers.pop('X-Amz-Date', None)
Expand Down Expand Up @@ -172,39 +147,18 @@ def _inject_headers(request: Any, **_ignored: Any) -> None:
**kwargs,
)

class OpenAIAuth(httpx.Auth):
def auth_flow(
self, request: httpx.Request,
) -> Generator[httpx.Request, None, None]:
if api_key_getter_fn is not None:
token_val = api_key_getter_fn()
if token_val:
request.headers['Authorization'] = f'Bearer {token_val}'
if obo_token_getter_fn is not None:
obo_val = obo_token_getter_fn()
if obo_val:
request.headers['X-S2-OBO'] = obo_val
yield request

if t is not None:
http_client = httpx.Client(
timeout=t,
auth=OpenAIAuth(),
)
else:
http_client = httpx.Client(
timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout
auth=OpenAIAuth(),
)

# OpenAI / Azure OpenAI path
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
token = api_key if api_key is not None else token_env

openai_kwargs = dict(
base_url=info.connection_url,
api_key='placeholder',
api_key=token,
model=model_name,
streaming=streaming,
)
openai_kwargs['http_client'] = http_client
if http_client is not None:
openai_kwargs['http_client'] = http_client
return ChatOpenAI(
**openai_kwargs,
**kwargs,
Expand Down
72 changes: 13 additions & 59 deletions singlestoredb/ai/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from collections.abc import Generator
from typing import Any
from typing import Callable
from typing import Optional
Expand Down Expand Up @@ -33,35 +32,15 @@

def SingleStoreEmbeddingsFactory(
model_name: str,
api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
api_key: Optional[str] = None,
http_client: Optional[httpx.Client] = None,
obo_token: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
base_url: Optional[str] = None,
hosting_platform: Optional[str] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> Union[OpenAIEmbeddings, BedrockEmbeddings]:
"""Return an embeddings model instance (OpenAIEmbeddings or BedrockEmbeddings).
"""
# Handle api_key and obo_token as callable functions
if callable(api_key):
api_key_getter_fn = api_key
else:
def api_key_getter_fn() -> Optional[str]:
if api_key is None:
return os.environ.get('SINGLESTOREDB_USER_TOKEN')
return api_key

if obo_token_getter is not None:
obo_token_getter_fn = obo_token_getter
else:
if callable(obo_token):
obo_token_getter_fn = obo_token
else:
def obo_token_getter_fn() -> Optional[str]:
return obo_token

# handle model info
if base_url is None:
base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL')
Expand Down Expand Up @@ -103,10 +82,6 @@ def obo_token_getter_fn() -> Optional[str]:
elif isinstance(t, (int, float)):
connect_timeout = float(t)
read_timeout = float(t)
if timeout is not None:
connect_timeout = timeout
read_timeout = timeout
t = httpx.Timeout(timeout)

if info.hosting_platform == 'Amazon':
# Instantiate Bedrock client
Expand All @@ -131,12 +106,12 @@ def obo_token_getter_fn() -> Optional[str]:

def _inject_headers(request: Any, **_ignored: Any) -> None:
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
if api_key_getter_fn is not None:
token_val = api_key_getter_fn()
if token_val:
request.headers['Authorization'] = f'Bearer {token_val}'
if obo_token_getter_fn is not None:
obo_val = obo_token_getter_fn()
token_env_val = os.environ.get('SINGLESTOREDB_USER_TOKEN')
token_val = api_key if api_key is not None else token_env_val
if token_val:
request.headers['Authorization'] = f'Bearer {token_val}'
if obo_token_getter is not None:
obo_val = obo_token_getter()
if obo_val:
request.headers['X-S2-OBO'] = obo_val
request.headers.pop('X-Amz-Date', None)
Expand All @@ -162,38 +137,17 @@ def _inject_headers(request: Any, **_ignored: Any) -> None:
**kwargs,
)

class OpenAIAuth(httpx.Auth):
def auth_flow(
self, request: httpx.Request,
) -> Generator[httpx.Request, None, None]:
if api_key_getter_fn is not None:
token_val = api_key_getter_fn()
if token_val:
request.headers['Authorization'] = f'Bearer {token_val}'
if obo_token_getter_fn is not None:
obo_val = obo_token_getter_fn()
if obo_val:
request.headers['X-S2-OBO'] = obo_val
yield request

if t is not None:
http_client = httpx.Client(
timeout=t,
auth=OpenAIAuth(),
)
else:
http_client = httpx.Client(
timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout
auth=OpenAIAuth(),
)

# OpenAI / Azure OpenAI path
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
token = api_key if api_key is not None else token_env

openai_kwargs = dict(
base_url=info.connection_url,
api_key='placeholder',
api_key=token,
model=model_name,
)
openai_kwargs['http_client'] = http_client
if http_client is not None:
openai_kwargs['http_client'] = http_client
return OpenAIEmbeddings(
**openai_kwargs,
**kwargs,
Expand Down