Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1597,7 +1597,7 @@ public PersistentAgentsClient(string endpoint, Azure.Core.TokenCredential creden
public static partial class PersistentAgentsClientExtensions
{
public static Microsoft.Extensions.AI.AITool AsAITool(this Azure.AI.Agents.Persistent.ToolDefinition tool) { throw null; }
public static Microsoft.Extensions.AI.IChatClient AsIChatClient(this Azure.AI.Agents.Persistent.PersistentAgentsClient client, string agentId, string? defaultThreadId = null) { throw null; }
public static Microsoft.Extensions.AI.IChatClient AsIChatClient(this Azure.AI.Agents.Persistent.PersistentAgentsClient client, string agentId, string? defaultThreadId = null, bool throwOnContentErrors = true) { throw null; }
}
public static partial class PersistentAgentsExtensions
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1597,7 +1597,7 @@ public PersistentAgentsClient(string endpoint, Azure.Core.TokenCredential creden
public static partial class PersistentAgentsClientExtensions
{
public static Microsoft.Extensions.AI.AITool AsAITool(this Azure.AI.Agents.Persistent.ToolDefinition tool) { throw null; }
public static Microsoft.Extensions.AI.IChatClient AsIChatClient(this Azure.AI.Agents.Persistent.PersistentAgentsClient client, string agentId, string? defaultThreadId = null) { throw null; }
public static Microsoft.Extensions.AI.IChatClient AsIChatClient(this Azure.AI.Agents.Persistent.PersistentAgentsClient client, string agentId, string? defaultThreadId = null, bool throwOnContentErrors = true) { throw null; }
}
public static partial class PersistentAgentsExtensions
{
Expand Down
2 changes: 1 addition & 1 deletion sdk/ai/Azure.AI.Agents.Persistent/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "net",
"TagPrefix": "net/ai/Azure.AI.Agents.Persistent",
"Tag": "net/ai/Azure.AI.Agents.Persistent_84020b2662"
"Tag": "net/ai/Azure.AI.Agents.Persistent_89f0bef6e6"
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,13 @@ internal partial class PersistentAgentsChatClient : IChatClient
/// <summary>Lazily-retrieved agent instance. Used for its properties.</summary>
private PersistentAgent? _agent;

/// <summary>
/// Indicates whether to throw exceptions when content errors are encountered.
/// </summary>
private readonly bool _throwOnContentErrors;

/// <summary>Initializes a new instance of the <see cref="PersistentAgentsChatClient"/> class for the specified <see cref="PersistentAgentsClient"/>.</summary>
public PersistentAgentsChatClient(PersistentAgentsClient client, string agentId, string? defaultThreadId = null)
public PersistentAgentsChatClient(PersistentAgentsClient client, string agentId, string? defaultThreadId = null, bool throwOnContentErrors = true)
{
Argument.AssertNotNull(client, nameof(client));
Argument.AssertNotNullOrWhiteSpace(agentId, nameof(agentId));
Expand All @@ -51,6 +56,7 @@ public PersistentAgentsChatClient(PersistentAgentsClient client, string agentId,
_defaultThreadId = defaultThreadId;

_metadata = new(ProviderName);
_throwOnContentErrors = throwOnContentErrors;
}

protected PersistentAgentsChatClient() { }
Expand Down Expand Up @@ -191,6 +197,14 @@ threadRun is not null &&

switch (ru)
{
case RunUpdate rup when rup.Value.Status == RunStatus.Failed && rup.Value.LastError is { } error:
if (_throwOnContentErrors)
{
throw new InvalidOperationException(error.Message) { Data = { ["ErrorCode"] = error.Code } };
}
ruUpdate.Contents.Add(new ErrorContent(error.Message) { ErrorCode = error.Code, RawRepresentation = error });
break;

case RequiredActionUpdate rau when rau.ToolCallId is string toolCallId && rau.FunctionName is string functionName:
ruUpdate.Contents.Add(new FunctionCallContent(
JsonSerializer.Serialize([ru.Value.Id, toolCallId], AgentsChatClientJsonContext.Default.StringArray),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ public static class PersistentAgentsClientExtensions
/// <see cref="IChatClient.GetResponseAsync"/> or <see cref="IChatClient.GetStreamingResponseAsync"/> via the <see cref="ChatOptions.ConversationId"/>
/// property. If no thread ID is provided via either mechanism, a new thread will be created for the request.
/// </param>
/// <param name="throwOnContentErrors">Throws an exception if content errors are returned from the service. Default is <c>true</c>. This is useful to detect errors when tools are misconfigured that otherwise would be unnoticed because those come as a streaming data update.</param>
/// <returns>An <see cref="IChatClient"/> instance configured to interact with the specified agent and thread.</returns>
public static IChatClient AsIChatClient(this PersistentAgentsClient client, string agentId, string? defaultThreadId = null) =>
new PersistentAgentsChatClient(client, agentId, defaultThreadId);
public static IChatClient AsIChatClient(this PersistentAgentsClient client, string agentId, string? defaultThreadId = null, bool throwOnContentErrors = true) =>
new PersistentAgentsChatClient(client, agentId, defaultThreadId, throwOnContentErrors);

/// <summary>Creates an <see cref="AITool"/> to represent a <see cref="ToolDefinition"/>.</summary>
/// <param name="tool">The tool to wrap as an <see cref="AITool"/>.</param>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Core.Pipeline;
using Azure.Core.TestFramework;
using Azure.Identity;
using Microsoft.Extensions.AI;
Expand All @@ -26,6 +28,11 @@ public class PersistentAgentsChatClientTests : RecordedTestBase<AIAgentsTestEnvi
private string _agentId;
private string _threadId;

private const string FakeAgentEndpoint = "https://fake-host";
private const string FakeAgentId = "agent-id";
private const string FakeRunId = "run-id";
private const string FakeThreadId = "thread-id";

public PersistentAgentsChatClientTests(bool isAsync) : base(isAsync)
{
TestDiagnostics = false;
Expand Down Expand Up @@ -350,7 +357,37 @@ public void TestGetService()
Assert.Throws<ArgumentNullException>(() => chatClient.GetService(null));
}

[RecordedTest]
public async Task TestContentErrorHandling()
{
var mockTransport = new MockTransport((request) =>
{
return GetResponse(request, emptyRunList: false);
});

PersistentAgentsClient client = GetClient(mockTransport);

IChatClient throwingChatClient = client.AsIChatClient(FakeAgentId, FakeThreadId, throwOnContentErrors: true);
IChatClient nonThrowingChatClient = client.AsIChatClient(FakeAgentId, FakeThreadId, throwOnContentErrors: false);

var exception = Assert.ThrowsAsync<InvalidOperationException>(() => throwingChatClient.GetResponseAsync(new ChatMessage(ChatRole.User, "Get Mike's favourite word")));
Assert.IsTrue(exception.Message.Contains("wrong-connection-id"));

var response = await nonThrowingChatClient.GetResponseAsync(new ChatMessage(ChatRole.User, "Get Mike's favourite word"));
var errorContent = response.Messages.SelectMany(m => m.Contents).OfType<ErrorContent>().Single();
Assert.IsTrue(errorContent.Message.Contains("wrong-connection-id"));
}

#region Helpers

private PersistentAgentsClient GetClient(HttpPipelineTransport transport)
{
return new(FakeAgentEndpoint, new MockCredential(), options: new PersistentAgentsAdministrationClientOptions()
{
Transport = transport
});
}

private class CompositeDisposable : IDisposable
{
private readonly List<IDisposable> _disposables = [];
Expand Down Expand Up @@ -471,5 +508,81 @@ private static string GetFile([CallerFilePath] string pth = "", string fileName
var dirName = Path.GetDirectoryName(pth) ?? "";
return Path.Combine(new string[] { dirName, "TestData", fileName });
}

private static MockResponse GetResponse(MockRequest request, bool? emptyRunList = true)
{
// Sent by client.Administration.GetAgentAsync(...) method
if (request.Method == RequestMethod.Get && request.Uri.Path == $"/assistants/{FakeAgentId}")
{
return CreateOKMockResponse($$"""
{
"id": "{{FakeAgentId}}"
}
""");
}
// Sent by client.Runs.GetRunsAsync(...) method
else if (request.Method == RequestMethod.Get && request.Uri.Path == $"/threads/{FakeThreadId}/runs")
{
return CreateOKMockResponse($$"""
{
"data": {{(emptyRunList is true
? "[]"
: $$"""[{"id": "{{FakeRunId}}"}]""")}}
}
""");
}
// Sent by client.Runs.CreateRunStreamingAsync(...) method
else if (request.Method == RequestMethod.Post && request.Uri.Path == $"/threads/{FakeThreadId}/runs")
{
// Content failure response
return CreateOKMockResponse(
$$$"""
event: thread.run.failed
data: { "id":"{{{FakeRunId}}}","object":"thread.run","created_at":1764170243,"assistant_id":"asst_uYPWW0weSBNqXK3VjgRMkuim","thread_id":"thread_dmz0AZPJtnO9MnAfrzP1AtJ6","status":"failed","started_at":1764170244,"expires_at":null,"cancelled_at":null,"failed_at":1764170244,"completed_at":null,"required_action":null,"last_error":{ "code":"tool_user_error","message":"Error: invalid_tool_input; The specified connection ID 'wrong-connection-id' was not found in the project or account connections. Please verify that the connection id in tool input is correct and exists in the project or account."},"model":"gpt-4o","instructions":"Use the bing grounding tool to answer questions.Use the bing grounding tool to answer questions.","tools":[{ "type":"bing_grounding","bing_grounding":{ "search_configurations":[{ "connection_id":"wrong-connection-id","market":"en-US","set_lang":"en","count":5}]} }],"tool_resources":{ "code_interpreter":{ "file_ids":[]} },"metadata":{ },"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{ "type":"auto","last_messages":null},"incomplete_details":null,"usage":{ "prompt_tokens":0,"completion_tokens":0,"total_tokens":0,"prompt_token_details":{ "cached_tokens":0} },"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}

event: done
data: [DONE]
"""
);
}
// Sent by client.Threads.CreateThreadAsync(...) method
else if (request.Method == RequestMethod.Post && request.Uri.Path == $"/threads")
{
return CreateOKMockResponse($$"""
{
"id": "{{FakeThreadId}}"
}
""");
}
// Sent by client.Runs.CancelRunAsync(...) method
else if (request.Method == RequestMethod.Post && request.Uri.Path == $"/threads/{FakeThreadId}/runs/{FakeRunId}/cancel")
{
return CreateOKMockResponse($$"""
{
"id": "{{FakeThreadId}}"
}
""");
}
// Sent by client.Runs.SubmitToolOutputsToStreamAsync(...) method
else if (request.Method == RequestMethod.Post && request.Uri.Path == $"/threads//runs/{FakeRunId}/submit_tool_outputs")
{
return CreateOKMockResponse($$"""
{
"data":[{
"id": "{{FakeRunId}}"
}]
}
""");
}

throw new InvalidOperationException("Unexpected request");
}

private static MockResponse CreateOKMockResponse(string content)
{
var response = new MockResponse(200);
response.SetContent(content);
return response;
}
}
}