diff --git a/backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py b/backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py new file mode 100644 index 00000000..8f65f055 --- /dev/null +++ b/backend/app/alembic/versions/041_adding_blob_column_in_collection_table.py @@ -0,0 +1,47 @@ +"""adding blob column in collection table + +Revision ID: 041 +Revises: 040 +Create Date: 2025-12-24 11:03:44.620424 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +revision = "041" +down_revision = "040" +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column( + "collection", + sa.Column( + "collection_blob", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + comment="Provider-specific knowledge base creation parameters (name, description, chunking params etc.)", + ), + ) + op.alter_column( + "collection", + "llm_service_name", + existing_type=sa.VARCHAR(), + comment="Name of the LLM service", + existing_comment="Name of the LLM service provider", + existing_nullable=False, + ) + + +def downgrade(): + op.alter_column( + "collection", + "llm_service_name", + existing_type=sa.VARCHAR(), + comment="Name of the LLM service provider", + existing_comment="Name of the LLM service", + existing_nullable=False, + ) + op.drop_column("collection", "collection_blob") diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index ac7e89d6..ef08fd09 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -8,9 +8,13 @@ from .collection import ( Collection, + CreateCollectionParams, + CreateCollectionResult, + CreationRequest, CollectionPublic, CollectionIDPublic, CollectionWithDocsPublic, + DeletionRequest, ) from .collection_job import ( CollectionActionType, diff --git a/backend/app/models/collection/__init__.py b/backend/app/models/collection/__init__.py new file mode 100644 index 00000000..e31f65bc --- /dev/null +++ b/backend/app/models/collection/__init__.py @@ -0,0 +1,14 @@ +from app.models.collection.request import ( + Collection, + CreationRequest, + DeletionRequest, + CallbackRequest, + AssistantOptions, + CreateCollectionParams, +) +from app.models.collection.response import ( + CollectionIDPublic, + CollectionPublic, + CollectionWithDocsPublic, + CreateCollectionResult, +) diff --git a/backend/app/models/collection.py b/backend/app/models/collection/request.py similarity index 63% rename from backend/app/models/collection.py rename to backend/app/models/collection/request.py index 57e5a17b..9f8e106b 100644 --- a/backend/app/models/collection.py +++ b/backend/app/models/collection/request.py @@ -3,13 +3,13 @@ from uuid import UUID, uuid4 from pydantic import HttpUrl, model_validator +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import Field, Relationship, SQLModel from app.core.util import now -from app.models.document import DocumentPublic - -from .organization import Organization -from .project import Project +from app.models.organization import Organization +from app.models.project import Project class Collection(SQLModel, table=True): @@ -30,8 +30,13 @@ class Collection(SQLModel, table=True): nullable=False, sa_column_kwargs={"comment": "Name of the LLM service"}, ) - - # Foreign keys + collection_blob: dict[str, Any] | None = Field( + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Provider-specific collection parameters (name, description, chunking params etc.)", + ) + ) organization_id: int = Field( foreign_key="organization.id", nullable=False, @@ -44,8 +49,6 @@ class Collection(SQLModel, table=True): ondelete="CASCADE", sa_column_kwargs={"comment": "Reference to the project"}, ) - - # Timestamps inserted_at: datetime = Field( default_factory=now, sa_column_kwargs={"comment": "Timestamp when the collection was created"}, @@ -64,27 +67,55 @@ class Collection(SQLModel, table=True): project: Project = Relationship(back_populates="collections") -# Request models -class DocumentOptions(SQLModel): - documents: list[UUID] = Field( - description="List of document IDs", +class DocumentInput(SQLModel): + """Document to be added to knowledge base.""" + + name: str | None = Field( + description="Display name for the document", ) - batch_size: int = Field( - default=1, - description=( - "Number of documents to send to OpenAI in a single " - "transaction. See the `file_ids` parameter in the " - "vector store [create batch](https://platform.openai.com/docs/api-reference/vector-stores-file-batches/createBatch)." - ), + id: UUID = Field( + description="Reference to uploaded file/document in Kaapi", + ) + + +class CreateCollectionParams(SQLModel): + """Request-specific parameters for knowledge base creation.""" + + name: str | None = Field( + min_length=1, + description="Name of the knowledge base to create or update", + ) + description: str | None = Field( + default=None, + description="Description of the knowledge base (required by Bedrock, optional for others)", + ) + documents: list[DocumentInput] = Field( + default_factory=list, + description="List of documents to add to the knowledge base", + ) + chunking_params: dict[str, Any] | None = Field( + default=None, + description="Chunking parameters for document processing (e.g., chunk_size, chunk_overlap)", + ) + additional_params: dict[str, Any] | None = Field( + default=None, + description="Additional provider-specific parameters", ) def model_post_init(self, __context: Any): - self.documents = list(set(self.documents)) + """Deduplicate documents by file_id.""" + seen = set() + unique_docs = [] + for doc in self.documents: + if doc.file_id not in seen: + seen.add(doc.file_id) + unique_docs.append(doc) + self.documents = unique_docs class AssistantOptions(SQLModel): # Fields to be passed along to OpenAI. They must be a subset of - # parameters accepted by the OpenAI.clien.beta.assistants.create + # parameters accepted by the OpenAI.client.beta.assistants.create # API. model: str | None = Field( default=None, @@ -139,6 +170,8 @@ def norm(x: Any) -> Any: class CallbackRequest(SQLModel): + """Optional callback configuration for async job notifications.""" + callback_url: HttpUrl | None = Field( default=None, description="URL to call to report endpoint status", @@ -153,40 +186,23 @@ class ProviderOptions(SQLModel): ) -class CreationRequest( - DocumentOptions, - ProviderOptions, - AssistantOptions, - CallbackRequest, -): - def extract_super_type(self, cls: "CreationRequest"): - for field_name in cls.model_fields.keys(): - field_value = getattr(self, field_name) - yield (field_name, field_value) - - -class DeletionRequest(CallbackRequest): - collection_id: UUID = Field(description="Collection to delete") - - -# Response models - - -class CollectionIDPublic(SQLModel): - id: UUID +class CreationRequest(AssistantOptions, ProviderOptions, CallbackRequest): + """API request for collection creation""" + collection_params: CreateCollectionParams = Field( + ..., + description="Collection creation specific parameters (name, documents, etc.)", + ) + batch_size: int = Field( + default=10, + ge=1, + le=500, + description="Number of documents to process in a single batch", + ) -class CollectionPublic(SQLModel): - id: UUID - llm_service_id: str - llm_service_name: str - project_id: int - organization_id: int - inserted_at: datetime - updated_at: datetime - deleted_at: datetime | None = None +class DeletionRequest(ProviderOptions, CallbackRequest): + """API request for collection deletion""" -class CollectionWithDocsPublic(CollectionPublic): - documents: list[DocumentPublic] | None = None + collection_id: UUID = Field(description="Collection to delete") diff --git a/backend/app/models/collection/response.py b/backend/app/models/collection/response.py new file mode 100644 index 00000000..f72c5ee7 --- /dev/null +++ b/backend/app/models/collection/response.py @@ -0,0 +1,33 @@ +from datetime import datetime +from typing import Any +from uuid import UUID + +from sqlmodel import SQLModel + +from app.models.document import DocumentPublic + + +class CreateCollectionResult(SQLModel): + llm_service_id: str + llm_service_name: str + collection_blob: dict[str, Any] + + +class CollectionIDPublic(SQLModel): + id: UUID + + +class CollectionPublic(SQLModel): + id: UUID + llm_service_id: str + llm_service_name: str + project_id: int + organization_id: int + + inserted_at: datetime + updated_at: datetime + deleted_at: datetime | None = None + + +class CollectionWithDocsPublic(CollectionPublic): + documents: list[DocumentPublic] | None = None diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index ed83e4a8..1086dc71 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -6,7 +6,6 @@ from asgi_correlation_id import correlation_id from app.core.cloud import get_cloud_storage -from app.core.util import now from app.core.db import engine from app.crud import ( CollectionCrud, @@ -14,7 +13,6 @@ DocumentCollectionCrud, CollectionJobCrud, ) -from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud from app.models import ( CollectionJobStatus, CollectionJob, @@ -23,18 +21,11 @@ CollectionPublic, CollectionJobPublic, ) -from app.models.collection import ( - CreationRequest, - AssistantOptions, -) -from app.services.collections.helpers import ( - _backout, - batch_documents, - extract_error_message, - OPENAI_VECTOR_STORE, -) +from app.models.collection import CreationRequest +from app.services.collections.helpers import extract_error_message +from app.services.collections.providers.registry import get_llm_provider from app.celery.utils import start_low_priority_job -from app.utils import get_openai_client, send_callback, APIResponse +from app.utils import send_callback, APIResponse logger = logging.getLogger(__name__) @@ -116,26 +107,6 @@ def build_failure_payload(collection_job: CollectionJob, error_message: str) -> ) -def _cleanup_remote_resources( - assistant, - assistant_crud, - vector_store, - vector_store_crud, -) -> None: - """Best-effort cleanup of partially created remote resources.""" - try: - if assistant is not None and assistant_crud is not None: - _backout(assistant_crud, assistant.id) - elif vector_store is not None and vector_store_crud is not None: - _backout(vector_store_crud, vector_store.id) - else: - logger.warning( - "[create_collection._backout] Skipping: no resource/crud available" - ) - except Exception: - logger.warning("[create_collection.execute_job] Backout failed") - - def _mark_job_failed( project_id: int, job_id: str, @@ -172,17 +143,15 @@ def execute_job( ) -> None: """ Worker entrypoint scheduled by start_job. - Orchestrates: job state, client/storage init, batching, vector-store upload, + Orchestrates: job state, provider init, collection creation, optional assistant creation, collection persistence, linking, callbacks, and cleanup. """ start_time = time.time() - # Keep references for potential backout/cleanup on failure - assistant = None - assistant_crud = None - vector_store = None - vector_store_crud = None + # Keeping the references for potential backout/cleanup on failure collection_job = None + result = None + provider = None try: creation_request = CreationRequest(**request) @@ -199,49 +168,32 @@ def execute_job( ), ) - client = get_openai_client(session, organization_id, project_id) storage = get_cloud_storage(session=session, project_id=project_id) - - # Batch documents for upload, and flatten for linking/metrics later document_crud = DocumentCrud(session, project_id) - docs_batches = batch_documents( - document_crud, - creation_request.documents, - creation_request.batch_size, + + provider = get_llm_provider( + session=session, + provider=creation_request.provider, + project_id=project_id, + organization_id=organization_id, ) - flat_docs = [doc for batch in docs_batches for doc in batch] - vector_store_crud = OpenAIVectorStoreCrud(client) - vector_store = vector_store_crud.create() - list(vector_store_crud.update(vector_store.id, storage, docs_batches)) + result = provider.create( + collection_request=creation_request, + storage=storage, + document_crud=document_crud, + ) - # if with_assistant is true, create assistant backed by the vector store - if with_assistant: - assistant_crud = OpenAIAssistantCrud(client) + llm_service_id = result.llm_service_id + llm_service_name = result.llm_service_name + # Storing collection params (name, description, chunking_params, etc.) in DB + # for future reference and to support different providers with varying configurations + collection_blob = result.collection_blob - # Filter out None to avoid sending unset options - assistant_options = dict( - creation_request.extract_super_type(AssistantOptions) - ) - assistant_options = { - k: v for k, v in assistant_options.items() if v is not None - } - - assistant = assistant_crud.create(vector_store.id, **assistant_options) - llm_service_id = assistant.id - llm_service_name = assistant_options.get("model") or "assistant" - - logger.info( - "[execute_job] Assistant created | assistant_id=%s, vector_store_id=%s", - assistant.id, - vector_store.id, - ) - else: - # If no assistant, the collection points directly at the vector store - llm_service_id = vector_store.id - llm_service_name = OPENAI_VECTOR_STORE - logger.info( - "[execute_job] Skipping assistant creation | with_assistant=False" + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + flat_docs = document_crud.read_many_by_ids( + [doc.id for doc in creation_request.collection_params.documents] ) file_exts = {doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname} @@ -259,6 +211,7 @@ def execute_job( organization_id=organization_id, llm_service_id=llm_service_id, llm_service_name=llm_service_name, + collection_blob=collection_blob, ) collection_crud.create(collection) collection = collection_crud.read_one(collection.id) @@ -299,12 +252,13 @@ def execute_job( exc_info=True, ) - _cleanup_remote_resources( - assistant=assistant, - assistant_crud=assistant_crud, - vector_store=vector_store, - vector_store_crud=vector_store_crud, - ) + if provider is not None and result is not None: + try: + provider.cleanup(result) + except Exception: + logger.warning( + "[create_collection.execute_job] Provider cleanup failed" + ) collection_job = _mark_job_failed( project_id=project_id, diff --git a/backend/app/services/collections/delete_collection.py b/backend/app/services/collections/delete_collection.py index ca337b79..e9570964 100644 --- a/backend/app/services/collections/delete_collection.py +++ b/backend/app/services/collections/delete_collection.py @@ -6,7 +6,6 @@ from app.core.db import engine from app.crud import CollectionCrud, CollectionJobCrud -from app.crud.rag import OpenAIAssistantCrud, OpenAIVectorStoreCrud from app.models import ( CollectionJobStatus, CollectionJobUpdate, @@ -15,9 +14,10 @@ CollectionIDPublic, ) from app.models.collection import DeletionRequest -from app.services.collections.helpers import extract_error_message, OPENAI_VECTOR_STORE +from app.services.collections.helpers import extract_error_message +from app.services.collections.providers.registry import get_llm_provider from app.celery.utils import start_low_priority_job -from app.utils import get_openai_client, send_callback, APIResponse +from app.utils import send_callback, APIResponse logger = logging.getLogger(__name__) @@ -155,7 +155,6 @@ def execute_job( job_uuid = UUID(job_id) collection_job = None - client = None try: with Session(engine) as session: @@ -169,20 +168,16 @@ def execute_job( ), ) - client = get_openai_client(session, organization_id, project_id) - collection = CollectionCrud(session, project_id).read_one(collection_id) - # Identify which external service (assistant/vector store) this collection belongs to - service = (collection.llm_service_name or "").strip().lower() - is_vector = service == OPENAI_VECTOR_STORE - llm_service_id = collection.llm_service_id + provider = get_llm_provider( + session=session, + provider=deletion_request.provider, + project_id=project_id, + organization_id=organization_id, + ) - # Delete the corresponding OpenAI resource (vector store or assistant) - if is_vector: - OpenAIVectorStoreCrud(client).delete(llm_service_id) - else: - OpenAIAssistantCrud(client).delete(llm_service_id) + provider.delete(collection) with Session(engine) as session: CollectionCrud(session, project_id).delete_by_id(collection_id) diff --git a/backend/app/services/collections/providers/__init__.py b/backend/app/services/collections/providers/__init__.py new file mode 100644 index 00000000..5a9b6a55 --- /dev/null +++ b/backend/app/services/collections/providers/__init__.py @@ -0,0 +1,6 @@ +from app.services.collections.providers.base import BaseProvider +from app.services.collections.providers.openai import OpenAIProvider +from app.services.collections.providers.registry import ( + LLMProvider, + get_llm_provider, +) diff --git a/backend/app/services/collections/providers/base.py b/backend/app/services/collections/providers/base.py new file mode 100644 index 00000000..9fb21f3e --- /dev/null +++ b/backend/app/services/collections/providers/base.py @@ -0,0 +1,84 @@ +from abc import ABC, abstractmethod +from typing import Any + +from app.crud import DocumentCrud +from app.core.cloud.storage import CloudStorage +from app.models import CreationRequest, CreateCollectionResult, Collection + + +class BaseProvider(ABC): + """Abstract base class for collection providers. + + All provider implementations (OpenAI, Bedrock, etc.) must inherit from + this class and implement the required methods. + + Providers handle creation of knowledge bases (vector stores) and + optional assistant/agent creation backed by those knowledge bases. + + Attributes: + client: The provider-specific client instance + """ + + def __init__(self, client: Any): + """Initialize provider with client. + + Args: + client: Provider-specific client instance + """ + self.client = client + + @abstractmethod + def create( + self, + collection_request: CreationRequest, + storage: CloudStorage, + document_crud: DocumentCrud, + ) -> CreateCollectionResult: + """Create collection with documents and optionally an assistant. + + Args: + collection_params: Collection parameters (name, description, chunking_params, etc.) + storage: Cloud storage instance for file access + document_crud: DocumentCrud instance for fetching documents + batch_size: Number of documents to process per batch + with_assistant: Whether to create an assistant/agent + assistant_options: Options for assistant creation (provider-specific) + + Returns: + CreateCollectionresult containing: + - llm_service_id: ID of the created resource (vector store or assistant) + - llm_service_name: Name of the service + - kb_blob: All collection params except documents + """ + raise NotImplementedError("Providers must implement execute method") + + @abstractmethod + def delete(self, collection: Collection) -> None: + """Delete remote resources associated with a collection. + + Called when a collection is being deleted and remote resources need to be cleaned up. + + Args: + llm_service_id: ID of the resource to delete + llm_service_name: Name of the service (determines resource type) + """ + raise NotImplementedError("Providers must implement delete method") + + @abstractmethod + def cleanup(self, collection_result: CreateCollectionResult) -> None: + """Clean up/rollback resources created during execute. + + Called when collection creation fails and remote resources need to be deleted. + + Args: + collection_result: The CreateCollectionresult returned from execute, containing resource IDs + """ + raise NotImplementedError("Providers must implement cleanup method") + + def get_provider_name(self) -> str: + """Get the name of the provider. + + Returns: + Provider name (e.g., "openai", "bedrock", "pinecone") + """ + return self.__class__.__name__.replace("Provider", "").lower() diff --git a/backend/app/services/collections/providers/openai.py b/backend/app/services/collections/providers/openai.py new file mode 100644 index 00000000..ba734d85 --- /dev/null +++ b/backend/app/services/collections/providers/openai.py @@ -0,0 +1,160 @@ +import logging +from typing import Any + +from openai import OpenAI + +from app.services.collections.providers import BaseProvider +from app.crud import DocumentCrud +from app.core.cloud.storage import CloudStorage +from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud +from app.services.collections.helpers import ( + batch_documents, + OPENAI_VECTOR_STORE, + _backout, +) +from app.models import CreateCollectionResult, CreationRequest, Collection + + +logger = logging.getLogger(__name__) + + +class OpenAIProvider(BaseProvider): + """OpenAI-specific collection provider for vector stores and assistants.""" + + def __init__(self, client: OpenAI): + super().__init__(client) + self.client = client + + def create( + self, + collection_request: CreationRequest, + storage: CloudStorage, + document_crud: DocumentCrud, + ) -> CreateCollectionResult: + """Create OpenAI vector store with documents and optionally an assistant. + + Args: + collection_params: Collection parameters (name, description, chunking_params, etc.) + storage: Cloud storage instance for file access + document_crud: DocumentCrud instance for fetching documents + batch_size: Number of documents to process per batch + with_assistant: Whether to create an assistant + assistant_options: Options for assistant creation (model, instructions, etc.) + + Returns: + CreateCollectionResult containing llm_service_id, llm_service_name, and collection_blob + """ + try: + collection_params = collection_request.collection_params + document_ids = [doc.id for doc in collection_params.documents] + + docs_batches = batch_documents( + document_crud, + document_ids, + collection_request.batch_size, + ) + + vector_store_crud = OpenAIVectorStoreCrud(self.client) + vector_store = vector_store_crud.create() + + list(vector_store_crud.update(vector_store.id, storage, docs_batches)) + + logger.info( + "[OpenAIProvider.execute] Vector store created | " + f"vector_store_id={vector_store.id}, batches={len(docs_batches)}" + ) + + collection_blob = { + "name": collection_params.name, + "description": collection_params.description, + "chunking_params": collection_params.chunking_params, + "additional_params": collection_params.additional_params, + } + + # Check if we need to create an assistant (based on assistant options in request) + with_assistant = ( + collection_request.model is not None + and collection_request.instructions is not None + ) + if with_assistant: + assistant_crud = OpenAIAssistantCrud(self.client) + + assistant_options = { + "model": collection_request.model, + "instructions": collection_request.instructions, + "temperature": collection_request.temperature, + } + filtered_options = { + k: v for k, v in assistant_options.items() if v is not None + } + + assistant = assistant_crud.create(vector_store.id, **filtered_options) + + logger.info( + "[OpenAIProvider.execute] Assistant created | " + f"assistant_id={assistant.id}, vector_store_id={vector_store.id}" + ) + + return CreateCollectionResult( + llm_service_id=assistant.id, + llm_service_name=filtered_options.get("model", "assistant"), + collection_blob=collection_blob, + ) + else: + logger.info( + "[OpenAIProvider.execute] Skipping assistant creation | with_assistant=False" + ) + + return CreateCollectionResult( + llm_service_id=vector_store.id, + llm_service_name=OPENAI_VECTOR_STORE, + collection_blob=collection_blob, + ) + + except Exception as e: + logger.error( + f"[OpenAIProvider.execute] Failed to create knowledge base: {str(e)}", + exc_info=True, + ) + raise + + def delete(self, collection: Collection) -> None: + """Delete OpenAI resources (assistant or vector store). + + Determines what to delete based on llm_service_name: + - If assistant was created, delete the assistant (which also removes the vector store) + - If only vector store was created, delete the vector store + + Args: + collection: Collection that has been requested to be deleted + """ + try: + if collection.llm_service_name != OPENAI_VECTOR_STORE: + OpenAIAssistantCrud(self.client).delete(collection.llm_service_id) + logger.info( + f"[OpenAIProvider.delete] Deleted assistant | assistant_id={collection.llm_service_id}" + ) + else: + OpenAIVectorStoreCrud(self.client).delete(collection.llm_service_id) + logger.info( + f"[OpenAIProvider.delete] Deleted vector store | vector_store_id={collection.llm_service_id}" + ) + except Exception as e: + logger.error( + f"[OpenAIProvider.delete] Failed to delete resource | " + f"llm_service_id={collection.llm_service_id}, error={str(e)}", + exc_info=True, + ) + raise + + def cleanup(self, result: CreateCollectionResult) -> None: + """Clean up OpenAI resources (assistant or vector store). + + Determines what to delete based on llm_service_name: + - If assistant was created, delete the assistant (which also removes the vector store) + - If only vector store was created, delete the vector store + + Args: + result: The CreateCollectionResult from execute containing resource IDs + """ + _backout(result.llm_service_id, result.llm_service_name) diff --git a/backend/app/services/collections/providers/registry.py b/backend/app/services/collections/providers/registry.py new file mode 100644 index 00000000..10d07d45 --- /dev/null +++ b/backend/app/services/collections/providers/registry.py @@ -0,0 +1,71 @@ +import logging + +from sqlmodel import Session +from openai import OpenAI + +from app.crud import get_provider_credential +from app.services.collections.providers.base import BaseProvider +from app.services.collections.providers.openai import OpenAIProvider + + +logger = logging.getLogger(__name__) + + +class LLMProvider: + OPENAI = "openai" + # Future constants for providers: + # ANTHROPIC = "ANTHROPIC" + # GEMINI = "gemini" + + _registry: dict[str, type[BaseProvider]] = { + OPENAI: OpenAIProvider, + # Future providers: + # ANTHROPIC: BedrockProvider, + # GEMINI: GeminiProvider, + } + + @classmethod + def get(cls, name: str) -> type[BaseProvider]: + """Return the provider class for a given name.""" + provider = cls._registry.get(name) + if not provider: + raise ValueError( + f"Provider '{name}' is not supported. " + f"Supported providers: {', '.join(cls._registry.keys())}" + ) + return provider + + @classmethod + def supported_providers(cls) -> list[str]: + """Return a list of supported provider names.""" + return list(cls._registry.keys()) + + +def get_llm_provider( + session: Session, provider: str, project_id: int, organization_id: int +) -> BaseProvider: + provider_class = LLMProvider.get(provider) + + credentials = get_provider_credential( + session=session, + provider=provider, + project_id=project_id, + org_id=organization_id, + ) + + if not credentials: + raise ValueError( + f"Credentials for provider '{provider}' not configured for this project." + ) + + if provider == LLMProvider.OPENAI: + if "api_key" not in credentials: + raise ValueError("OpenAI credentials not configured for this project.") + client = OpenAI(api_key=credentials["api_key"]) + else: + logger.error( + f"[get_llm_provider] Unsupported provider type requested: {provider}" + ) + raise ValueError(f"Provider '{provider}' is not supported.") + + return provider_class(client=client)