From 5bb6861be41ab1c3472d3f718d3501b2b950b95a Mon Sep 17 00:00:00 2001 From: balaboom123 Date: Thu, 22 May 2025 00:07:26 +0800 Subject: [PATCH] add 13 options of LLM support 1. create config.py and llm_provider.py to manage the API call for different LLM provider and their model's option. 2. To use all the langchain model, you need to download. I changed in the requirements.txt but not added in uv.lock: ``` langchain-mistralai==0.2.4 langchain-ibm==0.3.10 ``` 3. option includes: - OpenAI - Anthropic - Google - Azure OpenAI - DeepSeek - Mistral - Ollama - Alibaba - Moonshot - Unbound - SiliconFLOW - IBM - Grok --- workflows/.env.example | 42 ++- workflows/cli.py | 33 +- workflows/requirements.txt | Bin 0 -> 77662 bytes workflows/workflow_use/llm/config.py | 87 +++++ workflows/workflow_use/llm/llm_provider.py | 351 +++++++++++++++++++++ 5 files changed, 503 insertions(+), 10 deletions(-) create mode 100644 workflows/requirements.txt create mode 100644 workflows/workflow_use/llm/config.py create mode 100644 workflows/workflow_use/llm/llm_provider.py diff --git a/workflows/.env.example b/workflows/.env.example index 3d6a3f43..9b929fa2 100644 --- a/workflows/.env.example +++ b/workflows/.env.example @@ -1,2 +1,42 @@ # We support all langchain models, openai only for demo purposes -OPENAI_API_KEY= \ No newline at end of file +LLM_PROVIDER="openai" +MODEL_NAME="gpt-4o" + +OPENAI_ENDPOINT=https://api.openai.com/v1 +OPENAI_API_KEY= + +ANTHROPIC_API_KEY= +ANTHROPIC_ENDPOINT=https://api.anthropic.com + +GOOGLE_API_KEY= + +AZURE_OPENAI_ENDPOINT= +AZURE_OPENAI_API_KEY= +AZURE_OPENAI_API_VERSION=2025-01-01-preview + +DEEPSEEK_ENDPOINT=https://api.deepseek.com +DEEPSEEK_API_KEY= + +MISTRAL_API_KEY= +MISTRAL_ENDPOINT=https://api.mistral.ai/v1 + +OLLAMA_ENDPOINT=http://localhost:11434 + +ALIBABA_ENDPOINT=https://dashscope.aliyuncs.com/compatible-mode/v1 +ALIBABA_API_KEY= + +MOONSHOT_ENDPOINT=https://api.moonshot.cn/v1 +MOONSHOT_API_KEY= + +UNBOUND_ENDPOINT=https://api.getunbound.ai +UNBOUND_API_KEY= + +SiliconFLOW_ENDPOINT=https://api.siliconflow.cn/v1/ +SiliconFLOW_API_KEY= + +IBM_ENDPOINT=https://us-south.ml.cloud.ibm.com +IBM_API_KEY= +IBM_PROJECT_ID= + +GROK_ENDPOINT="https://api.x.ai/v1" +GROK_API_KEY= \ No newline at end of file diff --git a/workflows/cli.py b/workflows/cli.py index b2cc4734..c994bc61 100644 --- a/workflows/cli.py +++ b/workflows/cli.py @@ -1,18 +1,20 @@ +from dotenv import load_dotenv +load_dotenv() import asyncio import json import tempfile # For temporary file handling from pathlib import Path +import os import typer from browser_use.browser.browser import Browser -# Assuming OPENAI_API_KEY is set in the environment -from langchain_openai import ChatOpenAI - from workflow_use.builder.service import BuilderService from workflow_use.controller.service import WorkflowController from workflow_use.recorder.service import RecordingService # Added import from workflow_use.workflow.service import Workflow +from workflow_use.llm.llm_provider import get_llm_model +from workflow_use.llm.config import model_names # Placeholder for recorder functionality # from src.recorder.service import RecorderService @@ -27,13 +29,26 @@ # Default LLM instance to None llm_instance = None try: - llm_instance = ChatOpenAI(model='gpt-4o') + # Get provider and model name from environment or default to openai + provider = os.getenv("LLM_PROVIDER", "openai").lower() + model_name = os.getenv("MODEL_NAME", "") + + # If no model name specified, prompt user + if not model_name: + typer.echo(f"Available models for {provider}:") + typer.echo(f"{model_names[provider]}") + + model_name = typer.prompt(f"model name for {provider}, default=", default=model_names[provider][0]) + os.environ["MODEL_NAME"] = model_name + + # Initialize LLM with selected provider and model + llm_instance = get_llm_model(provider, model_name=model_name) + except Exception as e: - typer.secho(f'Error initializing LLM: {e}. Would you like to set your OPENAI_API_KEY?', fg=typer.colors.RED) - set_openai_api_key = input('Set OPENAI_API_KEY? (y/n): ') - if set_openai_api_key.lower() == 'y': - os.environ['OPENAI_API_KEY'] = input('Enter your OPENAI_API_KEY: ') - llm_instance = ChatOpenAI(model='gpt-4o') + typer.secho( + f"Error initializing LLM: {e}. Would you like to set your API key?", + fg=typer.colors.RED, + ) builder_service = BuilderService(llm=llm_instance) if llm_instance else None # recorder_service = RecorderService() # Placeholder diff --git a/workflows/requirements.txt b/workflows/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..47c5f0cddff837f9b31608f9a26cc359306c6375 GIT binary patch literal 77662 zcmeI5TXWpFj)3)is`fwl)I8*=Q?cVoGCRBb8}?z}Q@-1YFFAHQ8Ry6EEn`Ne zTlu}0&)4$zseI?Sdy+pNyRY)+SGn^@{$2|#?mw4%uI1-WAYI9w7xLMO{O?A7j^zKR z-M@GL*8QXVO+MEjg>%;Z)qRog-3z3>ZY4jr^1qcp1s`C8C!h0#dk2kg-piGL2`*Re zUpMlN(-&+gBwR_geVfNbn8ytYV7Pv|e-+Yp1^MwDDpEup>?!ElJao;yA!PjX0 zBHuri`+3rpPzOEwze~a5PGIq=!TQ}nOhdbN_g=^y&|C;7&mgMBDC{xN|k zcl5F6-(xIe-#T_wd=q}K?knLg|B8L{QaG>{djE=Eg%hj?RK@Ookb95h2}mAr^clzU z*-7_Vs63Y+xOgB`^1DBR;#cHOa6|@?Se|t(&{T%61j?g4<3X-H3Py)NxVPWE6{v?v zs7UO+U~(#VeiJ%?^z;lJ3A&T}!EaxnyzM^7pC6L2;;Rgg6#NU4Zo?Hf35LGI?p>0$Zw~9v0#Eh6FEPhXH?M_6rnlTtxNaq6GtC9ihbkx2cnAy4j$TbkfbTb zJCk)=qp9!2^}%cvmIfOIgcHHdu!GP0&i3q!uKp>%UxX^`l-azG^2w88;bY3SKWWhJ^a2(^6$#WN%Y-o~!3 z#bcfd)?7R3$&;-ml`hF%-f`ncOzmD8>lbLsd&3@7j?ftl~AFmCcIgJk-Wu?5Q zSC6P3DbCmMQ!H3|^+NkL#4evBJT<&my(_!|>$e+gK?pB*bz&R{VGTJOVz08STB7*~ zUUDT8j#d+OXvBd;zIWC>C6i^ij92Aeo_!+HK%S%8wsv+*F*jSM(J>sy&%*VTMT_C3 ze&T} z$g!ttWkVyv`1R?0cUnp>MdP^tUZNR|J2ekfe>Ann^?~2OxmuRABQWSgF%Nshe=Bpt}yl}{k+gxjX4;DH~v0`|R_?Bn5xEl!T}F z$Fo%4r)n>035?Ox=e5zIb|zeHvtOIS_dc=@pXe$6gz(c^t%}QWj2Bm;y*EEB3tTXd z8v8bpgVqR$qRFz1KGsO9%(45Xq|PAl?Dq-=KNX3z35O&{9fb41tIzjVB1u)L3o zt;+aYKF?K*@L!*WGt|*nPQS>%o*bv(65cGWHDNuz4*~d7*P&hIIPEWEG~I~R!wx8( z{CO=o+J$^Z1wD2Lc@}(%U9+{O=5X!45#p@80b4A{bX1G3}}Req(qK zo~~S*EE9-9O;I%DCw>H*0RNG#X`ep!6tPO4N!9O8N7H#=+D=-}1^(zxtmJ2Bxqyq- z6Yub6NVUa&)0SjtPfhN0|22hc_@vFEy-@CbZPz@vuko_`o2D&>$1o1xtd8D~v-o!i zYg|$ExFgfr386dT)KOAh3x7Y}QCB(&#+)e?#k4i=sK2TI>HQfz;|u{o4riyKF}JX#v3HqYVma2>;6$@*Q=U z=fZJ_8npURjSgU7YcPF;*(~o?Noa#+70v&dmK*6}vB=y{>n8#x=Lh6vXJ{RY%(=ThgN!FWzUY^JgHl7?6U!>Je-SZ8ly8pZ>c=T7QY?NYk zD%O>~B=BN~5! zze=f9Ea92&C0vS=z}Eg#G7Zi1d_SuCd-m<=N-M?x(N*m3+`3bphaYhJR6K<~u^ckd zh4vL}hd)F>^rk51SphUb^#R|6cHt%IexsX%`kP;+8oKPMNO&onuY#4@04$sJp_W5K zLdtdHt55UPbUA1zW$rG{XoYw$Djg?9{ zi1cBttS6Yg4eePg`p3ViO>i`rke*B%Nr7n4gUA zN&cjd2l{Ofwv4Ah_Ec-Nb{R=DoUtG(m3&ejr>5bR`zeL_S!9x!i@3Djr=Q9ior;tq z=}!VxyO7IP2WeuB>#}HMh5CSc7ji3n4N))Fp?pPv%i3xN0$gkkFw3*sv0T@9e1pcR zH|<&SRUv#^X%24Zd~_cMjYnHc^U4&oa=8W~>#IE)HIe1H*cUH_rE$rHa3AlF6e7*# zyz-EcH$AULPeMJ!8aXqpOzUf4;Wdy|;?f#8oSfBjU*weNKD89H#qrbb^YH+3#R{iMsAq$09pU112U;GU3Kc0q_Q*h$A7!=}Jk{ftXIMiR<@!yGvQ!omhgF9g z)nrb2;QPuY4T;rC4?cy8lEx=x>!*Dmz`J&ufk$9R4cTC2%b?=ry;Nwm1Zem`^pAQl zt6o~213p|Jnp`F>Q(kK1Lp_>$YiS0ou$XI7TDbrMF+83W9o4m{&0H=2Rqmx~+NCQ_ z4N0R3>|nfJjwCRlgZ&=pjG^%x_H4K0M`2sl9N|#b%=L3n(ky=~gM(-Rm@#dNsFKgggqVF$ z1|N^CQ7STx?Z$7~^ny~D=m=R%&HZ&dDe%Yd#w=y!wD(^p|KN|Ta#f#WS?@J5f@Tc5 zFIKv)%pEwY@?Su}rS~W6Jx7**Jhcv8V)N##ES%O}zsFz7fA4BXfUcUyCc1)>k z`JSF~m{gMR(sdezY*FP!&dR-e2xYlWnh-Ml(;8!41p>I5i)d`4wKHD@!qyR~F#ik< zEM0jbDcd#Dn0JZFgvQL+np5cnVda>=K@UYKp7^a8Jl)+Xmr~$ zr1LYi8)%u8)i}AlzQQi9YF(`*wa%tG%lYJdAbD%KwCipJy5*(51{zgfsx6v}EY~&I z@5~;@?4=l^h3F93>@Y7V(@JWC+w3f}FD#b|il<({=43+CBh&uZ+s! zeo{^h{`EPltCI+wHMcE^D6`bziO3oX54N0_l~8Eildh@t;K8r>g=p?4s9>2K+ z2mL1Vq4>$kbfzk*W!aXV=%9;sxCc6{Th_qNtu2;5W4?k{4mouiBCM0uJfXZAwZv1Y z`IN$m*Ja9~)%0@E_qLlJMCVay2S7T9DO=HDu22rcBJxrQyXiqEGQBo7Q;<5ho=ezG z|IrDTX+Dvo_^CamvWT_q!>O@&)}^@Axmkl}GeqiW#1ieM|MaJprZf~;CIgxy4)d|uf!oSJCOVA%e64XoWn9k}?bPSja; z63%XN4e^W43#px4!+8t!-n%(7o6o+@n_;t?Gn2a;;Crr?@8Vi^&dd@gilBCx*)TPp zSFzYlLD?kI+)Y8*j9c+jiMuH%oO49~Y<3NAitBe`l;x94PC?mCG2tyy@k-uqiV0^7 z%4Bpm#e|qBKVP>Bpq}Pvaf-=q`Uq2MsHpz@(?@LItW>vm(?U`{V>d0tPRqu9YQH;C zb!j&(gm`W@Ekrr6n-+q9DYMJrw2-jg)sHdQVPs+rPnz%IL{sBkH#r^2sb2XO&x$}N z9WU4^<|)?Dx9{Gg+g>j@@e6Ze?ESiSvUIQe)AjUnTA)AICAEJ09dC-?9({9KFn2xf zfh>YE4%n`Eo)+|Z$DMWZv0eTR`!Dh}ec11nLt)>mZ#Ma7ZVl(I$Jdz7+T~^&Zs$Fk z7K8f^JzwR`lf~gW>D}91ts03M_x$DHnzlWMTeUqklv+R-7;GJIxn^RTssrLpOJC)#J?fT#{yv%H!s(IF4TAzd;V^VOFXye zS!`J|YJ9bIvFFfN?yKuCysqT_`>jSwbGkTpJ#Lr6*{Pezt?ousN6A^$pD|_#*F=gH)F-Wr-)TyZAx$Iis%(Ak=wGMU9+Xtu^D`kub4 z%pJXH7~ADanPp-HQzsBx#P_S5yRef}HZ!t4JpFs+^Y;(ZamsezqEI3?M5FpvJgcu# zRc7WP=LOK27T09zyl4)-=@n|vHjis@^^(})+*qTNxA^dWtp1x?=s0-FlEn1NwwDvq z*~Qyrcz+i0m3H{#&FfCnxm^jKt7XhxkJma7zrk>u7Cn^1IPm@4Z`m z1Bp>*5&aTWw%7eNp3+bwr7cEboUG9|yAJnT&6d^B+u}Q(@T9!ad+p94G8^HH*!|C+{4ykOjjxU?Y!4Vr85WHRE*r_+ORVOoq4fNinjgk z-A}lijcj`Q^P~RpeKm5?;*JBoC1stV<+?hT!_8ff=Xh~SHd*T!Cvtnt%ehfzx%8aQ zy>wYq$2hBuX4~7}*of-Xh5QeHgw?8{x5YQ6 z=;;>oMfzqk=p4^n%Ax#a^=_)OrA2IG)2DLqOnd+Ka(uN8e@T>Gds?hb+HNIVpn{7f zRITyvMXy`S_TtpMd0%x@Z?_gX7rV?|kH?J37UN+&FuNx88G5VK zeOZ9E{RZE%<@@FPJ9~IQ){}U$c89OMny;tXb~mZP=#1k!&D^e)EXjF1txM(uEw8E7 zy6blssDcB(KBQs3tdbzIZQZS%30#|-Lfq@>0D9#36;M78aC7o~PgQ;ScZh<`l?fW*2QK^qO^Ov8JYc=+79gqp-b`5jPUUvqyaM zyqj8A@osYs*%_t^Tz}qoyU{|KoqQC^sPbY5@Ql!_6 z@vZj6d(5t?-DPfXM@lKPNz%XOx*2uLRufu`+iu1BV2$W?r4z|cCoI^j%_oU1I~@*aHP!F! z&SAC5HD#+{=R<1f^BwH;qft}y{orL-Ar({oJ7rMWZ?IkbQV3@*3!qn@9dfF!A7 z`jx9P7(OO0yC18B27VTb?&RmQtl*`5b|chLfzcj-^0zv-_C6i_AUpl!t=oGq3Z9o| zoK@KkJ9(_}7bozJWra$2@;0ALmoFPR{wy6iwdbqB)w&4kqmdSKGSy407xPJo5Rnhv zi>+JD_If79HmHWe7I%x85rkzkds9Q7y;Wgb^MFfsL2L$+W-;S=vo-W;zu{rmD!oWA zo+|I;^~+{#{dK{k68qJyMSC?-rIGzNC?pah?qenhr%h_8ZO`xF35N4RbLe|%nWF;# z=++>2-?JRUxu|;XdOVlT>l-hF%^LRkv)ubcLda+XF|I#N2gdWATOphnhV9GUnEEzx z-X!12p&%8g>cIi`sEuM@OU%UB zc9cJ7+vl#w{YsiAxt5uItnc!B^IW;8lbZJSvmQhj*omZqS3_HivwonBv){~p)>(Z_ zrYFCzqqN01RlDI%#Jc^8v<<6=brf#*q|Th3n!6siW40?Xtik2lFzwFyb!#C{t~yUW zxK6X%jM#K;;xKZ{qq98AE2~LiUCDoL#G}xYOHbT|%S?S=T#oMkeRz-i74tZq$^lyJ z%^?B?U(OdT*NeLC_^?DG@%?mW>e(u4BY4U98@hAkGPTMr?q|?BLavwUPpy^f4?p4b zu6ds1PkG9N^{Yvh-+_G)yM8P(#=98U`A62YBSvHHx_Lt_o=EVtFL~0qy*dJP-isf9 zbUERb%YLn^{a2LdCFUG#`%Gnt_xmO~dAvc@V(F+0U(4w%I)lDEXQi*7-y#%dLCD;Z zyt97T8MYD`#EV&;qx0uK$tRpRSVPPMbAKgs8u;DiG0t60|RIV z7M)6T9om#uB$+ki&yX)u^2?m_Z%zmAg;KALVF&7kd=L3$Uhhgc^F{C)@+`b#vz2&| zRoB{mh&^z}Iut%FPf{1k16Ebf;@`{j&gBOVT3K`q{Lwo>Tnf^R&2Tlbm-e+;uwXW zA_46A-HWw2b(*dDajIWeJ)ZYWp#-tBo%WzS{Xuw722Lf&-aU;CO3O|RXSK9cEjW3I zhNQmrira)G2TN@Ftj`H2YHy&?vy7AZM12*#%&>AZ6#wkh7z)t)jSp60uhV%&goj( z=8f;ydQ;zqq3J0PyeIaB*ypoo3$f3)fyY%VfGxX_&r@|u@YLy5sdWxHf=(}?UsM35 zbR)DKX7F6d6}L9wcZM<^gUGas824jVBlf+a-o(FQ z635j?!F(n13k7J;O6-(PF4X8GEs5%<==469dsa^W^?rqnec?ODBF$#uG+GJqQjH5g zg+DpjQ99C&K@6ij#T5)x8uUH#*uu9?qf None: + super().__init__(*args, **kwargs) + self.client = OpenAI( + base_url=kwargs.get("base_url"), + api_key=kwargs.get("api_key") + ) + + async def ainvoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> AIMessage: + message_history = [] + for input_ in input: + if isinstance(input_, SystemMessage): + message_history.append({"role": "system", "content": input_.content}) + elif isinstance(input_, AIMessage): + message_history.append({"role": "assistant", "content": input_.content}) + else: + message_history.append({"role": "user", "content": input_.content}) + + response = self.client.chat.completions.create( + model=self.model_name, + messages=message_history + ) + + reasoning_content = response.choices[0].message.reasoning_content + content = response.choices[0].message.content + return AIMessage(content=content, reasoning_content=reasoning_content) + + def invoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> AIMessage: + message_history = [] + for input_ in input: + if isinstance(input_, SystemMessage): + message_history.append({"role": "system", "content": input_.content}) + elif isinstance(input_, AIMessage): + message_history.append({"role": "assistant", "content": input_.content}) + else: + message_history.append({"role": "user", "content": input_.content}) + + response = self.client.chat.completions.create( + model=self.model_name, + messages=message_history + ) + + reasoning_content = response.choices[0].message.reasoning_content + content = response.choices[0].message.content + return AIMessage(content=content, reasoning_content=reasoning_content) + + +class DeepSeekR1ChatOllama(ChatOllama): + + async def ainvoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> AIMessage: + org_ai_message = await super().ainvoke(input=input) + org_content = org_ai_message.content + reasoning_content = org_content.split("")[0].replace("", "") + content = org_content.split("")[1] + if "**JSON Response:**" in content: + content = content.split("**JSON Response:**")[-1] + return AIMessage(content=content, reasoning_content=reasoning_content) + + def invoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> AIMessage: + org_ai_message = super().invoke(input=input) + org_content = org_ai_message.content + reasoning_content = org_content.split("")[0].replace("", "") + content = org_content.split("")[1] + if "**JSON Response:**" in content: + content = content.split("**JSON Response:**")[-1] + return AIMessage(content=content, reasoning_content=reasoning_content) + + +def get_llm_model(provider: str, **kwargs): + """ + Get LLM model + :param provider: LLM provider + :param kwargs: + :return: + """ + if provider not in ["ollama", "bedrock"]: + env_var = f"{provider.upper()}_API_KEY" + api_key = kwargs.get("api_key", "") or os.getenv(env_var, "") + if not api_key: + error_msg = f"🔑 Please set the `{env_var}` environment variable." + raise ValueError(error_msg) + kwargs["api_key"] = api_key + + if provider == "anthropic": + if not kwargs.get("base_url", ""): + base_url = "https://api.anthropic.com" + else: + base_url = kwargs.get("base_url") + + return ChatAnthropic( + model=kwargs.get("model_name", "claude-3-5-sonnet-20241022"), + temperature=kwargs.get("temperature", 0.0), + base_url=base_url, + api_key=api_key, + ) + elif provider == 'mistral': + if not kwargs.get("base_url", ""): + base_url = os.getenv("MISTRAL_ENDPOINT", "https://api.mistral.ai/v1") + else: + base_url = kwargs.get("base_url") + if not kwargs.get("api_key", ""): + api_key = os.getenv("MISTRAL_API_KEY", "") + else: + api_key = kwargs.get("api_key") + + return ChatMistralAI( + model=kwargs.get("model_name", "mistral-large-latest"), + temperature=kwargs.get("temperature", 0.0), + base_url=base_url, + api_key=api_key, + ) + elif provider == "openai": + if not kwargs.get("base_url", ""): + base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1") + else: + base_url = kwargs.get("base_url") + + return ChatOpenAI( + model=kwargs.get("model_name", "gpt-4o"), + temperature=kwargs.get("temperature", 0.0), + base_url=base_url, + api_key=api_key, + ) + elif provider == "grok": + if not kwargs.get("base_url", ""): + base_url = os.getenv("GROK_ENDPOINT", "https://api.x.ai/v1") + else: + base_url = kwargs.get("base_url") + + return ChatOpenAI( + model=kwargs.get("model_name", "grok-3"), + temperature=kwargs.get("temperature", 0.0), + base_url=base_url, + api_key=api_key, + ) + elif provider == "deepseek": + if not kwargs.get("base_url", ""): + base_url = os.getenv("DEEPSEEK_ENDPOINT", "") + else: + base_url = kwargs.get("base_url") + + if kwargs.get("model_name", "deepseek-chat") == "deepseek-reasoner": + return DeepSeekR1ChatOpenAI( + model=kwargs.get("model_name", "deepseek-reasoner"), + temperature=kwargs.get("temperature", 0.0), + base_url=base_url, + api_key=api_key, + ) + else: + return ChatOpenAI( + model=kwargs.get("model_name", "deepseek-chat"), + temperature=kwargs.get("temperature", 0.0), + base_url=base_url, + api_key=api_key, + ) + elif provider == "google": + return ChatGoogleGenerativeAI( + model=kwargs.get("model_name", "gemini-2.0-flash-exp"), + temperature=kwargs.get("temperature", 0.0), + api_key=api_key, + ) + elif provider == "ollama": + if not kwargs.get("base_url", ""): + base_url = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434") + else: + base_url = kwargs.get("base_url") + + if "deepseek-r1" in kwargs.get("model_name", "qwen2.5:7b"): + return DeepSeekR1ChatOllama( + model=kwargs.get("model_name", "deepseek-r1:14b"), + temperature=kwargs.get("temperature", 0.0), + num_ctx=kwargs.get("num_ctx", 32000), + base_url=base_url, + ) + else: + return ChatOllama( + model=kwargs.get("model_name", "qwen2.5:7b"), + temperature=kwargs.get("temperature", 0.0), + num_ctx=kwargs.get("num_ctx", 32000), + num_predict=kwargs.get("num_predict", 1024), + base_url=base_url, + ) + elif provider == "azure_openai": + if not kwargs.get("base_url", ""): + base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "") + else: + base_url = kwargs.get("base_url") + api_version = kwargs.get("api_version", "") or os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview") + return AzureChatOpenAI( + model=kwargs.get("model_name", "gpt-4o"), + temperature=kwargs.get("temperature", 0.0), + api_version=api_version, + azure_endpoint=base_url, + api_key=api_key, + ) + elif provider == "alibaba": + if not kwargs.get("base_url", ""): + base_url = os.getenv("ALIBABA_ENDPOINT", "https://dashscope.aliyuncs.com/compatible-mode/v1") + else: + base_url = kwargs.get("base_url") + + return ChatOpenAI( + model=kwargs.get("model_name", "qwen-plus"), + temperature=kwargs.get("temperature", 0.0), + base_url=base_url, + api_key=api_key, + ) + elif provider == "ibm": + parameters = { + "temperature": kwargs.get("temperature", 0.0), + "max_tokens": kwargs.get("num_ctx", 32000) + } + if not kwargs.get("base_url", ""): + base_url = os.getenv("IBM_ENDPOINT", "https://us-south.ml.cloud.ibm.com") + else: + base_url = kwargs.get("base_url") + + return ChatWatsonx( + model_id=kwargs.get("model_name", "ibm/granite-vision-3.1-2b-preview"), + url=base_url, + project_id=os.getenv("IBM_PROJECT_ID"), + apikey=os.getenv("IBM_API_KEY"), + params=parameters + ) + elif provider == "moonshot": + return ChatOpenAI( + model=kwargs.get("model_name", "moonshot-v1-32k-vision-preview"), + temperature=kwargs.get("temperature", 0.0), + base_url=os.getenv("MOONSHOT_ENDPOINT"), + api_key=os.getenv("MOONSHOT_API_KEY"), + ) + elif provider == "unbound": + return ChatOpenAI( + model=kwargs.get("model_name", "gpt-4o-mini"), + temperature=kwargs.get("temperature", 0.0), + base_url=os.getenv("UNBOUND_ENDPOINT", "https://api.getunbound.ai"), + api_key=api_key, + ) + elif provider == "siliconflow": + if not kwargs.get("api_key", ""): + api_key = os.getenv("SiliconFLOW_API_KEY", "") + else: + api_key = kwargs.get("api_key") + if not kwargs.get("base_url", ""): + base_url = os.getenv("SiliconFLOW_ENDPOINT", "") + else: + base_url = kwargs.get("base_url") + return ChatOpenAI( + api_key=api_key, + base_url=base_url, + model_name=kwargs.get("model_name", "Qwen/QwQ-32B"), + temperature=kwargs.get("temperature", 0.0), + ) + elif provider == "modelscope": + if not kwargs.get("api_key", ""): + api_key = os.getenv("MODELSCOPE_API_KEY", "") + else: + api_key = kwargs.get("api_key") + if not kwargs.get("base_url", ""): + base_url = os.getenv("MODELSCOPE_ENDPOINT", "") + else: + base_url = kwargs.get("base_url") + return ChatOpenAI( + api_key=api_key, + base_url=base_url, + model_name=kwargs.get("model_name", "Qwen/QwQ-32B"), + temperature=kwargs.get("temperature", 0.0), + ) + else: + raise ValueError(f"Unsupported provider: {provider}")