|
| 1 | +from typing import List, Optional, Literal, Dict, Union |
| 2 | +from fastapi import FastAPI |
| 3 | +from pydantic import BaseModel |
| 4 | +from llama_cpp import Llama |
| 5 | +from threading import Event |
| 6 | +from contextlib import asynccontextmanager |
| 7 | + |
| 8 | + |
| 9 | +import threading |
| 10 | +import uvicorn |
| 11 | + |
| 12 | +""" |
| 13 | +This module provides functionality for handling chat completions and embeddings requests using FastAPI, Pydantic, and |
| 14 | +the Llama library. It includes classes for defining the structure of requests and responses, and functions for setting |
| 15 | +up and running an OpenAI API-like server. |
| 16 | +
|
| 17 | +Classes: |
| 18 | +- ChatCompletionsRequestToolFunctionParametersProperty: Represents a property within the parameters of a tool function |
| 19 | + in a chat completions request. |
| 20 | +- ChatCompletionsRequestToolFunctionParameters: Defines the parameters for a tool function used in a chat completions |
| 21 | + request. |
| 22 | +- ChatCompletionsRequestToolFunction: Describes a function tool within a chat completions request. |
| 23 | +- ChatCompletionsRequestTool: Represents a tool in a chat completions request. |
| 24 | +- ChatCompletionsRequestMessageFunctionCall: Represents a function call within a chat message. |
| 25 | +- ChatCompletionsRequestMessage: Represents a single message in a chat completions request. |
| 26 | +- ChatCompletionsRequest: Defines the structure of a request for chat completions. |
| 27 | +- EmbeddingsRequest: Represents a request for generating embeddings. |
| 28 | +
|
| 29 | +Functions: |
| 30 | +- completions_endpoint: Asynchronous endpoint for handling chat completions requests. |
| 31 | +- embeddings_endpoint: Asynchronous endpoint for handling embeddings requests. |
| 32 | +- start_openai_api_server: Starts a thread running an OpenAI API server using FastAPI. |
| 33 | +
|
| 34 | +The module integrates various technologies, including FastAPI for web server functionality, Pydantic for data validation |
| 35 | +and serialization, and Llama for machine learning computations. It is designed to create a server capable of processing |
| 36 | +chat completion and embeddings requests similar to the OpenAI API, using structured models for input and output data. |
| 37 | +""" |
| 38 | + |
| 39 | +class ChatCompletionsRequestToolFunctionParametersProperty(BaseModel): |
| 40 | + """ |
| 41 | + Represents a single property within the parameters of a tool function in a chat completions request. |
| 42 | +
|
| 43 | + Attributes: |
| 44 | + - title (str): The title of the property. |
| 45 | + - type (str): The data type of the property. |
| 46 | + """ |
| 47 | + |
| 48 | + title: str |
| 49 | + type: str |
| 50 | + |
| 51 | +class ChatCompletionsRequestToolFunctionParameters(BaseModel): |
| 52 | + """ |
| 53 | + Defines the parameters for a tool function used in a chat completions request. |
| 54 | +
|
| 55 | + Attributes: |
| 56 | + - type (Literal["object"]): Specifies that the parameter type is an object. |
| 57 | + - title (str): The title of the parameters object. |
| 58 | + - required (List[str]): A list of names of required parameters. |
| 59 | + - properties (Dict[str, ChatCompletionsRequestToolFunctionParametersProperty]): A dictionary mapping |
| 60 | + parameter names to their properties. |
| 61 | + """ |
| 62 | + type: Literal["object"] |
| 63 | + title: str |
| 64 | + required: List[str] |
| 65 | + properties: Dict[str, ChatCompletionsRequestToolFunctionParametersProperty] |
| 66 | + |
| 67 | +class ChatCompletionsRequestToolFunction(BaseModel): |
| 68 | + """ |
| 69 | + Describes a function tool within a chat completions request. |
| 70 | +
|
| 71 | + Attributes: |
| 72 | + - name (str): The name of the function tool. |
| 73 | + - description (Optional[str]): An optional description of the function tool. |
| 74 | + - parameters (ChatCompletionsRequestToolFunctionParameters): The parameters for the function tool. |
| 75 | + """ |
| 76 | + name: str |
| 77 | + description: Optional[str] = None |
| 78 | + parameters: ChatCompletionsRequestToolFunctionParameters |
| 79 | + |
| 80 | +class ChatCompletionsRequestTool(BaseModel): |
| 81 | + """ |
| 82 | + Represents a tool in a chat completions request. |
| 83 | +
|
| 84 | + Attributes: |
| 85 | + - type (Literal["function"]): Indicates that the tool is a function. |
| 86 | + - function (ChatCompletionsRequestToolFunction): The function tool description. |
| 87 | + """ |
| 88 | + type: Literal["function"] |
| 89 | + function: ChatCompletionsRequestToolFunction |
| 90 | + |
| 91 | +class ChatCompletionsRequestMessageFunctionCall(BaseModel): |
| 92 | + """ |
| 93 | + Represents a function call within a chat message. |
| 94 | +
|
| 95 | + Attributes: |
| 96 | + - name (str): The name of the function being called. |
| 97 | + - arguments (str): The arguments passed to the function call. |
| 98 | + """ |
| 99 | + name: str |
| 100 | + arguments: str |
| 101 | + |
| 102 | +class ChatCompletionsRequestMessage(BaseModel): |
| 103 | + """ |
| 104 | + Represents a single message in a chat completions request. |
| 105 | +
|
| 106 | + Attributes: |
| 107 | + - content (Optional[str]): The content of the message. Can be None. |
| 108 | + - role (str): The role associated with the message (e.g., 'user', 'system'). |
| 109 | + - name (Optional[str]): An optional name associated with the message. |
| 110 | + - function_call (Optional[ChatCompletionsRequestMessageFunctionCall]): An optional function call associated |
| 111 | + with the message. |
| 112 | + """ |
| 113 | + content: Optional[str] = None |
| 114 | + role: str |
| 115 | + name: Optional[str] = None |
| 116 | + function_call: Optional[ChatCompletionsRequestMessageFunctionCall] = None |
| 117 | + |
| 118 | +class ChatCompletionsRequest(BaseModel): |
| 119 | + """ |
| 120 | + Defines the structure of a request for chat completions. |
| 121 | +
|
| 122 | + Attributes: |
| 123 | + - messages (List[ChatCompletionsRequestMessage]): A list of messages involved in the chat completion request. |
| 124 | + - model (str): The model to be used for generating chat completions. |
| 125 | + - tools (Optional[List[ChatCompletionsRequestTool]]): An optional list of tools to be used in the chat completion |
| 126 | + request. |
| 127 | + """ |
| 128 | + messages: List[ChatCompletionsRequestMessage] |
| 129 | + model: str |
| 130 | + tools: Optional[List[ChatCompletionsRequestTool]] = None |
| 131 | + |
| 132 | +class EmbeddingsRequest(BaseModel): |
| 133 | + """ |
| 134 | + Represents a request for generating embeddings. |
| 135 | +
|
| 136 | + Attributes: |
| 137 | + - model (str): The model to be used for generating embeddings. |
| 138 | + - input (Union[str, List[str]]): The input data for which embeddings are to be generated. Can be a single string |
| 139 | + or a list of strings. |
| 140 | + - encoding_format (Optional[str]): An optional encoding format for the embeddings. |
| 141 | + """ |
| 142 | + model: str |
| 143 | + input: Union[str, List[str]] |
| 144 | + encoding_format: Optional[str] = None |
| 145 | + |
| 146 | +async def completions_endpoint(llm: Llama, request: ChatCompletionsRequest): |
| 147 | + """ |
| 148 | + Asynchronous endpoint for handling chat completions requests. |
| 149 | +
|
| 150 | + This function processes a chat completions request, making adjustments to the request data as necessary, and then |
| 151 | + calls the appropriate method on the Llama (llm) instance to generate chat completions. |
| 152 | +
|
| 153 | + Args: |
| 154 | + - llm (Llama): An instance of the Llama class, used to generate chat completions. |
| 155 | + - request (ChatCompletionsRequest): The request object containing details for the chat completion. This object |
| 156 | + should be an instance of a model derived from BaseModel, containing chat messages and optionally tools. |
| 157 | +
|
| 158 | + Returns: |
| 159 | + - The response from the Llama instance's create_chat_completion method, which contains the generated chat completions. |
| 160 | +
|
| 161 | + The function first performs a model dump of the request, excluding any None values. It then restores any 'content' |
| 162 | + fields in the messages that were suppressed by the model dump. Finally, it calls the Llama instance's |
| 163 | + create_chat_completion method with the processed messages and tools. |
| 164 | + """ |
| 165 | + |
| 166 | + request = request.model_dump(exclude_none=True) |
| 167 | + |
| 168 | + # restore None content suppressed by model_dump |
| 169 | + messages = request["messages"] |
| 170 | + for message in messages: |
| 171 | + if not "content" in message: |
| 172 | + message["content"] = None |
| 173 | + |
| 174 | + return llm.createChat_completion( |
| 175 | + messages=messages, |
| 176 | + tools=request["tools"] if "tools" in request else None, |
| 177 | + ) |
| 178 | + |
| 179 | +async def embeddings_endpoint(llm: Llama, request: EmbeddingsRequest): |
| 180 | + """ |
| 181 | + Asynchronous endpoint for handling embeddings requests. |
| 182 | +
|
| 183 | + This function calls the Llama (llm) instance to generate embeddings based on the provided request. |
| 184 | +
|
| 185 | + Args: |
| 186 | + - llm (Llama): An instance of the Llama class, used to generate embeddings. |
| 187 | + - request (EmbeddingsRequest): The request object containing the input data for which embeddings are to be generated. |
| 188 | + The request should be an instance of a model derived from BaseModel, containing the model name and the input data. |
| 189 | +
|
| 190 | + Returns: |
| 191 | + - The response from the Llama instance's create_embedding method, which contains the generated embeddings. |
| 192 | +
|
| 193 | + The function simply calls the Llama instance's create_embedding method with the input data and model specified in the |
| 194 | + request. |
| 195 | + """ |
| 196 | + return llm.create_embedding( |
| 197 | + input=request.input, |
| 198 | + model=request.model, |
| 199 | + ) |
| 200 | + |
| 201 | + |
| 202 | + |
| 203 | +def start_openai_api_server(llm: Llama, host: str = "localhost", port: int = 8000): |
| 204 | + """ |
| 205 | + Starts a thread running an OpenAI API server using FastAPI. |
| 206 | +
|
| 207 | + This function creates a FastAPI application with endpoints for handling chat completions and embeddings requests. |
| 208 | + It runs the FastAPI application in a separate thread and waits until the server is ready before returning. |
| 209 | +
|
| 210 | + Args: |
| 211 | + - llm (Llama): An instance of the Llama class, which is used to process the chat completions and embeddings requests. |
| 212 | + - host (str, optional): The hostname on which the FastAPI server will listen. Defaults to "localhost". |
| 213 | + - port (int, optional): The port on which the FastAPI server will listen. Defaults to 8000. |
| 214 | +
|
| 215 | + Returns: |
| 216 | + - Tuple[Thread, FastAPI]: A tuple containing the thread running the FastAPI server and the FastAPI app instance. |
| 217 | +
|
| 218 | + The FastAPI application defines two POST endpoints: |
| 219 | + 1. "/v1/chat/completions": Accepts requests in the format of `ChatCompletionsRequest` and uses the Llama instance |
| 220 | + to create chat completions. |
| 221 | + 2. "/v1/embeddings": Accepts requests in the format of `EmbeddingsRequest` and uses the Llama instance to create |
| 222 | + embeddings. |
| 223 | +
|
| 224 | + The server runs in a daemon thread, ensuring that it does not block the main program from exiting. |
| 225 | + """ |
| 226 | + server_ready = Event() |
| 227 | + |
| 228 | + @asynccontextmanager |
| 229 | + async def lifespan(app: FastAPI): |
| 230 | + server_ready.set() |
| 231 | + yield |
| 232 | + |
| 233 | + app = FastAPI(lifespan=lifespan) |
| 234 | + |
| 235 | + @app.post("/v1/chat/completions") |
| 236 | + async def completions(request: ChatCompletionsRequest): |
| 237 | + return await completions_endpoint(llm, request) |
| 238 | + |
| 239 | + @app.post("/v1/embeddings") |
| 240 | + async def embeddings(request: EmbeddingsRequest): |
| 241 | + return await embeddings_endpoint(llm, request) |
| 242 | + |
| 243 | + thread = threading.Thread( |
| 244 | + daemon=True, |
| 245 | + target=lambda: uvicorn.run( |
| 246 | + app, |
| 247 | + host=host, |
| 248 | + port=port, |
| 249 | + ), |
| 250 | + ) |
| 251 | + |
| 252 | + thread.start() |
| 253 | + server_ready.wait() |
| 254 | + return (thread, app) |
0 commit comments