Skip to content

Commit 41e0674

Browse files
authored
[ml] Enable OBO flow on AmlSpark (Azure#30148)
* wip * wip * aml spark obo token * Get resource from request if not in response * Fix linter * use isodate instead of dateutil
1 parent 57c289c commit 41e0674

File tree

4 files changed

+131
-9
lines changed

4 files changed

+131
-9
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/identity/_aio/_internal/managed_identity_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ async def request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
3333
request = self._request_factory(resource, self._identity_config) # pylint: disable=no-member
3434
request_time = int(time.time())
3535
response = await self._pipeline.run(request, retry_on_methods=[request.method], **kwargs)
36-
token = self._process_response(response, request_time)
36+
token = self._process_response(response=response, request_time=request_time, resource=resource)
3737
return token
3838

3939
def _build_pipeline(self, **kwargs: "Any") -> "AsyncPipeline":
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
5+
6+
import functools
7+
import os
8+
from typing import Optional
9+
10+
from azure.core.pipeline.transport import HttpRequest
11+
12+
from .._internal.managed_identity_base import ManagedIdentityBase
13+
from .._internal.managed_identity_client import ManagedIdentityClient
14+
15+
16+
class _AzureMLSparkOnBehalfOfCredential(ManagedIdentityBase):
17+
def get_client(self, **kwargs) -> Optional[ManagedIdentityClient]:
18+
client_args = _get_client_args(**kwargs)
19+
if client_args:
20+
return ManagedIdentityClient(**client_args)
21+
return None
22+
23+
def get_unavailable_message(self) -> str:
24+
return "AzureML Spark On Behalf of credentials not available in this environment"
25+
26+
27+
def _get_client_args(**kwargs) -> Optional[dict]:
28+
from pyspark.sql import SparkSession # cspell:disable-line # pylint: disable=import-error
29+
30+
try:
31+
spark = SparkSession.builder.getOrCreate()
32+
except Exception:
33+
raise Exception("Fail to get spark session, please check if spark environment is set up.")
34+
35+
spark_conf = spark.sparkContext.getConf()
36+
spark_conf_vars = {
37+
"AZUREML_SYNAPSE_CLUSTER_IDENTIFIER": "spark.synapse.clusteridentifier",
38+
"AZUREML_SYNAPSE_TOKEN_SERVICE_ENDPOINT": "spark.tokenServiceEndpoint",
39+
}
40+
for env_key, conf_key in spark_conf_vars.items():
41+
value = spark_conf.get(conf_key)
42+
if value:
43+
os.environ[env_key] = value
44+
45+
# Override default settings if provided via arguments
46+
if len(kwargs) > 0:
47+
env_key_from_kwargs = [
48+
"AZUREML_SYNAPSE_CLUSTER_IDENTIFIER",
49+
"AZUREML_SYNAPSE_TOKEN_SERVICE_ENDPOINT",
50+
"AZUREML_RUN_ID",
51+
"AZUREML_RUN_TOKEN_EXPIRY",
52+
]
53+
for env_key in env_key_from_kwargs:
54+
if env_key in kwargs:
55+
os.environ[env_key] = kwargs[env_key]
56+
57+
token_service_endpoint = os.environ.get("AZUREML_SYNAPSE_TOKEN_SERVICE_ENDPOINT")
58+
obo_access_token = os.environ.get("AZUREML_OBO_CANARY_TOKEN")
59+
subscription_id = os.environ.get("AZUREML_ARM_SUBSCRIPTION")
60+
resource_group = os.environ.get("AZUREML_ARM_RESOURCEGROUP")
61+
workspace_name = os.environ.get("AZUREML_ARM_WORKSPACE_NAME")
62+
63+
if not obo_access_token:
64+
return None
65+
66+
# pylint: disable=line-too-long
67+
request_url_format = "https://{}/api/v1/proxy/obotoken/v1.0/subscriptions/{}/resourceGroups/{}/providers/Microsoft.MachineLearningServices/workspaces/{}/getuseraccesstokenforrun" # cspell:disable-line
68+
# pylint: enable=line-too-long
69+
70+
url = request_url_format.format(
71+
token_service_endpoint,
72+
subscription_id,
73+
resource_group,
74+
workspace_name,
75+
)
76+
77+
return dict(
78+
kwargs,
79+
request_factory=functools.partial(_get_request, url),
80+
)
81+
82+
83+
def _get_request(url, resource) -> HttpRequest:
84+
obo_access_token = os.environ.get("AZUREML_OBO_CANARY_TOKEN")
85+
experiment_name = os.environ.get("AZUREML_ARM_PROJECT_NAME")
86+
run_id = os.environ.get("AZUREML_RUN_ID")
87+
oid = os.environ.get("OID")
88+
tid = os.environ.get("TID")
89+
obo_service_endpoint = os.environ.get("AZUREML_OBO_SERVICE_ENDPOINT")
90+
cluster_identifier = os.environ.get("AZUREML_SYNAPSE_CLUSTER_IDENTIFIER")
91+
92+
request_body = {
93+
"oboToken": obo_access_token,
94+
"oid": oid,
95+
"tid": tid,
96+
"resource": resource,
97+
"experimentName": experiment_name,
98+
"runId": run_id,
99+
}
100+
headers = {
101+
"Content-Type": "application/json;charset=utf-8",
102+
"x-ms-proxy-host": obo_service_endpoint,
103+
"obo-access-token": obo_access_token,
104+
"x-ms-cluster-identifier": cluster_identifier,
105+
}
106+
request = HttpRequest(method="POST", url=url, headers=headers)
107+
request.set_json_body(request_body)
108+
return request

sdk/ml/azure-ai-ml/azure/ai/ml/identity/_credentials/aml_on_behalf_of.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from azure.core.pipeline.transport import HttpRequest
99

10+
from ._AzureMLSparkOnBehalfOfCredential import _AzureMLSparkOnBehalfOfCredential
1011
from .._internal.managed_identity_base import ManagedIdentityBase
1112
from .._internal.managed_identity_client import ManagedIdentityClient
1213

@@ -22,7 +23,11 @@ class AzureMLOnBehalfOfCredential(object):
2223
# pylint: enable=line-too-long
2324

2425
def __init__(self, **kwargs):
25-
self._credential = _AzureMLOnBehalfOfCredential(**kwargs)
26+
provider_type = os.environ.get("AZUREML_DATAPREP_TOKEN_PROVIDER")
27+
if provider_type == "sparkobo": # cspell:disable-line
28+
self._credential = _AzureMLSparkOnBehalfOfCredential(**kwargs)
29+
else:
30+
self._credential = _AzureMLOnBehalfOfCredential(**kwargs)
2631

2732
def get_token(self, *scopes, **kwargs):
2833
"""Request an access token for `scopes`.

sdk/ml/azure-ai-ml/azure/ai/ml/identity/_internal/managed_identity_client.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import time
77
from typing import TYPE_CHECKING
88

9+
import isodate
910
import six
1011
from msal import TokenCache
1112

@@ -41,8 +42,8 @@ def __init__(self, request_factory, **kwargs):
4142
self._pipeline = self._build_pipeline(**kwargs)
4243
self._request_factory = request_factory
4344

44-
def _process_response(self, response, request_time):
45-
# type: (PipelineResponse, int) -> AccessToken
45+
def _process_response(self, response, request_time, resource):
46+
# type: (PipelineResponse, int, str) -> AccessToken
4647

4748
content = response.context.get(ContentDecodePolicy.CONTEXT_NAME)
4849
if not content:
@@ -63,9 +64,13 @@ def _process_response(self, response, request_time):
6364
if not content:
6465
raise ClientAuthenticationError(message="No token received.", response=response.http_response)
6566

66-
if "access_token" not in content or not ("expires_in" in content or "expires_on" in content):
67+
if not ("access_token" in content or "token" in content) or not (
68+
"expires_in" in content or "expires_on" in content or "expiresOn" in content
69+
):
6770
if content and "access_token" in content:
6871
content["access_token"] = "****"
72+
if content and "token" in content:
73+
content["token"] = "****"
6974
raise ClientAuthenticationError(
7075
message='Unexpected response "{}"'.format(content),
7176
response=response.http_response,
@@ -74,14 +79,18 @@ def _process_response(self, response, request_time):
7479
if self._content_callback:
7580
self._content_callback(content)
7681

77-
expires_on = int(content.get("expires_on") or int(content["expires_in"]) + request_time)
82+
if "expires_in" in content or "expires_on" in content:
83+
expires_on = int(content.get("expires_on") or int(content["expires_in"]) + request_time)
84+
else:
85+
expires_on = int(isodate.parse_datetime(content["expiresOn"]).timestamp())
7886
content["expires_on"] = expires_on
7987

80-
token = AccessToken(content["access_token"], content["expires_on"])
88+
access_token = content.get("access_token") or content["token"]
89+
token = AccessToken(access_token, content["expires_on"])
8190

8291
# caching is the final step because TokenCache.add mutates its "event"
8392
self._cache.add(
84-
event={"response": content, "scope": [content["resource"]]},
93+
event={"response": content, "scope": [content.get("resource") or resource]},
8594
now=request_time,
8695
)
8796

@@ -124,7 +133,7 @@ def request_token(self, *scopes, **kwargs):
124133
request = self._request_factory(resource)
125134
request_time = int(time.time())
126135
response = self._pipeline.run(request, retry_on_methods=[request.method], **kwargs)
127-
token = self._process_response(response, request_time)
136+
token = self._process_response(response=response, request_time=request_time, resource=resource)
128137
return token
129138

130139
def _build_pipeline(self, **kwargs):

0 commit comments

Comments
 (0)