Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/ai/Azure.AI.Agents.Persistent/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "net",
"TagPrefix": "net/ai/Azure.AI.Agents.Persistent",
"Tag": "net/ai/Azure.AI.Agents.Persistent_84020b2662"
"Tag": "net/ai/Azure.AI.Agents.Persistent_89f0bef6e6"
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@ internal partial class PersistentAgentsChatClient : IChatClient
/// <summary>Lazily-retrieved agent instance. Used for its properties.</summary>
private PersistentAgent? _agent;

/// <summary>Chat tool mode indicating that no tools should be used.</summary>
private readonly bool _throwOnContentErrors;

/// <summary>Initializes a new instance of the <see cref="PersistentAgentsChatClient"/> class for the specified <see cref="PersistentAgentsClient"/>.</summary>
public PersistentAgentsChatClient(PersistentAgentsClient client, string agentId, string? defaultThreadId = null)
public PersistentAgentsChatClient(PersistentAgentsClient client, string agentId, string? defaultThreadId = null, bool throwOnContentErrors = true)
{
Argument.AssertNotNull(client, nameof(client));
Argument.AssertNotNullOrWhiteSpace(agentId, nameof(agentId));
Expand All @@ -51,6 +54,7 @@ public PersistentAgentsChatClient(PersistentAgentsClient client, string agentId,
_defaultThreadId = defaultThreadId;

_metadata = new(ProviderName);
_throwOnContentErrors = throwOnContentErrors;
}

protected PersistentAgentsChatClient() { }
Expand Down Expand Up @@ -191,6 +195,14 @@ threadRun is not null &&

switch (ru)
{
case RunUpdate rup when rup.Value.Status == RunStatus.Failed && rup.Value.LastError is { } error:
if (_throwOnContentErrors)
{
throw new InvalidOperationException(error.Message) { Data = { ["ErrorCode"] = error.Code } };
}
ruUpdate.Contents.Add(new ErrorContent(error.Message) { ErrorCode = error.Code, RawRepresentation = error });
break;

case RequiredActionUpdate rau when rau.ToolCallId is string toolCallId && rau.FunctionName is string functionName:
ruUpdate.Contents.Add(new FunctionCallContent(
JsonSerializer.Serialize([ru.Value.Id, toolCallId], AgentsChatClientJsonContext.Default.StringArray),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ public static class PersistentAgentsClientExtensions
/// <see cref="IChatClient.GetResponseAsync"/> or <see cref="IChatClient.GetStreamingResponseAsync"/> via the <see cref="ChatOptions.ConversationId"/>
/// property. If no thread ID is provided via either mechanism, a new thread will be created for the request.
/// </param>
/// <param name="throwOnContentErrors">Throws an exception if content errors are returned from the service. Default is <c>true</c>. This is useful to detect errors when tools are misconfigured that otherwise would be unnoticed because those come as a streaming data update.</param>
/// <returns>An <see cref="IChatClient"/> instance configured to interact with the specified agent and thread.</returns>
public static IChatClient AsIChatClient(this PersistentAgentsClient client, string agentId, string? defaultThreadId = null) =>
new PersistentAgentsChatClient(client, agentId, defaultThreadId);
public static IChatClient AsIChatClient(this PersistentAgentsClient client, string agentId, string? defaultThreadId = null, bool throwOnContentErrors = true) =>
new PersistentAgentsChatClient(client, agentId, defaultThreadId, throwOnContentErrors);

/// <summary>Creates an <see cref="AITool"/> to represent a <see cref="ToolDefinition"/>.</summary>
/// <param name="tool">The tool to wrap as an <see cref="AITool"/>.</param>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Core.Pipeline;
using Azure.Core.TestFramework;
using Azure.Identity;
using Microsoft.Extensions.AI;
Expand All @@ -26,6 +28,11 @@ public class PersistentAgentsChatClientTests : RecordedTestBase<AIAgentsTestEnvi
private string _agentId;
private string _threadId;

private const string FakeAgentEndpoint = "https://fake-host";
private const string FakeAgentId = "agent-id";
private const string FakeRunId = "run-id";
private const string FakeThreadId = "thread-id";

public PersistentAgentsChatClientTests(bool isAsync) : base(isAsync)
{
TestDiagnostics = false;
Expand Down Expand Up @@ -350,7 +357,37 @@ public void TestGetService()
Assert.Throws<ArgumentNullException>(() => chatClient.GetService(null));
}

[RecordedTest]
public async Task TestContentErrorHandling()
{
var mockTransport = new MockTransport((request) =>
{
return GetResponse(request, emptyRunList: false);
});

PersistentAgentsClient client = GetClient(mockTransport);

IChatClient throwingChatClient = client.AsIChatClient(FakeAgentId, FakeThreadId, throwOnContentErrors: true);
IChatClient nonThrowingChatClient = client.AsIChatClient(FakeAgentId, FakeThreadId, throwOnContentErrors: false);

var exception = Assert.ThrowsAsync<InvalidOperationException>(() => throwingChatClient.GetResponseAsync(new ChatMessage(ChatRole.User, "Get Mike's favourite word")));
Assert.IsTrue(exception.Message.Contains("wrong-connection-id"));

var response = await nonThrowingChatClient.GetResponseAsync(new ChatMessage(ChatRole.User, "Get Mike's favourite word"));
var errorContent = response.Messages.SelectMany(m => m.Contents).OfType<ErrorContent>().Single();
Assert.IsTrue(errorContent.Message.Contains("wrong-connection-id"));
}

#region Helpers

private PersistentAgentsClient GetClient(HttpPipelineTransport transport)
{
return new(FakeAgentEndpoint, new MockCredential(), options: new PersistentAgentsAdministrationClientOptions()
{
Transport = transport
});
}

private class CompositeDisposable : IDisposable
{
private readonly List<IDisposable> _disposables = [];
Expand Down Expand Up @@ -471,5 +508,81 @@ private static string GetFile([CallerFilePath] string pth = "", string fileName
var dirName = Path.GetDirectoryName(pth) ?? "";
return Path.Combine(new string[] { dirName, "TestData", fileName });
}

private static MockResponse GetResponse(MockRequest request, bool? emptyRunList = true)
{
// Sent by client.Administration.GetAgentAsync(...) method
if (request.Method == RequestMethod.Get && request.Uri.Path == $"/assistants/{FakeAgentId}")
{
return CreateOKMockResponse($$"""
{
"id": "{{FakeAgentId}}"
}
""");
}
// Sent by by client.Runs.GetRunsAsync(...) method
else if (request.Method == RequestMethod.Get && request.Uri.Path == $"/threads/{FakeThreadId}/runs")
{
return CreateOKMockResponse($$"""
{
"data": {{(emptyRunList is true
? "[]"
: $$"""[{"id": "{{FakeRunId}}"}]""")}}
}
""");
}
// Sent by client.Runs.CreateRunStreamingAsync(...) method
else if (request.Method == RequestMethod.Post && request.Uri.Path == $"/threads/{FakeThreadId}/runs")
{
// Content failure response
return CreateOKMockResponse(
$$$"""
event: thread.run.failed
data: { "id":"{{{FakeRunId}}}","object":"thread.run","created_at":1764170243,"assistant_id":"asst_uYPWW0weSBNqXK3VjgRMkuim","thread_id":"thread_dmz0AZPJtnO9MnAfrzP1AtJ6","status":"failed","started_at":1764170244,"expires_at":null,"cancelled_at":null,"failed_at":1764170244,"completed_at":null,"required_action":null,"last_error":{ "code":"tool_user_error","message":"Error: invalid_tool_input; The specified connection ID 'wrong-connection-id' was not found in the project or account connections. Please verify that the connection id in tool input is correct and exists in the project or account."},"model":"gpt-4o","instructions":"Use the bing grounding tool to answer questions.Use the bing grounding tool to answer questions.","tools":[{ "type":"bing_grounding","bing_grounding":{ "search_configurations":[{ "connection_id":"wrong-connection-id","market":"en-US","set_lang":"en","count":5}]} }],"tool_resources":{ "code_interpreter":{ "file_ids":[]} },"metadata":{ },"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{ "type":"auto","last_messages":null},"incomplete_details":null,"usage":{ "prompt_tokens":0,"completion_tokens":0,"total_tokens":0,"prompt_token_details":{ "cached_tokens":0} },"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}

event: done
data: [DONE]
"""
);
}
// Sent by client.Threads.CreateThreadAsync(...) method
else if (request.Method == RequestMethod.Post && request.Uri.Path == $"/threads")
{
return CreateOKMockResponse($$"""
{
"id": "{{FakeThreadId}}"
}
""");
}
// Sent by client.Runs.CancelRunAsync(...) method
else if (request.Method == RequestMethod.Post && request.Uri.Path == $"/threads/{FakeThreadId}/runs/{FakeRunId}/cancel")
{
return CreateOKMockResponse($$"""
{
"id": "{{FakeThreadId}}"
}
""");
}
// Sent by client.Runs.SubmitToolOutputsToStreamAsync(...) method
else if (request.Method == RequestMethod.Post && request.Uri.Path == $"/threads//runs/{FakeRunId}/submit_tool_outputs")
{
return CreateOKMockResponse($$"""
{
"data":[{
"id": "{{FakeRunId}}"
}]
}
""");
}

throw new InvalidOperationException("Unexpected request");
}

private static MockResponse CreateOKMockResponse(string content)
{
var response = new MockResponse(200);
response.SetContent(content);
return response;
}
}
}
Loading