11from __future__ import annotations
22
3+ import os
34import time
4- from functools import lru_cache
5+ from functools import lru_cache , partial
56from pathlib import Path
67
78import magic
9+ import vertexai
810from google import genai
11+ from google .auth .exceptions import GoogleAuthError
912from google .genai import types
13+ from google .genai .errors import APIError
1014from google .genai .types import (
1115 CountTokensConfig ,
1216 File ,
2630from openai .types .chat .chat_completion import ChatCompletion , Choice
2731from pydantic import BaseModel
2832from pydantic_ai .messages import ModelMessage , ModelResponse
29- from pydantic_ai .models import Model , ModelRequestParameters , StreamedResponse
33+ from pydantic_ai .models import Model as PydanticAiModel
34+ from pydantic_ai .models import ModelRequestParameters , StreamedResponse
3035from pydantic_ai .models .gemini import GeminiModel
3136from pydantic_ai .settings import ModelSettings
3237from pydantic_ai .usage import Usage
4045 Type ,
4146 Union ,
4247)
48+ from vertexai .generative_models import GenerativeModel , SafetySetting
4349
4450from patchwork .common .client .llm .protocol import NOT_GIVEN , LlmClient , NotGiven
4551from patchwork .common .client .llm .utils import json_schema_to_model
52+ from patchwork .logger import logger
4653
4754
4855class GoogleLlmClient (LlmClient ):
@@ -51,30 +58,63 @@ class GoogleLlmClient(LlmClient):
5158 dict (category = "HARM_CATEGORY_SEXUALLY_EXPLICIT" , threshold = "BLOCK_NONE" ),
5259 dict (category = "HARM_CATEGORY_DANGEROUS_CONTENT" , threshold = "BLOCK_NONE" ),
5360 dict (category = "HARM_CATEGORY_HARASSMENT" , threshold = "BLOCK_NONE" ),
61+ dict (category = "HARM_CATEGORY_CIVIC_INTEGRITY" , threshold = "BLOCK_NONE" ),
62+ ]
63+ __VERTEX_SAFETY_SETTINGS = [
64+ SafetySetting (
65+ category = SafetySetting .HarmCategory .HARM_CATEGORY_HATE_SPEECH ,
66+ threshold = SafetySetting .HarmBlockThreshold .OFF ,
67+ ),
68+ SafetySetting (
69+ category = SafetySetting .HarmCategory .HARM_CATEGORY_DANGEROUS_CONTENT ,
70+ threshold = SafetySetting .HarmBlockThreshold .OFF ,
71+ ),
72+ SafetySetting (
73+ category = SafetySetting .HarmCategory .HARM_CATEGORY_SEXUALLY_EXPLICIT ,
74+ threshold = SafetySetting .HarmBlockThreshold .OFF ,
75+ ),
76+ SafetySetting (
77+ category = SafetySetting .HarmCategory .HARM_CATEGORY_HARASSMENT , threshold = SafetySetting .HarmBlockThreshold .OFF
78+ ),
79+ SafetySetting (
80+ category = SafetySetting .HarmCategory .HARM_CATEGORY_CIVIC_INTEGRITY ,
81+ threshold = SafetySetting .HarmBlockThreshold .OFF ,
82+ ),
5483 ]
5584 __MODEL_PREFIX = "models/"
5685
57- def __init__ (self , api_key : str , location : Optional [str ] = None ):
86+ def __init__ (self , api_key : Optional [str ] = None , is_gcp : bool = False ):
5887 self .__api_key = api_key
59- self .__location = location
60- self .client = genai .Client (api_key = api_key , location = location )
88+ self .__is_gcp = is_gcp
89+ if not self .__is_gcp :
90+ self .client = genai .Client (api_key = api_key )
91+ else :
92+ self .client = genai .Client (api_key = api_key , vertexai = True )
93+ location = os .environ .get ("GOOGLE_CLOUD_LOCATION" , "global" )
94+ vertexai .init (
95+ project = os .environ .get ("GOOGLE_CLOUD_PROJECT" ),
96+ location = location ,
97+ api_endpoint = f"{ location } -aiplatform.googleapis.com" ,
98+ )
6199
62100 @lru_cache (maxsize = 1 )
63101 def __get_models_info (self ) -> list [Model ]:
64- return list (self .client .models .list ())
102+ if not self .__is_gcp :
103+ return list (self .client .models .list ())
104+ else :
105+ return list ()
65106
66- def __get_pydantic_model (self , model_settings : ModelSettings | None ) -> Model :
107+ def __get_pydantic_model (self , model_settings : ModelSettings | None ) -> PydanticAiModel :
67108 if model_settings is None :
68109 raise ValueError ("Model settings cannot be None" )
69110 model_name = model_settings .get ("model" )
70111 if model_name is None :
71112 raise ValueError ("Model must be set cannot be None" )
72113
73- if self .__location is None :
114+ if not self .__is_gcp :
74115 return GeminiModel (model_name , api_key = self .__api_key )
75-
76- url_template = f"https://{ self .__location } -generativelanguage.googleapis.com/v1beta/models/{{model}}:"
77- return GeminiModel (model_name , api_key = self .__api_key , url_template = url_template )
116+ else :
117+ return GeminiModel (model_name , provider = "google-vertex" )
78118
79119 async def request (
80120 self ,
@@ -108,12 +148,15 @@ def __get_model_limits(self, model: str) -> int:
108148 return model_info .input_token_limit
109149 return 1_000_000
110150
111- @lru_cache
112- def get_models (self ) -> set [str ]:
113- return {model_info .name .removeprefix (self .__MODEL_PREFIX ) for model_info in self .__get_models_info ()}
151+ def test (self ):
152+ return
114153
115154 def is_model_supported (self , model : str ) -> bool :
116- return model in self .get_models ()
155+ if not self .__is_gcp :
156+ model_names = {model_info .name .removeprefix (self .__MODEL_PREFIX ) for model_info in self .__get_models_info ()}
157+ return model in model_names
158+ else :
159+ return True
117160
118161 def __upload (self , file : Path | NotGiven ) -> Part | File | None :
119162 if isinstance (file , NotGiven ):
@@ -163,6 +206,8 @@ def is_prompt_supported(
163206 top_p : Optional [float ] | NotGiven = NOT_GIVEN ,
164207 file : Path | NotGiven = NOT_GIVEN ,
165208 ) -> int :
209+ if self .__is_gcp :
210+ return 1
166211 system , contents = self .__openai_messages_to_google_messages (messages )
167212
168213 file_ref = self .__upload (file )
@@ -178,7 +223,12 @@ def is_prompt_supported(
178223 ),
179224 )
180225 token_count = token_response .total_tokens
226+ except GoogleAuthError :
227+ raise
228+ except APIError :
229+ raise
181230 except Exception as e :
231+ logger .debug (f"Error during token count at GoogleLlmClient: { e } " )
182232 return - 1
183233 model_limit = self .__get_model_limits (model )
184234 return model_limit - token_count
@@ -245,15 +295,25 @@ def chat_completion(
245295 if file_ref is not None :
246296 contents .append (file_ref )
247297
248- response = self .client .models .generate_content (
249- model = model ,
250- contents = contents ,
251- config = GenerateContentConfig (
252- system_instruction = system_content ,
253- safety_settings = self .__SAFETY_SETTINGS ,
254- ** NotGiven .remove_not_given (generation_dict ),
255- ),
256- )
298+ if not self .__is_gcp :
299+ generate_content_func = partial (
300+ self .client .models .generate_content ,
301+ model = model ,
302+ config = GenerateContentConfig (
303+ system_instruction = system_content ,
304+ safety_settings = self .__SAFETY_SETTINGS ,
305+ ** NotGiven .remove_not_given (generation_dict ),
306+ ),
307+ )
308+ else :
309+ vertexai_model = GenerativeModel (model , system_instruction = system_content )
310+ generate_content_func = partial (
311+ vertexai_model .generate_content ,
312+ safety_settings = self .__VERTEX_SAFETY_SETTINGS ,
313+ generation_config = NotGiven .remove_not_given (generation_dict ),
314+ )
315+
316+ response = generate_content_func (contents = contents )
257317 return self .__google_response_to_openai_response (response , model )
258318
259319 @staticmethod
0 commit comments