Skip to content

Commit 4b6911f

Browse files
authored
Replace MsiCredential with CloudShellCredential (Azure#15880)
1 parent 78b7a38 commit 4b6911f

File tree

8 files changed

+276
-422
lines changed

8 files changed

+276
-422
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
import functools
6+
import os
7+
from typing import TYPE_CHECKING
8+
9+
from azure.core.pipeline.transport import HttpRequest
10+
11+
from .. import CredentialUnavailableError
12+
from .._constants import EnvironmentVariables
13+
from .._internal.managed_identity_client import ManagedIdentityClient
14+
from .._internal.get_token_mixin import GetTokenMixin
15+
16+
if TYPE_CHECKING:
17+
from typing import Any, Optional
18+
from azure.core.credentials import AccessToken
19+
20+
21+
class CloudShellCredential(GetTokenMixin):
22+
def __init__(self, **kwargs):
23+
# type: (**Any) -> None
24+
super(CloudShellCredential, self).__init__()
25+
url = os.environ.get(EnvironmentVariables.MSI_ENDPOINT)
26+
if url:
27+
self._available = True
28+
self._client = ManagedIdentityClient(
29+
request_factory=functools.partial(_get_request, url),
30+
base_headers={"Metadata": "true"},
31+
_identity_config=kwargs.pop("identity_config", None),
32+
**kwargs
33+
)
34+
else:
35+
self._available = False
36+
37+
def get_token(self, *scopes, **kwargs):
38+
# type: (*str, **Any) -> AccessToken
39+
if not self._available:
40+
raise CredentialUnavailableError(
41+
message="Cloud Shell managed identity configuration not found in environment"
42+
)
43+
return super(CloudShellCredential, self).get_token(*scopes, **kwargs)
44+
45+
def _acquire_token_silently(self, *scopes):
46+
# type: (*str) -> Optional[AccessToken]
47+
return self._client.get_cached_token(*scopes)
48+
49+
def _request_token(self, *scopes, **kwargs):
50+
# type: (*str, **Any) -> AccessToken
51+
return self._client.request_token(*scopes, **kwargs)
52+
53+
54+
def _get_request(url, scope, identity_config):
55+
# type: (str, str, dict) -> HttpRequest
56+
request = HttpRequest("POST", url, data=dict({"resource": scope}, **identity_config))
57+
return request

sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py

Lines changed: 4 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,10 @@ def __init__(self, **kwargs):
6161

6262
self._credential = AppServiceCredential(**kwargs)
6363
else:
64-
_LOGGER.info("%s will use MSI", self.__class__.__name__)
65-
self._credential = MsiCredential(**kwargs)
64+
_LOGGER.info("%s will use Cloud Shell managed identity", self.__class__.__name__)
65+
from .cloud_shell import CloudShellCredential
66+
67+
self._credential = CloudShellCredential(**kwargs)
6668
elif os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT):
6769
if (
6870
os.environ.get(EnvironmentVariables.IDENTITY_HEADER)
@@ -231,65 +233,3 @@ def _refresh_token(self, *scopes):
231233
# any other error is unexpected
232234
six.raise_from(ClientAuthenticationError(message=ex.message, response=ex.response), None)
233235
return token
234-
235-
236-
class MsiCredential(_ManagedIdentityBase):
237-
"""Authenticates via the MSI endpoint in an App Service or Cloud Shell environment.
238-
239-
:keyword str client_id: ID of a user-assigned identity. Leave unspecified to use a system-assigned identity.
240-
"""
241-
242-
def __init__(self, **kwargs):
243-
# type: (**Any) -> None
244-
self._endpoint = os.environ.get(EnvironmentVariables.MSI_ENDPOINT)
245-
if self._endpoint:
246-
super(MsiCredential, self).__init__(endpoint=self._endpoint, client_cls=AuthnClient, **kwargs)
247-
248-
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
249-
# type: (*str, **Any) -> AccessToken
250-
"""Request an access token for `scopes`.
251-
252-
This method is called automatically by Azure SDK clients.
253-
254-
:param str scopes: desired scope for the access token. This credential allows only one scope per request.
255-
:rtype: :class:`azure.core.credentials.AccessToken`
256-
:raises ~azure.identity.CredentialUnavailableError: the MSI endpoint is unavailable
257-
"""
258-
259-
if not self._endpoint:
260-
message = "ManagedIdentityCredential authentication unavailable, no managed identity endpoint found."
261-
raise CredentialUnavailableError(message=message)
262-
263-
if len(scopes) != 1:
264-
raise ValueError("This credential requires exactly one scope per token request.")
265-
266-
token = self._client.get_cached_token(scopes)
267-
if not token:
268-
token = self._refresh_token(*scopes)
269-
elif self._client.should_refresh(token):
270-
try:
271-
token = self._refresh_token(*scopes)
272-
except Exception: # pylint: disable=broad-except
273-
pass
274-
return token
275-
276-
def _refresh_token(self, *scopes):
277-
resource = scopes[0]
278-
if resource.endswith("/.default"):
279-
resource = resource[: -len("/.default")]
280-
secret = os.environ.get(EnvironmentVariables.MSI_SECRET)
281-
if secret:
282-
# MSI_ENDPOINT and MSI_SECRET set -> App Service
283-
token = self._request_app_service_token(scopes=scopes, resource=resource, secret=secret)
284-
else:
285-
# only MSI_ENDPOINT set -> legacy-style MSI (Cloud Shell)
286-
token = self._request_legacy_token(scopes=scopes, resource=resource)
287-
return token
288-
289-
def _request_app_service_token(self, scopes, resource, secret):
290-
params = dict({"api-version": "2017-09-01", "resource": resource}, **self._identity_config)
291-
return self._client.request_token(scopes, method="GET", headers={"secret": secret}, params=params)
292-
293-
def _request_legacy_token(self, scopes, resource):
294-
form_data = dict({"resource": resource}, **self._identity_config)
295-
return self._client.request_token(scopes, method="POST", form_data=form_data)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
import functools
6+
import os
7+
from typing import TYPE_CHECKING
8+
9+
from .._internal import AsyncContextManager
10+
from .._internal.get_token_mixin import GetTokenMixin
11+
from .._internal.managed_identity_client import AsyncManagedIdentityClient
12+
from ... import CredentialUnavailableError
13+
from ..._constants import EnvironmentVariables
14+
from ..._credentials.cloud_shell import _get_request
15+
16+
if TYPE_CHECKING:
17+
from typing import Any, Optional
18+
from azure.core.credentials import AccessToken
19+
20+
21+
class CloudShellCredential(AsyncContextManager, GetTokenMixin):
22+
def __init__(self, **kwargs: "Any") -> None:
23+
super(CloudShellCredential, self).__init__()
24+
url = os.environ.get(EnvironmentVariables.MSI_ENDPOINT)
25+
if url:
26+
self._available = True
27+
self._client = AsyncManagedIdentityClient(
28+
request_factory=functools.partial(_get_request, url),
29+
base_headers={"Metadata": "true"},
30+
_identity_config=kwargs.pop("identity_config", None),
31+
**kwargs,
32+
)
33+
else:
34+
self._available = False
35+
36+
async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
37+
if not self._available:
38+
raise CredentialUnavailableError(
39+
message="Cloud Shell managed identity configuration not found in environment"
40+
)
41+
return await super(CloudShellCredential, self).get_token(*scopes, **kwargs)
42+
43+
async def close(self) -> None:
44+
await self._client.close()
45+
46+
async def _acquire_token_silently(self, *scopes: str) -> "Optional[AccessToken]":
47+
return self._client.get_cached_token(*scopes)
48+
49+
async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
50+
return await self._client.request_token(*scopes, **kwargs)

sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py

Lines changed: 4 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,10 @@ def __init__(self, **kwargs: "Any") -> None:
4949

5050
self._credential = AppServiceCredential(**kwargs)
5151
else:
52-
_LOGGER.info("%s will use MSI", self.__class__.__name__)
53-
self._credential = MsiCredential(**kwargs)
52+
_LOGGER.info("%s will use Cloud Shell managed identity", self.__class__.__name__)
53+
from .cloud_shell import CloudShellCredential
54+
55+
self._credential = CloudShellCredential(**kwargs)
5456
elif os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT):
5557
if (
5658
os.environ.get(EnvironmentVariables.IDENTITY_HEADER)
@@ -192,63 +194,3 @@ async def _refresh_token(self, *scopes):
192194
# any other error is unexpected
193195
raise ClientAuthenticationError(message=ex.message, response=ex.response) from None
194196
return token
195-
196-
197-
class MsiCredential(_AsyncManagedIdentityBase):
198-
"""Authenticates via the MSI endpoint in an App Service or Cloud Shell environment.
199-
200-
:keyword str client_id: ID of a user-assigned identity. Leave unspecified to use a system-assigned identity.
201-
"""
202-
203-
def __init__(self, **kwargs: "Any") -> None:
204-
self._endpoint = os.environ.get(EnvironmentVariables.MSI_ENDPOINT)
205-
if self._endpoint:
206-
super().__init__(endpoint=self._endpoint, **kwargs)
207-
208-
async def get_token(self, *scopes: str, **kwargs: "Any") -> AccessToken: # pylint:disable=unused-argument
209-
"""Asynchronously request an access token for `scopes`.
210-
211-
This method is called automatically by Azure SDK clients.
212-
213-
:param str scopes: desired scope for the access token. This credential allows only one scope per request.
214-
:rtype: :class:`azure.core.credentials.AccessToken`
215-
:raises ~azure.identity.CredentialUnavailableError: the MSI endpoint is unavailable
216-
"""
217-
if not self._endpoint:
218-
message = "ManagedIdentityCredential authentication unavailable, no managed identity endpoint found."
219-
raise CredentialUnavailableError(message=message)
220-
221-
if len(scopes) != 1:
222-
raise ValueError("This credential requires exactly one scope per token request.")
223-
224-
token = self._client.get_cached_token(scopes)
225-
if not token:
226-
token = await self._refresh_token(*scopes)
227-
elif self._client.should_refresh(token):
228-
try:
229-
token = await self._refresh_token(*scopes)
230-
except Exception: # pylint: disable=broad-except
231-
pass
232-
return token
233-
234-
async def _refresh_token(self, *scopes):
235-
resource = scopes[0]
236-
if resource.endswith("/.default"):
237-
resource = resource[: -len("/.default")]
238-
239-
secret = os.environ.get(EnvironmentVariables.MSI_SECRET)
240-
if secret:
241-
# MSI_ENDPOINT and MSI_SECRET set -> App Service
242-
token = await self._request_app_service_token(scopes=scopes, resource=resource, secret=secret)
243-
else:
244-
# only MSI_ENDPOINT set -> legacy-style MSI (Cloud Shell)
245-
token = await self._request_legacy_token(scopes=scopes, resource=resource)
246-
return token
247-
248-
async def _request_app_service_token(self, scopes, resource, secret):
249-
params = {"api-version": "2017-09-01", "resource": resource, **self._identity_config}
250-
return await self._client.request_token(scopes, method="GET", headers={"secret": secret}, params=params)
251-
252-
async def _request_legacy_token(self, scopes, resource):
253-
form_data = {"resource": resource, **self._identity_config}
254-
return await self._client.request_token(scopes, method="POST", form_data=form_data)

0 commit comments

Comments
 (0)