11import asyncio
22import os
3- import subprocess
43from typing import Any
54
65import logfire
1413
1514from agents_mcp_usage .multi_mcp .mermaid_diagrams import (
1615 invalid_mermaid_diagram_easy ,
16+ invalid_mermaid_diagram_medium ,
17+ invalid_mermaid_diagram_hard ,
1718 valid_mermaid_diagram ,
1819)
20+ from mcp_servers .mermaid_validator import validate_mermaid_diagram
1921
2022load_dotenv ()
2123
2729logfire .instrument_pydantic_ai ()
2830
2931# Default model to use
30- DEFAULT_MODEL = "gemini-2.5-pro-preview-03-25 "
31- # DEFAULT_MODEL = "openai:o4-mini"
32+ DEFAULT_MODEL = "gemini-2.5-pro-preview-05-06 "
33+
3234# Configure MCP servers
3335local_server = MCPServerStdio (
3436 command = "uv" ,
3941 ],
4042)
4143mermaid_server = MCPServerStdio (
42- command = "npx " ,
44+ command = "uv " ,
4345 args = [
44- "-y " ,
45- "@rtuin/mcp-mermaid-validator@latest " ,
46+ "run " ,
47+ "mcp_servers/mermaid_validator.py " ,
4648 ],
4749)
4850
4951
5052# Create Agent with MCP servers
51- def create_agent (model : str = DEFAULT_MODEL ):
53+ def create_agent (model : str = DEFAULT_MODEL , model_settings : dict [ str , Any ] = {} ):
5254 return Agent (
5355 model ,
5456 mcp_servers = [local_server , mermaid_server ],
57+ model_settings = model_settings ,
5558 )
5659
5760
58- agent = create_agent ()
59- Agent .instrument_all ()
60-
61-
62- async def main (
63- query : str = "Hi!" , request_limit : int = 5 , model : str = DEFAULT_MODEL
64- ) -> Any :
65- """
66- Main function to run the agent
67-
68- Args:
69- query (str): The query to run the agent with
70- request_limit (int): The number of requests to make to the MCP servers
71- model (str): The model to use for the agent
72-
73- Returns:
74- The result from the agent's execution
75- """
76- # Create a fresh agent with the specified model
77- current_agent = create_agent (model )
78-
79- # Set a request limit for LLM calls
80- usage_limits = UsageLimits (request_limit = request_limit )
81-
82- # Invoke the agent with the usage limits
83- async with current_agent .run_mcp_servers ():
84- result = await current_agent .run (query , usage_limits = usage_limits )
85-
86- return result
87-
88-
8961# Define input and output schema for evaluations
9062class MermaidInput (BaseModel ):
9163 invalid_diagram : str
@@ -110,86 +82,35 @@ class MermaidDiagramValid(Evaluator[MermaidInput, MermaidOutput]):
11082 async def evaluate (
11183 self , ctx : EvaluatorContext [MermaidInput , MermaidOutput ]
11284 ) -> float :
113- diagram = ctx .output .fixed_diagram
114-
115- # Extract mermaid code from markdown code block if present
116- mermaid_code = diagram
117- if "```mermaid" in diagram and "```" in diagram :
118- start_idx = diagram .find ("```mermaid" ) + len ("```mermaid" )
119- end_idx = diagram .rfind ("```" )
120- mermaid_code = diagram [start_idx :end_idx ].strip ()
121-
122- # Validate using mmdc
123- is_valid , _ = self .validate_mermaid_string_via_mmdc (mermaid_code )
124- return 1.0 if is_valid else 0.0
125-
126- def validate_mermaid_string_via_mmdc (
127- self , mermaid_code : str , mmdc_path : str = "mmdc"
128- ) -> tuple [bool , str ]:
129- """
130- Validates a Mermaid string by attempting to compile it using the
131- Mermaid CLI (mmdc). Requires mmdc to be installed and in PATH,
132- or mmdc_path to be explicitly provided.
133-
134- Args:
135- mermaid_code: The string containing the Mermaid diagram syntax.
136- mmdc_path: The command or path to the mmdc executable.
137-
138- Returns:
139- A tuple (is_valid: bool, message: str).
140- 'message' will contain stderr output if not valid, or a success message.
141- """
142- # Define temporary file names
143- temp_mmd_file = "temp_mermaid_for_validation.mmd"
144- # mmdc requires an output file, even if we don't use its content for validation.
145- temp_output_file = "temp_mermaid_output.svg"
146-
147- # Write the mermaid code to a temporary file
148- with open (temp_mmd_file , "w" , encoding = "utf-8" ) as f :
149- f .write (mermaid_code )
150-
151- try :
152- # Construct the command to run mmdc
153- command = [mmdc_path , "-i" , temp_mmd_file , "-o" , temp_output_file ]
154-
155- # Execute the mmdc command
156- process = subprocess .run (
157- command ,
158- capture_output = True , # Capture stdout and stderr
159- text = True , # Decode output as text
160- check = False , # Do not raise an exception for non-zero exit codes
161- encoding = "utf-8" ,
85+ # Strip whitespace, remove backticks and ```mermaid markers
86+ input_str = ctx .output .fixed_diagram .strip ()
87+
88+ # Remove ```mermaid and ``` markers
89+ if input_str .startswith ("```mermaid" ):
90+ input_str = input_str [len ("```mermaid" ) :].strip ()
91+ if input_str .endswith ("```" ):
92+ input_str = input_str [:- 3 ].strip ()
93+
94+ # Remove any remaining backticks
95+ input_str = input_str .replace ("`" , "" )
96+
97+ logfire .info (
98+ "Evaluating mermaid diagram validity" ,
99+ diagram_length = len (input_str ),
100+ diagram_preview = input_str [:100 ],
101+ )
102+
103+ # Use the MCP server's validation function
104+ result = await validate_mermaid_diagram (input_str )
105+
106+ if result .is_valid :
107+ logfire .info ("Mermaid diagram validation succeeded" )
108+ else :
109+ logfire .warning (
110+ "Mermaid diagram validation failed" , error_message = result .error_message
162111 )
163112
164- if process .returncode == 0 :
165- return True , "Syntax appears valid (compiled successfully by mmdc)."
166- else :
167- # mmdc usually prints errors to stderr.
168- error_message = process .stderr .strip ()
169- # Sometimes, syntax errors might also appear in stdout for certain mmdc versions or error types
170- if not error_message and process .stdout .strip ():
171- error_message = process .stdout .strip ()
172- return (
173- False ,
174- f"Invalid syntax or mmdc error (exit code { process .returncode } ):\n { error_message } " ,
175- )
176- except FileNotFoundError :
177- return False , (
178- f"Validation failed: '{ mmdc_path } ' command not found. "
179- "Please ensure Mermaid CLI (mmdc) is installed and in your system's PATH, "
180- "or provide the full path to the executable."
181- )
182- except Exception as e :
183- return (
184- False ,
185- f"Validation failed due to an unexpected error during mmdc execution: { e } " ,
186- )
187- finally :
188- # Clean up the temporary files
189- if os .path .exists (temp_mmd_file ):
190- os .remove (temp_mmd_file )
191- if os .path .exists (temp_output_file ):
192- os .remove (temp_output_file )
113+ return 1.0 if result .is_valid else 0.0
193114
194115
195116async def fix_mermaid_diagram (
@@ -206,9 +127,15 @@ async def fix_mermaid_diagram(
206127 """
207128 query = f"Add the current time and fix the mermaid diagram syntax using the validator: { inputs .invalid_diagram } . Return only the fixed mermaid diagram between backticks."
208129
209- result = await main (query , model = model )
130+ # Create a fresh agent for each invocation to avoid concurrent usage issues
131+ current_agent = create_agent (model )
132+ usage_limits = UsageLimits (request_limit = 5 )
210133
211- # Extract the mermaid diagram from the output
134+ # Use the agent's context manager directly in this function
135+ async with current_agent .run_mcp_servers ():
136+ result = await current_agent .run (query , usage_limits = usage_limits )
137+
138+ # Extract the mermaid diagram from the result output
212139 output = result .output
213140
214141 # Logic to extract the diagram from between backticks
@@ -232,12 +159,25 @@ def create_evaluation_dataset(judge_model: str = DEFAULT_MODEL):
232159 The evaluation dataset
233160 """
234161 return Dataset [MermaidInput , MermaidOutput , Any ](
162+ # Construct 3 tests, each asks the LLM to fix an invalid mermaid diagram of increasing difficulty
235163 cases = [
236164 Case (
237- name = "fix_invalid_diagram_1 " ,
165+ name = "fix_invalid_diagram_easy " ,
238166 inputs = MermaidInput (invalid_diagram = invalid_mermaid_diagram_easy ),
239167 expected_output = MermaidOutput (fixed_diagram = valid_mermaid_diagram ),
240- metadata = {"test_type" : "mermaid_easy_fix" , "iteration" : 1 },
168+ metadata = {"test_type" : "mermaid_easy_fix" },
169+ ),
170+ Case (
171+ name = "fix_invalid_diagram_medium" ,
172+ inputs = MermaidInput (invalid_diagram = invalid_mermaid_diagram_medium ),
173+ expected_output = MermaidOutput (fixed_diagram = valid_mermaid_diagram ),
174+ metadata = {"test_type" : "mermaid_medium_fix" },
175+ ),
176+ Case (
177+ name = "fix_invalid_diagram_hard" ,
178+ inputs = MermaidInput (invalid_diagram = invalid_mermaid_diagram_hard ),
179+ expected_output = MermaidOutput (fixed_diagram = valid_mermaid_diagram ),
180+ metadata = {"test_type" : "mermaid_hard_fix" },
241181 ),
242182 ],
243183 evaluators = [
@@ -249,9 +189,9 @@ def create_evaluation_dataset(judge_model: str = DEFAULT_MODEL):
249189 model = judge_model ,
250190 ),
251191 LLMJudge (
252- rubric = "The fixed diagram should maintain the same overall structure and intent as the expected output diagram while fixing any syntax errors."
192+ rubric = "The output diagram should maintain the same overall structure and intent as the expected output diagram while fixing any syntax errors."
253193 + "Check if nodes, connections, and labels are preserved."
254- + "The current time should be placeholder should be replace with a datetime" ,
194+ + "The current time should be placeholder should be replace with a valid datetime" ,
255195 include_input = False ,
256196 model = judge_model ,
257197 ),
@@ -276,20 +216,24 @@ async def fix_with_model(inputs: MermaidInput) -> MermaidOutput:
276216 return await fix_mermaid_diagram (inputs , model = model )
277217
278218 report = await dataset .evaluate (
279- fix_with_model , name = f"{ model } -multi-mcp-mermaid-diagram-fix-evals"
219+ fix_with_model ,
220+ name = f"{ model } -multi-mcp-mermaid-diagram-fix-evals" ,
221+ max_concurrency = 1 , # Run one evaluation at a time
280222 )
281223
282- report .print (include_input = True , include_output = True )
224+ report .print (include_input = False , include_output = False )
283225 return report
284226
285227
286228if __name__ == "__main__" :
287229 # You can use different models for the agent and the judge
288- agent_model = os .getenv ("AGENT_MODEL" , DEFAULT_MODEL )
230+ # agent_model = os.getenv("AGENT_MODEL", DEFAULT_MODEL)
231+ agent_model = "gemini-2.5-flash-preview-04-17"
232+ # agent_model = "openai:o4-mini"
233+ # agent_model = "gemini-2.5-flash-preview-04-17"
289234 judge_model = os .getenv ("JUDGE_MODEL" , DEFAULT_MODEL )
290235
291236 async def run_all ():
292- # Run evaluations
293237 await run_evaluations (model = agent_model , judge_model = judge_model )
294238
295239 asyncio .run (run_all ())
0 commit comments