Skip to content

Commit 8b1c366

Browse files
authored
ManagedIdentityClient handles unexpected content-type (Azure#18137)
1 parent e85087d commit 8b1c366

File tree

3 files changed

+141
-8
lines changed

3 files changed

+141
-8
lines changed

sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
from typing import TYPE_CHECKING
88

99
from msal import TokenCache
10+
import six
1011

1112
from azure.core.configuration import Configuration
1213
from azure.core.credentials import AccessToken
13-
from azure.core.exceptions import ClientAuthenticationError
14+
from azure.core.exceptions import ClientAuthenticationError, DecodeError
1415
from azure.core.pipeline import Pipeline
1516
from azure.core.pipeline.policies import (
1617
ContentDecodePolicy,
@@ -58,10 +59,19 @@ def __init__(self, request_factory, client_id=None, **kwargs):
5859
def _process_response(self, response, request_time):
5960
# type: (PipelineResponse, int) -> AccessToken
6061

61-
# ContentDecodePolicy sets this, and should have raised if it couldn't deserialize the response
62-
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response) # type: dict
63-
if not content:
64-
raise ClientAuthenticationError(message="No token received.", response=response.http_response)
62+
try:
63+
content = ContentDecodePolicy.deserialize_from_text(
64+
response.http_response.text(), mime_type="application/json"
65+
)
66+
if not content:
67+
raise ClientAuthenticationError(message="No token received.", response=response.http_response)
68+
except DecodeError as ex:
69+
if response.http_response.content_type.startswith("application/json"):
70+
message = "Failed to deserialize JSON from response"
71+
else:
72+
message = 'Unexpected content type "{}"'.format(response.http_response.content_type)
73+
six.raise_from(ClientAuthenticationError(message=message, response=response.http_response), ex)
74+
6575
if "access_token" not in content or not ("expires_in" in content or "expires_on" in content):
6676
if content and "access_token" in content:
6777
content["access_token"] = "****"
@@ -79,7 +89,8 @@ def _process_response(self, response, request_time):
7989

8090
# caching is the final step because TokenCache.add mutates its "event"
8191
self._cache.add(
82-
event={"response": content, "scope": [content["resource"]]}, now=request_time,
92+
event={"response": content, "scope": [content["resource"]]},
93+
now=request_time,
8394
)
8495

8596
return token

sdk/identity/azure-identity/tests/test_managed_identity_client.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
# Copyright (c) Microsoft Corporation.
33
# Licensed under the MIT License.
44
# ------------------------------------
5+
import json
56
import time
67

8+
from azure.core.exceptions import ClientAuthenticationError
79
from azure.core.pipeline.transport import HttpRequest
810
from azure.identity._internal.managed_identity_client import ManagedIdentityClient
11+
import pytest
912

1013
from helpers import mock, mock_response, Request, validating_transport
1114

@@ -44,3 +47,61 @@ def test_caching():
4447
token = client.get_cached_token(scope)
4548
assert token.expires_on == expected_expires_on
4649
assert token.token == expected_token
50+
51+
52+
def test_deserializes_json_from_text():
53+
"""The client should gracefully handle a response with a JSON body and content-type text/plain"""
54+
55+
scope = "scope"
56+
now = int(time.time())
57+
expected_expires_on = now + 3600
58+
expected_token = "*"
59+
60+
def send(request, **_):
61+
body = json.dumps(
62+
{
63+
"access_token": expected_token,
64+
"expires_in": 3600,
65+
"expires_on": expected_expires_on,
66+
"resource": scope,
67+
"token_type": "Bearer",
68+
}
69+
)
70+
return mock.Mock(
71+
status_code=200,
72+
headers={"Content-Type": "text/plain"},
73+
content_type="text/plain",
74+
text=lambda encoding=None: body,
75+
)
76+
77+
client = ManagedIdentityClient(
78+
request_factory=lambda _, __: HttpRequest("GET", "http://localhost"), transport=mock.Mock(send=send)
79+
)
80+
81+
token = client.request_token(scope)
82+
assert token.expires_on == expected_expires_on
83+
assert token.token == expected_token
84+
85+
86+
@pytest.mark.parametrize("content_type", ("text/html","application/json"))
87+
def test_unexpected_content(content_type):
88+
content = "<html><body>not JSON</body></html>"
89+
90+
def send(request, **_):
91+
return mock.Mock(
92+
status_code=200,
93+
headers={"Content-Type": content_type},
94+
content_type=content_type,
95+
text=lambda encoding=None: content,
96+
)
97+
98+
client = ManagedIdentityClient(
99+
request_factory=lambda _, __: HttpRequest("GET", "http://localhost"), transport=mock.Mock(send=send)
100+
)
101+
102+
with pytest.raises(ClientAuthenticationError) as ex:
103+
client.request_token("scope")
104+
assert ex.value.response.text() == content
105+
106+
if "json" not in content_type:
107+
assert content_type in ex.value.message

sdk/identity/azure-identity/tests/test_managed_identity_client_async.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,21 @@
22
# Copyright (c) Microsoft Corporation.
33
# Licensed under the MIT License.
44
# ------------------------------------
5+
import json
56
import time
6-
from unittest.mock import patch
7+
from unittest.mock import Mock, patch
78

9+
from azure.core.exceptions import ClientAuthenticationError
810
from azure.core.pipeline.transport import HttpRequest
911
from azure.identity.aio._internal.managed_identity_client import AsyncManagedIdentityClient
1012
import pytest
1113

1214
from helpers import mock_response, Request
1315
from helpers_async import async_validating_transport
1416

17+
pytestmark = pytest.mark.asyncio
18+
1519

16-
@pytest.mark.asyncio
1720
async def test_caching():
1821
scope = "scope"
1922
now = int(time.time())
@@ -48,3 +51,61 @@ async def test_caching():
4851
token = client.get_cached_token(scope)
4952
assert token.expires_on == expected_expires_on
5053
assert token.token == expected_token
54+
55+
56+
async def test_deserializes_json_from_text():
57+
"""The client should gracefully handle a response with a JSON body and content-type text/plain"""
58+
59+
scope = "scope"
60+
now = int(time.time())
61+
expected_expires_on = now + 3600
62+
expected_token = "*"
63+
64+
async def send(request, **_):
65+
body = json.dumps(
66+
{
67+
"access_token": expected_token,
68+
"expires_in": 3600,
69+
"expires_on": expected_expires_on,
70+
"resource": scope,
71+
"token_type": "Bearer",
72+
}
73+
)
74+
return Mock(
75+
status_code=200,
76+
headers={"Content-Type": "text/plain"},
77+
content_type="text/plain",
78+
text=lambda encoding=None: body,
79+
)
80+
81+
client = AsyncManagedIdentityClient(
82+
request_factory=lambda _, __: HttpRequest("GET", "http://localhost"), transport=Mock(send=send)
83+
)
84+
85+
token = await client.request_token(scope)
86+
assert token.expires_on == expected_expires_on
87+
assert token.token == expected_token
88+
89+
90+
@pytest.mark.parametrize("content_type", ("text/html", "application/json"))
91+
async def test_unexpected_content(content_type):
92+
content = "<html><body>not JSON</body></html>"
93+
94+
async def send(request, **_):
95+
return Mock(
96+
status_code=200,
97+
headers={"Content-Type": content_type},
98+
content_type=content_type,
99+
text=lambda encoding=None: content,
100+
)
101+
102+
client = AsyncManagedIdentityClient(
103+
request_factory=lambda _, __: HttpRequest("GET", "http://localhost"), transport=Mock(send=send)
104+
)
105+
106+
with pytest.raises(ClientAuthenticationError) as ex:
107+
await client.request_token("scope")
108+
assert ex.value.response.text() == content
109+
110+
if "json" not in content_type:
111+
assert content_type in ex.value.message

0 commit comments

Comments
 (0)