|
2 | 2 | import json |
3 | 3 | import asyncio |
4 | 4 | import logging |
| 5 | +import time |
5 | 6 | from typing import Optional, AsyncGenerator |
6 | 7 | from fastapi import FastAPI, Request, HTTPException |
7 | 8 | from fastapi.responses import StreamingResponse, Response |
|
14 | 15 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
15 | 16 | PROJECT_ROOT = os.path.dirname(SCRIPT_DIR) |
16 | 17 | DATA_DIR = os.path.join(PROJECT_ROOT, 'data') |
17 | | -WORKERS_CONFIG_PATH = os.path.join(DATA_DIR, 'workers.json') |
| 18 | + |
| 19 | +MANAGER_URL = "http://127.0.0.1:9000" |
| 20 | +RATE_LIMIT_KEYWORDS = [b"exceeded quota", b"out of free generations", b"rate limit"] |
18 | 21 |
|
19 | 22 | app = FastAPI(title="AIStudio2API Gateway") |
20 | 23 |
|
21 | | -workers = [] |
22 | | -current_index = 0 |
| 24 | +_session: Optional[aiohttp.ClientSession] = None |
| 25 | +_worker_cache = {"workers": [], "last_update": 0, "index": 0} |
| 26 | +CACHE_TTL = 5 |
23 | 27 |
|
24 | | -def load_workers(): |
25 | | - global workers |
26 | | - if os.path.exists(WORKERS_CONFIG_PATH): |
27 | | - try: |
28 | | - with open(WORKERS_CONFIG_PATH, 'r', encoding='utf-8') as f: |
29 | | - config = json.load(f) |
30 | | - workers = [w['port'] for w in config.get('workers', [])] |
31 | | - logger.info(f"Loaded {len(workers)} workers: {workers}") |
32 | | - except Exception as e: |
33 | | - logger.error(f"Load workers failed: {e}") |
| 28 | +async def get_session() -> aiohttp.ClientSession: |
| 29 | + global _session |
| 30 | + if _session is None or _session.closed: |
| 31 | + connector = aiohttp.TCPConnector(limit=100, limit_per_host=20, keepalive_timeout=30) |
| 32 | + _session = aiohttp.ClientSession(connector=connector) |
| 33 | + return _session |
34 | 34 |
|
35 | | -def get_next_worker() -> Optional[int]: |
36 | | - global current_index, workers |
37 | | - if not workers: |
| 35 | +async def refresh_workers(): |
| 36 | + cache = _worker_cache |
| 37 | + if time.time() - cache["last_update"] < CACHE_TTL and cache["workers"]: |
| 38 | + return |
| 39 | + try: |
| 40 | + session = await get_session() |
| 41 | + async with session.get(f"{MANAGER_URL}/api/workers", timeout=aiohttp.ClientTimeout(total=5)) as resp: |
| 42 | + workers = await resp.json() |
| 43 | + cache["workers"] = [w for w in workers if w.get("status") == "running"] |
| 44 | + cache["last_update"] = time.time() |
| 45 | + except Exception as e: |
| 46 | + logger.warning(f"Refresh workers failed: {e}") |
| 47 | + |
| 48 | +def get_next_worker(model: str = "") -> Optional[dict]: |
| 49 | + cache = _worker_cache |
| 50 | + available = cache["workers"] |
| 51 | + if not available: |
38 | 52 | return None |
39 | | - port = workers[current_index % len(workers)] |
40 | | - current_index += 1 |
41 | | - return port |
| 53 | + worker = available[cache["index"] % len(available)] |
| 54 | + cache["index"] += 1 |
| 55 | + return worker |
| 56 | + |
| 57 | +async def report_rate_limit(worker_id: str, model: str): |
| 58 | + try: |
| 59 | + session = await get_session() |
| 60 | + await session.post(f"{MANAGER_URL}/api/workers/{worker_id}/rate-limit", json={"model": model}, timeout=aiohttp.ClientTimeout(total=2)) |
| 61 | + except: |
| 62 | + pass |
| 63 | + |
| 64 | +def check_rate_limit_in_response(content: bytes) -> bool: |
| 65 | + content_lower = content.lower() |
| 66 | + return any(kw in content_lower for kw in RATE_LIMIT_KEYWORDS) |
42 | 67 |
|
43 | 68 | @app.on_event("startup") |
44 | 69 | async def startup(): |
45 | | - load_workers() |
46 | | - logger.info(f"Gateway started with {len(workers)} workers") |
| 70 | + await refresh_workers() |
| 71 | + logger.info(f"Gateway started") |
| 72 | + |
| 73 | +@app.on_event("shutdown") |
| 74 | +async def shutdown(): |
| 75 | + global _session |
| 76 | + if _session and not _session.closed: |
| 77 | + await _session.close() |
47 | 78 |
|
48 | 79 | @app.get("/") |
49 | 80 | async def root(): |
50 | | - return {"status": "ok", "mode": "gateway", "workers": len(workers)} |
| 81 | + return {"status": "ok", "mode": "gateway", "workers": len(_worker_cache["workers"])} |
51 | 82 |
|
52 | 83 | @app.get("/v1/models") |
53 | 84 | async def models(): |
54 | | - port = get_next_worker() |
55 | | - if not port: |
| 85 | + await refresh_workers() |
| 86 | + worker = get_next_worker() |
| 87 | + if not worker: |
56 | 88 | raise HTTPException(status_code=503, detail="No workers available") |
57 | 89 |
|
| 90 | + port = worker["port"] |
58 | 91 | url = f"http://127.0.0.1:{port}/v1/models" |
59 | | - logger.info(f"GET /v1/models -> worker:{port}") |
60 | 92 |
|
61 | | - timeout = aiohttp.ClientTimeout(total=60) |
62 | | - async with aiohttp.ClientSession(timeout=timeout) as session: |
63 | | - try: |
64 | | - async with session.get(url) as resp: |
65 | | - content = await resp.read() |
66 | | - return Response(content=content, status_code=resp.status, media_type=resp.content_type) |
67 | | - except Exception as e: |
68 | | - logger.error(f"Forward /v1/models failed: {e}") |
69 | | - raise HTTPException(status_code=502, detail=str(e)) |
| 93 | + session = await get_session() |
| 94 | + try: |
| 95 | + async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp: |
| 96 | + content = await resp.read() |
| 97 | + return Response(content=content, status_code=resp.status, media_type=resp.content_type) |
| 98 | + except Exception as e: |
| 99 | + logger.error(f"Forward /v1/models failed: {e}") |
| 100 | + raise HTTPException(status_code=502, detail=str(e)) |
70 | 101 |
|
71 | 102 | @app.post("/v1/chat/completions") |
72 | 103 | async def chat_completions(request: Request): |
| 104 | + await refresh_workers() |
73 | 105 | body = await request.body() |
74 | 106 | body_json = json.loads(body) |
75 | 107 | is_stream = body_json.get("stream", False) |
| 108 | + model_id = body_json.get("model", "") |
76 | 109 |
|
77 | | - port = get_next_worker() |
78 | | - if not port: |
| 110 | + worker = get_next_worker(model_id) |
| 111 | + if not worker: |
79 | 112 | raise HTTPException(status_code=503, detail="No workers available") |
80 | 113 |
|
| 114 | + port = worker["port"] |
| 115 | + worker_id = worker.get("id", "") |
81 | 116 | url = f"http://127.0.0.1:{port}/v1/chat/completions" |
82 | | - req_id = f"gw-{current_index}" |
| 117 | + req_id = f"gw-{worker_id}" |
83 | 118 | logger.info(f"[{req_id}] POST -> worker:{port} (stream={is_stream})") |
84 | 119 |
|
85 | 120 | forward_headers = {'Content-Type': 'application/json'} |
86 | 121 | for k, v in request.headers.items(): |
87 | | - k_lower = k.lower() |
88 | | - if k_lower not in ('host', 'content-length', 'transfer-encoding', 'content-type'): |
| 122 | + if k.lower() not in ('host', 'content-length', 'transfer-encoding', 'content-type'): |
89 | 123 | forward_headers[k] = v |
90 | 124 |
|
| 125 | + session = await get_session() |
| 126 | + |
91 | 127 | if is_stream: |
92 | 128 | async def stream_proxy() -> AsyncGenerator[bytes, None]: |
93 | | - timeout = aiohttp.ClientTimeout(total=600, sock_read=300) |
94 | | - connector = aiohttp.TCPConnector(force_close=True) |
95 | | - async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session: |
96 | | - try: |
97 | | - async with session.post(url, data=body, headers=forward_headers) as resp: |
98 | | - logger.info(f"[{req_id}] Stream started, status={resp.status}") |
99 | | - chunk_count = 0 |
100 | | - async for chunk in resp.content.iter_chunks(): |
101 | | - data, end_of_chunk = chunk |
102 | | - if data: |
103 | | - chunk_count += 1 |
104 | | - yield data |
105 | | - logger.info(f"[{req_id}] Stream completed, chunks={chunk_count}") |
106 | | - except asyncio.CancelledError: |
107 | | - logger.warning(f"[{req_id}] Stream cancelled") |
108 | | - except Exception as e: |
109 | | - logger.error(f"[{req_id}] Stream error: {e}") |
110 | | - |
111 | | - return StreamingResponse( |
112 | | - stream_proxy(), |
113 | | - media_type="text/event-stream", |
114 | | - headers={ |
115 | | - "Cache-Control": "no-cache", |
116 | | - "Connection": "keep-alive", |
117 | | - "X-Accel-Buffering": "no", |
118 | | - "Transfer-Encoding": "chunked" |
119 | | - } |
120 | | - ) |
121 | | - else: |
122 | | - timeout = aiohttp.ClientTimeout(total=300) |
123 | | - async with aiohttp.ClientSession(timeout=timeout) as session: |
| 129 | + rate_limited = False |
| 130 | + check_count = 0 |
124 | 131 | try: |
125 | | - async with session.post(url, data=body, headers=forward_headers) as resp: |
126 | | - content = await resp.read() |
127 | | - logger.info(f"[{req_id}] Non-stream response, status={resp.status}, len={len(content)}") |
128 | | - return Response(content=content, status_code=resp.status, media_type=resp.content_type) |
| 132 | + async with session.post(url, data=body, headers=forward_headers, timeout=aiohttp.ClientTimeout(total=600, sock_read=300)) as resp: |
| 133 | + async for chunk in resp.content.iter_chunks(): |
| 134 | + data, _ = chunk |
| 135 | + if data: |
| 136 | + check_count += 1 |
| 137 | + if check_count <= 5 and not rate_limited: |
| 138 | + if check_rate_limit_in_response(data): |
| 139 | + rate_limited = True |
| 140 | + yield data |
| 141 | + if rate_limited and worker_id and model_id: |
| 142 | + asyncio.create_task(report_rate_limit(worker_id, model_id)) |
| 143 | + except asyncio.CancelledError: |
| 144 | + pass |
129 | 145 | except Exception as e: |
130 | | - logger.error(f"[{req_id}] Forward failed: {e}") |
131 | | - raise HTTPException(status_code=502, detail=str(e)) |
| 146 | + logger.error(f"[{req_id}] Stream error: {e}") |
| 147 | + |
| 148 | + return StreamingResponse(stream_proxy(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) |
| 149 | + else: |
| 150 | + try: |
| 151 | + async with session.post(url, data=body, headers=forward_headers, timeout=aiohttp.ClientTimeout(total=300)) as resp: |
| 152 | + content = await resp.read() |
| 153 | + if check_rate_limit_in_response(content) and worker_id and model_id: |
| 154 | + asyncio.create_task(report_rate_limit(worker_id, model_id)) |
| 155 | + return Response(content=content, status_code=resp.status, media_type=resp.content_type) |
| 156 | + except Exception as e: |
| 157 | + logger.error(f"[{req_id}] Forward failed: {e}") |
| 158 | + raise HTTPException(status_code=502, detail=str(e)) |
132 | 159 |
|
133 | 160 | @app.get("/health") |
134 | 161 | async def health(): |
135 | | - return {"status": "ok", "workers": workers} |
| 162 | + return {"status": "ok", "workers": len(_worker_cache["workers"])} |
136 | 163 |
|
137 | 164 | def main(): |
138 | 165 | import argparse |
139 | 166 | parser = argparse.ArgumentParser() |
140 | 167 | parser.add_argument('--port', type=int, default=2048) |
141 | 168 | args = parser.parse_args() |
142 | | - |
143 | | - logger.info(f"Starting Gateway on port {args.port}") |
144 | | - uvicorn.run(app, host="0.0.0.0", port=args.port, log_level="info") |
| 169 | + uvicorn.run(app, host="0.0.0.0", port=args.port, log_level="warning") |
145 | 170 |
|
146 | 171 | if __name__ == "__main__": |
147 | 172 | main() |
| 173 | + |
0 commit comments