Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 5 additions & 33 deletions singlestoredb/ai/chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from collections.abc import Generator
from typing import Any
from typing import Callable
from typing import Optional
Expand Down Expand Up @@ -40,7 +39,6 @@ def SingleStoreChatFactory(
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
base_url: Optional[str] = None,
hosting_platform: Optional[str] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> Union[ChatOpenAI, ChatBedrockConverse]:
"""Return a chat model instance (ChatOpenAI or ChatBedrockConverse).
Expand Down Expand Up @@ -104,10 +102,6 @@ def obo_token_getter_fn() -> Optional[str]:
elif isinstance(t, (int, float)):
connect_timeout = float(t)
read_timeout = float(t)
if timeout is not None:
connect_timeout = timeout
read_timeout = timeout
t = httpx.Timeout(timeout)

if info.hosting_platform == 'Amazon':
# Instantiate Bedrock client
Expand Down Expand Up @@ -172,39 +166,17 @@ def _inject_headers(request: Any, **_ignored: Any) -> None:
**kwargs,
)

class OpenAIAuth(httpx.Auth):
def auth_flow(
self, request: httpx.Request,
) -> Generator[httpx.Request, None, None]:
if api_key_getter_fn is not None:
token_val = api_key_getter_fn()
if token_val:
request.headers['Authorization'] = f'Bearer {token_val}'
if obo_token_getter_fn is not None:
obo_val = obo_token_getter_fn()
if obo_val:
request.headers['X-S2-OBO'] = obo_val
yield request

if t is not None:
http_client = httpx.Client(
timeout=t,
auth=OpenAIAuth(),
)
else:
http_client = httpx.Client(
timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout
auth=OpenAIAuth(),
)

# OpenAI / Azure OpenAI path
token = api_key_getter_fn()

openai_kwargs = dict(
base_url=info.connection_url,
api_key='placeholder',
api_key=token,
model=model_name,
streaming=streaming,
)
openai_kwargs['http_client'] = http_client
if http_client is not None:
openai_kwargs['http_client'] = http_client
return ChatOpenAI(
**openai_kwargs,
**kwargs,
Expand Down
38 changes: 5 additions & 33 deletions singlestoredb/ai/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from collections.abc import Generator
from typing import Any
from typing import Callable
from typing import Optional
Expand Down Expand Up @@ -39,7 +38,6 @@ def SingleStoreEmbeddingsFactory(
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
base_url: Optional[str] = None,
hosting_platform: Optional[str] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> Union[OpenAIEmbeddings, BedrockEmbeddings]:
"""Return an embeddings model instance (OpenAIEmbeddings or BedrockEmbeddings).
Expand Down Expand Up @@ -103,10 +101,6 @@ def obo_token_getter_fn() -> Optional[str]:
elif isinstance(t, (int, float)):
connect_timeout = float(t)
read_timeout = float(t)
if timeout is not None:
connect_timeout = timeout
read_timeout = timeout
t = httpx.Timeout(timeout)

if info.hosting_platform == 'Amazon':
# Instantiate Bedrock client
Expand Down Expand Up @@ -162,38 +156,16 @@ def _inject_headers(request: Any, **_ignored: Any) -> None:
**kwargs,
)

class OpenAIAuth(httpx.Auth):
def auth_flow(
self, request: httpx.Request,
) -> Generator[httpx.Request, None, None]:
if api_key_getter_fn is not None:
token_val = api_key_getter_fn()
if token_val:
request.headers['Authorization'] = f'Bearer {token_val}'
if obo_token_getter_fn is not None:
obo_val = obo_token_getter_fn()
if obo_val:
request.headers['X-S2-OBO'] = obo_val
yield request

if t is not None:
http_client = httpx.Client(
timeout=t,
auth=OpenAIAuth(),
)
else:
http_client = httpx.Client(
timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout
auth=OpenAIAuth(),
)

# OpenAI / Azure OpenAI path
token = api_key_getter_fn()

openai_kwargs = dict(
base_url=info.connection_url,
api_key='placeholder',
api_key=token,
model=model_name,
)
openai_kwargs['http_client'] = http_client
if http_client is not None:
openai_kwargs['http_client'] = http_client
return OpenAIEmbeddings(
**openai_kwargs,
**kwargs,
Expand Down