Skip to content

Commit c191dde

Browse files
committed
Code changes to support assuming a role which could be a cross account role.
1 parent c2b35c2 commit c191dde

File tree

8 files changed

+149
-10
lines changed

8 files changed

+149
-10
lines changed

README.md

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ To use the configuration file, set the `advanced.auth-provider.class` to `softwa
119119
1. Set the `advanced.auth-provider.class` to `software.aws.mcs.auth.SigV4AuthProvider`.
120120
1. Set `basic.load-balancing-policy.local-datacenter` to the region name. In this case, use `us-east-2`.
121121

122-
The following is an example of this.
122+
The following is an example of this config without explicit role to be assumed.
123123

124124
``` text
125125
datastax-java-driver {
@@ -138,3 +138,24 @@ The following is an example of this.
138138
}
139139
}
140140
```
141+
142+
Dollowing is an example of this config with explicit role to be assumed.
143+
144+
``` text
145+
datastax-java-driver {
146+
basic.load-balancing-policy {
147+
class = DefaultLoadBalancingPolicy
148+
local-datacenter = us-east-2
149+
}
150+
advanced {
151+
auth-provider = {
152+
class = software.aws.mcs.auth.SigV4AuthProvider
153+
aws-region = us-east-2
154+
aws-role = "arn:aws:iam::ACCOUNT_ID:role/ROLE_NAME"
155+
}
156+
ssl-engine-factory {
157+
class = DefaultSslEngineFactory
158+
}
159+
}
160+
}
161+
```

pom.xml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
4545
<jackson.version>2.13.4</jackson.version>
4646
<jackson-databind.version>2.13.4.2</jackson-databind.version>
47+
<aws-sdk-version>2.15.66</aws-sdk-version>
4748
</properties>
4849

4950
<dependencies>
@@ -88,7 +89,13 @@
8889
<dependency>
8990
<groupId>software.amazon.awssdk</groupId>
9091
<artifactId>auth</artifactId>
91-
<version>2.15.66</version>
92+
<version>${aws-sdk-version}</version>
93+
</dependency>
94+
<!-- https://mvnrepository.com/artifact/software.amazon.awssdk/sts -->
95+
<dependency>
96+
<groupId>software.amazon.awssdk</groupId>
97+
<artifactId>sts</artifactId>
98+
<version>${aws-sdk-version}</version>
9299
</dependency>
93100
</dependencies>
94101

src/main/java/software/aws/mcs/auth/SigV4AuthProvider.java

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import java.util.Arrays;
3333
import java.util.Collections;
3434
import java.util.List;
35+
import java.util.Optional;
3536
import java.util.concurrent.CompletableFuture;
3637
import java.util.concurrent.CompletionStage;
3738
import javax.crypto.Mac;
@@ -53,6 +54,11 @@
5354
import software.amazon.awssdk.auth.signer.internal.SignerConstant;
5455
import software.amazon.awssdk.regions.Region;
5556
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
57+
import software.amazon.awssdk.services.sts.StsClient;
58+
import software.amazon.awssdk.services.sts.StsClientBuilder;
59+
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;
60+
import software.amazon.awssdk.services.sts.auth.StsGetSessionTokenCredentialsProvider;
61+
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;
5662

5763
import static software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider.create;
5864

@@ -106,6 +112,12 @@ public String getPath() {
106112
}
107113
};
108114

115+
private final static DriverOption ROLE_OPTION = new DriverOption() {
116+
public String getPath() {
117+
return "advanced.auth-provider.aws-role";
118+
}
119+
};
120+
109121
/**
110122
* This constructor is provided so that the driver can create
111123
* instances of this class based on configuration. For example:
@@ -130,7 +142,8 @@ public String getPath() {
130142
* Unused for this plugin.
131143
*/
132144
public SigV4AuthProvider(DriverContext driverContext) {
133-
this(driverContext.getConfig().getDefaultProfile().getString(REGION_OPTION, null));
145+
this(driverContext.getConfig().getDefaultProfile().getString(REGION_OPTION, getDefaultRegion()),
146+
driverContext.getConfig().getDefaultProfile().getString(ROLE_OPTION, null));
134147
}
135148

136149
/**
@@ -139,8 +152,8 @@ public SigV4AuthProvider(DriverContext driverContext) {
139152
* null value indicates to use the AWS_REGION environment
140153
* variable, or the "aws.region" system property to configure it.
141154
*/
142-
public SigV4AuthProvider(final String region) {
143-
this(create(), region);
155+
public SigV4AuthProvider(final String region,final String roleArn) {
156+
this(Optional.ofNullable(roleArn).map(r->(AwsCredentialsProvider)createSTSRoleCredentialProvider(r,"keyspaces-session",region)).orElse(create()), region);
144157
}
145158

146159
/**
@@ -373,4 +386,36 @@ static int indexOf(byte[] target, byte[] pattern) {
373386
// Loop exhaustion means we did not find it
374387
return -1;
375388
}
389+
390+
391+
/**
392+
* Creates a STS role credential provider
393+
* @param roleArn The ARN of the role to assume
394+
* @param sessionName The name of the session
395+
* @param stsRegion The region of the STS endpoint
396+
* @return
397+
*/
398+
private static StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
399+
String sessionName, String stsRegion) {
400+
StsClient stsClient = StsClient.builder()
401+
.region(Region.of(stsRegion))
402+
.build();
403+
AssumeRoleRequest assumeRoleRequest=AssumeRoleRequest.builder()
404+
.roleArn(roleArn)
405+
.roleSessionName(sessionName)
406+
.build();
407+
return StsAssumeRoleCredentialsProvider.builder()
408+
.stsClient(stsClient)
409+
.refreshRequest(assumeRoleRequest)
410+
.build();
411+
}
412+
413+
/**
414+
* Gets the default region for SigV4 if region is not provided.
415+
* @return
416+
*/
417+
private static String getDefaultRegion() {
418+
DefaultAwsRegionProviderChain chain = new DefaultAwsRegionProviderChain();
419+
return Optional.ofNullable(chain.getRegion()).orElse(Region.US_EAST_1).toString();
420+
}
376421
}

src/test/java/software/aws/mcs/auth/TestSigV4Config.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ public static void main(String[] args) throws Exception {
5555
//By default the reference.conf is loaded by the driver which contains all defaults.
5656
//You can override this by providing reference.conf on the classpath
5757
//to isolate test you can load conf with a custom name
58-
URL url = TestSigV4Config.class.getClassLoader().getResource("keyspaces-reference.conf");
58+
URL url = TestSigV4Config.class.getClassLoader().getResource("keyspaces-reference-norole.conf");
5959

6060
File file = new File(url.toURI());
6161
// The CqlSession object is the main entry point of the driver.
@@ -64,18 +64,16 @@ public static void main(String[] args) throws Exception {
6464
// it throughout your application.
6565
try (CqlSession session = CqlSession.builder()
6666
.withConfigLoader(DriverConfigLoader.fromFile(file))
67-
.addContactPoints(contactPoints)
68-
.withLocalDatacenter("us-west-2")
6967
.build()) {
7068

7169
// We use execute to send a query to Cassandra. This returns a ResultSet, which is essentially a collection
7270
// of Row objects.
73-
ResultSet rs = session.execute("select release_version from system.local");
71+
ResultSet rs = session.execute("select * from testkeyspace.testconf");
7472
// Extract the first row (which is the only one in this case).
7573
Row row = rs.one();
7674

7775
// Extract the value of the first (and only) column from the row.
78-
String releaseVersion = row.getString("release_version");
76+
String releaseVersion = row.getString("category");
7977
System.out.printf("Cassandra version is: %s%n", releaseVersion);
8078
}
8179
}

src/test/resources/ddl.cql

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
CREATE KEYSPACE "testkeyspace" WITH REPLICATION = {'class': 'SingleRegionStrategy'};
2+
3+
CREATE TABLE "testkeyspace"."testconf"(
4+
"id" ascii,
5+
"category" ascii,
6+
PRIMARY KEY("id")
7+
)
8+
WITH CUSTOM_PROPERTIES = {
9+
'capacity_mode':{
10+
'throughput_mode':'PAY_PER_REQUEST'
11+
},
12+
'point_in_time_recovery':{
13+
'status':'enabled'
14+
},
15+
'encryption_specification':{
16+
'encryption_type':'AWS_OWNED_KMS_KEY'
17+
}
18+
} ;

src/test/resources/dml.cql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
INSERT into testkeyspace.testconf(id,category) values('first','first Category');
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
datastax-java-driver {
2+
basic.contact-points = ["cassandra.ap-south-1.amazonaws.com:9142"]
3+
basic.load-balancing-policy {
4+
class = DefaultLoadBalancingPolicy
5+
local-datacenter = ap-south-1
6+
slow-replica-avoidance = false
7+
}
8+
basic.request {
9+
consistency = LOCAL_QUORUM
10+
}
11+
advanced {
12+
auth-provider = {
13+
class = software.aws.mcs.auth.SigV4AuthProvider
14+
aws-region = ap-south-1
15+
}
16+
ssl-engine-factory {
17+
class = DefaultSslEngineFactory
18+
truststore-path = "<path>/cassandra_truststore.jks"
19+
truststore-password = "<password>"
20+
hostname-validation=false
21+
}
22+
}
23+
advanced.connection.pool.local.size = 3
24+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
datastax-java-driver {
2+
basic.contact-points = ["cassandra.ap-south-1.amazonaws.com:9142"]
3+
basic.load-balancing-policy {
4+
class = DefaultLoadBalancingPolicy
5+
local-datacenter = ap-south-1
6+
slow-replica-avoidance = false
7+
}
8+
basic.request {
9+
consistency = LOCAL_QUORUM
10+
}
11+
advanced {
12+
auth-provider = {
13+
class = software.aws.mcs.auth.SigV4AuthProvider
14+
aws-region = ap-south-1
15+
aws-role = "arn:aws:iam::ACCOUNT_ID:role/keyspaces-act2-role"
16+
}
17+
ssl-engine-factory {
18+
class = DefaultSslEngineFactory
19+
truststore-path = "<path>/cassandra_truststore.jks"
20+
truststore-password = "<password>"
21+
hostname-validation=false
22+
}
23+
}
24+
advanced.connection.pool.local.size = 3
25+
}

0 commit comments

Comments
 (0)