Skip to content

Commit 69e3bee

Browse files
committed
Add option to require a valid aceess token audience
1 parent 81f8713 commit 69e3bee

File tree

4 files changed

+36
-4
lines changed

4 files changed

+36
-4
lines changed

core/trino-main/src/main/java/io/trino/server/security/oauth2/NimbusOAuth2Client.java

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,26 @@ public NimbusOAuth2Client(OAuth2Config oauthConfig, OAuth2ServerConfigProvider s
117117
principalField = oauthConfig.getPrincipalField();
118118
maxClockSkew = oauthConfig.getMaxClockSkew();
119119
jwtType = oauthConfig.getJwtType();
120-
121-
accessTokenAudiences = new HashSet<>(oauthConfig.getAdditionalAudiences());
122-
accessTokenAudiences.add(clientId.getValue());
123-
accessTokenAudiences.add(null); // A null value in the set allows JWTs with no audience
120+
accessTokenAudiences = getAccessTokenAudiences(oauthConfig);
124121

125122
this.serverConfigurationProvider = requireNonNull(serverConfigurationProvider, "serverConfigurationProvider is null");
126123
this.httpClient = requireNonNull(httpClient, "httpClient is null");
127124
}
128125

126+
private static Set<String> getAccessTokenAudiences(OAuth2Config oauthConfig)
127+
{
128+
HashSet<String> audiences = new HashSet<>();
129+
audiences.add(clientId.getValue());
130+
audiences.addAll(oauthConfig.getAdditionalAudiences());
131+
132+
if (!oauthConfig.isRequireAudience()) {
133+
// A null value in the set allows JWTs with no audience
134+
audiences.add(null);
135+
}
136+
137+
return audiences;
138+
}
139+
129140
@Override
130141
public void load()
131142
{

core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2Config.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ public class OAuth2Config
4242
private Set<String> scopes = ImmutableSet.of(OPENID_SCOPE);
4343
private String principalField = "sub";
4444
private Optional<String> groupsField = Optional.empty();
45+
private boolean requireAudience = false;
4546
private List<String> additionalAudiences = Collections.emptyList();
4647
private Duration challengeTimeout = new Duration(15, TimeUnit.MINUTES);
4748
private Duration maxClockSkew = new Duration(1, TimeUnit.MINUTES);
@@ -107,6 +108,20 @@ public OAuth2Config setClientSecret(String clientSecret)
107108
return this;
108109
}
109110

111+
@NotNull
112+
public boolean isRequireAudience()
113+
{
114+
return this.requireAudience;
115+
}
116+
117+
@Config("http-server.authentication.oauth2.require-audience")
118+
@ConfigDescription("Require a valid audience. If false (default), access tokens without an aud claim will be accepted.")
119+
public OAuth2Config setRequireAudience(boolean requireAudience)
120+
{
121+
this.requireAudience = requireAudience;
122+
return this;
123+
}
124+
110125
@NotNull
111126
public List<String> getAdditionalAudiences()
112127
{

core/trino-main/src/test/java/io/trino/server/security/oauth2/TestOAuth2Config.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ public void testDefaults()
4545
.setChallengeTimeout(new Duration(15, MINUTES))
4646
.setPrincipalField("sub")
4747
.setGroupsField(null)
48+
.setRequireAudience(false)
4849
.setAdditionalAudiences(Collections.emptyList())
4950
.setMaxClockSkew(new Duration(1, MINUTES))
5051
.setJwtType(null)
@@ -67,6 +68,7 @@ public void testExplicitPropertyMappings()
6768
.put("http-server.authentication.oauth2.scopes", "email,offline")
6869
.put("http-server.authentication.oauth2.principal-field", "some-field")
6970
.put("deprecated.http-server.authentication.oauth2.groups-field", "groups")
71+
.put("http-server.authentication.oauth2.require-audience", "true")
7072
.put("http-server.authentication.oauth2.additional-audiences", "test-aud1,test-aud2")
7173
.put("http-server.authentication.oauth2.challenge-timeout", "90s")
7274
.put("http-server.authentication.oauth2.max-clock-skew", "15s")
@@ -85,6 +87,7 @@ public void testExplicitPropertyMappings()
8587
.setScopes(ImmutableSet.of("email", "offline"))
8688
.setPrincipalField("some-field")
8789
.setGroupsField("groups")
90+
.setRequireAudience(true)
8891
.setAdditionalAudiences(List.of("test-aud1", "test-aud2"))
8992
.setChallengeTimeout(new Duration(90, SECONDS))
9093
.setMaxClockSkew(new Duration(15, SECONDS))

docs/src/main/sphinx/security/oauth2.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ The following configuration properties are available:
126126
- The public identifier of the Trino client.
127127
* - `http-server.authentication.oauth2.client-secret`
128128
- The secret used to authorize Trino client with the authorization server.
129+
* - `http-server.authentication.oauth2.require-audience`
130+
- Require a valid audience. If false (default), access tokens
131+
without an aud claim will be accepeted.
129132
* - `http-server.authentication.oauth2.additional-audiences`
130133
- Additional audiences to trust in addition to the client ID which is
131134
always a trusted audience.

0 commit comments

Comments
 (0)