Skip to content

Commit df2b4e5

Browse files
committed
1. Refactoring to create getRoleNameFromArn method in SigV4AuthProvider.java and unit test to validate.
2. Code formatting changes in SigV4AuthProvider.java
1 parent 7335140 commit df2b4e5

File tree

2 files changed

+35
-8
lines changed

2 files changed

+35
-8
lines changed

src/main/java/software/aws/mcs/auth/SigV4AuthProvider.java

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ public SigV4AuthProvider() {
133133
*/
134134
public SigV4AuthProvider(DriverContext driverContext) {
135135
this(driverContext.getConfig().getDefaultProfile().getString(REGION_OPTION, getDefaultRegion()),
136-
driverContext.getConfig().getDefaultProfile().getString(ROLE_OPTION, null));
136+
driverContext.getConfig().getDefaultProfile().getString(ROLE_OPTION, null));
137137
}
138138

139139
/**
@@ -393,14 +393,10 @@ static int indexOf(byte[] target, byte[] pattern) {
393393
* @param stsRegion The region of the STS endpoint
394394
* @return The STS role credential provider
395395
*/
396-
private static StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
397-
String stsRegion) {
396+
private static StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider(@NotNull String roleArn,
397+
@NotNull String stsRegion) {
398398
//Get role name from ARN
399-
String[] arnParts = roleArn.split("/");
400-
if(arnParts.length < 2){
401-
throw new IllegalArgumentException("Invalid role ARN");
402-
}
403-
String roleName = arnParts[arnParts.length - 1];
399+
String roleName = getRoleNameFromArn(roleArn);
404400
final String sessionName="keyspaces-session-"+roleName+System.currentTimeMillis();
405401
StsClient stsClient = StsClient.builder()
406402
.region(Region.of(stsRegion))
@@ -415,6 +411,20 @@ private static StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider(
415411
.build();
416412
}
417413

414+
/**
415+
* Extracts the role name from the ARN.
416+
* @param roleArn The ARN of the role to assume
417+
* @return
418+
*/
419+
static String getRoleNameFromArn(@NotNull String roleArn) {
420+
String[] arnParts = roleArn.split("/");
421+
if(arnParts.length < 2){
422+
throw new IllegalArgumentException("Invalid role ARN");
423+
}
424+
String roleName = arnParts[arnParts.length - 1];
425+
return roleName;
426+
}
427+
418428
/**
419429
* Gets the default region for SigV4 if region is not provided.
420430
* @return Default region

src/test/java/software/aws/mcs/auth/SigV4AuthProviderTest.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,21 @@ public void testPartialTrailingIndexOf() {
8787
assertEquals(-1, SigV4AuthProvider.indexOf(target, pattern));
8888
}
8989

90+
91+
@Test
92+
public void testGetRoleNameFromArn() {
93+
String arn = "arn:aws:iam::ACCOUNT_ID:role/keyspaces-act2-role";
94+
assertEquals("keyspaces-act2-role", SigV4AuthProvider.getRoleNameFromArn(arn));
95+
}
96+
97+
@Test
98+
public void testGetRoleNameFromArnFailure() {
99+
assertThrows(IllegalArgumentException.class, () -> SigV4AuthProvider.getRoleNameFromArn(""));
100+
assertThrows(IllegalArgumentException.class, () -> SigV4AuthProvider.getRoleNameFromArn("roleName"));
101+
assertThrows(IllegalArgumentException.class, () -> SigV4AuthProvider.getRoleNameFromArn("illegalerolearn:rolename"));
102+
}
103+
104+
105+
106+
90107
}

0 commit comments

Comments
 (0)