Skip to content

Commit 3d01273

Browse files
committed
fix: rate limit 實施
1 parent e46e117 commit 3d01273

File tree

3 files changed

+59
-3
lines changed

3 files changed

+59
-3
lines changed

src/api/request_processor.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,15 @@ async def create_stream_generator_from_helper(event_to_set: Event, task_to_cance
292292
continue
293293
elif isinstance(raw_data, dict):
294294
data = raw_data
295+
if data.get('error') == 'rate_limit':
296+
logger.warning(f"[{req_id}] 🚨 接收到来自代理的速率限制信号: {data}")
297+
try:
298+
error_chunk = {'id': chat_completion_id, 'object': 'chat.completion.chunk', 'model': model_name_for_stream, 'created': created_timestamp, 'choices': [{'index': 0, 'delta': {'role': 'assistant', 'content': f"\n\n[System: Rate Limit Exceeded - {data.get('detail', 'Quota exceeded')}]"}, 'finish_reason': 'stop', 'native_finish_reason': 'stop'}]}
299+
yield f"data: {json.dumps(error_chunk, ensure_ascii=False, separators=(',', ':'))}\n\n"
300+
except: pass
301+
if not event_to_set.is_set():
302+
event_to_set.set()
303+
break
295304
else:
296305
logger.warning(f'[{req_id}] 未知的流数据类型: {type(raw_data)}')
297306
continue
@@ -339,6 +348,30 @@ async def create_stream_generator_from_helper(event_to_set: Event, task_to_cance
339348
choice_item = {'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': 'stop', 'native_finish_reason': 'stop'}
340349
output = {'id': chat_completion_id, 'object': 'chat.completion.chunk', 'model': model_name_for_stream, 'created': created_timestamp, 'choices': [choice_item]}
341350
yield f"data: {json.dumps(output, ensure_ascii=False, separators=(',', ':'))}\n\n"
351+
352+
# Late Rate Limit Check
353+
late_check_wait = 2.0 if len(full_body_content) < 50 else 0.2
354+
if late_check_wait > 0.5:
355+
logger.info(f"[{req_id}] 内容较短 ({len(full_body_content)}), 等待 {late_check_wait}s 检查延迟 Rate Limit")
356+
await asyncio.sleep(late_check_wait)
357+
try:
358+
from server import STREAM_QUEUE
359+
import queue
360+
if STREAM_QUEUE:
361+
while True:
362+
try:
363+
msg = STREAM_QUEUE.get_nowait()
364+
if isinstance(msg, dict) and msg.get('error') == 'rate_limit':
365+
logger.warning(f"[{req_id}] 🚨 捕获到延迟的 Rate Limit 信号: {msg}")
366+
try:
367+
error_chunk = {'id': chat_completion_id, 'object': 'chat.completion.chunk', 'model': model_name_for_stream, 'created': created_timestamp, 'choices': [{'index': 0, 'delta': {'role': 'assistant', 'content': f"\n\n[System: Rate Limit Exceeded - {msg.get('detail', 'Quota exceeded')}]"}, 'finish_reason': 'stop', 'native_finish_reason': 'stop'}]}
368+
yield f"data: {json.dumps(error_chunk, ensure_ascii=False, separators=(',', ':'))}\n\n"
369+
except: pass
370+
except queue.Empty:
371+
break
372+
except Exception as e:
373+
logger.error(f"[{req_id}] Late check failed: {e}")
374+
342375
except ClientDisconnectedError as disconnect_err:
343376
abort_handler = AbortSignalHandler()
344377
disconnect_info = abort_handler.handle_error(disconnect_err, req_id)
@@ -427,6 +460,9 @@ async def create_stream_generator_from_helper(event_to_set: Event, task_to_cance
427460
continue
428461
elif isinstance(raw_data, dict):
429462
data = raw_data
463+
if data.get('error') == 'rate_limit':
464+
logger.warning(f"[{req_id}] 🚨 非流式请求中接收到速率限制: {data}")
465+
raise HTTPException(status_code=429, detail=f"Rate limit exceeded: {data.get('detail')}")
430466
else:
431467
logger.warning(f'[{req_id}] 非流式未知数据类型: {type(raw_data)}')
432468
continue

src/gateway.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ async def stream_proxy() -> AsyncGenerator[bytes, None]:
134134
data, _ = chunk
135135
if data:
136136
check_count += 1
137-
if check_count <= 5 and not rate_limited:
137+
if not rate_limited:
138138
if check_rate_limit_in_response(data):
139139
rate_limited = True
140140
yield data

src/proxy/server.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ def _get_tls_context(self, domain: str):
4545
certfile=self.cert_store.storage_dir / f'{domain}.crt',
4646
keyfile=self.cert_store.storage_dir / f'{domain}.key'
4747
)
48+
try:
49+
ctx.set_alpn_protocols(['http/1.1'])
50+
except Exception:
51+
pass
4852

4953
if len(self._context_cache) > 50:
5054
self._context_cache.clear()
@@ -131,8 +135,13 @@ async def _process_tunnel(
131135
)
132136

133137
try:
138+
upstream_ctx = ssl.create_default_context()
139+
try:
140+
upstream_ctx.set_alpn_protocols(['http/1.1'])
141+
except Exception:
142+
pass
134143
server_reader, server_writer = await self.connector.open_connection(
135-
host, port, ssl.create_default_context()
144+
host, port, upstream_ctx
136145
)
137146
await self._relay_with_inspection(
138147
client_reader, client_writer,
@@ -219,7 +228,18 @@ async def process_upstream():
219228
client_buf.clear()
220229
continue
221230

222-
if 'GenerateContent' in path:
231+
if 'jserror' in path:
232+
inspect_response = False
233+
try:
234+
path_str = path
235+
if 'quota' in path_str or 'limit' in path_str or 'exceeded' in path_str:
236+
self.log.info(f"Rate limit keyword found in jserror: {path_str}")
237+
if self.message_queue is not None:
238+
self.message_queue.put({'error': 'rate_limit', 'detail': 'Rate limit detected via jserror', 'source': 'jserror', 'path': path_str})
239+
except Exception as e:
240+
self.log.error(f"Error inspecting jserror: {e}")
241+
server_writer.write(client_buf)
242+
elif 'GenerateContent' in path:
223243
inspect_response = True
224244
processed = await self.response_handler.handle_request(
225245
body_bytes, host, path

0 commit comments

Comments
 (0)