3232import java .util .Arrays ;
3333import java .util .Collections ;
3434import java .util .List ;
35+ import java .util .Optional ;
3536import java .util .concurrent .CompletableFuture ;
3637import java .util .concurrent .CompletionStage ;
3738import javax .crypto .Mac ;
5354import software .amazon .awssdk .auth .signer .internal .SignerConstant ;
5455import software .amazon .awssdk .regions .Region ;
5556import software .amazon .awssdk .regions .providers .DefaultAwsRegionProviderChain ;
57+ import software .amazon .awssdk .services .sts .StsClient ;
58+ import software .amazon .awssdk .services .sts .StsClientBuilder ;
59+ import software .amazon .awssdk .services .sts .auth .StsAssumeRoleCredentialsProvider ;
60+ import software .amazon .awssdk .services .sts .auth .StsGetSessionTokenCredentialsProvider ;
61+ import software .amazon .awssdk .services .sts .model .AssumeRoleRequest ;
5662
5763import static software .amazon .awssdk .auth .credentials .DefaultCredentialsProvider .create ;
5864
@@ -106,6 +112,12 @@ public String getPath() {
106112 }
107113 };
108114
115+ private final static DriverOption ROLE_OPTION = new DriverOption () {
116+ public String getPath () {
117+ return "advanced.auth-provider.aws-role" ;
118+ }
119+ };
120+
109121 /**
110122 * This constructor is provided so that the driver can create
111123 * instances of this class based on configuration. For example:
@@ -130,7 +142,8 @@ public String getPath() {
130142 * Unused for this plugin.
131143 */
132144 public SigV4AuthProvider (DriverContext driverContext ) {
133- this (driverContext .getConfig ().getDefaultProfile ().getString (REGION_OPTION , null ));
145+ this (driverContext .getConfig ().getDefaultProfile ().getString (REGION_OPTION , getDefaultRegion ()),
146+ driverContext .getConfig ().getDefaultProfile ().getString (ROLE_OPTION , null ));
134147 }
135148
136149 /**
@@ -139,8 +152,8 @@ public SigV4AuthProvider(DriverContext driverContext) {
139152 * null value indicates to use the AWS_REGION environment
140153 * variable, or the "aws.region" system property to configure it.
141154 */
142- public SigV4AuthProvider (final String region ) {
143- this (create (), region );
155+ public SigV4AuthProvider (final String region , final String roleArn ) {
156+ this (Optional . ofNullable ( roleArn ). map ( r ->( AwsCredentialsProvider ) createSTSRoleCredentialProvider ( r , "keyspaces-session" , region )). orElse ( create () ), region );
144157 }
145158
146159 /**
@@ -373,4 +386,36 @@ static int indexOf(byte[] target, byte[] pattern) {
373386 // Loop exhaustion means we did not find it
374387 return -1 ;
375388 }
389+
390+
391+ /**
392+ * Creates a STS role credential provider
393+ * @param roleArn The ARN of the role to assume
394+ * @param sessionName The name of the session
395+ * @param stsRegion The region of the STS endpoint
396+ * @return
397+ */
398+ private static StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider (String roleArn ,
399+ String sessionName , String stsRegion ) {
400+ StsClient stsClient = StsClient .builder ()
401+ .region (Region .of (stsRegion ))
402+ .build ();
403+ AssumeRoleRequest assumeRoleRequest =AssumeRoleRequest .builder ()
404+ .roleArn (roleArn )
405+ .roleSessionName (sessionName )
406+ .build ();
407+ return StsAssumeRoleCredentialsProvider .builder ()
408+ .stsClient (stsClient )
409+ .refreshRequest (assumeRoleRequest )
410+ .build ();
411+ }
412+
413+ /**
414+ * Gets the default region for SigV4 if region is not provided.
415+ * @return
416+ */
417+ private static String getDefaultRegion () {
418+ DefaultAwsRegionProviderChain chain = new DefaultAwsRegionProviderChain ();
419+ return Optional .ofNullable (chain .getRegion ()).orElse (Region .US_EAST_1 ).toString ();
420+ }
376421}
0 commit comments