diff --git a/main.go b/main.go index c58ebe64ef..e3426cd699 100644 --- a/main.go +++ b/main.go @@ -81,7 +81,7 @@ func main() { ctrl.SetLogger(appLogger) klog.SetLoggerWithOptions(appLogger, klog.ContextualLogger(true)) - cloud, err := aws.NewCloud(controllerCFG.AWSConfig, metrics.Registry, ctrl.Log) + cloud, err := aws.NewCloud(controllerCFG.AWSConfig, metrics.Registry, ctrl.Log, nil) if err != nil { setupLog.Error(err, "unable to initialize AWS cloud") os.Exit(1) diff --git a/pkg/aws/cloud.go b/pkg/aws/cloud.go index 1ebe084dab..cc679b9436 100644 --- a/pkg/aws/cloud.go +++ b/pkg/aws/cloud.go @@ -24,6 +24,7 @@ import ( "github.com/prometheus/client_golang/prometheus" amerrors "k8s.io/apimachinery/pkg/util/errors" epresolver "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider" "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" ) @@ -59,7 +60,7 @@ type Cloud interface { } // NewCloud constructs new Cloud implementation. -func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger logr.Logger) (Cloud, error) { +func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger logr.Logger, awsClientsProvider provider.AWSClientsProvider) (Cloud, error) { hasIPv4 := true addrs, err := net.InterfaceAddrs() if err == nil { @@ -129,7 +130,14 @@ func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger l awsConfig.APIOptions = metrics.WithSDKMetricCollector(metricsCollector, awsConfig.APIOptions) } - ec2Service := services.NewEC2(awsConfig, endpointsResolver) + if awsClientsProvider == nil { + var err error + awsClientsProvider, err = provider.NewDefaultAWSClientsProvider(awsConfig, endpointsResolver) + if err != nil { + return nil, errors.Wrap(err, "failed to create aws clients provider") + } + } + ec2Service := services.NewEC2(awsClientsProvider) vpcID, err := getVpcID(cfg, ec2Service, ec2Metadata, logger) if err != nil { @@ -139,17 +147,16 @@ func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger l return &defaultCloud{ cfg: cfg, ec2: ec2Service, - elbv2: services.NewELBV2(awsConfig, endpointsResolver), - acm: services.NewACM(awsConfig, endpointsResolver), - wafv2: services.NewWAFv2(awsConfig, endpointsResolver), - wafRegional: services.NewWAFRegional(awsConfig, endpointsResolver, cfg.Region), - shield: services.NewShield(awsConfig, endpointsResolver), //done - rgt: services.NewRGT(awsConfig, endpointsResolver), + elbv2: services.NewELBV2(awsClientsProvider), + acm: services.NewACM(awsClientsProvider), + wafv2: services.NewWAFv2(awsClientsProvider), + wafRegional: services.NewWAFRegional(awsClientsProvider, cfg.Region), + shield: services.NewShield(awsClientsProvider), + rgt: services.NewRGT(awsClientsProvider), }, nil } func getVpcID(cfg CloudConfig, ec2Service services.EC2, ec2Metadata services.EC2Metadata, logger logr.Logger) (string, error) { - if cfg.VpcID != "" { logger.V(1).Info("vpcid is specified using flag --aws-vpc-id, controller will use the value", "vpc: ", cfg.VpcID) return cfg.VpcID, nil diff --git a/pkg/aws/provider/default_aws_clients_provider.go b/pkg/aws/provider/default_aws_clients_provider.go new file mode 100644 index 0000000000..5a1f50dda8 --- /dev/null +++ b/pkg/aws/provider/default_aws_clients_provider.go @@ -0,0 +1,109 @@ +package provider + +import ( + "context" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/acm" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi" + "github.com/aws/aws-sdk-go-v2/service/shield" + "github.com/aws/aws-sdk-go-v2/service/wafregional" + "github.com/aws/aws-sdk-go-v2/service/wafv2" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" +) + +type defaultAWSClientsProvider struct { + ec2Client *ec2.Client + elbv2Client *elasticloadbalancingv2.Client + acmClient *acm.Client + wafv2Client *wafv2.Client + wafRegionClient *wafregional.Client + shieldClient *shield.Client + rgtClient *resourcegroupstaggingapi.Client +} + +func NewDefaultAWSClientsProvider(cfg aws.Config, endpointsResolver *endpoints.Resolver) (*defaultAWSClientsProvider, error) { + ec2CustomEndpoint := endpointsResolver.EndpointFor(ec2.ServiceID) + elbv2CustomEndpoint := endpointsResolver.EndpointFor(elasticloadbalancingv2.ServiceID) + acmCustomEndpoint := endpointsResolver.EndpointFor(acm.ServiceID) + wafv2CustomEndpoint := endpointsResolver.EndpointFor(wafv2.ServiceID) + wafregionalCustomEndpoint := endpointsResolver.EndpointFor(wafregional.ServiceID) + shieldCustomEndpoint := endpointsResolver.EndpointFor(shield.ServiceID) + rgtCustomEndpoint := endpointsResolver.EndpointFor(resourcegroupstaggingapi.ServiceID) + + ec2Client := ec2.NewFromConfig(cfg, func(o *ec2.Options) { + if ec2CustomEndpoint != nil { + o.BaseEndpoint = ec2CustomEndpoint + } + }) + elbv2Client := elasticloadbalancingv2.NewFromConfig(cfg, func(o *elasticloadbalancingv2.Options) { + if elbv2CustomEndpoint != nil { + o.BaseEndpoint = elbv2CustomEndpoint + } + }) + acmClient := acm.NewFromConfig(cfg, func(o *acm.Options) { + if acmCustomEndpoint != nil { + o.BaseEndpoint = acmCustomEndpoint + } + }) + wafv2Client := wafv2.NewFromConfig(cfg, func(o *wafv2.Options) { + if wafv2CustomEndpoint != nil { + o.BaseEndpoint = wafv2CustomEndpoint + } + }) + wafregionalClient := wafregional.NewFromConfig(cfg, func(o *wafregional.Options) { + o.Region = cfg.Region + o.BaseEndpoint = wafregionalCustomEndpoint + }) + sheildClient := shield.NewFromConfig(cfg, func(o *shield.Options) { + o.Region = "us-east-1" + o.BaseEndpoint = shieldCustomEndpoint + }) + rgtClient := resourcegroupstaggingapi.NewFromConfig(cfg, func(o *resourcegroupstaggingapi.Options) { + if rgtCustomEndpoint != nil { + o.BaseEndpoint = rgtCustomEndpoint + } + }) + + return &defaultAWSClientsProvider{ + ec2Client: ec2Client, + elbv2Client: elbv2Client, + acmClient: acmClient, + wafv2Client: wafv2Client, + wafRegionClient: wafregionalClient, + shieldClient: sheildClient, + rgtClient: rgtClient, + }, nil +} + +// DO NOT REMOVE operationName as parameter, this is on purpose +// to retain the default behavior for OSS controller to use the default client for each aws service +// for our internal controller, we will choose different client based on operationName +func (p *defaultAWSClientsProvider) GetEC2Client(ctx context.Context, operationName string) (*ec2.Client, error) { + return p.ec2Client, nil +} + +func (p *defaultAWSClientsProvider) GetELBv2Client(ctx context.Context, operationName string) (*elasticloadbalancingv2.Client, error) { + return p.elbv2Client, nil +} + +func (p *defaultAWSClientsProvider) GetACMClient(ctx context.Context, operationName string) (*acm.Client, error) { + return p.acmClient, nil +} + +func (p *defaultAWSClientsProvider) GetWAFv2Client(ctx context.Context, operationName string) (*wafv2.Client, error) { + return p.wafv2Client, nil +} + +func (p *defaultAWSClientsProvider) GetWAFRegionClient(ctx context.Context, operationName string) (*wafregional.Client, error) { + return p.wafRegionClient, nil +} + +func (p *defaultAWSClientsProvider) GetShieldClient(ctx context.Context, operationName string) (*shield.Client, error) { + return p.shieldClient, nil +} + +func (p *defaultAWSClientsProvider) GetRGTClient(ctx context.Context, operationName string) (*resourcegroupstaggingapi.Client, error) { + return p.rgtClient, nil +} diff --git a/pkg/aws/provider/provider.go b/pkg/aws/provider/provider.go new file mode 100644 index 0000000000..95b1c47425 --- /dev/null +++ b/pkg/aws/provider/provider.go @@ -0,0 +1,22 @@ +package provider + +import ( + "context" + "github.com/aws/aws-sdk-go-v2/service/acm" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi" + "github.com/aws/aws-sdk-go-v2/service/shield" + "github.com/aws/aws-sdk-go-v2/service/wafregional" + "github.com/aws/aws-sdk-go-v2/service/wafv2" +) + +type AWSClientsProvider interface { + GetEC2Client(ctx context.Context, operationName string) (*ec2.Client, error) + GetELBv2Client(ctx context.Context, operationName string) (*elasticloadbalancingv2.Client, error) + GetACMClient(ctx context.Context, operationName string) (*acm.Client, error) + GetWAFv2Client(ctx context.Context, operationName string) (*wafv2.Client, error) + GetWAFRegionClient(ctx context.Context, operationName string) (*wafregional.Client, error) + GetShieldClient(ctx context.Context, operationName string) (*shield.Client, error) + GetRGTClient(ctx context.Context, operationName string) (*resourcegroupstaggingapi.Client, error) +} diff --git a/pkg/aws/services/acm.go b/pkg/aws/services/acm.go index 78ad10fd39..eab8e43191 100644 --- a/pkg/aws/services/acm.go +++ b/pkg/aws/services/acm.go @@ -2,10 +2,9 @@ package services import ( "context" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/acm" "github.com/aws/aws-sdk-go-v2/service/acm/types" - "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider" ) type ACM interface { @@ -15,24 +14,23 @@ type ACM interface { } // NewACM constructs new ACM implementation. -func NewACM(cfg aws.Config, endpointsResolver *endpoints.Resolver) ACM { - customEndpoint := endpointsResolver.EndpointFor(acm.ServiceID) +func NewACM(awsClientsProvider provider.AWSClientsProvider) ACM { return &acmClient{ - acmClient: acm.NewFromConfig(cfg, func(o *acm.Options) { - if customEndpoint != nil { - o.BaseEndpoint = customEndpoint - } - }), + awsClientsProvider: awsClientsProvider, } } type acmClient struct { - acmClient *acm.Client + awsClientsProvider provider.AWSClientsProvider } func (c *acmClient) ListCertificatesAsList(ctx context.Context, input *acm.ListCertificatesInput) ([]types.CertificateSummary, error) { var result []types.CertificateSummary - paginator := acm.NewListCertificatesPaginator(c.acmClient, input) + client, err := c.awsClientsProvider.GetACMClient(ctx, "ListCertificates") + if err != nil { + return nil, err + } + paginator := acm.NewListCertificatesPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -44,5 +42,9 @@ func (c *acmClient) ListCertificatesAsList(ctx context.Context, input *acm.ListC } func (c *acmClient) DescribeCertificateWithContext(ctx context.Context, input *acm.DescribeCertificateInput) (*acm.DescribeCertificateOutput, error) { - return c.acmClient.DescribeCertificate(ctx, input) + client, err := c.awsClientsProvider.GetACMClient(ctx, "DescribeCertificate") + if err != nil { + return nil, err + } + return client.DescribeCertificate(ctx, input) } diff --git a/pkg/aws/services/ec2.go b/pkg/aws/services/ec2.go index f52969cc51..3acd3564da 100644 --- a/pkg/aws/services/ec2.go +++ b/pkg/aws/services/ec2.go @@ -2,10 +2,9 @@ package services import ( "context" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" - "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider" ) type EC2 interface { @@ -37,28 +36,31 @@ type EC2 interface { } // NewEC2 constructs new EC2 implementation. -func NewEC2(cfg aws.Config, endpointsResolver *endpoints.Resolver) EC2 { - customEndpoint := endpointsResolver.EndpointFor(ec2.ServiceID) +func NewEC2(awsClientsProvider provider.AWSClientsProvider) EC2 { return &ec2Client{ - ec2Client: ec2.NewFromConfig(cfg, func(o *ec2.Options) { - if customEndpoint != nil { - o.BaseEndpoint = customEndpoint - } - }), + awsClientsProvider: awsClientsProvider, } } type ec2Client struct { - ec2Client *ec2.Client + awsClientsProvider provider.AWSClientsProvider } func (c *ec2Client) DescribeInstancesWithContext(ctx context.Context, input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) { - return c.ec2Client.DescribeInstances(ctx, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeInstances") + if err != nil { + return nil, err + } + return client.DescribeInstances(ctx, input) } func (c *ec2Client) DescribeInstancesAsList(ctx context.Context, input *ec2.DescribeInstancesInput) ([]types.Instance, error) { var result []types.Instance - paginator := ec2.NewDescribeInstancesPaginator(c.ec2Client, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeInstances") + if err != nil { + return nil, err + } + paginator := ec2.NewDescribeInstancesPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -73,7 +75,11 @@ func (c *ec2Client) DescribeInstancesAsList(ctx context.Context, input *ec2.Desc func (c *ec2Client) DescribeNetworkInterfacesAsList(ctx context.Context, input *ec2.DescribeNetworkInterfacesInput) ([]types.NetworkInterface, error) { var result []types.NetworkInterface - paginator := ec2.NewDescribeNetworkInterfacesPaginator(c.ec2Client, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeNetworkInterfaces") + if err != nil { + return nil, err + } + paginator := ec2.NewDescribeNetworkInterfacesPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -86,7 +92,11 @@ func (c *ec2Client) DescribeNetworkInterfacesAsList(ctx context.Context, input * func (c *ec2Client) DescribeSecurityGroupsAsList(ctx context.Context, input *ec2.DescribeSecurityGroupsInput) ([]types.SecurityGroup, error) { var result []types.SecurityGroup - paginator := ec2.NewDescribeSecurityGroupsPaginator(c.ec2Client, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeSecurityGroups") + if err != nil { + return nil, err + } + paginator := ec2.NewDescribeSecurityGroupsPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -99,7 +109,11 @@ func (c *ec2Client) DescribeSecurityGroupsAsList(ctx context.Context, input *ec2 func (c *ec2Client) DescribeSubnetsAsList(ctx context.Context, input *ec2.DescribeSubnetsInput) ([]types.Subnet, error) { var result []types.Subnet - paginator := ec2.NewDescribeSubnetsPaginator(c.ec2Client, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeSubnets") + if err != nil { + return nil, err + } + paginator := ec2.NewDescribeSubnetsPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -112,7 +126,11 @@ func (c *ec2Client) DescribeSubnetsAsList(ctx context.Context, input *ec2.Descri func (c *ec2Client) DescribeVPCsAsList(ctx context.Context, input *ec2.DescribeVpcsInput) ([]types.Vpc, error) { var result []types.Vpc - paginator := ec2.NewDescribeVpcsPaginator(c.ec2Client, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeVpcs") + if err != nil { + return nil, err + } + paginator := ec2.NewDescribeVpcsPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -124,33 +142,65 @@ func (c *ec2Client) DescribeVPCsAsList(ctx context.Context, input *ec2.DescribeV } func (c *ec2Client) CreateTagsWithContext(ctx context.Context, input *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error) { - return c.ec2Client.CreateTags(ctx, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "CreateTags") + if err != nil { + return nil, err + } + return client.CreateTags(ctx, input) } func (c *ec2Client) DeleteTagsWithContext(ctx context.Context, input *ec2.DeleteTagsInput) (*ec2.DeleteTagsOutput, error) { - return c.ec2Client.DeleteTags(ctx, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "DeleteTags") + if err != nil { + return nil, err + } + return client.DeleteTags(ctx, input) } func (c *ec2Client) CreateSecurityGroupWithContext(ctx context.Context, input *ec2.CreateSecurityGroupInput) (*ec2.CreateSecurityGroupOutput, error) { - return c.ec2Client.CreateSecurityGroup(ctx, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "CreateSecurityGroup") + if err != nil { + return nil, err + } + return client.CreateSecurityGroup(ctx, input) } func (c *ec2Client) DeleteSecurityGroupWithContext(ctx context.Context, input *ec2.DeleteSecurityGroupInput) (*ec2.DeleteSecurityGroupOutput, error) { - return c.ec2Client.DeleteSecurityGroup(ctx, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "DeleteSecurityGroup") + if err != nil { + return nil, err + } + return client.DeleteSecurityGroup(ctx, input) } func (c *ec2Client) AuthorizeSecurityGroupIngressWithContext(ctx context.Context, input *ec2.AuthorizeSecurityGroupIngressInput) (*ec2.AuthorizeSecurityGroupIngressOutput, error) { - return c.ec2Client.AuthorizeSecurityGroupIngress(ctx, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "AuthorizeSecurityGroupIngress") + if err != nil { + return nil, err + } + return client.AuthorizeSecurityGroupIngress(ctx, input) } func (c *ec2Client) RevokeSecurityGroupIngressWithContext(ctx context.Context, input *ec2.RevokeSecurityGroupIngressInput) (*ec2.RevokeSecurityGroupIngressOutput, error) { - return c.ec2Client.RevokeSecurityGroupIngress(ctx, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "RevokeSecurityGroupIngress") + if err != nil { + return nil, err + } + return client.RevokeSecurityGroupIngress(ctx, input) } func (c *ec2Client) DescribeAvailabilityZonesWithContext(ctx context.Context, input *ec2.DescribeAvailabilityZonesInput) (*ec2.DescribeAvailabilityZonesOutput, error) { - return c.ec2Client.DescribeAvailabilityZones(ctx, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeAvailabilityZones") + if err != nil { + return nil, err + } + return client.DescribeAvailabilityZones(ctx, input) } func (c *ec2Client) DescribeVpcsWithContext(ctx context.Context, input *ec2.DescribeVpcsInput) (*ec2.DescribeVpcsOutput, error) { - return c.ec2Client.DescribeVpcs(ctx, input) + client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeVpcs") + if err != nil { + return nil, err + } + return client.DescribeVpcs(ctx, input) } diff --git a/pkg/aws/services/elbv2.go b/pkg/aws/services/elbv2.go index 0ff0e7d187..4b49990656 100644 --- a/pkg/aws/services/elbv2.go +++ b/pkg/aws/services/elbv2.go @@ -2,12 +2,11 @@ package services import ( "context" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider" "time" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" - "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" ) type ELBV2 interface { @@ -62,154 +61,284 @@ type ELBV2 interface { ModifyListenerAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyListenerAttributesInput) (*elasticloadbalancingv2.ModifyListenerAttributesOutput, error) } -func NewELBV2(cfg aws.Config, endpointsResolver *endpoints.Resolver) ELBV2 { - customEndpoint := endpointsResolver.EndpointFor(elasticloadbalancingv2.ServiceID) - client := elasticloadbalancingv2.NewFromConfig(cfg, func(o *elasticloadbalancingv2.Options) { - if customEndpoint != nil { - o.BaseEndpoint = customEndpoint - } - }) - return &elbv2Client{elbv2Client: client} +func NewELBV2(awsClientsProvider provider.AWSClientsProvider) ELBV2 { + return &elbv2Client{ + awsClientsProvider: awsClientsProvider, + } } // default implementation for ELBV2. type elbv2Client struct { - elbv2Client *elasticloadbalancingv2.Client + awsClientsProvider provider.AWSClientsProvider } func (c *elbv2Client) AddListenerCertificatesWithContext(ctx context.Context, input *elasticloadbalancingv2.AddListenerCertificatesInput) (*elasticloadbalancingv2.AddListenerCertificatesOutput, error) { - return c.elbv2Client.AddListenerCertificates(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "AddListenerCertificates") + if err != nil { + return nil, err + } + return client.AddListenerCertificates(ctx, input) } func (c *elbv2Client) RemoveListenerCertificatesWithContext(ctx context.Context, input *elasticloadbalancingv2.RemoveListenerCertificatesInput) (*elasticloadbalancingv2.RemoveListenerCertificatesOutput, error) { - return c.elbv2Client.RemoveListenerCertificates(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "RemoveListenerCertificates") + if err != nil { + return nil, err + } + return client.RemoveListenerCertificates(ctx, input) } func (c *elbv2Client) DescribeListenersWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeListenersInput) (*elasticloadbalancingv2.DescribeListenersOutput, error) { - return c.elbv2Client.DescribeListeners(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeListeners") + if err != nil { + return nil, err + } + return client.DescribeListeners(ctx, input) } func (c *elbv2Client) DescribeRulesWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeRulesInput) (*elasticloadbalancingv2.DescribeRulesOutput, error) { - return c.elbv2Client.DescribeRules(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeRules") + if err != nil { + return nil, err + } + return client.DescribeRules(ctx, input) } func (c *elbv2Client) RegisterTargetsWithContext(ctx context.Context, input *elasticloadbalancingv2.RegisterTargetsInput) (*elasticloadbalancingv2.RegisterTargetsOutput, error) { - return c.elbv2Client.RegisterTargets(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "RegisterTargets") + if err != nil { + return nil, err + } + return client.RegisterTargets(ctx, input) } func (c *elbv2Client) DeregisterTargetsWithContext(ctx context.Context, input *elasticloadbalancingv2.DeregisterTargetsInput) (*elasticloadbalancingv2.DeregisterTargetsOutput, error) { - return c.elbv2Client.DeregisterTargets(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DeregisterTargets") + if err != nil { + return nil, err + } + return client.DeregisterTargets(ctx, input) } func (c *elbv2Client) DescribeTrustStoresWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeTrustStoresInput) (*elasticloadbalancingv2.DescribeTrustStoresOutput, error) { - return c.elbv2Client.DescribeTrustStores(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeTrustStores") + if err != nil { + return nil, err + } + return client.DescribeTrustStores(ctx, input) } func (c *elbv2Client) ModifyRuleWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyRuleInput) (*elasticloadbalancingv2.ModifyRuleOutput, error) { - return c.elbv2Client.ModifyRule(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "ModifyRule") + if err != nil { + return nil, err + } + return client.ModifyRule(ctx, input) } func (c *elbv2Client) DeleteRuleWithContext(ctx context.Context, input *elasticloadbalancingv2.DeleteRuleInput) (*elasticloadbalancingv2.DeleteRuleOutput, error) { - return c.elbv2Client.DeleteRule(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DeleteRule") + if err != nil { + return nil, err + } + return client.DeleteRule(ctx, input) } func (c *elbv2Client) CreateRuleWithContext(ctx context.Context, input *elasticloadbalancingv2.CreateRuleInput) (*elasticloadbalancingv2.CreateRuleOutput, error) { - return c.elbv2Client.CreateRule(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "CreateRule") + if err != nil { + return nil, err + } + return client.CreateRule(ctx, input) } func (c *elbv2Client) WaitUntilLoadBalancerAvailableWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeLoadBalancersInput) error { - waiter := elasticloadbalancingv2.NewLoadBalancerAvailableWaiter(c.elbv2Client) - err := waiter.Wait(ctx, input, 5*time.Minute) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeLoadBalancers") + if err != nil { + return err + } + waiter := elasticloadbalancingv2.NewLoadBalancerAvailableWaiter(client) + err = waiter.Wait(ctx, input, 5*time.Minute) return err } func (c *elbv2Client) DescribeLoadBalancersWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeLoadBalancersInput) (*elasticloadbalancingv2.DescribeLoadBalancersOutput, error) { - return c.elbv2Client.DescribeLoadBalancers(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeLoadBalancers") + if err != nil { + return nil, err + } + return client.DescribeLoadBalancers(ctx, input) } func (c *elbv2Client) DescribeTargetHealthWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeTargetHealthInput) (*elasticloadbalancingv2.DescribeTargetHealthOutput, error) { - return c.elbv2Client.DescribeTargetHealth(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeTargetHealth") + if err != nil { + return nil, err + } + return client.DescribeTargetHealth(ctx, input) } func (c *elbv2Client) DescribeTargetGroupsWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeTargetGroupsInput) (*elasticloadbalancingv2.DescribeTargetGroupsOutput, error) { - return c.elbv2Client.DescribeTargetGroups(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeTargetGroups") + if err != nil { + return nil, err + } + return client.DescribeTargetGroups(ctx, input) } func (c *elbv2Client) DeleteTargetGroupWithContext(ctx context.Context, input *elasticloadbalancingv2.DeleteTargetGroupInput) (*elasticloadbalancingv2.DeleteTargetGroupOutput, error) { - return c.elbv2Client.DeleteTargetGroup(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DeleteTargetGroup") + if err != nil { + return nil, err + } + return client.DeleteTargetGroup(ctx, input) } func (c *elbv2Client) ModifyTargetGroupWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyTargetGroupInput) (*elasticloadbalancingv2.ModifyTargetGroupOutput, error) { - return c.elbv2Client.ModifyTargetGroup(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "ModifyTargetGroup") + if err != nil { + return nil, err + } + return client.ModifyTargetGroup(ctx, input) } func (c *elbv2Client) CreateTargetGroupWithContext(ctx context.Context, input *elasticloadbalancingv2.CreateTargetGroupInput) (*elasticloadbalancingv2.CreateTargetGroupOutput, error) { - return c.elbv2Client.CreateTargetGroup(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "CreateTargetGroup") + if err != nil { + return nil, err + } + return client.CreateTargetGroup(ctx, input) } func (c *elbv2Client) DescribeTargetGroupAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeTargetGroupAttributesInput) (*elasticloadbalancingv2.DescribeTargetGroupAttributesOutput, error) { - return c.elbv2Client.DescribeTargetGroupAttributes(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeTargetGroupAttributes") + if err != nil { + return nil, err + } + return client.DescribeTargetGroupAttributes(ctx, input) } func (c *elbv2Client) ModifyTargetGroupAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyTargetGroupAttributesInput) (*elasticloadbalancingv2.ModifyTargetGroupAttributesOutput, error) { - return c.elbv2Client.ModifyTargetGroupAttributes(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "ModifyTargetGroupAttributes") + if err != nil { + return nil, err + } + return client.ModifyTargetGroupAttributes(ctx, input) } func (c *elbv2Client) SetSecurityGroupsWithContext(ctx context.Context, input *elasticloadbalancingv2.SetSecurityGroupsInput) (*elasticloadbalancingv2.SetSecurityGroupsOutput, error) { - return c.elbv2Client.SetSecurityGroups(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "SetSecurityGroups") + if err != nil { + return nil, err + } + return client.SetSecurityGroups(ctx, input) } func (c *elbv2Client) SetSubnetsWithContext(ctx context.Context, input *elasticloadbalancingv2.SetSubnetsInput) (*elasticloadbalancingv2.SetSubnetsOutput, error) { - return c.elbv2Client.SetSubnets(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "SetSubnets") + if err != nil { + return nil, err + } + return client.SetSubnets(ctx, input) } func (c *elbv2Client) SetIpAddressTypeWithContext(ctx context.Context, input *elasticloadbalancingv2.SetIpAddressTypeInput) (*elasticloadbalancingv2.SetIpAddressTypeOutput, error) { - return c.elbv2Client.SetIpAddressType(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "SetIpAddressType") + if err != nil { + return nil, err + } + return client.SetIpAddressType(ctx, input) } func (c *elbv2Client) DeleteLoadBalancerWithContext(ctx context.Context, input *elasticloadbalancingv2.DeleteLoadBalancerInput) (*elasticloadbalancingv2.DeleteLoadBalancerOutput, error) { - return c.elbv2Client.DeleteLoadBalancer(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DeleteLoadBalancer") + if err != nil { + return nil, err + } + return client.DeleteLoadBalancer(ctx, input) } func (c *elbv2Client) CreateLoadBalancerWithContext(ctx context.Context, input *elasticloadbalancingv2.CreateLoadBalancerInput) (*elasticloadbalancingv2.CreateLoadBalancerOutput, error) { - return c.elbv2Client.CreateLoadBalancer(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "CreateLoadBalancer") + if err != nil { + return nil, err + } + return client.CreateLoadBalancer(ctx, input) } func (c *elbv2Client) DescribeLoadBalancerAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeLoadBalancerAttributesInput) (*elasticloadbalancingv2.DescribeLoadBalancerAttributesOutput, error) { - return c.elbv2Client.DescribeLoadBalancerAttributes(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeLoadBalancerAttributes") + if err != nil { + return nil, err + } + return client.DescribeLoadBalancerAttributes(ctx, input) } func (c *elbv2Client) ModifyLoadBalancerAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyLoadBalancerAttributesInput) (*elasticloadbalancingv2.ModifyLoadBalancerAttributesOutput, error) { - return c.elbv2Client.ModifyLoadBalancerAttributes(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "ModifyLoadBalancerAttributes") + if err != nil { + return nil, err + } + return client.ModifyLoadBalancerAttributes(ctx, input) } func (c *elbv2Client) ModifyListenerWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyListenerInput) (*elasticloadbalancingv2.ModifyListenerOutput, error) { - return c.elbv2Client.ModifyListener(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "ModifyListener") + if err != nil { + return nil, err + } + return client.ModifyListener(ctx, input) } func (c *elbv2Client) DeleteListenerWithContext(ctx context.Context, input *elasticloadbalancingv2.DeleteListenerInput) (*elasticloadbalancingv2.DeleteListenerOutput, error) { - return c.elbv2Client.DeleteListener(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DeleteListener") + if err != nil { + return nil, err + } + return client.DeleteListener(ctx, input) } func (c *elbv2Client) CreateListenerWithContext(ctx context.Context, input *elasticloadbalancingv2.CreateListenerInput) (*elasticloadbalancingv2.CreateListenerOutput, error) { - return c.elbv2Client.CreateListener(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "CreateListener") + if err != nil { + return nil, err + } + return client.CreateListener(ctx, input) } func (c *elbv2Client) DescribeTagsWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeTagsInput) (*elasticloadbalancingv2.DescribeTagsOutput, error) { - return c.elbv2Client.DescribeTags(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeTags") + if err != nil { + return nil, err + } + return client.DescribeTags(ctx, input) } func (c *elbv2Client) AddTagsWithContext(ctx context.Context, input *elasticloadbalancingv2.AddTagsInput) (*elasticloadbalancingv2.AddTagsOutput, error) { - return c.elbv2Client.AddTags(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "AddTags") + if err != nil { + return nil, err + } + return client.AddTags(ctx, input) } func (c *elbv2Client) RemoveTagsWithContext(ctx context.Context, input *elasticloadbalancingv2.RemoveTagsInput) (*elasticloadbalancingv2.RemoveTagsOutput, error) { - return c.elbv2Client.RemoveTags(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "RemoveTags") + if err != nil { + return nil, err + } + return client.RemoveTags(ctx, input) } func (c *elbv2Client) DescribeLoadBalancersAsList(ctx context.Context, input *elasticloadbalancingv2.DescribeLoadBalancersInput) ([]types.LoadBalancer, error) { var result []types.LoadBalancer - paginator := elasticloadbalancingv2.NewDescribeLoadBalancersPaginator(c.elbv2Client, input) + var client *elasticloadbalancingv2.Client + var err error + client, err = c.awsClientsProvider.GetELBv2Client(ctx, "DescribeLoadBalancers") + if err != nil { + return nil, err + } + paginator := elasticloadbalancingv2.NewDescribeLoadBalancersPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -222,7 +351,13 @@ func (c *elbv2Client) DescribeLoadBalancersAsList(ctx context.Context, input *el func (c *elbv2Client) DescribeTargetGroupsAsList(ctx context.Context, input *elasticloadbalancingv2.DescribeTargetGroupsInput) ([]types.TargetGroup, error) { var result []types.TargetGroup - paginator := elasticloadbalancingv2.NewDescribeTargetGroupsPaginator(c.elbv2Client, input) + var client *elasticloadbalancingv2.Client + var err error + client, err = c.awsClientsProvider.GetELBv2Client(ctx, "DescribeTargetGroups") + if err != nil { + return nil, err + } + paginator := elasticloadbalancingv2.NewDescribeTargetGroupsPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -235,7 +370,13 @@ func (c *elbv2Client) DescribeTargetGroupsAsList(ctx context.Context, input *ela func (c *elbv2Client) DescribeListenersAsList(ctx context.Context, input *elasticloadbalancingv2.DescribeListenersInput) ([]types.Listener, error) { var result []types.Listener - paginator := elasticloadbalancingv2.NewDescribeListenersPaginator(c.elbv2Client, input) + var client *elasticloadbalancingv2.Client + var err error + client, err = c.awsClientsProvider.GetELBv2Client(ctx, "DescribeListeners") + if err != nil { + return nil, err + } + paginator := elasticloadbalancingv2.NewDescribeListenersPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -248,7 +389,13 @@ func (c *elbv2Client) DescribeListenersAsList(ctx context.Context, input *elasti func (c *elbv2Client) DescribeListenerCertificatesAsList(ctx context.Context, input *elasticloadbalancingv2.DescribeListenerCertificatesInput) ([]types.Certificate, error) { var result []types.Certificate - paginator := elasticloadbalancingv2.NewDescribeListenerCertificatesPaginator(c.elbv2Client, input) + var client *elasticloadbalancingv2.Client + var err error + client, err = c.awsClientsProvider.GetELBv2Client(ctx, "DescribeListenerCertificates") + if err != nil { + return nil, err + } + paginator := elasticloadbalancingv2.NewDescribeListenerCertificatesPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -261,7 +408,13 @@ func (c *elbv2Client) DescribeListenerCertificatesAsList(ctx context.Context, in func (c *elbv2Client) DescribeRulesAsList(ctx context.Context, input *elasticloadbalancingv2.DescribeRulesInput) ([]types.Rule, error) { var result []types.Rule - paginator := elasticloadbalancingv2.NewDescribeRulesPaginator(c.elbv2Client, input) + var client *elasticloadbalancingv2.Client + var err error + client, err = c.awsClientsProvider.GetELBv2Client(ctx, "DescribeRules") + if err != nil { + return nil, err + } + paginator := elasticloadbalancingv2.NewDescribeRulesPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { @@ -273,9 +426,17 @@ func (c *elbv2Client) DescribeRulesAsList(ctx context.Context, input *elasticloa } func (c *elbv2Client) DescribeListenerAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeListenerAttributesInput) (*elasticloadbalancingv2.DescribeListenerAttributesOutput, error) { - return c.elbv2Client.DescribeListenerAttributes(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeListenerAttributes") + if err != nil { + return nil, err + } + return client.DescribeListenerAttributes(ctx, input) } func (c *elbv2Client) ModifyListenerAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyListenerAttributesInput) (*elasticloadbalancingv2.ModifyListenerAttributesOutput, error) { - return c.elbv2Client.ModifyListenerAttributes(ctx, input) + client, err := c.awsClientsProvider.GetELBv2Client(ctx, "ModifyListenerAttributes") + if err != nil { + return nil, err + } + return client.ModifyListenerAttributes(ctx, input) } diff --git a/pkg/aws/services/rgt.go b/pkg/aws/services/rgt.go index 1d39a0bf11..1558e0e4e1 100644 --- a/pkg/aws/services/rgt.go +++ b/pkg/aws/services/rgt.go @@ -5,7 +5,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi" rgttypes "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi/types" - "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider" ) const ( @@ -18,23 +18,23 @@ type RGT interface { } // NewRGT constructs new RGT implementation. -func NewRGT(cfg aws.Config, endpointsResolver *endpoints.Resolver) RGT { - customEndpoint := endpointsResolver.EndpointFor(resourcegroupstaggingapi.ServiceID) - client := resourcegroupstaggingapi.NewFromConfig(cfg, func(o *resourcegroupstaggingapi.Options) { - if customEndpoint != nil { - o.BaseEndpoint = customEndpoint - } - }) - return &rgtClient{rgtClient: client} +func NewRGT(awsClientsProvider provider.AWSClientsProvider) RGT { + return &rgtClient{ + awsClientsProvider: awsClientsProvider, + } } type rgtClient struct { - rgtClient *resourcegroupstaggingapi.Client + awsClientsProvider provider.AWSClientsProvider } func (c *rgtClient) GetResourcesAsList(ctx context.Context, input *resourcegroupstaggingapi.GetResourcesInput) ([]rgttypes.ResourceTagMapping, error) { + client, err := c.awsClientsProvider.GetRGTClient(ctx, "GetResources") + if err != nil { + return nil, err + } var result []rgttypes.ResourceTagMapping - paginator := resourcegroupstaggingapi.NewGetResourcesPaginator(c.rgtClient, input) + paginator := resourcegroupstaggingapi.NewGetResourcesPaginator(client, input) for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { diff --git a/pkg/aws/services/shield.go b/pkg/aws/services/shield.go index 02def20c4e..ad97be240b 100644 --- a/pkg/aws/services/shield.go +++ b/pkg/aws/services/shield.go @@ -2,9 +2,8 @@ package services import ( "context" - "github.com/aws/aws-sdk-go-v2/aws" shieldsdk "github.com/aws/aws-sdk-go-v2/service/shield" - "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider" ) type Shield interface { @@ -15,33 +14,45 @@ type Shield interface { } // NewShield constructs new Shield implementation. -func NewShield(cfg aws.Config, endpointsResolver *endpoints.Resolver) Shield { - customEndpoint := endpointsResolver.EndpointFor(shieldsdk.ServiceID) - // shield is only available as a global API in us-east-1. - client := shieldsdk.NewFromConfig(cfg, func(o *shieldsdk.Options) { - o.Region = "us-east-1" - o.BaseEndpoint = customEndpoint - }) - return &shieldClient{shieldClient: client} +func NewShield(awsClientsProvider provider.AWSClientsProvider) Shield { + return &shieldClient{ + awsClientsProvider: awsClientsProvider, + } } // default implementation for Shield. type shieldClient struct { - shieldClient *shieldsdk.Client + awsClientsProvider provider.AWSClientsProvider } func (s *shieldClient) GetSubscriptionStateWithContext(ctx context.Context, input *shieldsdk.GetSubscriptionStateInput) (*shieldsdk.GetSubscriptionStateOutput, error) { - return s.shieldClient.GetSubscriptionState(ctx, input) + client, err := s.awsClientsProvider.GetShieldClient(ctx, "GetSubscriptionState") + if err != nil { + return nil, err + } + return client.GetSubscriptionState(ctx, input) } func (s *shieldClient) DescribeProtectionWithContext(ctx context.Context, input *shieldsdk.DescribeProtectionInput) (*shieldsdk.DescribeProtectionOutput, error) { - return s.shieldClient.DescribeProtection(ctx, input) + client, err := s.awsClientsProvider.GetShieldClient(ctx, "DescribeProtection") + if err != nil { + return nil, err + } + return client.DescribeProtection(ctx, input) } func (s *shieldClient) CreateProtectionWithContext(ctx context.Context, input *shieldsdk.CreateProtectionInput) (*shieldsdk.CreateProtectionOutput, error) { - return s.shieldClient.CreateProtection(ctx, input) + client, err := s.awsClientsProvider.GetShieldClient(ctx, "CreateProtection") + if err != nil { + return nil, err + } + return client.CreateProtection(ctx, input) } func (s *shieldClient) DeleteProtectionWithContext(ctx context.Context, input *shieldsdk.DeleteProtectionInput) (*shieldsdk.DeleteProtectionOutput, error) { - return s.shieldClient.DeleteProtection(ctx, input) + client, err := s.awsClientsProvider.GetShieldClient(ctx, "DeleteProtection") + if err != nil { + return nil, err + } + return client.DeleteProtection(ctx, input) } diff --git a/pkg/aws/services/wafregional.go b/pkg/aws/services/wafregional.go index e11a81c376..9ffdf0661f 100644 --- a/pkg/aws/services/wafregional.go +++ b/pkg/aws/services/wafregional.go @@ -2,9 +2,8 @@ package services import ( "context" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/wafregional" - "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider" ) type WAFRegional interface { @@ -16,21 +15,17 @@ type WAFRegional interface { } // NewWAFRegional constructs new WAFRegional implementation. -func NewWAFRegional(cfg aws.Config, endpointsResolver *endpoints.Resolver, region string) WAFRegional { - customEndpoint := endpointsResolver.EndpointFor(wafregional.ServiceID) +func NewWAFRegional(awsClientsProvider provider.AWSClientsProvider, region string) WAFRegional { return &wafRegionalClient{ - wafRegionalClient: wafregional.NewFromConfig(cfg, func(o *wafregional.Options) { - o.Region = region - o.BaseEndpoint = customEndpoint - }), - region: region, + awsClientsProvider: awsClientsProvider, + region: region, } } // default implementation for WAFRegional. type wafRegionalClient struct { - wafRegionalClient *wafregional.Client - region string + awsClientsProvider provider.AWSClientsProvider + region string } func (c *wafRegionalClient) Available() bool { @@ -42,13 +37,25 @@ func (c *wafRegionalClient) Available() bool { } func (c *wafRegionalClient) AssociateWebACLWithContext(ctx context.Context, input *wafregional.AssociateWebACLInput) (*wafregional.AssociateWebACLOutput, error) { - return c.wafRegionalClient.AssociateWebACL(ctx, input) + client, err := c.awsClientsProvider.GetWAFRegionClient(ctx, "AssociateWebACL") + if err != nil { + return nil, err + } + return client.AssociateWebACL(ctx, input) } func (c *wafRegionalClient) DisassociateWebACLWithContext(ctx context.Context, input *wafregional.DisassociateWebACLInput) (*wafregional.DisassociateWebACLOutput, error) { - return c.wafRegionalClient.DisassociateWebACL(ctx, input) + client, err := c.awsClientsProvider.GetWAFRegionClient(ctx, "DisassociateWebACL") + if err != nil { + return nil, err + } + return client.DisassociateWebACL(ctx, input) } func (c *wafRegionalClient) GetWebACLForResourceWithContext(ctx context.Context, input *wafregional.GetWebACLForResourceInput) (*wafregional.GetWebACLForResourceOutput, error) { - return c.wafRegionalClient.GetWebACLForResource(ctx, input) + client, err := c.awsClientsProvider.GetWAFRegionClient(ctx, "GetWebACLForResource") + if err != nil { + return nil, err + } + return client.GetWebACLForResource(ctx, input) } diff --git a/pkg/aws/services/wafv2.go b/pkg/aws/services/wafv2.go index 3547a8678c..e186409200 100644 --- a/pkg/aws/services/wafv2.go +++ b/pkg/aws/services/wafv2.go @@ -2,9 +2,8 @@ package services import ( "context" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/wafv2" - "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider" ) type WAFv2 interface { @@ -14,28 +13,36 @@ type WAFv2 interface { } // NewWAFv2 constructs new WAFv2 implementation. -func NewWAFv2(cfg aws.Config, endpointsResolver *endpoints.Resolver) WAFv2 { - customEndpoint := endpointsResolver.EndpointFor(wafv2.ServiceID) - client := wafv2.NewFromConfig(cfg, func(o *wafv2.Options) { - if customEndpoint != nil { - o.BaseEndpoint = customEndpoint - } - }) - return &wafv2Client{wafv2Client: client} +func NewWAFv2(awsClientsProvider provider.AWSClientsProvider) WAFv2 { + return &wafv2Client{ + awsClientsProvider: awsClientsProvider, + } } type wafv2Client struct { - wafv2Client *wafv2.Client + awsClientsProvider provider.AWSClientsProvider } func (c *wafv2Client) AssociateWebACLWithContext(ctx context.Context, req *wafv2.AssociateWebACLInput) (*wafv2.AssociateWebACLOutput, error) { - return c.wafv2Client.AssociateWebACL(ctx, req) + client, err := c.awsClientsProvider.GetWAFv2Client(ctx, "AssociateWebACL") + if err != nil { + return nil, err + } + return client.AssociateWebACL(ctx, req) } func (c *wafv2Client) DisassociateWebACLWithContext(ctx context.Context, req *wafv2.DisassociateWebACLInput) (*wafv2.DisassociateWebACLOutput, error) { - return c.wafv2Client.DisassociateWebACL(ctx, req) + client, err := c.awsClientsProvider.GetWAFv2Client(ctx, "DisassociateWebACL") + if err != nil { + return nil, err + } + return client.DisassociateWebACL(ctx, req) } func (c *wafv2Client) GetWebACLForResourceWithContext(ctx context.Context, req *wafv2.GetWebACLForResourceInput) (*wafv2.GetWebACLForResourceOutput, error) { - return c.wafv2Client.GetWebACLForResource(ctx, req) + client, err := c.awsClientsProvider.GetWAFv2Client(ctx, "GetWebACLForResource") + if err != nil { + return nil, err + } + return client.GetWebACLForResource(ctx, req) } diff --git a/test/framework/framework.go b/test/framework/framework.go index 4402f16d17..52171817a1 100644 --- a/test/framework/framework.go +++ b/test/framework/framework.go @@ -62,7 +62,7 @@ func InitFramework() (*Framework, error) { VpcID: globalOptions.AWSVPCID, MaxRetries: 3, ThrottleConfig: throttle.NewDefaultServiceOperationsThrottleConfig(), - }, nil, logger) + }, nil, logger, nil) if err != nil { return nil, err }