|
| 1 | +#!/usr/bin/env python3 |
| 2 | +import asyncio |
| 3 | +import aiohttp |
| 4 | +import base64 |
| 5 | +import sys |
| 6 | +import os |
| 7 | +import time |
| 8 | +from pathlib import Path |
| 9 | +from typing import Tuple, List |
| 10 | +from dataclasses import dataclass, field |
| 11 | + |
| 12 | +BASE_URL = os.environ.get("API_BASE_URL", "http://localhost:2048") |
| 13 | +OUTPUT_DIR = Path(__file__).parent / "test_output" |
| 14 | +TIMEOUT_SHORT = 30 |
| 15 | +TIMEOUT_MEDIUM = 120 |
| 16 | +TIMEOUT_LONG = 600 |
| 17 | + |
| 18 | + |
| 19 | +@dataclass |
| 20 | +class TestResult: |
| 21 | + passed: int = 0 |
| 22 | + failed: int = 0 |
| 23 | + results: List[Tuple[str, bool, str, float]] = field(default_factory=list) |
| 24 | + |
| 25 | + def record(self, name: str, success: bool, message: str = "", duration: float = 0): |
| 26 | + icon = "[OK]" if success else "[X]" |
| 27 | + self.results.append((name, success, message, duration)) |
| 28 | + if success: |
| 29 | + self.passed += 1 |
| 30 | + print(f" {icon} {name} ({duration:.2f}s)") |
| 31 | + else: |
| 32 | + self.failed += 1 |
| 33 | + print(f" {icon} {name}: {message}") |
| 34 | + return success |
| 35 | + |
| 36 | + |
| 37 | +def ensure_output_dir(): |
| 38 | + OUTPUT_DIR.mkdir(exist_ok=True) |
| 39 | + |
| 40 | + |
| 41 | +async def test_health(session: aiohttp.ClientSession) -> Tuple[bool, str, float]: |
| 42 | + start = time.time() |
| 43 | + try: |
| 44 | + async with session.get(f"{BASE_URL}/health", timeout=aiohttp.ClientTimeout(total=TIMEOUT_SHORT)) as resp: |
| 45 | + data = await resp.json() |
| 46 | + return True, f"status={data.get('status')}", time.time() - start |
| 47 | + except Exception as e: |
| 48 | + return False, str(e), time.time() - start |
| 49 | + |
| 50 | + |
| 51 | +async def test_models(session: aiohttp.ClientSession) -> Tuple[bool, str, float]: |
| 52 | + start = time.time() |
| 53 | + try: |
| 54 | + async with session.get(f"{BASE_URL}/v1/models", timeout=aiohttp.ClientTimeout(total=TIMEOUT_SHORT)) as resp: |
| 55 | + data = await resp.json() |
| 56 | + count = len(data.get("data", [])) |
| 57 | + return True, f"count={count}", time.time() - start |
| 58 | + except Exception as e: |
| 59 | + return False, str(e), time.time() - start |
| 60 | + |
| 61 | + |
| 62 | +async def test_chat(session: aiohttp.ClientSession) -> Tuple[bool, str, float]: |
| 63 | + start = time.time() |
| 64 | + try: |
| 65 | + payload = { |
| 66 | + "model": "gemini-2.5-flash", |
| 67 | + "messages": [{"role": "user", "content": "Say 'test ok' in 2 words."}], |
| 68 | + "max_tokens": 50 |
| 69 | + } |
| 70 | + async with session.post( |
| 71 | + f"{BASE_URL}/v1/chat/completions", |
| 72 | + json=payload, |
| 73 | + timeout=aiohttp.ClientTimeout(total=TIMEOUT_MEDIUM) |
| 74 | + ) as resp: |
| 75 | + data = await resp.json() |
| 76 | + content = data["choices"][0]["message"]["content"][:30] |
| 77 | + return True, f"response={content}", time.time() - start |
| 78 | + except Exception as e: |
| 79 | + return False, str(e), time.time() - start |
| 80 | + |
| 81 | + |
| 82 | +async def test_tts(session: aiohttp.ClientSession) -> Tuple[bool, str, float]: |
| 83 | + start = time.time() |
| 84 | + try: |
| 85 | + payload = { |
| 86 | + "model": "gemini-2.5-flash-preview-tts", |
| 87 | + "contents": "Hello, this is a test.", |
| 88 | + "generationConfig": { |
| 89 | + "responseModalities": ["AUDIO"], |
| 90 | + "speechConfig": { |
| 91 | + "voiceConfig": { |
| 92 | + "prebuiltVoiceConfig": {"voiceName": "Kore"} |
| 93 | + } |
| 94 | + } |
| 95 | + } |
| 96 | + } |
| 97 | + async with session.post( |
| 98 | + f"{BASE_URL}/generate-speech", |
| 99 | + json=payload, |
| 100 | + timeout=aiohttp.ClientTimeout(total=TIMEOUT_MEDIUM) |
| 101 | + ) as resp: |
| 102 | + data = await resp.json() |
| 103 | + audio_b64 = data["candidates"][0]["content"]["parts"][0]["inlineData"]["data"] |
| 104 | + audio_bytes = base64.b64decode(audio_b64) |
| 105 | + output_path = OUTPUT_DIR / "tts_output.wav" |
| 106 | + output_path.write_bytes(audio_bytes) |
| 107 | + return True, f"size={len(audio_bytes)}", time.time() - start |
| 108 | + except Exception as e: |
| 109 | + return False, str(e), time.time() - start |
| 110 | + |
| 111 | + |
| 112 | +async def test_imagen(session: aiohttp.ClientSession) -> Tuple[bool, str, float]: |
| 113 | + start = time.time() |
| 114 | + try: |
| 115 | + payload = { |
| 116 | + "prompt": "A mountain landscape at sunset", |
| 117 | + "model": "imagen-3.0-generate-002", |
| 118 | + "number_of_images": 1, |
| 119 | + "aspect_ratio": "16:9" |
| 120 | + } |
| 121 | + async with session.post( |
| 122 | + f"{BASE_URL}/generate-image", |
| 123 | + json=payload, |
| 124 | + timeout=aiohttp.ClientTimeout(total=TIMEOUT_LONG) |
| 125 | + ) as resp: |
| 126 | + data = await resp.json() |
| 127 | + img_b64 = data["generatedImages"][0]["image"]["imageBytes"] |
| 128 | + img_bytes = base64.b64decode(img_b64) |
| 129 | + output_path = OUTPUT_DIR / "imagen_output.png" |
| 130 | + output_path.write_bytes(img_bytes) |
| 131 | + return True, f"size={len(img_bytes)}", time.time() - start |
| 132 | + except Exception as e: |
| 133 | + return False, str(e), time.time() - start |
| 134 | + |
| 135 | + |
| 136 | +async def test_nano(session: aiohttp.ClientSession) -> Tuple[bool, str, float]: |
| 137 | + start = time.time() |
| 138 | + try: |
| 139 | + payload = { |
| 140 | + "model": "gemini-2.5-flash-image", |
| 141 | + "contents": [{"parts": [{"text": "A cute cartoon cat"}]}] |
| 142 | + } |
| 143 | + async with session.post( |
| 144 | + f"{BASE_URL}/nano/generate", |
| 145 | + json=payload, |
| 146 | + timeout=aiohttp.ClientTimeout(total=TIMEOUT_LONG) |
| 147 | + ) as resp: |
| 148 | + data = await resp.json() |
| 149 | + parts = data["candidates"][0]["content"]["parts"] |
| 150 | + for i, part in enumerate(parts): |
| 151 | + if "inlineData" in part: |
| 152 | + img_bytes = base64.b64decode(part["inlineData"]["data"]) |
| 153 | + output_path = OUTPUT_DIR / f"nano_output_{i}.png" |
| 154 | + output_path.write_bytes(img_bytes) |
| 155 | + return True, f"size={len(img_bytes)}", time.time() - start |
| 156 | + return False, "No image in response", time.time() - start |
| 157 | + except Exception as e: |
| 158 | + return False, str(e), time.time() - start |
| 159 | + |
| 160 | + |
| 161 | +async def test_veo(session: aiohttp.ClientSession) -> Tuple[bool, str, float]: |
| 162 | + start = time.time() |
| 163 | + try: |
| 164 | + payload = { |
| 165 | + "prompt": "Ocean waves on beach", |
| 166 | + "model": "veo-2.0-generate-001", |
| 167 | + "aspect_ratio": "16:9", |
| 168 | + "duration_seconds": 5 |
| 169 | + } |
| 170 | + async with session.post( |
| 171 | + f"{BASE_URL}/generate-video", |
| 172 | + json=payload, |
| 173 | + timeout=aiohttp.ClientTimeout(total=TIMEOUT_LONG) |
| 174 | + ) as resp: |
| 175 | + data = await resp.json() |
| 176 | + vid_b64 = data["generatedVideos"][0]["video"]["videoBytes"] |
| 177 | + vid_bytes = base64.b64decode(vid_b64) |
| 178 | + output_path = OUTPUT_DIR / "veo_output.mp4" |
| 179 | + output_path.write_bytes(vid_bytes) |
| 180 | + return True, f"size={len(vid_bytes)}", time.time() - start |
| 181 | + except Exception as e: |
| 182 | + return False, str(e), time.time() - start |
| 183 | + |
| 184 | + |
| 185 | +async def run_concurrent_tests(skip_veo: bool = True): |
| 186 | + print("=" * 50) |
| 187 | + print(" AIStudio2API Concurrent Tests") |
| 188 | + print(f" Base URL: {BASE_URL}") |
| 189 | + print(f" Mode: CONCURRENT (all tests run in parallel)") |
| 190 | + print("=" * 50) |
| 191 | + |
| 192 | + ensure_output_dir() |
| 193 | + |
| 194 | + async with aiohttp.ClientSession() as session: |
| 195 | + tests = [ |
| 196 | + ("Health", test_health(session)), |
| 197 | + ("Models", test_models(session)), |
| 198 | + ("Chat", test_chat(session)), |
| 199 | + ("TTS", test_tts(session)), |
| 200 | + ("Imagen", test_imagen(session)), |
| 201 | + ("Nano", test_nano(session)), |
| 202 | + ] |
| 203 | + |
| 204 | + if not skip_veo: |
| 205 | + tests.append(("Veo", test_veo(session))) |
| 206 | + |
| 207 | + print(f"\nRunning {len(tests)} tests concurrently...") |
| 208 | + start_all = time.time() |
| 209 | + |
| 210 | + tasks = [t[1] for t in tests] |
| 211 | + names = [t[0] for t in tests] |
| 212 | + |
| 213 | + results = await asyncio.gather(*tasks, return_exceptions=True) |
| 214 | + |
| 215 | + total_time = time.time() - start_all |
| 216 | + |
| 217 | + print("\n=== Results ===") |
| 218 | + result = TestResult() |
| 219 | + for name, res in zip(names, results): |
| 220 | + if isinstance(res, Exception): |
| 221 | + result.record(name, False, str(res), 0) |
| 222 | + else: |
| 223 | + success, msg, duration = res |
| 224 | + result.record(name, success, msg, duration) |
| 225 | + |
| 226 | + if skip_veo: |
| 227 | + print(" [--] Veo: SKIPPED (use --veo to include)") |
| 228 | + |
| 229 | + print("\n" + "=" * 50) |
| 230 | + total = result.passed + result.failed |
| 231 | + print(f" Results: {result.passed}/{total} passed") |
| 232 | + print(f" Total time: {total_time:.2f}s (concurrent)") |
| 233 | + print("=" * 50) |
| 234 | + |
| 235 | + return result.failed == 0 |
| 236 | + |
| 237 | + |
| 238 | +def main(): |
| 239 | + skip_veo = "--veo" not in sys.argv |
| 240 | + success = asyncio.run(run_concurrent_tests(skip_veo)) |
| 241 | + sys.exit(0 if success else 1) |
| 242 | + |
| 243 | + |
| 244 | +if __name__ == "__main__": |
| 245 | + main() |
0 commit comments