Skip to content

Commit 544d8d3

Browse files
committed
add generic openai api compatible provider
1 parent c54d6ee commit 544d8d3

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

src/shelloracle/providers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,13 @@ def _providers() -> dict[str, type[Provider]]:
7979
from shelloracle.providers.localai import LocalAI
8080
from shelloracle.providers.ollama import Ollama
8181
from shelloracle.providers.openai import OpenAI
82+
from shelloracle.providers.openai_compat import OpenAICompat
8283
from shelloracle.providers.xai import XAI
8384

8485
return {
8586
Ollama.name: Ollama,
8687
OpenAI.name: OpenAI,
88+
OpenAICompat.name: OpenAICompat,
8789
LocalAI.name: LocalAI,
8890
XAI.name: XAI,
8991
Deepseek.name: Deepseek,
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from collections.abc import AsyncIterator
2+
3+
from openai import APIError, AsyncOpenAI
4+
5+
from shelloracle.providers import Provider, ProviderError, Setting, system_prompt
6+
7+
8+
class OpenAICompat(Provider):
9+
name = "OpenAICompat"
10+
11+
base_url = Setting(default="")
12+
api_key = Setting(default="")
13+
model = Setting(default="")
14+
15+
def __init__(self):
16+
if not self.api_key:
17+
msg = "No API key provided"
18+
raise ProviderError(msg)
19+
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
20+
21+
async def generate(self, prompt: str) -> AsyncIterator[str]:
22+
try:
23+
stream = await self.client.chat.completions.create(
24+
model=self.model,
25+
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
26+
stream=True,
27+
)
28+
async for chunk in stream:
29+
if chunk.choices[0].delta.content is not None:
30+
yield chunk.choices[0].delta.content
31+
except APIError as e:
32+
msg = f"Something went wrong while querying OpenAICompat: {e}"
33+
raise ProviderError(msg) from e

0 commit comments

Comments
 (0)