Refer to the Kubernetes API documentation for the other fields of the
metadata field.
diff --git a/docs/guide/targetgroupbinding/targetgroupbinding.md b/docs/guide/targetgroupbinding/targetgroupbinding.md
index 0f3a055e74..e7d3ccf4aa 100644
--- a/docs/guide/targetgroupbinding/targetgroupbinding.md
+++ b/docs/guide/targetgroupbinding/targetgroupbinding.md
@@ -112,10 +112,108 @@ spec:
### AssumeRole
Sometimes the AWS LoadBalancer controller needs to manipulate target groups from different AWS accounts.
-The way to do that is assuming a role from such account. There are annotations that can help you with that:
+The way to do that is assuming a role from such an account. The following spec fields help you with that.
-* `alb.ingress.kubernetes.io/IamRoleArnToAssume`: the ARN that you need to assume
-* `alb.ingress.kubernetes.io/AssumeRoleExternalId`: the external ID for the assume role operation. Optional, but recommended. It helps you to prevent the confused deputy problem ( https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html )
+* `iamRoleArnToAssume`: the ARN that you need to assume
+* `assumeRoleExternalId`: the external ID for the assume role operation. Optional, but recommended. It helps you to prevent the confused deputy problem ( https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html )
+
+
+```yaml
+apiVersion: elbv2.k8s.aws/v1beta1
+kind: TargetGroupBinding
+metadata:
+ name: peered-tg
+ namespace: nlb-game-2048-1
+spec:
+ assumeRoleExternalId: very-secret-string-2
+ iamRoleArnToAssume: arn:aws:iam::155642222660:role/tg-management-role
+ networking:
+ ingress:
+ - from:
+ - securityGroup:
+ groupID: sg-0b6a41a2fd959623f
+ ports:
+ - port: 80
+ protocol: TCP
+ serviceRef:
+ name: service-2048
+ port: 80
+ targetGroupARN: arn:aws:elasticloadbalancing:us-west-2:155642222660:targetgroup/peered-tg/6a4ecf7bfae473c1
+```
+
+In the following examples, we will refer to Cluster Owner (CO) and Target Group Owner (TGO) accounts.
+
+First, in the TGO account creates a role that will allow the AWS LBC in the CO account to assume it.
+For improved security, we only allow the AWS LBC role in CO account to assume the role.
+
+```json
+{
+ "Version": "2012-10-17",
+ "Statement": [
+ {
+ "Sid": "",
+ "Effect": "Allow",
+ "Principal": {
+ "AWS": "arn:aws:iam::565768096483:role/eksctl-awslbc-loadtest-addon-iamserviceaccoun-Role1-13RdJCMqV6p2"
+ },
+ "Action": "sts:AssumeRole",
+ "Condition": {
+ "StringEquals": {
+ "sts:ExternalId": "very-secret-string"
+ }
+ }
+ }
+ ]
+}
+```
+
+Next, still in the TGO account we need to add the following permissions to the Role created in the first step.
+
+```json
+{
+ "Version": "2012-10-17",
+ "Statement": [
+ {
+ "Sid": "VisualEditor0",
+ "Effect": "Allow",
+ "Action": [
+ "elasticloadbalancing:RegisterTargets",
+ "elasticloadbalancing:DeregisterTargets"
+ ],
+ "Resource": [
+ "arn:aws:elasticloadbalancing:us-west-2:155642222660:targetgroup/tg1/*",
+ "arn:aws:elasticloadbalancing:us-west-2:155642222660:targetgroup/tg2/*"
+ // add more here //
+ ]
+ },
+ {
+ "Sid": "VisualEditor1",
+ "Effect": "Allow",
+ "Action": [
+ "elasticloadbalancing:DescribeTargetGroups",
+ "elasticloadbalancing:DescribeTargetHealth"
+ ],
+ "Resource": "*"
+ }
+ ]
+}
+```
+
+
+Next, in the CO account, we need to allow the AWS LBC to perform the AssumeRole call.
+By default, this permission is not a part of the standard IAM policy that is vended with the LBC installation scripts.
+For improved security, it is possible to scope the AssumeRole permissions down to only roles that you know ahead of time the
+LBC will need to Assume.
+
+```json
+ {
+ "Effect": "Allow",
+ "Action": [
+ "sts:AssumeRole"
+ ],
+ "Resource": "*"
+ }
+```
## Sample YAML
@@ -125,10 +223,9 @@ apiVersion: elbv2.k8s.aws/v1beta1
kind: TargetGroupBinding
metadata:
name: my-tgb
- annotations:
- alb.ingress.kubernetes.io/IamRoleArnToAssume: "arn:aws:iam::999999999999:role/alb-controller-policy-to-assume"
- alb.ingress.kubernetes.io/AssumeRoleExternalId: "some-magic-string"
spec:
+ iamRoleArnToAssume: "arn:aws:iam::999999999999:role/alb-controller-policy-to-assume"
+ assumeRoleExternalId: "some-magic-string"
...
```
diff --git a/go.sum b/go.sum
index 292ee96ce6..6676d7940c 100644
--- a/go.sum
+++ b/go.sum
@@ -60,6 +60,7 @@ github.com/aws/aws-sdk-go-v2/service/ec2 v1.173.0 h1:ta62lid9JkIpKZtZZXSj6rP2AqY
github.com/aws/aws-sdk-go-v2/service/ec2 v1.173.0/go.mod h1:o6QDjdVKpP5EF0dp/VlvqckzuSDATr1rLdHt3A5m0YY=
github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2 v1.43.1 h1:L9Wt9zgtoYKIlaeFTy+EztGjL4oaXBBGtVXA+jaeYko=
github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2 v1.43.1/go.mod h1:yxzLdxt7bVGvIOPYIKFtiaJCJnx2ChlIIvlhW4QgI6M=
+github.com/aws/aws-sdk-go-v2/service/iam v1.36.3/go.mod h1:HSvujsK8xeEHMIB18oMXjSfqaN9cVqpo/MtHJIksQRk=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 h1:dT3MqvGhSoaIhRseqw2I0yH81l7wiR2vjs57O51EAm8=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3/go.mod h1:GlAeCkHwugxdHaueRr4nhPuY+WW+gR8UjlcqzPr1SPI=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 h1:HGErhhrxZlQ044RiM+WdoZxp0p+EGM62y3L6pwA4olE=
diff --git a/helm/aws-load-balancer-controller/crds/crds.yaml b/helm/aws-load-balancer-controller/crds/crds.yaml
index b72e687892..ea31acc114 100644
--- a/helm/aws-load-balancer-controller/crds/crds.yaml
+++ b/helm/aws-load-balancer-controller/crds/crds.yaml
@@ -317,6 +317,15 @@ spec:
spec:
description: TargetGroupBindingSpec defines the desired state of TargetGroupBinding
properties:
+ assumeRoleExternalId:
+ description: IAM Role ARN to assume when calling AWS APIs. Needed
+ to assume a role in another account and prevent the confused deputy
+ problem. https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html
+ type: string
+ iamRoleArnToAssume:
+ description: IAM Role ARN to assume when calling AWS APIs. Useful
+ if the target group is in a different AWS account
+ type: string
multiClusterTargetGroup:
description: MultiClusterTargetGroup Denotes if the TargetGroup is
shared among multiple clusters
@@ -494,6 +503,15 @@ spec:
spec:
description: TargetGroupBindingSpec defines the desired state of TargetGroupBinding
properties:
+ assumeRoleExternalId:
+ description: IAM Role ARN to assume when calling AWS APIs. Needed
+ to assume a role in another account and prevent the confused deputy
+ problem. https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html
+ type: string
+ iamRoleArnToAssume:
+ description: IAM Role ARN to assume when calling AWS APIs. Useful
+ if the target group is in a different AWS account
+ type: string
ipAddressType:
description: ipAddressType specifies whether the target group is of
type IPv4 or IPv6. If unspecified, it will be automatically inferred.
diff --git a/main.go b/main.go
index de2a3edc9d..c4641fb109 100644
--- a/main.go
+++ b/main.go
@@ -90,7 +90,7 @@ func main() {
awsMetricsCollector = awsmetrics.NewCollector(metrics.Registry)
}
- cloud, err := aws.NewCloud(controllerCFG.AWSConfig, awsMetricsCollector, ctrl.Log, nil)
+ cloud, err := aws.NewCloud(controllerCFG.AWSConfig, controllerCFG.ClusterName, awsMetricsCollector, ctrl.Log, nil)
if err != nil {
setupLog.Error(err, "unable to initialize AWS cloud")
os.Exit(1)
diff --git a/pkg/aws/aws_config.go b/pkg/aws/aws_config.go
new file mode 100644
index 0000000000..4172f34daf
--- /dev/null
+++ b/pkg/aws/aws_config.go
@@ -0,0 +1,81 @@
+package aws
+
+import (
+ "context"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
+ "github.com/aws/aws-sdk-go-v2/aws/ratelimit"
+ "github.com/aws/aws-sdk-go-v2/aws/retry"
+ "github.com/aws/aws-sdk-go-v2/config"
+ "github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
+ smithymiddleware "github.com/aws/smithy-go/middleware"
+ "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/throttle"
+ awsmetrics "sigs.k8s.io/aws-load-balancer-controller/pkg/metrics/aws"
+ "sigs.k8s.io/aws-load-balancer-controller/pkg/version"
+)
+
+const (
+ userAgent = "elbv2.k8s.aws"
+)
+
+func NewAWSConfigGenerator(cfg CloudConfig, ec2IMDSEndpointMode imds.EndpointModeState, metricsCollector *awsmetrics.Collector) AWSConfigGenerator {
+ return &awsConfigGeneratorImpl{
+ cfg: cfg,
+ ec2IMDSEndpointMode: ec2IMDSEndpointMode,
+ metricsCollector: metricsCollector,
+ }
+
+}
+
+// AWSConfigGenerator is responsible for generating an aws config based on the running environment
+type AWSConfigGenerator interface {
+ GenerateAWSConfig(optFns ...func(*config.LoadOptions) error) (aws.Config, error)
+}
+
+type awsConfigGeneratorImpl struct {
+ cfg CloudConfig
+ ec2IMDSEndpointMode imds.EndpointModeState
+ metricsCollector *awsmetrics.Collector
+}
+
+func (gen *awsConfigGeneratorImpl) GenerateAWSConfig(optFns ...func(*config.LoadOptions) error) (aws.Config, error) {
+
+ defaultOpts := []func(*config.LoadOptions) error{
+ config.WithRegion(gen.cfg.Region),
+ config.WithRetryer(func() aws.Retryer {
+ return retry.NewStandard(func(o *retry.StandardOptions) {
+ o.RateLimiter = ratelimit.None
+ o.MaxAttempts = gen.cfg.MaxRetries
+ })
+ }),
+ config.WithEC2IMDSEndpointMode(gen.ec2IMDSEndpointMode),
+ config.WithAPIOptions([]func(stack *smithymiddleware.Stack) error{
+ awsmiddleware.AddUserAgentKeyValue(userAgent, version.GitVersion),
+ }),
+ }
+
+ defaultOpts = append(defaultOpts, optFns...)
+
+ awsConfig, err := config.LoadDefaultConfig(context.TODO(),
+ defaultOpts...,
+ )
+
+ if err != nil {
+ return aws.Config{}, err
+ }
+
+ if gen.cfg.ThrottleConfig != nil {
+ throttler := throttle.NewThrottler(gen.cfg.ThrottleConfig)
+ awsConfig.APIOptions = append(awsConfig.APIOptions, func(stack *smithymiddleware.Stack) error {
+ return throttle.WithSDKRequestThrottleMiddleware(throttler)(stack)
+ })
+ }
+
+ if gen.metricsCollector != nil {
+ awsConfig.APIOptions = awsmetrics.WithSDKMetricCollector(gen.metricsCollector, awsConfig.APIOptions)
+ }
+
+ return awsConfig, nil
+}
+
+var _ AWSConfigGenerator = &awsConfigGeneratorImpl{}
diff --git a/pkg/aws/cloud.go b/pkg/aws/cloud.go
index 1e9d50f878..bd733a657d 100644
--- a/pkg/aws/cloud.go
+++ b/pkg/aws/cloud.go
@@ -3,23 +3,18 @@ package aws
import (
"context"
"fmt"
- "log"
+ "k8s.io/apimachinery/pkg/util/cache"
"net"
"os"
"strings"
+ "sync"
+ "time"
- awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
- "github.com/aws/aws-sdk-go-v2/aws/ratelimit"
- "github.com/aws/aws-sdk-go-v2/aws/retry"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/aws-sdk-go-v2/service/sts"
- smithymiddleware "github.com/aws/smithy-go/middleware"
- "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/throttle"
- "sigs.k8s.io/aws-load-balancer-controller/pkg/version"
-
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/ec2"
@@ -32,10 +27,12 @@ import (
aws_metrics "sigs.k8s.io/aws-load-balancer-controller/pkg/metrics/aws"
)
-const userAgent = "elbv2.k8s.aws"
+const (
+ cacheTTLBufferTime = 30 * time.Second
+)
// NewCloud constructs new Cloud implementation.
-func NewCloud(cfg CloudConfig, metricsCollector *aws_metrics.Collector, logger logr.Logger, awsClientsProvider provider.AWSClientsProvider) (services.Cloud, error) {
+func NewCloud(cfg CloudConfig, clusterName string, metricsCollector *aws_metrics.Collector, logger logr.Logger, awsClientsProvider provider.AWSClientsProvider) (services.Cloud, error) {
hasIPv4 := true
addrs, err := net.InterfaceAddrs()
if err == nil {
@@ -76,29 +73,11 @@ func NewCloud(cfg CloudConfig, metricsCollector *aws_metrics.Collector, logger l
}
cfg.Region = region
}
- awsConfig, err := config.LoadDefaultConfig(context.TODO(),
- config.WithRegion(cfg.Region),
- config.WithRetryer(func() aws.Retryer {
- return retry.NewStandard(func(o *retry.StandardOptions) {
- o.RateLimiter = ratelimit.None
- o.MaxAttempts = cfg.MaxRetries
- })
- }),
- config.WithEC2IMDSEndpointMode(ec2IMDSEndpointMode),
- config.WithAPIOptions([]func(stack *smithymiddleware.Stack) error{
- awsmiddleware.AddUserAgentKeyValue(userAgent, version.GitVersion),
- }),
- )
-
- if cfg.ThrottleConfig != nil {
- throttler := throttle.NewThrottler(cfg.ThrottleConfig)
- awsConfig.APIOptions = append(awsConfig.APIOptions, func(stack *smithymiddleware.Stack) error {
- return throttle.WithSDKRequestThrottleMiddleware(throttler)(stack)
- })
- }
- if metricsCollector != nil {
- awsConfig.APIOptions = aws_metrics.WithSDKMetricCollector(metricsCollector, awsConfig.APIOptions)
+ awsConfigGenerator := NewAWSConfigGenerator(cfg, ec2IMDSEndpointMode, metricsCollector)
+ awsConfig, err := awsConfigGenerator.GenerateAWSConfig()
+ if err != nil {
+ return nil, errors.Wrap(err, "Unable to generate AWS config")
}
if awsClientsProvider == nil {
@@ -119,6 +98,7 @@ func NewCloud(cfg CloudConfig, metricsCollector *aws_metrics.Collector, logger l
thisObj := &defaultCloud{
cfg: cfg,
+ clusterName: clusterName,
ec2: ec2Service,
acm: services.NewACM(awsClientsProvider),
wafv2: services.NewWAFv2(awsClientsProvider),
@@ -126,7 +106,10 @@ func NewCloud(cfg CloudConfig, metricsCollector *aws_metrics.Collector, logger l
shield: services.NewShield(awsClientsProvider),
rgt: services.NewRGT(awsClientsProvider),
- assumeRoleElbV2: make(map[string]services.ELBV2),
+ awsConfigGenerator: awsConfigGenerator,
+
+ assumeRoleElbV2Cache: cache.NewExpiring(),
+
awsClientsProvider: awsClientsProvider,
logger: logger,
}
@@ -220,77 +203,64 @@ type defaultCloud struct {
shield services.Shield
rgt services.RGT
- assumeRoleElbV2 map[string]services.ELBV2
+ clusterName string
+
+ awsConfigGenerator AWSConfigGenerator
+
+ // A cache holding elbv2 clients that are assuming a role.
+ assumeRoleElbV2Cache *cache.Expiring
+ // assumeRoleElbV2CacheMutex protects assumeRoleElbV2Cache
+ assumeRoleElbV2CacheMutex sync.RWMutex
+
awsClientsProvider provider.AWSClientsProvider
logger logr.Logger
}
-// returns ELBV2 client for the given assumeRoleArn, or the default ELBV2 client if assumeRoleArn is empty
-func (c *defaultCloud) GetAssumedRoleELBV2(ctx context.Context, assumeRoleArn string, externalId string) services.ELBV2 {
-
+// GetAssumedRoleELBV2 returns ELBV2 client for the given assumeRoleArn, or the default ELBV2 client if assumeRoleArn is empty
+func (c *defaultCloud) GetAssumedRoleELBV2(ctx context.Context, assumeRoleArn string, externalId string) (services.ELBV2, error) {
if assumeRoleArn == "" {
- return c.elbv2
+ return c.elbv2, nil
}
- assumedRoleELBV2, exists := c.assumeRoleElbV2[assumeRoleArn]
+ c.assumeRoleElbV2CacheMutex.RLock()
+ assumedRoleELBV2, exists := c.assumeRoleElbV2Cache.Get(assumeRoleArn)
+ c.assumeRoleElbV2CacheMutex.RUnlock()
+
if exists {
- return assumedRoleELBV2
+ return assumedRoleELBV2.(services.ELBV2), nil
}
- c.logger.Info("awsCloud", "method", "GetAssumedRoleELBV2", "AssumeRoleArn", assumeRoleArn, "externalId", externalId)
+ c.logger.Info("Constructing new elbv2 client", "AssumeRoleArn", assumeRoleArn, "externalId", externalId)
- ////////////////
- existingAwsConfig, _ := c.awsClientsProvider.GetAWSConfig(ctx, "GetAWSConfigForIAMRoleImpersonation")
+ stsClient, err := c.awsClientsProvider.GetSTSClient(ctx, "AssumeRole")
+ if err != nil {
+ // This should never happen, but let's be forward-looking.
+ return nil, err
+ }
- sourceAccount := sts.NewFromConfig(*existingAwsConfig)
- response, err := sourceAccount.AssumeRole(ctx, &sts.AssumeRoleInput{
+ response, err := stsClient.AssumeRole(ctx, &sts.AssumeRoleInput{
RoleArn: aws.String(assumeRoleArn),
- RoleSessionName: aws.String("aws-load-balancer-controller"),
+ RoleSessionName: aws.String(generateAssumeRoleSessionName(c.clusterName)),
ExternalId: aws.String(externalId),
})
if err != nil {
- log.Fatalf("Unable to assume target role, %v. Attempting to use default client", err)
- return c.elbv2
+ c.logger.Error(err, "Unable to assume target role", "roleArn", assumeRoleArn)
+ return nil, err
}
assumedRoleCreds := response.Credentials
newCreds := credentials.NewStaticCredentialsProvider(*assumedRoleCreds.AccessKeyId, *assumedRoleCreds.SecretAccessKey, *assumedRoleCreds.SessionToken)
- newAwsConfig, err := config.LoadDefaultConfig(ctx, config.WithRegion(c.cfg.Region), config.WithCredentialsProvider(newCreds))
+ newAwsConfig, err := c.awsConfigGenerator.GenerateAWSConfig(config.WithCredentialsProvider(newCreds))
if err != nil {
- log.Fatalf("Unable to load static credentials for service client config, %v. Attempting to use default client", err)
- return c.elbv2
+ c.logger.Error(err, "Create new service client config service client config", "roleArn", assumeRoleArn)
+ return nil, err
}
- existingAwsConfig.Credentials = newAwsConfig.Credentials // response.Credentials
-
- // // var assumedRoleCreds *stsTypes.Credentials = response.Credentials
-
- // // Create config with target service client, using assumed role
- // cfg, err = config.LoadDefaultConfig(ctx, config.WithRegion(region), config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(*assumedRoleCreds.AccessKeyId, *assumedRoleCreds.SecretAccessKey, *assumedRoleCreds.SessionToken)))
- // if err != nil {
- // log.Fatalf("unable to load static credentials for service client config, %v", err)
- // }
-
- // ////////////////
- // appCreds := stscreds.NewAssumeRoleProvider(client, assumeRoleArn)
- // value, err := appCreds.Retrieve(context.TODO())
- // if err != nil {
- // // handle error
- // }
- // /////////
-
- // ///////////// OLD
- // creds := stscreds.NewCredentials(c.session, assumeRoleArn, func(p *stscreds.AssumeRoleProvider) {
- // p.ExternalID = &externalId
- // })
- // //////////////
-
- // c.awsConfig.Credentials = creds
- // // newObj := services.NewELBV2(c.session, c, c.awsCFG)
- // newObj := services.NewELBV2(*c.awsConfig, c.endpointsResolver, c)
-
- newObj := services.NewELBV2(c.awsClientsProvider, c)
- c.assumeRoleElbV2[assumeRoleArn] = newObj
+ cacheTTL := assumedRoleCreds.Expiration.Sub(time.Now())
+ elbv2WithAssumedRole := services.NewELBV2FromStaticClient(c.awsClientsProvider.GenerateNewELBv2Client(newAwsConfig), c)
- return newObj
+ c.assumeRoleElbV2CacheMutex.Lock()
+ defer c.assumeRoleElbV2CacheMutex.Unlock()
+ c.assumeRoleElbV2Cache.Set(assumeRoleArn, elbv2WithAssumedRole, cacheTTL-cacheTTLBufferTime)
+ return elbv2WithAssumedRole, nil
}
func (c *defaultCloud) EC2() services.EC2 {
diff --git a/pkg/aws/cloud_util.go b/pkg/aws/cloud_util.go
new file mode 100644
index 0000000000..6da13b8f0b
--- /dev/null
+++ b/pkg/aws/cloud_util.go
@@ -0,0 +1,25 @@
+package aws
+
+import (
+ "fmt"
+ "regexp"
+)
+
+const (
+ sessionNamePrefix = "AWS-LBC-"
+ maxSessionNameLength = 2047
+)
+
+var illegalValuesInSessionName = regexp.MustCompile(`[^a-zA-Z0-9=,.@\-_]+`)
+
+func generateAssumeRoleSessionName(clusterName string) string {
+ safeClusterName := illegalValuesInSessionName.ReplaceAllString(clusterName, "")
+
+ sessionName := fmt.Sprintf("%s%s", sessionNamePrefix, safeClusterName)
+
+ if len(sessionName) > maxSessionNameLength {
+ return sessionName[:maxSessionNameLength]
+ }
+
+ return sessionName
+}
diff --git a/pkg/aws/cloud_util_test.go b/pkg/aws/cloud_util_test.go
new file mode 100644
index 0000000000..83e37b612b
--- /dev/null
+++ b/pkg/aws/cloud_util_test.go
@@ -0,0 +1,42 @@
+package aws
+
+import (
+ "github.com/stretchr/testify/assert"
+ "testing"
+)
+
+func TestUpdateTrackedTargets(t *testing.T) {
+ testCases := []struct {
+ name string
+ clusterName string
+ expectedSessionName string
+ }{
+ {
+ name: "no mods",
+ clusterName: "my-cluster-name",
+ expectedSessionName: "AWS-LBC-my-cluster-name",
+ },
+ {
+ name: "mix lower and upper case",
+ clusterName: "My-ClUsTeR-name",
+ expectedSessionName: "AWS-LBC-My-ClUsTeR-name",
+ },
+ {
+ name: "with legal characters",
+ clusterName: "my_cluster-name=foo,something@here.",
+ expectedSessionName: "AWS-LBC-my_cluster-name=foo,something@here.",
+ },
+ {
+ name: "with illegal characters",
+ clusterName: "my&*&*cluster()!(&name",
+ expectedSessionName: "AWS-LBC-myclustername",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ result := generateAssumeRoleSessionName(tc.clusterName)
+ assert.Equal(t, tc.expectedSessionName, result)
+ })
+ }
+}
diff --git a/pkg/aws/provider/default_aws_clients_provider.go b/pkg/aws/provider/default_aws_clients_provider.go
index 41cd780554..1d1a2b713e 100644
--- a/pkg/aws/provider/default_aws_clients_provider.go
+++ b/pkg/aws/provider/default_aws_clients_provider.go
@@ -8,6 +8,7 @@ import (
"github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2"
"github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi"
"github.com/aws/aws-sdk-go-v2/service/shield"
+ "github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/aws-sdk-go-v2/service/wafregional"
"github.com/aws/aws-sdk-go-v2/service/wafv2"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
@@ -21,11 +22,13 @@ type defaultAWSClientsProvider struct {
wafRegionClient *wafregional.Client
shieldClient *shield.Client
rgtClient *resourcegroupstaggingapi.Client
+ stsClient *sts.Client
- awsConfig *aws.Config
+ // used for dynamic creation of ELBv2 client
+ elbv2CustomEndpoint *string
}
-func NewDefaultAWSClientsProvider(cfg aws.Config, endpointsResolver *endpoints.Resolver) (*defaultAWSClientsProvider, error) {
+func NewDefaultAWSClientsProvider(cfg aws.Config, endpointsResolver *endpoints.Resolver) (AWSClientsProvider, error) {
ec2CustomEndpoint := endpointsResolver.EndpointFor(ec2.ServiceID)
elbv2CustomEndpoint := endpointsResolver.EndpointFor(elasticloadbalancingv2.ServiceID)
acmCustomEndpoint := endpointsResolver.EndpointFor(acm.ServiceID)
@@ -33,17 +36,16 @@ func NewDefaultAWSClientsProvider(cfg aws.Config, endpointsResolver *endpoints.R
wafregionalCustomEndpoint := endpointsResolver.EndpointFor(wafregional.ServiceID)
shieldCustomEndpoint := endpointsResolver.EndpointFor(shield.ServiceID)
rgtCustomEndpoint := endpointsResolver.EndpointFor(resourcegroupstaggingapi.ServiceID)
+ stsCustomEndpoint := endpointsResolver.EndpointFor(sts.ServiceID)
ec2Client := ec2.NewFromConfig(cfg, func(o *ec2.Options) {
if ec2CustomEndpoint != nil {
o.BaseEndpoint = ec2CustomEndpoint
}
})
- elbv2Client := elasticloadbalancingv2.NewFromConfig(cfg, func(o *elasticloadbalancingv2.Options) {
- if elbv2CustomEndpoint != nil {
- o.BaseEndpoint = elbv2CustomEndpoint
- }
- })
+
+ elbv2Client := generateNewELBv2ClientHelper(cfg, elbv2CustomEndpoint)
+
acmClient := acm.NewFromConfig(cfg, func(o *acm.Options) {
if acmCustomEndpoint != nil {
o.BaseEndpoint = acmCustomEndpoint
@@ -68,6 +70,12 @@ func NewDefaultAWSClientsProvider(cfg aws.Config, endpointsResolver *endpoints.R
}
})
+ stsClient := sts.NewFromConfig(cfg, func(o *sts.Options) {
+ if stsCustomEndpoint != nil {
+ o.BaseEndpoint = stsCustomEndpoint
+ }
+ })
+
return &defaultAWSClientsProvider{
ec2Client: ec2Client,
elbv2Client: elbv2Client,
@@ -76,8 +84,9 @@ func NewDefaultAWSClientsProvider(cfg aws.Config, endpointsResolver *endpoints.R
wafRegionClient: wafregionalClient,
shieldClient: shieldClient,
rgtClient: rgtClient,
+ stsClient: stsClient,
- awsConfig: &cfg,
+ elbv2CustomEndpoint: elbv2CustomEndpoint,
}, nil
}
@@ -112,6 +121,18 @@ func (p *defaultAWSClientsProvider) GetRGTClient(ctx context.Context, operationN
return p.rgtClient, nil
}
-func (p *defaultAWSClientsProvider) GetAWSConfig(ctx context.Context, operationName string) (*aws.Config, error) {
- return p.awsConfig, nil
+func (p *defaultAWSClientsProvider) GetSTSClient(ctx context.Context, operationName string) (*sts.Client, error) {
+ return p.stsClient, nil
+}
+
+func (p *defaultAWSClientsProvider) GenerateNewELBv2Client(cfg aws.Config) *elasticloadbalancingv2.Client {
+ return generateNewELBv2ClientHelper(cfg, p.elbv2CustomEndpoint)
+}
+
+func generateNewELBv2ClientHelper(cfg aws.Config, elbv2CustomEndpoint *string) *elasticloadbalancingv2.Client {
+ return elasticloadbalancingv2.NewFromConfig(cfg, func(o *elasticloadbalancingv2.Options) {
+ if elbv2CustomEndpoint != nil {
+ o.BaseEndpoint = elbv2CustomEndpoint
+ }
+ })
}
diff --git a/pkg/aws/provider/provider.go b/pkg/aws/provider/provider.go
index 2cdff45745..66bb168286 100644
--- a/pkg/aws/provider/provider.go
+++ b/pkg/aws/provider/provider.go
@@ -8,6 +8,7 @@ import (
"github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2"
"github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi"
"github.com/aws/aws-sdk-go-v2/service/shield"
+ "github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/aws-sdk-go-v2/service/wafregional"
"github.com/aws/aws-sdk-go-v2/service/wafv2"
)
@@ -20,6 +21,6 @@ type AWSClientsProvider interface {
GetWAFRegionClient(ctx context.Context, operationName string) (*wafregional.Client, error)
GetShieldClient(ctx context.Context, operationName string) (*shield.Client, error)
GetRGTClient(ctx context.Context, operationName string) (*resourcegroupstaggingapi.Client, error)
-
- GetAWSConfig(ctx context.Context, operationName string) (*aws.Config, error)
+ GetSTSClient(ctx context.Context, operationName string) (*sts.Client, error)
+ GenerateNewELBv2Client(cfg aws.Config) *elasticloadbalancingv2.Client
}
diff --git a/pkg/aws/services/cloudInterface.go b/pkg/aws/services/cloudInterface.go
index 3e0ff558ec..8b11eaeb16 100644
--- a/pkg/aws/services/cloudInterface.go
+++ b/pkg/aws/services/cloudInterface.go
@@ -30,5 +30,5 @@ type Cloud interface {
// VpcID for the LoadBalancer resources.
VpcID() string
- GetAssumedRoleELBV2(ctx context.Context, assumeRoleArn string, externalId string) ELBV2
+ GetAssumedRoleELBV2(ctx context.Context, assumeRoleArn string, externalId string) (ELBV2, error)
}
diff --git a/pkg/aws/services/elbv2.go b/pkg/aws/services/elbv2.go
index 877cdd5f3a..293f61f3b2 100644
--- a/pkg/aws/services/elbv2.go
+++ b/pkg/aws/services/elbv2.go
@@ -60,7 +60,7 @@ type ELBV2 interface {
ModifyListenerAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyListenerAttributesInput) (*elasticloadbalancingv2.ModifyListenerAttributesOutput, error)
ModifyCapacityReservationWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyCapacityReservationInput) (*elasticloadbalancingv2.ModifyCapacityReservationOutput, error)
DescribeCapacityReservationWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeCapacityReservationInput) (*elasticloadbalancingv2.DescribeCapacityReservationOutput, error)
- AssumeRole(ctx context.Context, assumeRoleArn string, externalId string) ELBV2
+ AssumeRole(ctx context.Context, assumeRoleArn string, externalId string) (ELBV2, error)
}
func NewELBV2(awsClientsProvider provider.AWSClientsProvider, cloud Cloud) ELBV2 {
@@ -70,21 +70,29 @@ func NewELBV2(awsClientsProvider provider.AWSClientsProvider, cloud Cloud) ELBV2
}
}
+func NewELBV2FromStaticClient(staticELBClient *elasticloadbalancingv2.Client, cloud Cloud) ELBV2 {
+ return &elbv2Client{
+ staticELBClient: staticELBClient,
+ cloud: cloud,
+ }
+}
+
// default implementation for ELBV2.
type elbv2Client struct {
awsClientsProvider provider.AWSClientsProvider
+ staticELBClient *elasticloadbalancingv2.Client
cloud Cloud
}
-func (c *elbv2Client) AssumeRole(ctx context.Context, assumeRoleArn string, externalId string) ELBV2 {
+func (c *elbv2Client) AssumeRole(ctx context.Context, assumeRoleArn string, externalId string) (ELBV2, error) {
if assumeRoleArn == "" {
- return c
+ return c, nil
}
return c.cloud.GetAssumedRoleELBV2(ctx, assumeRoleArn, externalId)
}
func (c *elbv2Client) AddListenerCertificatesWithContext(ctx context.Context, input *elasticloadbalancingv2.AddListenerCertificatesInput) (*elasticloadbalancingv2.AddListenerCertificatesOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "AddListenerCertificates")
+ client, err := c.getClient(ctx, "AddListenerCertificates")
if err != nil {
return nil, err
}
@@ -92,7 +100,7 @@ func (c *elbv2Client) AddListenerCertificatesWithContext(ctx context.Context, in
}
func (c *elbv2Client) RemoveListenerCertificatesWithContext(ctx context.Context, input *elasticloadbalancingv2.RemoveListenerCertificatesInput) (*elasticloadbalancingv2.RemoveListenerCertificatesOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "RemoveListenerCertificates")
+ client, err := c.getClient(ctx, "RemoveListenerCertificates")
if err != nil {
return nil, err
}
@@ -100,7 +108,7 @@ func (c *elbv2Client) RemoveListenerCertificatesWithContext(ctx context.Context,
}
func (c *elbv2Client) DescribeListenersWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeListenersInput) (*elasticloadbalancingv2.DescribeListenersOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeListeners")
+ client, err := c.getClient(ctx, "DescribeListeners")
if err != nil {
return nil, err
}
@@ -108,7 +116,7 @@ func (c *elbv2Client) DescribeListenersWithContext(ctx context.Context, input *e
}
func (c *elbv2Client) DescribeRulesWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeRulesInput) (*elasticloadbalancingv2.DescribeRulesOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeRules")
+ client, err := c.getClient(ctx, "DescribeRules")
if err != nil {
return nil, err
}
@@ -116,7 +124,7 @@ func (c *elbv2Client) DescribeRulesWithContext(ctx context.Context, input *elast
}
func (c *elbv2Client) RegisterTargetsWithContext(ctx context.Context, input *elasticloadbalancingv2.RegisterTargetsInput) (*elasticloadbalancingv2.RegisterTargetsOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "RegisterTargets")
+ client, err := c.getClient(ctx, "RegisterTargets")
if err != nil {
return nil, err
}
@@ -124,7 +132,7 @@ func (c *elbv2Client) RegisterTargetsWithContext(ctx context.Context, input *ela
}
func (c *elbv2Client) DeregisterTargetsWithContext(ctx context.Context, input *elasticloadbalancingv2.DeregisterTargetsInput) (*elasticloadbalancingv2.DeregisterTargetsOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DeregisterTargets")
+ client, err := c.getClient(ctx, "DeregisterTargets")
if err != nil {
return nil, err
}
@@ -132,7 +140,7 @@ func (c *elbv2Client) DeregisterTargetsWithContext(ctx context.Context, input *e
}
func (c *elbv2Client) DescribeTrustStoresWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeTrustStoresInput) (*elasticloadbalancingv2.DescribeTrustStoresOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeTrustStores")
+ client, err := c.getClient(ctx, "DescribeTrustStores")
if err != nil {
return nil, err
}
@@ -140,7 +148,7 @@ func (c *elbv2Client) DescribeTrustStoresWithContext(ctx context.Context, input
}
func (c *elbv2Client) ModifyRuleWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyRuleInput) (*elasticloadbalancingv2.ModifyRuleOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "ModifyRule")
+ client, err := c.getClient(ctx, "ModifyRule")
if err != nil {
return nil, err
}
@@ -148,7 +156,7 @@ func (c *elbv2Client) ModifyRuleWithContext(ctx context.Context, input *elasticl
}
func (c *elbv2Client) DeleteRuleWithContext(ctx context.Context, input *elasticloadbalancingv2.DeleteRuleInput) (*elasticloadbalancingv2.DeleteRuleOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DeleteRule")
+ client, err := c.getClient(ctx, "DeleteRule")
if err != nil {
return nil, err
}
@@ -156,7 +164,7 @@ func (c *elbv2Client) DeleteRuleWithContext(ctx context.Context, input *elasticl
}
func (c *elbv2Client) CreateRuleWithContext(ctx context.Context, input *elasticloadbalancingv2.CreateRuleInput) (*elasticloadbalancingv2.CreateRuleOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "CreateRule")
+ client, err := c.getClient(ctx, "CreateRule")
if err != nil {
return nil, err
}
@@ -164,7 +172,7 @@ func (c *elbv2Client) CreateRuleWithContext(ctx context.Context, input *elasticl
}
func (c *elbv2Client) WaitUntilLoadBalancerAvailableWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeLoadBalancersInput) error {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeLoadBalancers")
+ client, err := c.getClient(ctx, "DescribeLoadBalancers")
if err != nil {
return err
}
@@ -174,7 +182,7 @@ func (c *elbv2Client) WaitUntilLoadBalancerAvailableWithContext(ctx context.Cont
}
func (c *elbv2Client) DescribeLoadBalancersWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeLoadBalancersInput) (*elasticloadbalancingv2.DescribeLoadBalancersOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeLoadBalancers")
+ client, err := c.getClient(ctx, "DescribeLoadBalancers")
if err != nil {
return nil, err
}
@@ -182,7 +190,7 @@ func (c *elbv2Client) DescribeLoadBalancersWithContext(ctx context.Context, inpu
}
func (c *elbv2Client) DescribeTargetHealthWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeTargetHealthInput) (*elasticloadbalancingv2.DescribeTargetHealthOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeTargetHealth")
+ client, err := c.getClient(ctx, "DescribeTargetHealth")
if err != nil {
return nil, err
}
@@ -190,7 +198,7 @@ func (c *elbv2Client) DescribeTargetHealthWithContext(ctx context.Context, input
}
func (c *elbv2Client) DescribeTargetGroupsWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeTargetGroupsInput) (*elasticloadbalancingv2.DescribeTargetGroupsOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeTargetGroups")
+ client, err := c.getClient(ctx, "DescribeTargetGroups")
if err != nil {
return nil, err
}
@@ -198,7 +206,7 @@ func (c *elbv2Client) DescribeTargetGroupsWithContext(ctx context.Context, input
}
func (c *elbv2Client) DeleteTargetGroupWithContext(ctx context.Context, input *elasticloadbalancingv2.DeleteTargetGroupInput) (*elasticloadbalancingv2.DeleteTargetGroupOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DeleteTargetGroup")
+ client, err := c.getClient(ctx, "DeleteTargetGroup")
if err != nil {
return nil, err
}
@@ -206,7 +214,7 @@ func (c *elbv2Client) DeleteTargetGroupWithContext(ctx context.Context, input *e
}
func (c *elbv2Client) ModifyTargetGroupWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyTargetGroupInput) (*elasticloadbalancingv2.ModifyTargetGroupOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "ModifyTargetGroup")
+ client, err := c.getClient(ctx, "ModifyTargetGroup")
if err != nil {
return nil, err
}
@@ -214,7 +222,7 @@ func (c *elbv2Client) ModifyTargetGroupWithContext(ctx context.Context, input *e
}
func (c *elbv2Client) CreateTargetGroupWithContext(ctx context.Context, input *elasticloadbalancingv2.CreateTargetGroupInput) (*elasticloadbalancingv2.CreateTargetGroupOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "CreateTargetGroup")
+ client, err := c.getClient(ctx, "CreateTargetGroup")
if err != nil {
return nil, err
}
@@ -222,7 +230,7 @@ func (c *elbv2Client) CreateTargetGroupWithContext(ctx context.Context, input *e
}
func (c *elbv2Client) DescribeTargetGroupAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeTargetGroupAttributesInput) (*elasticloadbalancingv2.DescribeTargetGroupAttributesOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeTargetGroupAttributes")
+ client, err := c.getClient(ctx, "DescribeTargetGroupAttributes")
if err != nil {
return nil, err
}
@@ -230,7 +238,7 @@ func (c *elbv2Client) DescribeTargetGroupAttributesWithContext(ctx context.Conte
}
func (c *elbv2Client) ModifyTargetGroupAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyTargetGroupAttributesInput) (*elasticloadbalancingv2.ModifyTargetGroupAttributesOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "ModifyTargetGroupAttributes")
+ client, err := c.getClient(ctx, "ModifyTargetGroupAttributes")
if err != nil {
return nil, err
}
@@ -238,7 +246,7 @@ func (c *elbv2Client) ModifyTargetGroupAttributesWithContext(ctx context.Context
}
func (c *elbv2Client) SetSecurityGroupsWithContext(ctx context.Context, input *elasticloadbalancingv2.SetSecurityGroupsInput) (*elasticloadbalancingv2.SetSecurityGroupsOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "SetSecurityGroups")
+ client, err := c.getClient(ctx, "SetSecurityGroups")
if err != nil {
return nil, err
}
@@ -246,7 +254,7 @@ func (c *elbv2Client) SetSecurityGroupsWithContext(ctx context.Context, input *e
}
func (c *elbv2Client) SetSubnetsWithContext(ctx context.Context, input *elasticloadbalancingv2.SetSubnetsInput) (*elasticloadbalancingv2.SetSubnetsOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "SetSubnets")
+ client, err := c.getClient(ctx, "SetSubnets")
if err != nil {
return nil, err
}
@@ -254,7 +262,7 @@ func (c *elbv2Client) SetSubnetsWithContext(ctx context.Context, input *elasticl
}
func (c *elbv2Client) SetIpAddressTypeWithContext(ctx context.Context, input *elasticloadbalancingv2.SetIpAddressTypeInput) (*elasticloadbalancingv2.SetIpAddressTypeOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "SetIpAddressType")
+ client, err := c.getClient(ctx, "SetIpAddressType")
if err != nil {
return nil, err
}
@@ -262,7 +270,7 @@ func (c *elbv2Client) SetIpAddressTypeWithContext(ctx context.Context, input *el
}
func (c *elbv2Client) DeleteLoadBalancerWithContext(ctx context.Context, input *elasticloadbalancingv2.DeleteLoadBalancerInput) (*elasticloadbalancingv2.DeleteLoadBalancerOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DeleteLoadBalancer")
+ client, err := c.getClient(ctx, "DeleteLoadBalancer")
if err != nil {
return nil, err
}
@@ -270,7 +278,7 @@ func (c *elbv2Client) DeleteLoadBalancerWithContext(ctx context.Context, input *
}
func (c *elbv2Client) CreateLoadBalancerWithContext(ctx context.Context, input *elasticloadbalancingv2.CreateLoadBalancerInput) (*elasticloadbalancingv2.CreateLoadBalancerOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "CreateLoadBalancer")
+ client, err := c.getClient(ctx, "CreateLoadBalancer")
if err != nil {
return nil, err
}
@@ -278,7 +286,7 @@ func (c *elbv2Client) CreateLoadBalancerWithContext(ctx context.Context, input *
}
func (c *elbv2Client) DescribeLoadBalancerAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeLoadBalancerAttributesInput) (*elasticloadbalancingv2.DescribeLoadBalancerAttributesOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeLoadBalancerAttributes")
+ client, err := c.getClient(ctx, "DescribeLoadBalancerAttributes")
if err != nil {
return nil, err
}
@@ -286,7 +294,7 @@ func (c *elbv2Client) DescribeLoadBalancerAttributesWithContext(ctx context.Cont
}
func (c *elbv2Client) ModifyLoadBalancerAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyLoadBalancerAttributesInput) (*elasticloadbalancingv2.ModifyLoadBalancerAttributesOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "ModifyLoadBalancerAttributes")
+ client, err := c.getClient(ctx, "ModifyLoadBalancerAttributes")
if err != nil {
return nil, err
}
@@ -294,7 +302,7 @@ func (c *elbv2Client) ModifyLoadBalancerAttributesWithContext(ctx context.Contex
}
func (c *elbv2Client) ModifyListenerWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyListenerInput) (*elasticloadbalancingv2.ModifyListenerOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "ModifyListener")
+ client, err := c.getClient(ctx, "ModifyListener")
if err != nil {
return nil, err
}
@@ -302,7 +310,7 @@ func (c *elbv2Client) ModifyListenerWithContext(ctx context.Context, input *elas
}
func (c *elbv2Client) DeleteListenerWithContext(ctx context.Context, input *elasticloadbalancingv2.DeleteListenerInput) (*elasticloadbalancingv2.DeleteListenerOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DeleteListener")
+ client, err := c.getClient(ctx, "DeleteListener")
if err != nil {
return nil, err
}
@@ -310,7 +318,7 @@ func (c *elbv2Client) DeleteListenerWithContext(ctx context.Context, input *elas
}
func (c *elbv2Client) CreateListenerWithContext(ctx context.Context, input *elasticloadbalancingv2.CreateListenerInput) (*elasticloadbalancingv2.CreateListenerOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "CreateListener")
+ client, err := c.getClient(ctx, "CreateListener")
if err != nil {
return nil, err
}
@@ -318,7 +326,7 @@ func (c *elbv2Client) CreateListenerWithContext(ctx context.Context, input *elas
}
func (c *elbv2Client) DescribeTagsWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeTagsInput) (*elasticloadbalancingv2.DescribeTagsOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeTags")
+ client, err := c.getClient(ctx, "DescribeTags")
if err != nil {
return nil, err
}
@@ -326,7 +334,7 @@ func (c *elbv2Client) DescribeTagsWithContext(ctx context.Context, input *elasti
}
func (c *elbv2Client) AddTagsWithContext(ctx context.Context, input *elasticloadbalancingv2.AddTagsInput) (*elasticloadbalancingv2.AddTagsOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "AddTags")
+ client, err := c.getClient(ctx, "AddTags")
if err != nil {
return nil, err
}
@@ -334,7 +342,7 @@ func (c *elbv2Client) AddTagsWithContext(ctx context.Context, input *elasticload
}
func (c *elbv2Client) RemoveTagsWithContext(ctx context.Context, input *elasticloadbalancingv2.RemoveTagsInput) (*elasticloadbalancingv2.RemoveTagsOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "RemoveTags")
+ client, err := c.getClient(ctx, "RemoveTags")
if err != nil {
return nil, err
}
@@ -345,7 +353,7 @@ func (c *elbv2Client) DescribeLoadBalancersAsList(ctx context.Context, input *el
var result []types.LoadBalancer
var client *elasticloadbalancingv2.Client
var err error
- client, err = c.awsClientsProvider.GetELBv2Client(ctx, "DescribeLoadBalancers")
+ client, err = c.getClient(ctx, "DescribeLoadBalancers")
if err != nil {
return nil, err
}
@@ -364,7 +372,7 @@ func (c *elbv2Client) DescribeTargetGroupsAsList(ctx context.Context, input *ela
var result []types.TargetGroup
var client *elasticloadbalancingv2.Client
var err error
- client, err = c.awsClientsProvider.GetELBv2Client(ctx, "DescribeTargetGroups")
+ client, err = c.getClient(ctx, "DescribeTargetGroups")
if err != nil {
return nil, err
}
@@ -383,7 +391,7 @@ func (c *elbv2Client) DescribeListenersAsList(ctx context.Context, input *elasti
var result []types.Listener
var client *elasticloadbalancingv2.Client
var err error
- client, err = c.awsClientsProvider.GetELBv2Client(ctx, "DescribeListeners")
+ client, err = c.getClient(ctx, "DescribeListeners")
if err != nil {
return nil, err
}
@@ -402,7 +410,7 @@ func (c *elbv2Client) DescribeListenerCertificatesAsList(ctx context.Context, in
var result []types.Certificate
var client *elasticloadbalancingv2.Client
var err error
- client, err = c.awsClientsProvider.GetELBv2Client(ctx, "DescribeListenerCertificates")
+ client, err = c.getClient(ctx, "DescribeListenerCertificates")
if err != nil {
return nil, err
}
@@ -421,7 +429,7 @@ func (c *elbv2Client) DescribeRulesAsList(ctx context.Context, input *elasticloa
var result []types.Rule
var client *elasticloadbalancingv2.Client
var err error
- client, err = c.awsClientsProvider.GetELBv2Client(ctx, "DescribeRules")
+ client, err = c.getClient(ctx, "DescribeRules")
if err != nil {
return nil, err
}
@@ -437,7 +445,7 @@ func (c *elbv2Client) DescribeRulesAsList(ctx context.Context, input *elasticloa
}
func (c *elbv2Client) DescribeListenerAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeListenerAttributesInput) (*elasticloadbalancingv2.DescribeListenerAttributesOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeListenerAttributes")
+ client, err := c.getClient(ctx, "DescribeListenerAttributes")
if err != nil {
return nil, err
}
@@ -445,7 +453,7 @@ func (c *elbv2Client) DescribeListenerAttributesWithContext(ctx context.Context,
}
func (c *elbv2Client) ModifyListenerAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyListenerAttributesInput) (*elasticloadbalancingv2.ModifyListenerAttributesOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "ModifyListenerAttributes")
+ client, err := c.getClient(ctx, "ModifyListenerAttributes")
if err != nil {
return nil, err
}
@@ -453,7 +461,7 @@ func (c *elbv2Client) ModifyListenerAttributesWithContext(ctx context.Context, i
}
func (c *elbv2Client) ModifyCapacityReservationWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyCapacityReservationInput) (*elasticloadbalancingv2.ModifyCapacityReservationOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "ModifyCapacityReservation")
+ client, err := c.getClient(ctx, "ModifyCapacityReservation")
if err != nil {
return nil, err
}
@@ -461,9 +469,16 @@ func (c *elbv2Client) ModifyCapacityReservationWithContext(ctx context.Context,
}
func (c *elbv2Client) DescribeCapacityReservationWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeCapacityReservationInput) (*elasticloadbalancingv2.DescribeCapacityReservationOutput, error) {
- client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeCapacityReservation")
+ client, err := c.getClient(ctx, "DescribeCapacityReservation")
if err != nil {
return nil, err
}
return client.DescribeCapacityReservation(ctx, input)
}
+
+func (c *elbv2Client) getClient(ctx context.Context, operation string) (*elasticloadbalancingv2.Client, error) {
+ if c.staticELBClient != nil {
+ return c.staticELBClient, nil
+ }
+ return c.awsClientsProvider.GetELBv2Client(ctx, operation)
+}
diff --git a/pkg/aws/services/elbv2_mocks.go b/pkg/aws/services/elbv2_mocks.go
index 8d5b540a7e..1f427dd566 100644
--- a/pkg/aws/services/elbv2_mocks.go
+++ b/pkg/aws/services/elbv2_mocks.go
@@ -67,11 +67,12 @@ func (mr *MockELBV2MockRecorder) AddTagsWithContext(arg0, arg1 interface{}) *gom
}
// AssumeRole mocks base method.
-func (m *MockELBV2) AssumeRole(arg0 context.Context, arg1, arg2 string) ELBV2 {
+func (m *MockELBV2) AssumeRole(arg0 context.Context, arg1, arg2 string) (ELBV2, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AssumeRole", arg0, arg1, arg2)
ret0, _ := ret[0].(ELBV2)
- return ret0
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
}
// AssumeRole indicates an expected call of AssumeRole.
diff --git a/pkg/targetgroupbinding/resource_manager.go b/pkg/targetgroupbinding/resource_manager.go
index 666d25b1c4..be3dfd81d2 100644
--- a/pkg/targetgroupbinding/resource_manager.go
+++ b/pkg/targetgroupbinding/resource_manager.go
@@ -96,7 +96,6 @@ func (m *defaultResourceManager) Reconcile(ctx context.Context, tgb *elbv2api.Ta
var oldCheckPoint string
var isDeferred bool
var err error
- AnnotationsToFields(tgb)
if *tgb.Spec.TargetType == elbv2api.TargetTypeIP {
newCheckPoint, oldCheckPoint, isDeferred, err = m.reconcileWithIPTargetType(ctx, tgb)
@@ -116,7 +115,6 @@ func (m *defaultResourceManager) Reconcile(ctx context.Context, tgb *elbv2api.Ta
}
func (m *defaultResourceManager) Cleanup(ctx context.Context, tgb *elbv2api.TargetGroupBinding) error {
- AnnotationsToFields(tgb)
if err := m.cleanupTargets(ctx, tgb); err != nil {
return err
}
@@ -546,42 +544,45 @@ func (m *defaultResourceManager) deregisterTargets(ctx context.Context, tgb *elb
func (m *defaultResourceManager) registerPodEndpoints(ctx context.Context, tgb *elbv2api.TargetGroupBinding, endpoints []backend.PodEndpoint) error {
vpcID := m.vpcID
- // Target group is in a different VPC from the cluster's VPC
if tgb.Spec.VpcID != "" && tgb.Spec.VpcID != m.vpcID {
vpcID = tgb.Spec.VpcID
m.logger.Info(fmt.Sprintf(
"registering endpoints using the targetGroup's vpcID %s which is different from the cluster's vpcID %s", tgb.Spec.VpcID, m.vpcID))
+ }
- if tgb.Spec.IamRoleArnToAssume != "" {
- // since we need to assume a role for this TGB,
- // it is from a different account
- // so the packets will need to leave the VPC and therefore
- // target.AvailabilityZone = awssdk.String("all") must be set
- // or else nothing will work
- sdkTargets := make([]elbv2types.TargetDescription, 0, len(endpoints))
- for _, endpoint := range endpoints {
- target := elbv2types.TargetDescription{
- Id: awssdk.String(endpoint.IP),
- Port: awssdk.Int32(endpoint.Port),
- }
- target.AvailabilityZone = awssdk.String("all")
- sdkTargets = append(sdkTargets, target)
- }
- return m.targetsManager.RegisterTargets(ctx, tgb, sdkTargets)
+ var overrideAzFn func(addr netip.Addr) bool
+ if tgb.Spec.IamRoleArnToAssume != "" {
+ // If we're interacting with another account, then we should always be setting "all" AZ to allow this
+ // target to get registered by the ELB API.
+ overrideAzFn = func(_ netip.Addr) bool {
+ return true
+ }
+ } else {
+ vpcInfo, err := m.vpcInfoProvider.FetchVPCInfo(ctx, vpcID)
+ if err != nil {
+ return err
+ }
+ var vpcRawCIDRs []string
+ vpcRawCIDRs = append(vpcRawCIDRs, vpcInfo.AssociatedIPv4CIDRs()...)
+ vpcRawCIDRs = append(vpcRawCIDRs, vpcInfo.AssociatedIPv6CIDRs()...)
+ vpcCIDRs, err := networking.ParseCIDRs(vpcRawCIDRs)
+ if err != nil {
+ return err
+ }
+ // If the pod ip resides out of all the VPC CIDRs, then the only way to force the ELB API is to use "all" AZ.
+ overrideAzFn = func(addr netip.Addr) bool {
+ return !networking.IsIPWithinCIDRs(addr, vpcCIDRs)
}
}
- vpcInfo, err := m.vpcInfoProvider.FetchVPCInfo(ctx, vpcID)
- if err != nil {
- return err
- }
- var vpcRawCIDRs []string
- vpcRawCIDRs = append(vpcRawCIDRs, vpcInfo.AssociatedIPv4CIDRs()...)
- vpcRawCIDRs = append(vpcRawCIDRs, vpcInfo.AssociatedIPv6CIDRs()...)
- vpcCIDRs, err := networking.ParseCIDRs(vpcRawCIDRs)
+
+ sdkTargets, err := m.prepareRegistrationCall(endpoints, overrideAzFn)
if err != nil {
return err
}
+ return m.targetsManager.RegisterTargets(ctx, tgb, sdkTargets)
+}
+func (m *defaultResourceManager) prepareRegistrationCall(endpoints []backend.PodEndpoint, doAzOverride func(addr netip.Addr) bool) ([]elbv2types.TargetDescription, error) {
sdkTargets := make([]elbv2types.TargetDescription, 0, len(endpoints))
for _, endpoint := range endpoints {
target := elbv2types.TargetDescription{
@@ -590,14 +591,14 @@ func (m *defaultResourceManager) registerPodEndpoints(ctx context.Context, tgb *
}
podIP, err := netip.ParseAddr(endpoint.IP)
if err != nil {
- return err
+ return sdkTargets, err
}
- if !networking.IsIPWithinCIDRs(podIP, vpcCIDRs) {
+ if doAzOverride(podIP) {
target.AvailabilityZone = awssdk.String("all")
}
sdkTargets = append(sdkTargets, target)
}
- return m.targetsManager.RegisterTargets(ctx, tgb, sdkTargets)
+ return sdkTargets, nil
}
func (m *defaultResourceManager) registerNodePortEndpoints(ctx context.Context, tgb *elbv2api.TargetGroupBinding, endpoints []backend.NodePortEndpoint) error {
diff --git a/pkg/targetgroupbinding/targets_manager.go b/pkg/targetgroupbinding/targets_manager.go
index b00f95a9b5..13e6117a80 100644
--- a/pkg/targetgroupbinding/targets_manager.go
+++ b/pkg/targetgroupbinding/targets_manager.go
@@ -89,7 +89,13 @@ func (m *cachedTargetsManager) RegisterTargets(ctx context.Context, tgb *elbv2ap
m.logger.Info("registering targets",
"arn", tgARN,
"targets", targetsChunk)
- _, err := m.elbv2Client.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId).RegisterTargetsWithContext(ctx, req)
+
+ clientToUse, err := m.elbv2Client.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId)
+ if err != nil {
+ return err
+ }
+
+ _, err = clientToUse.RegisterTargetsWithContext(ctx, req)
if err != nil {
return err
}
@@ -111,7 +117,11 @@ func (m *cachedTargetsManager) DeregisterTargets(ctx context.Context, tgb *elbv2
m.logger.Info("deRegistering targets",
"arn", tgARN,
"targets", targetsChunk)
- _, err := m.elbv2Client.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId).DeregisterTargetsWithContext(ctx, req)
+ clientToUse, err := m.elbv2Client.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId)
+ if err != nil {
+ return err
+ }
+ _, err = clientToUse.DeregisterTargetsWithContext(ctx, req)
if err != nil {
return err
}
@@ -199,7 +209,11 @@ func (m *cachedTargetsManager) listTargetsFromAWS(ctx context.Context, tgb *elbv
TargetGroupArn: aws.String(tgARN),
Targets: pointerizeTargetDescriptions(targets),
}
- resp, err := m.elbv2Client.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId).DescribeTargetHealthWithContext(ctx, req)
+ clientToUse, err := m.elbv2Client.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId)
+ if err != nil {
+ return nil, err
+ }
+ resp, err := clientToUse.DescribeTargetHealthWithContext(ctx, req)
if err != nil {
return nil, err
}
diff --git a/pkg/targetgroupbinding/targets_manager_test.go b/pkg/targetgroupbinding/targets_manager_test.go
index f4f476fe03..5b85a6b7c4 100644
--- a/pkg/targetgroupbinding/targets_manager_test.go
+++ b/pkg/targetgroupbinding/targets_manager_test.go
@@ -280,7 +280,7 @@ func Test_cachedTargetsManager_RegisterTargets(t *testing.T) {
ctx := context.Background()
for _, call := range tt.fields.registerTargetsWithContextCalls {
elbv2Client.EXPECT().RegisterTargetsWithContext(gomock.Any(), call.req).Return(call.resp, call.err)
- elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client)
+ elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil)
}
targetsCache := cache.NewExpiring()
@@ -526,7 +526,7 @@ func Test_cachedTargetsManager_DeregisterTargets(t *testing.T) {
ctx := context.Background()
for _, call := range tt.fields.deregisterTargetsWithContextCalls {
elbv2Client.EXPECT().DeregisterTargetsWithContext(gomock.Any(), call.req).Return(call.resp, call.err)
- elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client)
+ elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil)
}
targetsCache := cache.NewExpiring()
@@ -792,7 +792,7 @@ func Test_cachedTargetsManager_ListTargets(t *testing.T) {
elbv2Client := services.NewMockELBV2(ctrl)
for _, call := range tt.fields.describeTargetHealthWithContextCalls {
elbv2Client.EXPECT().DescribeTargetHealthWithContext(gomock.Any(), call.req).Return(call.resp, call.err)
- elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client)
+ elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil)
}
targetsCache := cache.NewExpiring()
targetsCacheTTL := 1 * time.Minute
@@ -1203,7 +1203,7 @@ func Test_cachedTargetsManager_refreshUnhealthyTargets(t *testing.T) {
elbv2Client := services.NewMockELBV2(ctrl)
for _, call := range tt.fields.describeTargetHealthWithContextCalls {
elbv2Client.EXPECT().DescribeTargetHealthWithContext(gomock.Any(), call.req).Return(call.resp, call.err)
- elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client)
+ elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil)
}
m := &cachedTargetsManager{
elbv2Client: elbv2Client,
@@ -1367,7 +1367,7 @@ func Test_cachedTargetsManager_listTargetsFromAWS(t *testing.T) {
elbv2Client := services.NewMockELBV2(ctrl)
for _, call := range tt.fields.describeTargetHealthWithContextCalls {
elbv2Client.EXPECT().DescribeTargetHealthWithContext(gomock.Any(), call.req).Return(call.resp, call.err)
- elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client)
+ elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil)
}
m := &cachedTargetsManager{
diff --git a/pkg/targetgroupbinding/utils.go b/pkg/targetgroupbinding/utils.go
index 013c94081d..057e08a77a 100644
--- a/pkg/targetgroupbinding/utils.go
+++ b/pkg/targetgroupbinding/utils.go
@@ -25,27 +25,8 @@ const (
// Index Key for "ServiceReference" index.
IndexKeyServiceRefName = "spec.serviceRef.name"
-
- // Annotation for IAM Role ARN to assume when calling AWS APIs.
- AnnotationIamRoleArnToAssume = "alb.ingress.kubernetes.io/IamRoleArnToAssume"
-
- // Annotation for IAM Role External ID to use when calling AWS APIs.
- AnnotationAssumeRoleExternalId = "alb.ingress.kubernetes.io/AssumeRoleExternalId"
)
-// AnnotationsToFields converts annotations to fields. Currently it's tgb.Spec.IamRoleArnToAssume and tgb.Spec.AssumeRoleExternalId
-func AnnotationsToFields(tgb *elbv2api.TargetGroupBinding) {
- for key, value := range tgb.Annotations {
- if key == AnnotationIamRoleArnToAssume {
- tgb.Spec.IamRoleArnToAssume = value
- } else {
- if key == AnnotationAssumeRoleExternalId {
- tgb.Spec.AssumeRoleExternalId = value
- }
- }
- }
-}
-
// BuildTargetHealthPodConditionType constructs the condition type for TargetHealth pod condition.
func BuildTargetHealthPodConditionType(tgb *elbv2api.TargetGroupBinding) corev1.PodConditionType {
return corev1.PodConditionType(fmt.Sprintf("%s/%s", TargetHealthPodConditionTypePrefix, tgb.Name))
diff --git a/test/framework/framework.go b/test/framework/framework.go
index 9830647bb4..eb99b67b8d 100644
--- a/test/framework/framework.go
+++ b/test/framework/framework.go
@@ -63,7 +63,7 @@ func InitFramework() (*Framework, error) {
VpcID: globalOptions.AWSVPCID,
MaxRetries: 3,
ThrottleConfig: throttle.NewDefaultServiceOperationsThrottleConfig(),
- }, nil, logger, nil)
+ }, "clusterName", nil, logger, nil)
if err != nil {
return nil, err
}
diff --git a/webhooks/elbv2/targetgroup_helper.go b/webhooks/elbv2/targetgroup_helper.go
new file mode 100644
index 0000000000..dc847dc6fe
--- /dev/null
+++ b/webhooks/elbv2/targetgroup_helper.go
@@ -0,0 +1,45 @@
+package elbv2
+
+import (
+ "context"
+ elbv2sdk "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2"
+ elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types"
+ "github.com/pkg/errors"
+ elbv2api "sigs.k8s.io/aws-load-balancer-controller/apis/elbv2/v1beta1"
+ "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services"
+)
+
+// getTargetGroupFromAWS returns the AWS target group corresponding to the arn
+func getTargetGroupFromAWS(ctx context.Context, elbv2Client services.ELBV2, tgb *elbv2api.TargetGroupBinding) (*elbv2types.TargetGroup, error) {
+ tgARN := tgb.Spec.TargetGroupARN
+ req := &elbv2sdk.DescribeTargetGroupsInput{
+ TargetGroupArns: []string{tgARN},
+ }
+ return getTargetGroupHelper(ctx, elbv2Client, tgb, tgARN, req)
+}
+
+// getTargetGroupsByNameFromAWS returns the AWS target group corresponding to the name
+func getTargetGroupsByNameFromAWS(ctx context.Context, elbv2Client services.ELBV2, tgb *elbv2api.TargetGroupBinding) (*elbv2types.TargetGroup, error) {
+ req := &elbv2sdk.DescribeTargetGroupsInput{
+ Names: []string{tgb.Spec.TargetGroupName},
+ }
+
+ return getTargetGroupHelper(ctx, elbv2Client, tgb, tgb.Spec.TargetGroupName, req)
+}
+
+func getTargetGroupHelper(ctx context.Context, elbv2Client services.ELBV2, tgb *elbv2api.TargetGroupBinding, tgIdentifier string, req *elbv2sdk.DescribeTargetGroupsInput) (*elbv2types.TargetGroup, error) {
+ clientToUse, err := elbv2Client.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId)
+
+ if err != nil {
+ return nil, err
+ }
+
+ tgList, err := clientToUse.DescribeTargetGroupsAsList(ctx, req)
+ if err != nil {
+ return nil, err
+ }
+ if len(tgList) != 1 {
+ return nil, errors.Errorf("expecting a single targetGroup with query [%s] but got %v", tgIdentifier, len(tgList))
+ }
+ return &tgList[0], nil
+}
diff --git a/webhooks/elbv2/targetgroupbinding_mutator.go b/webhooks/elbv2/targetgroupbinding_mutator.go
index 7818ce39d3..aed2f6dd65 100644
--- a/webhooks/elbv2/targetgroupbinding_mutator.go
+++ b/webhooks/elbv2/targetgroupbinding_mutator.go
@@ -6,14 +6,11 @@ import (
elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types"
awssdk "github.com/aws/aws-sdk-go-v2/aws"
- elbv2sdk "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2"
-
"github.com/go-logr/logr"
"github.com/pkg/errors"
"k8s.io/apimachinery/pkg/runtime"
elbv2api "sigs.k8s.io/aws-load-balancer-controller/apis/elbv2/v1beta1"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services"
- "sigs.k8s.io/aws-load-balancer-controller/pkg/targetgroupbinding"
"sigs.k8s.io/aws-load-balancer-controller/pkg/webhook"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
@@ -48,7 +45,6 @@ func (m *targetGroupBindingMutator) MutateCreate(ctx context.Context, obj runtim
if err := m.getArnFromNameIfNeeded(ctx, tgb); err != nil {
return nil, err
}
- targetgroupbinding.AnnotationsToFields(tgb)
if err := m.defaultingTargetType(ctx, tgb); err != nil {
return nil, err
}
@@ -63,7 +59,7 @@ func (m *targetGroupBindingMutator) MutateCreate(ctx context.Context, obj runtim
func (m *targetGroupBindingMutator) getArnFromNameIfNeeded(ctx context.Context, tgb *elbv2api.TargetGroupBinding) error {
if tgb.Spec.TargetGroupARN == "" && tgb.Spec.TargetGroupName != "" {
- tgObj, err := m.getTargetGroupsByNameFromAWS(ctx, tgb.Spec.TargetGroupName)
+ tgObj, err := getTargetGroupsByNameFromAWS(ctx, m.elbv2Client, tgb)
if err != nil {
return err
}
@@ -123,7 +119,7 @@ func (m *targetGroupBindingMutator) defaultingVpcID(ctx context.Context, tgb *el
}
func (m *targetGroupBindingMutator) obtainSDKTargetTypeFromAWS(ctx context.Context, tgb *elbv2api.TargetGroupBinding) (string, error) {
- targetGroup, err := m.getTargetGroupFromAWS(ctx, tgb)
+ targetGroup, err := getTargetGroupFromAWS(ctx, m.elbv2Client, tgb)
if err != nil {
return "", err
}
@@ -132,7 +128,7 @@ func (m *targetGroupBindingMutator) obtainSDKTargetTypeFromAWS(ctx context.Conte
// getTargetGroupIPAddressTypeFromAWS returns the target group IP address type of AWS target group
func (m *targetGroupBindingMutator) getTargetGroupIPAddressTypeFromAWS(ctx context.Context, tgb *elbv2api.TargetGroupBinding) (elbv2api.TargetGroupIPAddressType, error) {
- targetGroup, err := m.getTargetGroupFromAWS(ctx, tgb)
+ targetGroup, err := getTargetGroupFromAWS(ctx, m.elbv2Client, tgb)
if err != nil {
return "", err
}
@@ -148,37 +144,8 @@ func (m *targetGroupBindingMutator) getTargetGroupIPAddressTypeFromAWS(ctx conte
return ipAddressType, nil
}
-func (m *targetGroupBindingMutator) getTargetGroupFromAWS(ctx context.Context, tgb *elbv2api.TargetGroupBinding) (*elbv2types.TargetGroup, error) {
- tgARN := tgb.Spec.TargetGroupARN
- req := &elbv2sdk.DescribeTargetGroupsInput{
- TargetGroupArns: []string{tgARN},
- }
- tgList, err := m.elbv2Client.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId).DescribeTargetGroupsAsList(ctx, req)
- if err != nil {
- return nil, err
- }
- if len(tgList) != 1 {
- return nil, errors.Errorf("expecting a single targetGroup but got %v", len(tgList))
- }
- return &tgList[0], nil
-}
-
-func (m *targetGroupBindingMutator) getTargetGroupsByNameFromAWS(ctx context.Context, tgName string) (*elbv2types.TargetGroup, error) {
- req := &elbv2sdk.DescribeTargetGroupsInput{
- Names: []string{tgName},
- }
- tgList, err := m.elbv2Client.DescribeTargetGroupsAsList(ctx, req)
- if err != nil {
- return nil, err
- }
- if len(tgList) != 1 {
- return nil, errors.Errorf("expecting a single targetGroup with name [%s] but got %v", tgName, len(tgList))
- }
- return &tgList[0], nil
-}
-
func (m *targetGroupBindingMutator) getVpcIDFromAWS(ctx context.Context, tgb *elbv2api.TargetGroupBinding) (string, error) {
- targetGroup, err := m.getTargetGroupFromAWS(ctx, tgb)
+ targetGroup, err := getTargetGroupFromAWS(ctx, m.elbv2Client, tgb)
if err != nil {
return "", err
}
diff --git a/webhooks/elbv2/targetgroupbinding_mutator_test.go b/webhooks/elbv2/targetgroupbinding_mutator_test.go
index 032440ad14..0348d21de0 100644
--- a/webhooks/elbv2/targetgroupbinding_mutator_test.go
+++ b/webhooks/elbv2/targetgroupbinding_mutator_test.go
@@ -308,7 +308,7 @@ func Test_targetGroupBindingMutator_MutateCreate(t *testing.T) {
ctx := context.Background()
for _, call := range tt.fields.describeTargetGroupsAsListCalls {
elbv2Client.EXPECT().DescribeTargetGroupsAsList(gomock.Any(), call.req).Return(call.resp, call.err).AnyTimes()
- elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client).AnyTimes()
+ elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil).AnyTimes()
}
m := &targetGroupBindingMutator{
@@ -415,7 +415,7 @@ func Test_targetGroupBindingMutator_obtainSDKTargetTypeFromAWS(t *testing.T) {
elbv2Client := services.NewMockELBV2(ctrl)
for _, call := range tt.fields.describeTargetGroupsAsListCalls {
elbv2Client.EXPECT().DescribeTargetGroupsAsList(gomock.Any(), call.req).Return(call.resp, call.err)
- elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client).AnyTimes()
+ elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil).AnyTimes()
}
m := &targetGroupBindingMutator{
@@ -545,7 +545,7 @@ func Test_targetGroupBindingMutator_getIPAddressTypeFromAWS(t *testing.T) {
elbv2Client := services.NewMockELBV2(ctrl)
for _, call := range tt.fields.describeTargetGroupsAsListCalls {
elbv2Client.EXPECT().DescribeTargetGroupsAsList(gomock.Any(), call.req).Return(call.resp, call.err)
- elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client).AnyTimes()
+ elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil).AnyTimes()
}
m := &targetGroupBindingMutator{
@@ -631,7 +631,7 @@ func Test_targetGroupBindingMutator_obtainSDKVpcIDFromAWS(t *testing.T) {
elbv2Client := services.NewMockELBV2(ctrl)
for _, call := range tt.fields.describeTargetGroupsAsListCalls {
elbv2Client.EXPECT().DescribeTargetGroupsAsList(gomock.Any(), call.req).Return(call.resp, call.err)
- elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client).AnyTimes()
+ elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil).AnyTimes()
}
m := &targetGroupBindingMutator{
diff --git a/webhooks/elbv2/targetgroupbinding_validator.go b/webhooks/elbv2/targetgroupbinding_validator.go
index bb6ee770a5..2e3e63bfee 100644
--- a/webhooks/elbv2/targetgroupbinding_validator.go
+++ b/webhooks/elbv2/targetgroupbinding_validator.go
@@ -9,14 +9,12 @@ import (
elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types"
awssdk "github.com/aws/aws-sdk-go-v2/aws"
- elbv2sdk "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2"
"github.com/go-logr/logr"
"github.com/pkg/errors"
"k8s.io/apimachinery/pkg/runtime"
elbv2api "sigs.k8s.io/aws-load-balancer-controller/apis/elbv2/v1beta1"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services"
"sigs.k8s.io/aws-load-balancer-controller/pkg/k8s"
- "sigs.k8s.io/aws-load-balancer-controller/pkg/targetgroupbinding"
"sigs.k8s.io/aws-load-balancer-controller/pkg/webhook"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
@@ -56,7 +54,6 @@ func (v *targetGroupBindingValidator) Prototype(_ admission.Request) (runtime.Ob
func (v *targetGroupBindingValidator) ValidateCreate(ctx context.Context, obj runtime.Object) error {
tgb := obj.(*elbv2api.TargetGroupBinding)
- targetgroupbinding.AnnotationsToFields(tgb)
if err := v.checkRequiredFields(ctx, tgb); err != nil {
return err
}
@@ -72,13 +69,15 @@ func (v *targetGroupBindingValidator) ValidateCreate(ctx context.Context, obj ru
if err := v.checkTargetGroupVpcID(ctx, tgb); err != nil {
return err
}
+ if err := v.checkAssumeRoleConfig(tgb); err != nil {
+ return err
+ }
return nil
}
func (v *targetGroupBindingValidator) ValidateUpdate(ctx context.Context, obj runtime.Object, oldObj runtime.Object) error {
tgb := obj.(*elbv2api.TargetGroupBinding)
oldTgb := oldObj.(*elbv2api.TargetGroupBinding)
- targetgroupbinding.AnnotationsToFields(tgb)
if err := v.checkRequiredFields(ctx, tgb); err != nil {
return err
}
@@ -88,6 +87,9 @@ func (v *targetGroupBindingValidator) ValidateUpdate(ctx context.Context, obj ru
if err := v.checkNodeSelector(tgb); err != nil {
return err
}
+ if err := v.checkAssumeRoleConfig(tgb); err != nil {
+ return err
+ }
return nil
}
@@ -113,7 +115,7 @@ func (v *targetGroupBindingValidator) checkRequiredFields(ctx context.Context, t
By changing the object here I guarantee as early as possible that that assumption is true.
*/
- tgObj, err := v.getTargetGroupsByNameFromAWS(ctx, tgb.Spec.TargetGroupName)
+ tgObj, err := getTargetGroupsByNameFromAWS(ctx, v.elbv2Client, tgb)
if err != nil {
return fmt.Errorf("searching TargetGroup with name %s: %w", tgb.Spec.TargetGroupName, err)
}
@@ -215,7 +217,7 @@ func (v *targetGroupBindingValidator) checkTargetGroupVpcID(ctx context.Context,
// getTargetGroupIPAddressTypeFromAWS returns the target group IP address type of AWS target group
func (v *targetGroupBindingValidator) getTargetGroupIPAddressTypeFromAWS(ctx context.Context, tgb *elbv2api.TargetGroupBinding) (elbv2api.TargetGroupIPAddressType, error) {
- targetGroup, err := v.getTargetGroupFromAWS(ctx, tgb)
+ targetGroup, err := getTargetGroupFromAWS(ctx, v.elbv2Client, tgb)
if err != nil {
return "", err
}
@@ -231,43 +233,25 @@ func (v *targetGroupBindingValidator) getTargetGroupIPAddressTypeFromAWS(ctx con
return ipAddressType, nil
}
-// getTargetGroupFromAWS returns the AWS target group corresponding to the ARN
-func (v *targetGroupBindingValidator) getTargetGroupFromAWS(ctx context.Context, tgb *elbv2api.TargetGroupBinding) (*elbv2types.TargetGroup, error) {
- tgARN := tgb.Spec.TargetGroupARN
- req := &elbv2sdk.DescribeTargetGroupsInput{
- TargetGroupArns: []string{tgARN},
- }
- tgList, err := v.elbv2Client.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId).DescribeTargetGroupsAsList(ctx, req)
- if err != nil {
- return nil, err
- }
- if len(tgList) != 1 {
- return nil, errors.Errorf("expecting a single targetGroup but got %v", len(tgList))
- }
- return &tgList[0], nil
-}
-
func (v *targetGroupBindingValidator) getVpcIDFromAWS(ctx context.Context, tgb *elbv2api.TargetGroupBinding) (string, error) {
- targetGroup, err := v.getTargetGroupFromAWS(ctx, tgb)
+ targetGroup, err := getTargetGroupFromAWS(ctx, v.elbv2Client, tgb)
if err != nil {
return "", err
}
return awssdk.ToString(targetGroup.VpcId), nil
}
-// getTargetGroupFromAWS returns the AWS target group corresponding to the tgName
-func (v *targetGroupBindingValidator) getTargetGroupsByNameFromAWS(ctx context.Context, tgName string) (*elbv2types.TargetGroup, error) {
- req := &elbv2sdk.DescribeTargetGroupsInput{
- Names: []string{tgName},
- }
- tgList, err := v.elbv2Client.DescribeTargetGroupsAsList(ctx, req)
- if err != nil {
- return nil, err
+// checkAssumeRoleConfig various checks for using cross account target group bindings.
+func (v *targetGroupBindingValidator) checkAssumeRoleConfig(tgb *elbv2api.TargetGroupBinding) error {
+ if tgb.Spec.IamRoleArnToAssume == "" {
+ return nil
}
- if len(tgList) != 1 {
- return nil, errors.Errorf("expecting a single targetGroup with name [%s] but got %v", tgName, len(tgList))
+
+ if tgb.Spec.TargetType != nil && *tgb.Spec.TargetType == elbv2api.TargetTypeInstance {
+ return errors.New("Unable to use instance target type while using assume role")
}
- return &tgList[0], nil
+
+ return nil
}
// +kubebuilder:webhook:path=/validate-elbv2-k8s-aws-v1beta1-targetgroupbinding,mutating=false,failurePolicy=fail,groups=elbv2.k8s.aws,resources=targetgroupbindings,verbs=create;update,versions=v1beta1,name=vtargetgroupbinding.elbv2.k8s.aws,sideEffects=None,webhookVersions=v1,admissionReviewVersions=v1beta1
diff --git a/webhooks/elbv2/targetgroupbinding_validator_test.go b/webhooks/elbv2/targetgroupbinding_validator_test.go
index 4889604942..33fe3607a4 100644
--- a/webhooks/elbv2/targetgroupbinding_validator_test.go
+++ b/webhooks/elbv2/targetgroupbinding_validator_test.go
@@ -333,7 +333,7 @@ func Test_targetGroupBindingValidator_ValidateCreate(t *testing.T) {
elbv2Client := services.NewMockELBV2(ctrl)
for _, call := range tt.fields.describeTargetGroupsAsListCalls {
elbv2Client.EXPECT().DescribeTargetGroupsAsList(gomock.Any(), call.req).Return(call.resp, call.err)
- elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client).AnyTimes()
+ elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil).AnyTimes()
}
v := &targetGroupBindingValidator{
k8sClient: k8sClient,
@@ -1395,7 +1395,7 @@ func Test_targetGroupBindingValidator_checkTargetGroupVpcID(t *testing.T) {
elbv2Client := services.NewMockELBV2(ctrl)
for _, call := range tt.fields.describeTargetGroupsAsListCalls {
elbv2Client.EXPECT().DescribeTargetGroupsAsList(gomock.Any(), call.req).Return(call.resp, call.err)
- elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client).AnyTimes()
+ elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil).AnyTimes()
}
v := &targetGroupBindingValidator{
k8sClient: k8sClient,
@@ -1412,6 +1412,67 @@ func Test_targetGroupBindingValidator_checkTargetGroupVpcID(t *testing.T) {
}
}
+func TestCheckAssumeRoleConfig(t *testing.T) {
+ instance := elbv2api.TargetTypeInstance
+ ip := elbv2api.TargetTypeIP
+ testCases := []struct {
+ name string
+ tgb *elbv2api.TargetGroupBinding
+ err error
+ }{
+ {
+ name: "ip no assume role",
+ tgb: &elbv2api.TargetGroupBinding{
+ Spec: elbv2api.TargetGroupBindingSpec{
+ TargetType: &ip,
+ },
+ },
+ },
+ {
+ name: "instance no assume role",
+ tgb: &elbv2api.TargetGroupBinding{
+ Spec: elbv2api.TargetGroupBindingSpec{
+ TargetType: &instance,
+ },
+ },
+ },
+ {
+ name: "ip with assume role",
+ tgb: &elbv2api.TargetGroupBinding{
+ Spec: elbv2api.TargetGroupBindingSpec{
+ TargetType: &ip,
+ IamRoleArnToAssume: "foo",
+ },
+ },
+ },
+ {
+ name: "instance with assume role",
+ tgb: &elbv2api.TargetGroupBinding{
+ Spec: elbv2api.TargetGroupBindingSpec{
+ TargetType: &instance,
+ IamRoleArnToAssume: "foo",
+ },
+ },
+ err: errors.New("Unable to use instance target type while using assume role"),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v := &targetGroupBindingValidator{
+ logger: logr.New(&log.NullLogSink{}),
+ }
+
+ err := v.checkAssumeRoleConfig(tc.tgb)
+ if tc.err == nil {
+ assert.Nil(t, err)
+ } else {
+ assert.EqualError(t, err, tc.err.Error())
+ }
+ })
+ }
+}
+
func generateRandomString(n int, addChars ...rune) string {
const letters = "0123456789abcdef"
|