Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
a9b6611
feat: Introduce SingleStoreChat wrapper that uses interchangeably Ope…
mgiannakopoulos Sep 17, 2025
0e69c15
Provide option for getting 'X-S2-OBO' token for every request.
mgiannakopoulos Sep 17, 2025
6e96e96
Provide headers that indicate passthrough Amazon bedrock requests.
mgiannakopoulos Sep 17, 2025
c23f3fa
Hardcode dummy credentials and region info for ChatBedrockConverse cl…
mgiannakopoulos Sep 17, 2025
f2ccfbf
Rename 'region' parameter to 'region_name'.
mgiannakopoulos Sep 17, 2025
a2728dc
Expose 'streaming' parameter setting its oposite value to 'disable_st…
mgiannakopoulos Sep 17, 2025
a0a5197
Set the default value for 'streaming' paramter to True.
mgiannakopoulos Sep 17, 2025
3ea9400
Remove the cache option.
mgiannakopoulos Sep 17, 2025
321fd9b
Remove unsupported kargs from Bedrock calls.
mgiannakopoulos Sep 17, 2025
8ff4dfe
Replace composition wrapper with a factory method.
mgiannakopoulos Sep 18, 2025
f376141
Pass bedrock runtime client as client parameter.
mgiannakopoulos Sep 18, 2025
84ad037
Pass also the 'X-BEDROCK-CONVERSE' headers that indicate that the req…
mgiannakopoulos Sep 18, 2025
40de94e
Remove some amazon specific headers, along with validation, remove X-…
mgiannakopoulos Sep 18, 2025
48b79d7
Remove commented out code; set max retries to 1 for Amazon Bedrock mo…
mgiannakopoulos Sep 19, 2025
02b1f47
Use 'Union' return type to satisfy pre-commit checks for python versi…
mgiannakopoulos Sep 19, 2025
5db3fbf
Expose also the hostingPlatform for InferenceAPIInfo.
mgiannakopoulos Sep 19, 2025
2fdd49b
Do not use model prefix, rely on hosting platform.
mgiannakopoulos Sep 20, 2025
054c8e2
Introduce SingleStoreEmbeddingsFactory; small fixes.
mgiannakopoulos Sep 20, 2025
540100d
Fix openai langchain library.
mgiannakopoulos Sep 22, 2025
80b487b
Remove any comments that expose internal implementation details.
mgiannakopoulos Sep 22, 2025
ca059a3
Minor fixes.
mgiannakopoulos Sep 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 123 additions & 4 deletions singlestoredb/ai/chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import os
from typing import Any
from typing import Callable
from typing import Optional
from typing import Union

import httpx

from singlestoredb.fusion.handlers.utils import get_workspace_manager

Expand All @@ -11,30 +16,144 @@
'Please install it with `pip install langchain_openai`.',
)

try:
from langchain_aws import ChatBedrockConverse
except ImportError:
raise ImportError(
'Could not import langchain-aws python package. '
'Please install it with `pip install langchain-aws`.',
)

import boto3
from botocore import UNSIGNED
from botocore.config import Config


class SingleStoreChatOpenAI(ChatOpenAI):
def __init__(self, model_name: str, **kwargs: Any):
def __init__(self, model_name: str, api_key: Optional[str] = None, **kwargs: Any):
inference_api_manger = (
get_workspace_manager().organizations.current.inference_apis
)
info = inference_api_manger.get(model_name=model_name)
token = (
api_key
if api_key is not None
else os.environ.get('SINGLESTOREDB_USER_TOKEN')
)
super().__init__(
base_url=info.connection_url,
api_key=os.environ.get('SINGLESTOREDB_USER_TOKEN'),
api_key=token,
model=model_name,
**kwargs,
)


class SingleStoreChat(ChatOpenAI):
def __init__(self, model_name: str, **kwargs: Any):
def __init__(self, model_name: str, api_key: Optional[str] = None, **kwargs: Any):
inference_api_manger = (
get_workspace_manager().organizations.current.inference_apis
)
info = inference_api_manger.get(model_name=model_name)
token = (
api_key
if api_key is not None
else os.environ.get('SINGLESTOREDB_USER_TOKEN')
)
super().__init__(
base_url=info.connection_url,
api_key=os.environ.get('SINGLESTOREDB_USER_TOKEN'),
api_key=token,
model=model_name,
**kwargs,
)


def SingleStoreChatFactory(
model_name: str,
api_key: Optional[str] = None,
streaming: bool = True,
http_client: Optional[httpx.Client] = None,
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
**kwargs: Any,
) -> Union[ChatOpenAI, ChatBedrockConverse]:
"""Return a chat model instance (ChatOpenAI or ChatBedrockConverse).
"""
inference_api_manager = (
get_workspace_manager().organizations.current.inference_apis
)
info = inference_api_manager.get(model_name=model_name)
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
token = api_key if api_key is not None else token_env

if info.hosting_platform == 'Amazon':
# Instantiate Bedrock client
cfg_kwargs = {
'signature_version': UNSIGNED,
'retries': {'max_attempts': 1, 'mode': 'standard'},
}
if http_client is not None and http_client.timeout is not None:
cfg_kwargs['read_timeout'] = http_client.timeout
cfg_kwargs['connect_timeout'] = http_client.timeout

cfg = Config(**cfg_kwargs)
client = boto3.client(
'bedrock-runtime',
endpoint_url=info.connection_url,
region_name='us-east-1',
aws_access_key_id='placeholder',
aws_secret_access_key='placeholder',
config=cfg,
)

def _inject_headers(request: Any, **_ignored: Any) -> None:
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
if obo_token_getter is not None:
obo_val = obo_token_getter()
if obo_val:
request.headers['X-S2-OBO'] = obo_val
if token:
request.headers['Authorization'] = f'Bearer {token}'
request.headers.pop('X-Amz-Date', None)
request.headers.pop('X-Amz-Security-Token', None)

emitter = client._endpoint._event_emitter
emitter.register_first(
'before-send.bedrock-runtime.Converse',
_inject_headers,
)
emitter.register_first(
'before-send.bedrock-runtime.ConverseStream',
_inject_headers,
)
emitter.register_first(
'before-send.bedrock-runtime.InvokeModel',
_inject_headers,
)
emitter.register_first(
'before-send.bedrock-runtime.InvokeModelWithResponseStream',
_inject_headers,
)

return ChatBedrockConverse(
model_id=model_name,
endpoint_url=info.connection_url,
region_name='us-east-1',
aws_access_key_id='placeholder',
aws_secret_access_key='placeholder',
disable_streaming=not streaming,
client=client,
**kwargs,
)

# OpenAI / Azure OpenAI path
openai_kwargs = dict(
base_url=info.connection_url,
api_key=token,
model=model_name,
streaming=streaming,
)
if http_client is not None:
openai_kwargs['http_client'] = http_client
return ChatOpenAI(
**openai_kwargs,
**kwargs,
)
98 changes: 98 additions & 0 deletions singlestoredb/ai/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import os
from typing import Any
from typing import Callable
from typing import Optional
from typing import Union

import httpx

from singlestoredb.fusion.handlers.utils import get_workspace_manager

Expand All @@ -11,6 +16,18 @@
'Please install it with `pip install langchain_openai`.',
)

try:
from langchain_aws import BedrockEmbeddings
except ImportError:
raise ImportError(
'Could not import langchain-aws python package. '
'Please install it with `pip install langchain-aws`.',
)

import boto3
from botocore import UNSIGNED
from botocore.config import Config


class SingleStoreEmbeddings(OpenAIEmbeddings):

Expand All @@ -25,3 +42,84 @@ def __init__(self, model_name: str, **kwargs: Any):
model=model_name,
**kwargs,
)


def SingleStoreEmbeddingsFactory(
model_name: str,
api_key: Optional[str] = None,
http_client: Optional[httpx.Client] = None,
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
**kwargs: Any,
) -> Union[OpenAIEmbeddings, BedrockEmbeddings]:
"""Return an embeddings model instance (OpenAIEmbeddings or BedrockEmbeddings).
"""
inference_api_manager = (
get_workspace_manager().organizations.current.inference_apis
)
info = inference_api_manager.get(model_name=model_name)
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
token = api_key if api_key is not None else token_env

if info.hosting_platform == 'Amazon':
# Instantiate Bedrock client
cfg_kwargs = {
'signature_version': UNSIGNED,
'retries': {'max_attempts': 1, 'mode': 'standard'},
}
if http_client is not None and http_client.timeout is not None:
cfg_kwargs['read_timeout'] = http_client.timeout
cfg_kwargs['connect_timeout'] = http_client.timeout

cfg = Config(**cfg_kwargs)
client = boto3.client(
'bedrock-runtime',
endpoint_url=info.connection_url,
region_name='us-east-1',
aws_access_key_id='placeholder',
aws_secret_access_key='placeholder',
config=cfg,
)

def _inject_headers(request: Any, **_ignored: Any) -> None:
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
if obo_token_getter is not None:
obo_val = obo_token_getter()
if obo_val:
request.headers['X-S2-OBO'] = obo_val
if token:
request.headers['Authorization'] = f'Bearer {token}'
request.headers.pop('X-Amz-Date', None)
request.headers.pop('X-Amz-Security-Token', None)

emitter = client._endpoint._event_emitter
emitter.register_first(
'before-send.bedrock-runtime.InvokeModel',
_inject_headers,
)
emitter.register_first(
'before-send.bedrock-runtime.InvokeModelWithResponseStream',
_inject_headers,
)

return BedrockEmbeddings(
model_id=model_name,
endpoint_url=info.connection_url,
region_name='us-east-1',
aws_access_key_id='placeholder',
aws_secret_access_key='placeholder',
client=client,
**kwargs,
)

# OpenAI / Azure OpenAI path
openai_kwargs = dict(
base_url=info.connection_url,
api_key=token,
model=model_name,
)
if http_client is not None:
openai_kwargs['http_client'] = http_client
return OpenAIEmbeddings(
**openai_kwargs,
**kwargs,
)
4 changes: 4 additions & 0 deletions singlestoredb/management/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class InferenceAPIInfo(object):
name: str
connection_url: str
project_id: str
hosting_platform: str

def __init__(
self,
Expand All @@ -31,12 +32,14 @@ def __init__(
name: str,
connection_url: str,
project_id: str,
hosting_platform: str,
):
self.service_id = service_id
self.connection_url = connection_url
self.model_name = model_name
self.name = name
self.project_id = project_id
self.hosting_platform = hosting_platform

@classmethod
def from_dict(
Expand All @@ -62,6 +65,7 @@ def from_dict(
model_name=obj['modelName'],
name=obj['name'],
connection_url=obj['connectionURL'],
hosting_platform=obj['hostingPlatform'],
)
return out

Expand Down