Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 103 additions & 31 deletions singlestoredb/ai/chat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from collections.abc import Generator
from typing import Any
from typing import Callable
from typing import Optional
Expand All @@ -7,6 +8,7 @@
import httpx

from singlestoredb import manage_workspaces
from singlestoredb.management.inference_api import InferenceAPIInfo

try:
from langchain_openai import ChatOpenAI
Expand All @@ -31,44 +33,88 @@

def SingleStoreChatFactory(
model_name: str,
api_key: Optional[str] = None,
api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
streaming: bool = True,
http_client: Optional[httpx.Client] = None,
obo_token: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
base_url: Optional[str] = None,
hosting_platform: Optional[str] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> Union[ChatOpenAI, ChatBedrockConverse]:
"""Return a chat model instance (ChatOpenAI or ChatBedrockConverse).
"""
inference_api_manager = (
manage_workspaces().organizations.current.inference_apis
)
info = inference_api_manager.get(model_name=model_name)
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
token = api_key if api_key is not None else token_env
# Handle api_key and obo_token as callable functions
if callable(api_key):
api_key_getter_fn = api_key
else:
def api_key_getter_fn() -> Optional[str]:
if api_key is None:
return os.environ.get('SINGLESTOREDB_USER_TOKEN')
return api_key

if obo_token_getter is not None:
obo_token_getter_fn = obo_token_getter
else:
if callable(obo_token):
obo_token_getter_fn = obo_token
else:
def obo_token_getter_fn() -> Optional[str]:
return obo_token

# handle model info
if base_url is None:
base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL')
if hosting_platform is None:
hosting_platform = os.environ.get('SINGLESTOREDB_INFERENCE_API_HOSTING_PLATFORM')
if base_url is None or hosting_platform is None:
inference_api_manager = (
manage_workspaces().organizations.current.inference_apis
)
info = inference_api_manager.get(model_name=model_name)
else:
info = InferenceAPIInfo(
service_id='',
model_name=model_name,
name='',
connection_url=base_url,
project_id='',
hosting_platform=hosting_platform,
)
if base_url is not None:
info.connection_url = base_url
if hosting_platform is not None:
info.hosting_platform = hosting_platform

# Extract timeouts from http_client if provided
t = http_client.timeout if http_client is not None else None
connect_timeout = None
read_timeout = None
if t is not None:
if isinstance(t, httpx.Timeout):
if t.connect is not None:
connect_timeout = float(t.connect)
if t.read is not None:
read_timeout = float(t.read)
if connect_timeout is None and read_timeout is not None:
connect_timeout = read_timeout
if read_timeout is None and connect_timeout is not None:
read_timeout = connect_timeout
elif isinstance(t, (int, float)):
connect_timeout = float(t)
read_timeout = float(t)
if timeout is not None:
connect_timeout = timeout
read_timeout = timeout
t = httpx.Timeout(timeout)

if info.hosting_platform == 'Amazon':
# Instantiate Bedrock client
cfg_kwargs = {
'signature_version': UNSIGNED,
'retries': {'max_attempts': 1, 'mode': 'standard'},
}
# Extract timeouts from http_client if provided
t = http_client.timeout if http_client is not None else None
connect_timeout = None
read_timeout = None
if t is not None:
if isinstance(t, httpx.Timeout):
if t.connect is not None:
connect_timeout = float(t.connect)
if t.read is not None:
read_timeout = float(t.read)
if connect_timeout is None and read_timeout is not None:
connect_timeout = read_timeout
if read_timeout is None and connect_timeout is not None:
read_timeout = connect_timeout
elif isinstance(t, (int, float)):
connect_timeout = float(t)
read_timeout = float(t)
if read_timeout is not None:
cfg_kwargs['read_timeout'] = read_timeout
if connect_timeout is not None:
Expand All @@ -86,12 +132,14 @@ def SingleStoreChatFactory(

def _inject_headers(request: Any, **_ignored: Any) -> None:
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
if obo_token_getter is not None:
obo_val = obo_token_getter()
if api_key_getter_fn is not None:
token_val = api_key_getter_fn()
if token_val:
request.headers['Authorization'] = f'Bearer {token_val}'
if obo_token_getter_fn is not None:
obo_val = obo_token_getter_fn()
if obo_val:
request.headers['X-S2-OBO'] = obo_val
if token:
request.headers['Authorization'] = f'Bearer {token}'
request.headers.pop('X-Amz-Date', None)
request.headers.pop('X-Amz-Security-Token', None)

Expand Down Expand Up @@ -124,15 +172,39 @@ def _inject_headers(request: Any, **_ignored: Any) -> None:
**kwargs,
)

class OpenAIAuth(httpx.Auth):
def auth_flow(
self, request: httpx.Request,
) -> Generator[httpx.Request, None, None]:
if api_key_getter_fn is not None:
token_val = api_key_getter_fn()
if token_val:
request.headers['Authorization'] = f'Bearer {token_val}'
if obo_token_getter_fn is not None:
obo_val = obo_token_getter_fn()
if obo_val:
request.headers['X-S2-OBO'] = obo_val
yield request

if t is not None:
http_client = httpx.Client(
timeout=t,
auth=OpenAIAuth(),
)
else:
http_client = httpx.Client(
timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout
auth=OpenAIAuth(),
)

# OpenAI / Azure OpenAI path
openai_kwargs = dict(
base_url=info.connection_url,
api_key=token,
api_key='placeholder',
model=model_name,
streaming=streaming,
)
if http_client is not None:
openai_kwargs['http_client'] = http_client
openai_kwargs['http_client'] = http_client
return ChatOpenAI(
**openai_kwargs,
**kwargs,
Expand Down
134 changes: 103 additions & 31 deletions singlestoredb/ai/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from collections.abc import Generator
from typing import Any
from typing import Callable
from typing import Optional
Expand All @@ -7,6 +8,7 @@
import httpx

from singlestoredb import manage_workspaces
from singlestoredb.management.inference_api import InferenceAPIInfo

try:
from langchain_openai import OpenAIEmbeddings
Expand All @@ -31,43 +33,87 @@

def SingleStoreEmbeddingsFactory(
model_name: str,
api_key: Optional[str] = None,
api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
http_client: Optional[httpx.Client] = None,
obo_token: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
base_url: Optional[str] = None,
hosting_platform: Optional[str] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> Union[OpenAIEmbeddings, BedrockEmbeddings]:
"""Return an embeddings model instance (OpenAIEmbeddings or BedrockEmbeddings).
"""
inference_api_manager = (
manage_workspaces().organizations.current.inference_apis
)
info = inference_api_manager.get(model_name=model_name)
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
token = api_key if api_key is not None else token_env
# Handle api_key and obo_token as callable functions
if callable(api_key):
api_key_getter_fn = api_key
else:
def api_key_getter_fn() -> Optional[str]:
if api_key is None:
return os.environ.get('SINGLESTOREDB_USER_TOKEN')
return api_key

if obo_token_getter is not None:
obo_token_getter_fn = obo_token_getter
else:
if callable(obo_token):
obo_token_getter_fn = obo_token
else:
def obo_token_getter_fn() -> Optional[str]:
return obo_token

# handle model info
if base_url is None:
base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL')
if hosting_platform is None:
hosting_platform = os.environ.get('SINGLESTOREDB_INFERENCE_API_HOSTING_PLATFORM')
if base_url is None or hosting_platform is None:
inference_api_manager = (
manage_workspaces().organizations.current.inference_apis
)
info = inference_api_manager.get(model_name=model_name)
else:
info = InferenceAPIInfo(
service_id='',
model_name=model_name,
name='',
connection_url=base_url,
project_id='',
hosting_platform=hosting_platform,
)
if base_url is not None:
info.connection_url = base_url
if hosting_platform is not None:
info.hosting_platform = hosting_platform

# Extract timeouts from http_client if provided
t = http_client.timeout if http_client is not None else None
connect_timeout = None
read_timeout = None
if t is not None:
if isinstance(t, httpx.Timeout):
if t.connect is not None:
connect_timeout = float(t.connect)
if t.read is not None:
read_timeout = float(t.read)
if connect_timeout is None and read_timeout is not None:
connect_timeout = read_timeout
if read_timeout is None and connect_timeout is not None:
read_timeout = connect_timeout
elif isinstance(t, (int, float)):
connect_timeout = float(t)
read_timeout = float(t)
if timeout is not None:
connect_timeout = timeout
read_timeout = timeout
t = httpx.Timeout(timeout)
Comment on lines +48 to +109
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code looks nearly identical to the code in SingleStoreChatFactory. Is there a reason they can't be combined into a single function that is called to do these operations instead of duplicating code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do this refactoring as part of cleaning up the code along with the deprecated parameters.


if info.hosting_platform == 'Amazon':
# Instantiate Bedrock client
cfg_kwargs = {
'signature_version': UNSIGNED,
'retries': {'max_attempts': 1, 'mode': 'standard'},
}
# Extract timeouts from http_client if provided
t = http_client.timeout if http_client is not None else None
connect_timeout = None
read_timeout = None
if t is not None:
if isinstance(t, httpx.Timeout):
if t.connect is not None:
connect_timeout = float(t.connect)
if t.read is not None:
read_timeout = float(t.read)
if connect_timeout is None and read_timeout is not None:
connect_timeout = read_timeout
if read_timeout is None and connect_timeout is not None:
read_timeout = connect_timeout
elif isinstance(t, (int, float)):
connect_timeout = float(t)
read_timeout = float(t)
if read_timeout is not None:
cfg_kwargs['read_timeout'] = read_timeout
if connect_timeout is not None:
Expand All @@ -85,12 +131,14 @@ def SingleStoreEmbeddingsFactory(

def _inject_headers(request: Any, **_ignored: Any) -> None:
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
if obo_token_getter is not None:
obo_val = obo_token_getter()
if api_key_getter_fn is not None:
token_val = api_key_getter_fn()
if token_val:
request.headers['Authorization'] = f'Bearer {token_val}'
if obo_token_getter_fn is not None:
obo_val = obo_token_getter_fn()
if obo_val:
request.headers['X-S2-OBO'] = obo_val
if token:
request.headers['Authorization'] = f'Bearer {token}'
request.headers.pop('X-Amz-Date', None)
request.headers.pop('X-Amz-Security-Token', None)

Expand All @@ -114,14 +162,38 @@ def _inject_headers(request: Any, **_ignored: Any) -> None:
**kwargs,
)

class OpenAIAuth(httpx.Auth):
def auth_flow(
self, request: httpx.Request,
) -> Generator[httpx.Request, None, None]:
if api_key_getter_fn is not None:
token_val = api_key_getter_fn()
if token_val:
request.headers['Authorization'] = f'Bearer {token_val}'
if obo_token_getter_fn is not None:
obo_val = obo_token_getter_fn()
if obo_val:
request.headers['X-S2-OBO'] = obo_val
yield request

if t is not None:
http_client = httpx.Client(
timeout=t,
auth=OpenAIAuth(),
)
else:
http_client = httpx.Client(
timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout
auth=OpenAIAuth(),
)

# OpenAI / Azure OpenAI path
openai_kwargs = dict(
base_url=info.connection_url,
api_key=token,
api_key='placeholder',
model=model_name,
)
if http_client is not None:
openai_kwargs['http_client'] = http_client
openai_kwargs['http_client'] = http_client
return OpenAIEmbeddings(
**openai_kwargs,
**kwargs,
Expand Down