Skip to content

Commit 2e1688b

Browse files
add sagemaker-hyperpod compute type to resolve its pods via VPC ENI (#3886)
* add sagemaker-hyperpod compute type to resolve its pods via VPC ENI * consolidate fargate/hyperpod pod flags in resolveViaCascadedLookup into isNonEc2Pod flag * introduce PodsByComputeType struct
1 parent 75b5793 commit 2e1688b

File tree

2 files changed

+224
-18
lines changed

2 files changed

+224
-18
lines changed

pkg/networking/pod_eni_info_resolver.go

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ const (
2626
// EC2:DescribeNetworkInterface supports up to 200 filters per call.
2727
describeNetworkInterfacesFiltersLimit = 200
2828

29-
labelEKSComputeType = "eks.amazonaws.com/compute-type"
29+
labelEKSComputeType = "eks.amazonaws.com/compute-type"
30+
labelSageMakerComputeType = "sagemaker.amazonaws.com/compute-type"
3031
)
3132

3233
// PodENIInfoResolver is responsible for resolve the AWS VPC ENI that supports pod network.
@@ -141,20 +142,20 @@ func (r *defaultPodENIInfoResolver) saveENIInfosToCache(pods []k8s.PodInfo, eniI
141142
}
142143

143144
func (r *defaultPodENIInfoResolver) resolvePodsViaCascadedLookup(ctx context.Context, pods []k8s.PodInfo) (map[types.NamespacedName]ENIInfo, error) {
144-
podsOnEc2, podsOnFargate, err := r.classifyPodsByComputeType(ctx, pods)
145+
podsByComputeType, err := r.classifyPodsByComputeType(ctx, pods)
145146
if err != nil {
146147
return nil, err
147148
}
148149
eniInfoByPodKey := make(map[types.NamespacedName]ENIInfo)
149-
if len(podsOnEc2) > 0 {
150-
eniInfoByPodKeyEc2, err := r.resolveViaCascadedLookup(ctx, podsOnEc2, false)
150+
if len(podsByComputeType.ec2Pods) > 0 {
151+
eniInfoByPodKeyEc2, err := r.resolveViaCascadedLookup(ctx, podsByComputeType.ec2Pods, false)
151152
if err != nil {
152153
return nil, err
153154
}
154155
eniInfoByPodKey = eniInfoByPodKeyEc2
155156
}
156-
if len(podsOnFargate) > 0 {
157-
eniInfoByPodKeyFargate, err := r.resolveViaCascadedLookup(ctx, podsOnFargate, true)
157+
if len(podsByComputeType.fargatePods) > 0 {
158+
eniInfoByPodKeyFargate, err := r.resolveViaCascadedLookup(ctx, podsByComputeType.fargatePods, true)
158159
if err != nil {
159160
return nil, err
160161
}
@@ -164,17 +165,28 @@ func (r *defaultPodENIInfoResolver) resolvePodsViaCascadedLookup(ctx context.Con
164165
}
165166
}
166167
}
168+
if len(podsByComputeType.sageMakerHyperPodPods) > 0 {
169+
eniInfoByPodKeySageMakerHyperPod, err := r.resolveViaCascadedLookup(ctx, podsByComputeType.sageMakerHyperPodPods, true)
170+
if err != nil {
171+
return nil, err
172+
}
173+
if len(eniInfoByPodKeySageMakerHyperPod) > 0 {
174+
for podKey, eniInfo := range eniInfoByPodKeySageMakerHyperPod {
175+
eniInfoByPodKey[podKey] = eniInfo
176+
}
177+
}
178+
}
167179
return eniInfoByPodKey, nil
168180
}
169181

170-
func (r *defaultPodENIInfoResolver) resolveViaCascadedLookup(ctx context.Context, pods []k8s.PodInfo, isFargateNode bool) (map[types.NamespacedName]ENIInfo, error) {
182+
func (r *defaultPodENIInfoResolver) resolveViaCascadedLookup(ctx context.Context, pods []k8s.PodInfo, isNonEc2Pod bool) (map[types.NamespacedName]ENIInfo, error) {
171183
eniInfoByPodKey := make(map[types.NamespacedName]ENIInfo)
172184
resolveFuncs := []func(ctx context.Context, pods []k8s.PodInfo) (map[types.NamespacedName]ENIInfo, error){
173185
r.resolveViaPodENIAnnotation,
174186
r.resolveViaNodeENIs,
175187
// TODO, add support for kubenet CNI plugin(kops) by resolve via routeTable.
176188
}
177-
if isFargateNode {
189+
if isNonEc2Pod {
178190
resolveFuncs = []func(ctx context.Context, pods []k8s.PodInfo) (map[types.NamespacedName]ENIInfo, error){
179191
r.resolveViaVPCENIs,
180192
}
@@ -281,6 +293,7 @@ func (r *defaultPodENIInfoResolver) resolveViaNodeENIs(ctx context.Context, pods
281293

282294
// resolveViaVPCENIs tries to resolve pod ENI by matching podIP against ENIs in vpc.
283295
// with EKS fargate pods, podIP is supported by an ENI in vpc.
296+
// with SageMaker HyperPod pods, podIP is supported by the visible cross-account ENI in customer vpc.
284297
func (r *defaultPodENIInfoResolver) resolveViaVPCENIs(ctx context.Context, pods []k8s.PodInfo) (map[types.NamespacedName]ENIInfo, error) {
285298
podKeysByIP := make(map[string][]types.NamespacedName, len(pods))
286299
for _, pod := range pods {
@@ -388,33 +401,45 @@ func (r *defaultPodENIInfoResolver) isPodSupportedByNodeENI(pod k8s.PodInfo, nod
388401
return false
389402
}
390403

391-
// classifyPodsByComputeType classifies in to ec2 and fargate groups
392-
func (r *defaultPodENIInfoResolver) classifyPodsByComputeType(ctx context.Context, pods []k8s.PodInfo) ([]k8s.PodInfo, []k8s.PodInfo, error) {
393-
podsOnFargate := make([]k8s.PodInfo, 0, len(pods))
394-
podsOnEc2 := make([]k8s.PodInfo, 0, len(pods))
404+
// PodsByComputeType groups pods based on their compute type (EC2, Fargate, SageMaker HyperPod)
405+
type PodsByComputeType struct {
406+
ec2Pods []k8s.PodInfo
407+
fargatePods []k8s.PodInfo
408+
sageMakerHyperPodPods []k8s.PodInfo
409+
}
410+
411+
// classifyPodsByComputeType classifies in to ec2, fargate and sagemaker-hyperpod groups
412+
func (r *defaultPodENIInfoResolver) classifyPodsByComputeType(ctx context.Context, pods []k8s.PodInfo) (PodsByComputeType, error) {
413+
var podsByComputeType PodsByComputeType
395414
nodeNameByComputeType := make(map[string]string)
396415
for _, pod := range pods {
397416
if _, exists := nodeNameByComputeType[pod.NodeName]; exists {
398417
if nodeNameByComputeType[pod.NodeName] == "fargate" {
399-
podsOnFargate = append(podsOnFargate, pod)
418+
podsByComputeType.fargatePods = append(podsByComputeType.fargatePods, pod)
419+
} else if nodeNameByComputeType[pod.NodeName] == "sagemaker-hyperpod" {
420+
podsByComputeType.sageMakerHyperPodPods = append(podsByComputeType.sageMakerHyperPodPods, pod)
400421
} else {
401-
podsOnEc2 = append(podsOnEc2, pod)
422+
podsByComputeType.ec2Pods = append(podsByComputeType.ec2Pods, pod)
402423
}
403424
}
425+
404426
nodeKey := types.NamespacedName{Name: pod.NodeName}
405427
node := &corev1.Node{}
406428
if err := r.k8sClient.Get(ctx, nodeKey, node); err != nil {
407-
return nil, nil, err
429+
return PodsByComputeType{}, err
408430
}
409431
if node.Labels[labelEKSComputeType] == "fargate" {
410-
podsOnFargate = append(podsOnFargate, pod)
432+
podsByComputeType.fargatePods = append(podsByComputeType.fargatePods, pod)
411433
nodeNameByComputeType[pod.NodeName] = "fargate"
434+
} else if node.Labels[labelSageMakerComputeType] == "hyperpod" {
435+
podsByComputeType.sageMakerHyperPodPods = append(podsByComputeType.sageMakerHyperPodPods, pod)
436+
nodeNameByComputeType[pod.NodeName] = "sagemaker-hyperpod"
412437
} else {
413-
podsOnEc2 = append(podsOnEc2, pod)
438+
podsByComputeType.ec2Pods = append(podsByComputeType.ec2Pods, pod)
414439
nodeNameByComputeType[pod.NodeName] = "ec2"
415440
}
416441
}
417-
return podsOnEc2, podsOnFargate, nil
442+
return podsByComputeType, nil
418443
}
419444

420445
// computePodENIInfoCacheKey computes the cacheKey for pod's ENIInfo cache.

pkg/networking/pod_eni_info_resolver_test.go

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -999,6 +999,187 @@ func Test_defaultPodENIInfoResolver_resolveViaCascadedLookup_Fargate(t *testing.
999999
}
10001000
}
10011001

1002+
func Test_defaultPodENIInfoResolver_resolveViaCascadedLookup_SageMakerHyperPod(t *testing.T) {
1003+
hyperPodNodeA := &corev1.Node{
1004+
ObjectMeta: metav1.ObjectMeta{
1005+
Name: "hyperpod-i-04442beca624ba65b",
1006+
Labels: map[string]string{
1007+
"sagemaker.amazonaws.com/compute-type": "hyperpod",
1008+
},
1009+
},
1010+
Spec: corev1.NodeSpec{
1011+
ProviderID: "aws:///usw2-az2/sagemaker/cluster/hyperpod-xxxxxxxxxxxx-i-04442beca624ba65b",
1012+
},
1013+
}
1014+
hyperPodNodeB := &corev1.Node{
1015+
ObjectMeta: metav1.ObjectMeta{
1016+
Name: "hyperpod-i-04159267183583d03",
1017+
Labels: map[string]string{
1018+
"sagemaker.amazonaws.com/compute-type": "hyperpod",
1019+
},
1020+
},
1021+
Spec: corev1.NodeSpec{
1022+
ProviderID: "aws:///usw2-az2/sagemaker/cluster/hyperpod-xxxxxxxxxxxx-i-04159267183583d03",
1023+
},
1024+
}
1025+
type describeNetworkInterfacesAsListCall struct {
1026+
req *ec2sdk.DescribeNetworkInterfacesInput
1027+
resp []ec2types.NetworkInterface
1028+
err error
1029+
}
1030+
type fetchNodeInstancesCall struct {
1031+
nodes []*corev1.Node
1032+
nodeInstanceByNodeKey map[types.NamespacedName]*ec2types.Instance
1033+
err error
1034+
}
1035+
type env struct {
1036+
nodes []*corev1.Node
1037+
}
1038+
type fields struct {
1039+
describeNetworkInterfacesAsListCalls []describeNetworkInterfacesAsListCall
1040+
fetchNodeInstancesCalls []fetchNodeInstancesCall
1041+
}
1042+
type args struct {
1043+
pods []k8s.PodInfo
1044+
}
1045+
tests := []struct {
1046+
name string
1047+
env env
1048+
fields fields
1049+
args args
1050+
want map[types.NamespacedName]ENIInfo
1051+
wantErr error
1052+
}{
1053+
{
1054+
name: "all pod's ENI resolved via VPC's ENIs",
1055+
env: env{
1056+
nodes: []*corev1.Node{hyperPodNodeA, hyperPodNodeB},
1057+
},
1058+
fields: fields{
1059+
describeNetworkInterfacesAsListCalls: []describeNetworkInterfacesAsListCall{
1060+
{
1061+
req: &ec2sdk.DescribeNetworkInterfacesInput{
1062+
Filters: []ec2types.Filter{
1063+
{
1064+
Name: awssdk.String("vpc-id"),
1065+
Values: []string{"vpc-0d6d9ee10bd062dcc"},
1066+
},
1067+
{
1068+
Name: awssdk.String("addresses.private-ip-address"),
1069+
Values: []string{"192.168.128.151", "192.168.128.152"},
1070+
},
1071+
},
1072+
},
1073+
resp: []ec2types.NetworkInterface{
1074+
{
1075+
NetworkInterfaceId: awssdk.String("eni-c"),
1076+
PrivateIpAddresses: []ec2types.NetworkInterfacePrivateIpAddress{
1077+
{
1078+
PrivateIpAddress: awssdk.String("192.168.128.150"),
1079+
},
1080+
{
1081+
PrivateIpAddress: awssdk.String("192.168.128.151"),
1082+
},
1083+
},
1084+
Groups: []ec2types.GroupIdentifier{
1085+
{
1086+
GroupId: awssdk.String("sg-c-1"),
1087+
},
1088+
},
1089+
},
1090+
{
1091+
NetworkInterfaceId: awssdk.String("eni-d"),
1092+
PrivateIpAddresses: []ec2types.NetworkInterfacePrivateIpAddress{
1093+
{
1094+
PrivateIpAddress: awssdk.String("192.168.128.152"),
1095+
},
1096+
{
1097+
PrivateIpAddress: awssdk.String("192.168.128.153"),
1098+
},
1099+
},
1100+
Groups: []ec2types.GroupIdentifier{
1101+
{
1102+
GroupId: awssdk.String("sg-d-1"),
1103+
},
1104+
},
1105+
},
1106+
},
1107+
},
1108+
},
1109+
},
1110+
args: args{
1111+
pods: []k8s.PodInfo{
1112+
{
1113+
Key: types.NamespacedName{Namespace: "default", Name: "pod-1"},
1114+
UID: types.UID("2d8740a6-f4b1-4074-a91c-f0084ec0bc01"),
1115+
NodeName: "hyperpod-i-04442beca624ba65b",
1116+
PodIP: "192.168.128.151",
1117+
},
1118+
{
1119+
Key: types.NamespacedName{Namespace: "default", Name: "pod-2"},
1120+
UID: types.UID("2d8740a6-f4b1-4074-a91c-f0084ec0bc02"),
1121+
NodeName: "hyperpod-i-04159267183583d03",
1122+
PodIP: "192.168.128.152",
1123+
},
1124+
},
1125+
},
1126+
want: map[types.NamespacedName]ENIInfo{
1127+
types.NamespacedName{Namespace: "default", Name: "pod-1"}: {
1128+
NetworkInterfaceID: "eni-c",
1129+
SecurityGroups: []string{"sg-c-1"},
1130+
},
1131+
types.NamespacedName{Namespace: "default", Name: "pod-2"}: {
1132+
NetworkInterfaceID: "eni-d",
1133+
SecurityGroups: []string{"sg-d-1"},
1134+
},
1135+
},
1136+
},
1137+
}
1138+
for _, tt := range tests {
1139+
t.Run(tt.name, func(t *testing.T) {
1140+
ctrl := gomock.NewController(t)
1141+
defer ctrl.Finish()
1142+
1143+
ec2Client := services.NewMockEC2(ctrl)
1144+
for _, call := range tt.fields.describeNetworkInterfacesAsListCalls {
1145+
ec2Client.EXPECT().DescribeNetworkInterfacesAsList(gomock.Any(), call.req).Return(call.resp, call.err)
1146+
}
1147+
k8sSchema := runtime.NewScheme()
1148+
clientgoscheme.AddToScheme(k8sSchema)
1149+
k8sClient := fake.NewClientBuilder().WithScheme(k8sSchema).Build()
1150+
for _, node := range tt.env.nodes {
1151+
assert.NoError(t, k8sClient.Create(context.Background(), node.DeepCopy()))
1152+
}
1153+
nodeInfoProvider := NewMockNodeInfoProvider(ctrl)
1154+
for _, call := range tt.fields.fetchNodeInstancesCalls {
1155+
updatedNodes := make([]*corev1.Node, 0, len(call.nodes))
1156+
for _, node := range call.nodes {
1157+
updatedNode := &corev1.Node{}
1158+
assert.NoError(t, k8sClient.Get(context.Background(), k8s.NamespacedName(node), updatedNode))
1159+
updatedNodes = append(updatedNodes, updatedNode)
1160+
}
1161+
nodeInfoProvider.EXPECT().FetchNodeInstances(gomock.Any(), gomock.InAnyOrder(updatedNodes)).Return(call.nodeInstanceByNodeKey, call.err)
1162+
}
1163+
r := &defaultPodENIInfoResolver{
1164+
ec2Client: ec2Client,
1165+
k8sClient: k8sClient,
1166+
nodeInfoProvider: nodeInfoProvider,
1167+
vpcID: "vpc-0d6d9ee10bd062dcc",
1168+
logger: logr.New(&log.NullLogSink{}),
1169+
describeNetworkInterfacesIPChunkSize: 2,
1170+
}
1171+
1172+
got, err := r.resolveViaCascadedLookup(context.Background(), tt.args.pods, true)
1173+
if tt.wantErr != nil {
1174+
assert.EqualError(t, err, tt.wantErr.Error())
1175+
} else {
1176+
assert.NoError(t, err)
1177+
assert.Equal(t, tt.want, got)
1178+
}
1179+
})
1180+
}
1181+
}
1182+
10021183
func Test_defaultPodENIInfoResolver_resolveViaPodENIAnnotation(t *testing.T) {
10031184
type describeNetworkInterfacesAsListCall struct {
10041185
req *ec2sdk.DescribeNetworkInterfacesInput

0 commit comments

Comments
 (0)