22# Copyright (c) Microsoft Corporation. All rights reserved.
33# ---------------------------------------------------------
44
5- from typing import Any , Dict , Optional
5+ from typing import Any , Dict , Optional , List
66
77from azure .ai .ml ._restclient .v2022_12_01_preview .models import (
88 ManagedNetworkSettings as RestManagedNetwork ,
1919
2020@experimental
2121class OutboundRule :
22+ """Base class for Outbound Rules, should not be instantiated directly.
23+
24+ :param rule_name: Name of the outbound rule.
25+ :type rule_name: str
26+ :param type: Type of the outbound rule. Supported types are "FQDN", "PrivateEndpoint", "ServiceTag"
27+ :type type: str
28+ """
29+
2230 def __init__ (
23- self , type : str = None , category : str = OutboundRuleCategory .USER_DEFINED # pylint: disable=redefined-builtin
31+ self ,
32+ * ,
33+ rule_name : str = None ,
34+ ** kwargs ,
2435 ) -> None :
25- self .type = type
26- self .category = category
36+ self .rule_name = rule_name
37+ self .type = kwargs .pop ("type" , None )
38+ self .category = kwargs .pop ("category" , OutboundRuleCategory .USER_DEFINED )
2739
2840 @classmethod
29- def _from_rest_object (cls , rest_obj : Any ) -> "OutboundRule" :
41+ def _from_rest_object (cls , rest_obj : Any , rule_name : str ) -> "OutboundRule" :
3042 if isinstance (rest_obj , RestFqdnOutboundRule ):
31- rule = FqdnDestination (destination = rest_obj .destination )
43+ rule = FqdnDestination (destination = rest_obj .destination , rule_name = rule_name )
3244 rule .category = rest_obj .category
3345 return rule
3446 if isinstance (rest_obj , RestPrivateEndpointOutboundRule ):
3547 rule = PrivateEndpointDestination (
3648 service_resource_id = rest_obj .destination .service_resource_id ,
3749 subresource_target = rest_obj .destination .subresource_target ,
3850 spark_enabled = rest_obj .destination .spark_enabled ,
51+ rule_name = rule_name ,
3952 )
4053 rule .category = rest_obj .category
4154 return rule
@@ -44,37 +57,44 @@ def _from_rest_object(cls, rest_obj: Any) -> "OutboundRule":
4457 service_tag = rest_obj .destination .service_tag ,
4558 protocol = rest_obj .destination .protocol ,
4659 port_ranges = rest_obj .destination .port_ranges ,
60+ rule_name = rule_name ,
4761 )
4862 rule .category = rest_obj .category
4963 return rule
5064
5165
5266@experimental
5367class FqdnDestination (OutboundRule ):
54- def __init__ (self , destination : str , category : str = OutboundRuleCategory . USER_DEFINED ) -> None :
68+ def __init__ (self , * , rule_name : str , destination : str , ** kwargs ) -> None :
5569 self .destination = destination
56- OutboundRule .__init__ (self , type = OutboundRuleType .FQDN , category = category )
70+ category = kwargs .pop ("category" , OutboundRuleCategory .USER_DEFINED )
71+ OutboundRule .__init__ (self , type = OutboundRuleType .FQDN , category = category , rule_name = rule_name )
5772
5873 def _to_rest_object (self ) -> RestFqdnOutboundRule :
5974 return RestFqdnOutboundRule (type = self .type , category = self .category , destination = self .destination )
6075
6176 def _to_dict (self ) -> Dict :
62- return {"type" : OutboundRuleType .FQDN , "category" : self .category , "destination" : self .destination }
77+ return {
78+ self .rule_name : {"type" : OutboundRuleType .FQDN , "category" : self .category , "destination" : self .destination }
79+ }
6380
6481
6582@experimental
6683class PrivateEndpointDestination (OutboundRule ):
6784 def __init__ (
6885 self ,
86+ * ,
87+ rule_name : str ,
6988 service_resource_id : str ,
7089 subresource_target : str ,
7190 spark_enabled : bool = False ,
72- category : str = OutboundRuleCategory . USER_DEFINED ,
91+ ** kwargs ,
7392 ) -> None :
7493 self .service_resource_id = service_resource_id
7594 self .subresource_target = subresource_target
7695 self .spark_enabled = spark_enabled
77- OutboundRule .__init__ (self , OutboundRuleType .PRIVATE_ENDPOINT , category = category )
96+ category = kwargs .pop ("category" , OutboundRuleCategory .USER_DEFINED )
97+ OutboundRule .__init__ (self , type = OutboundRuleType .PRIVATE_ENDPOINT , category = category , rule_name = rule_name )
7898
7999 def _to_rest_object (self ) -> RestPrivateEndpointOutboundRule :
80100 return RestPrivateEndpointOutboundRule (
@@ -89,25 +109,34 @@ def _to_rest_object(self) -> RestPrivateEndpointOutboundRule:
89109
90110 def _to_dict (self ) -> Dict :
91111 return {
92- "type" : OutboundRuleType .PRIVATE_ENDPOINT ,
93- "category" : self .category ,
94- "destination" : {
95- "service_resource_id" : self .service_resource_id ,
96- "subresource_target" : self .subresource_target ,
97- "spark_enabled" : self .spark_enabled ,
98- },
112+ self .rule_name : {
113+ "type" : OutboundRuleType .PRIVATE_ENDPOINT ,
114+ "category" : self .category ,
115+ "destination" : {
116+ "service_resource_id" : self .service_resource_id ,
117+ "subresource_target" : self .subresource_target ,
118+ "spark_enabled" : self .spark_enabled ,
119+ },
120+ }
99121 }
100122
101123
102124@experimental
103125class ServiceTagDestination (OutboundRule ):
104126 def __init__ (
105- self , service_tag : str , protocol : str , port_ranges : str , category : str = OutboundRuleCategory .USER_DEFINED
127+ self ,
128+ * ,
129+ rule_name : str ,
130+ service_tag : str ,
131+ protocol : str ,
132+ port_ranges : str ,
133+ ** kwargs ,
106134 ) -> None :
107135 self .service_tag = service_tag
108136 self .protocol = protocol
109137 self .port_ranges = port_ranges
110- OutboundRule .__init__ (self , OutboundRuleType .SERVICE_TAG , category = category )
138+ category = kwargs .pop ("category" , OutboundRuleCategory .USER_DEFINED )
139+ OutboundRule .__init__ (self , type = OutboundRuleType .SERVICE_TAG , category = category , rule_name = rule_name )
111140
112141 def _to_rest_object (self ) -> RestServiceTagOutboundRule :
113142 return RestServiceTagOutboundRule (
@@ -120,13 +149,15 @@ def _to_rest_object(self) -> RestServiceTagOutboundRule:
120149
121150 def _to_dict (self ) -> Dict :
122151 return {
123- "type" : OutboundRuleType .SERVICE_TAG ,
124- "category" : self .category ,
125- "destination" : {
126- "service_tag" : self .service_tag ,
127- "protocol" : self .protocol ,
128- "port_ranges" : self .port_ranges ,
129- },
152+ self .rule_name : {
153+ "type" : OutboundRuleType .SERVICE_TAG ,
154+ "category" : self .category ,
155+ "destination" : {
156+ "service_tag" : self .service_tag ,
157+ "protocol" : self .protocol ,
158+ "port_ranges" : self .port_ranges ,
159+ },
160+ }
130161 }
131162
132163
@@ -135,7 +166,7 @@ class ManagedNetwork:
135166 def __init__ (
136167 self ,
137168 isolation_mode : str = IsolationMode .DISABLED ,
138- outbound_rules : Optional [Dict [ str , OutboundRule ]] = None ,
169+ outbound_rules : Optional [List [ OutboundRule ]] = None ,
139170 network_id : Optional [str ] = None ,
140171 ) -> None :
141172 self .isolation_mode = isolation_mode
@@ -145,8 +176,8 @@ def __init__(
145176 def _to_rest_object (self ) -> RestManagedNetwork :
146177 rest_outbound_rules = (
147178 {
148- rule_name : self . outbound_rules [ rule_name ] ._to_rest_object () # pylint: disable=protected-access
149- for rule_name in self .outbound_rules
179+ outbound_rule . rule_name : outbound_rule ._to_rest_object () # pylint: disable=protected-access
180+ for outbound_rule in self .outbound_rules
150181 }
151182 if self .outbound_rules
152183 else None
@@ -156,12 +187,12 @@ def _to_rest_object(self) -> RestManagedNetwork:
156187 @classmethod
157188 def _from_rest_object (cls , obj : RestManagedNetwork ) -> "ManagedNetwork" :
158189 from_rest_outbound_rules = (
159- {
160- rule_name : OutboundRule ._from_rest_object ( # pylint: disable=protected-access
161- obj .outbound_rules [rule_name ]
190+ [
191+ OutboundRule ._from_rest_object ( # pylint: disable=protected-access
192+ obj .outbound_rules [rule_name ], rule_name = rule_name
162193 )
163194 for rule_name in obj .outbound_rules
164- }
195+ ]
165196 if obj .outbound_rules
166197 else {}
167198 )
0 commit comments