Skip to content

Commit 04ca715

Browse files
authored
feat(sdk): provide access tokens dynamically to KAS (#51)
* give the `KASClient` a way to create channels that are configured in the way the user expects * add stub interfaces for `Policy` and `KASInfo`
1 parent 3774f07 commit 04ca715

File tree

5 files changed

+178
-37
lines changed

5 files changed

+178
-37
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package io.opentdf.platform.sdk;
2+
3+
import io.grpc.Channel;
4+
import io.opentdf.platform.kas.AccessServiceGrpc;
5+
import io.opentdf.platform.kas.PublicKeyRequest;
6+
import io.opentdf.platform.kas.RewrapRequest;
7+
8+
import java.util.HashMap;
9+
import java.util.function.Function;
10+
11+
public class KASClient implements SDK.KAS {
12+
13+
private final Function<SDK.KASInfo, Channel> channelFactory;
14+
15+
public KASClient(Function <SDK.KASInfo, Channel> channelFactory) {
16+
this.channelFactory = channelFactory;
17+
}
18+
19+
@Override
20+
public String getPublicKey(SDK.KASInfo kasInfo) {
21+
return getStub(kasInfo).publicKey(PublicKeyRequest.getDefaultInstance()).getPublicKey();
22+
}
23+
24+
@Override
25+
public byte[] unwrap(SDK.KASInfo kasInfo, SDK.Policy policy) {
26+
// this is obviously wrong. we still have to generate a correct request and decrypt the payload
27+
return getStub(kasInfo).rewrap(RewrapRequest.getDefaultInstance()).getEntityWrappedKey().toByteArray();
28+
}
29+
30+
private final HashMap<SDK.KASInfo, AccessServiceGrpc.AccessServiceBlockingStub> stubs = new HashMap<>();
31+
32+
private synchronized AccessServiceGrpc.AccessServiceBlockingStub getStub(SDK.KASInfo kasInfo) {
33+
if (!stubs.containsKey(kasInfo)) {
34+
var channel = channelFactory.apply(kasInfo);
35+
var stub = AccessServiceGrpc.newBlockingStub(channel);
36+
stubs.put(kasInfo, stub);
37+
}
38+
39+
return stubs.get(kasInfo);
40+
}
41+
}

sdk/src/main/java/io/opentdf/platform/sdk/SDK.java

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,25 @@
1717
public class SDK {
1818
private final Services services;
1919

20+
public interface KASInfo{
21+
String getAddress();
22+
}
23+
public interface Policy{}
24+
25+
interface KAS {
26+
String getPublicKey(KASInfo kasInfo);
27+
byte[] unwrap(KASInfo kasInfo, Policy policy);
28+
}
29+
2030
// TODO: add KAS
21-
public interface Services {
31+
interface Services {
2232
AttributesServiceFutureStub attributes();
2333
NamespaceServiceFutureStub namespaces();
2434
SubjectMappingServiceFutureStub subjectMappings();
2535
ResourceMappingServiceFutureStub resourceMappings();
36+
KAS kas();
2637

27-
static Services newServices(Channel channel) {
38+
static Services newServices(Channel channel, KAS kas) {
2839
var attributeService = AttributesServiceGrpc.newFutureStub(channel);
2940
var namespaceService = NamespaceServiceGrpc.newFutureStub(channel);
3041
var subjectMappingService = SubjectMappingServiceGrpc.newFutureStub(channel);
@@ -50,11 +61,16 @@ public SubjectMappingServiceFutureStub subjectMappings() {
5061
public ResourceMappingServiceFutureStub resourceMappings() {
5162
return resourceMappingService;
5263
}
64+
65+
@Override
66+
public KAS kas() {
67+
return kas;
68+
}
5369
};
5470
}
5571
}
5672

57-
public SDK(Services services) {
73+
SDK(Services services) {
5874
this.services = services;
5975
}
6076
}

sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import com.nimbusds.oauth2.sdk.id.ClientID;
1212
import com.nimbusds.oauth2.sdk.id.Issuer;
1313
import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata;
14+
import io.grpc.Channel;
1415
import io.grpc.ManagedChannel;
1516
import io.grpc.ManagedChannelBuilder;
1617
import io.grpc.Status;
@@ -23,6 +24,7 @@
2324

2425
import java.io.IOException;
2526
import java.util.UUID;
27+
import java.util.function.Function;
2628

2729
/**
2830
* A builder class for creating instances of the SDK class.
@@ -33,9 +35,13 @@ public class SDKBuilder {
3335
private ClientAuthentication clientAuth = null;
3436
private Boolean usePlainText;
3537

38+
private static final Logger logger = LoggerFactory.getLogger(SDKBuilder.class);
39+
3640
public static SDKBuilder newBuilder() {
3741
SDKBuilder builder = new SDKBuilder();
3842
builder.usePlainText = false;
43+
builder.clientAuth = null;
44+
builder.platformEndpoint = null;
3945

4046
return builder;
4147
}
@@ -57,8 +63,16 @@ public SDKBuilder useInsecurePlaintextConnection(Boolean usePlainText) {
5763
return this;
5864
}
5965

60-
// this is not exposed publicly so that it can be tested
61-
ManagedChannel buildChannel() {
66+
private GRPCAuthInterceptor getGrpcAuthInterceptor() {
67+
if (platformEndpoint == null) {
68+
throw new SDKException("cannot build an SDK without specifying the platform endpoint");
69+
}
70+
71+
if (clientAuth == null) {
72+
// this simplifies things for now, if we need to support this case we can revisit
73+
throw new SDKException("cannot build an SDK without specifying OAuth credentials");
74+
}
75+
6276
// we don't add the auth listener to this channel since it is only used to call the
6377
// well known endpoint
6478
ManagedChannel bootstrapChannel = null;
@@ -107,24 +121,39 @@ ManagedChannel buildChannel() {
107121
throw new SDKException("Error generating DPoP key", e);
108122
}
109123

110-
GRPCAuthInterceptor interceptor = new GRPCAuthInterceptor(clientAuth, rsaKey, providerMetadata.getTokenEndpointURI());
124+
return new GRPCAuthInterceptor(clientAuth, rsaKey, providerMetadata.getTokenEndpointURI());
125+
}
111126

112-
return getManagedChannelBuilder()
113-
.intercept(interceptor)
114-
.build();
127+
SDK.Services buildServices() {
128+
var authInterceptor = getGrpcAuthInterceptor();
129+
var channel = getManagedChannelBuilder().intercept(authInterceptor).build();
130+
var client = new KASClient(getChannelFactory(authInterceptor));
131+
return SDK.Services.newServices(channel, client);
115132
}
116133

117134
public SDK build() {
118-
return new SDK(SDK.Services.newServices(buildChannel()));
135+
return new SDK(buildServices());
119136
}
120137

121138
private ManagedChannelBuilder<?> getManagedChannelBuilder() {
122-
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder
123-
.forTarget(platformEndpoint);
139+
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder.forTarget(platformEndpoint);
124140

125141
if (usePlainText) {
126142
channelBuilder = channelBuilder.usePlaintext();
127143
}
128144
return channelBuilder;
129145
}
146+
147+
Function<SDK.KASInfo, Channel> getChannelFactory(GRPCAuthInterceptor authInterceptor) {
148+
var pt = usePlainText; // no need to have the builder be able to influence things from beyond the grave
149+
return (SDK.KASInfo kasInfo) -> {
150+
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder
151+
.forTarget(kasInfo.getAddress())
152+
.intercept(authInterceptor);
153+
if (pt) {
154+
channelBuilder = channelBuilder.usePlaintext();
155+
}
156+
return channelBuilder.build();
157+
};
158+
}
130159
}

sdk/src/main/java/io/opentdf/platform/sdk/SDKException.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,8 @@ public class SDKException extends RuntimeException {
44
public SDKException(String message, Exception reason) {
55
super(message, reason);
66
}
7+
8+
public SDKException(String message) {
9+
super(message);
10+
}
711
}

sdk/src/test/java/io/opentdf/platform/sdk/SDKBuilderTest.java

Lines changed: 76 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,19 @@
1111
import io.grpc.ServerCallHandler;
1212
import io.grpc.ServerInterceptor;
1313
import io.grpc.stub.StreamObserver;
14+
import io.opentdf.platform.kas.AccessServiceGrpc;
15+
import io.opentdf.platform.kas.RewrapRequest;
16+
import io.opentdf.platform.kas.RewrapResponse;
17+
import io.opentdf.platform.policy.namespaces.GetNamespaceRequest;
18+
import io.opentdf.platform.policy.namespaces.GetNamespaceResponse;
19+
import io.opentdf.platform.policy.namespaces.NamespaceServiceGrpc;
1420
import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationRequest;
1521
import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationResponse;
1622
import io.opentdf.platform.wellknownconfiguration.WellKnownServiceGrpc;
1723
import okhttp3.mockwebserver.MockResponse;
1824
import okhttp3.mockwebserver.MockWebServer;
25+
import org.junit.jupiter.api.AfterAll;
26+
import org.junit.jupiter.api.BeforeAll;
1927
import org.junit.jupiter.api.Test;
2028

2129
import java.io.IOException;
@@ -30,8 +38,9 @@
3038
public class SDKBuilderTest {
3139

3240
@Test
33-
void testCreatingSDKChannel() throws IOException, InterruptedException {
34-
Server wellknownServer = null;
41+
void testCreatingSDKServices() throws IOException, InterruptedException {
42+
Server platformServicesServer = null;
43+
Server kasServer = null;
3544
// we use the HTTP server for two things:
3645
// * it returns the OIDC configuration we use at bootstrapping time
3746
// * it fakes out being an IDP and returns an access token when need to retrieve an access token
@@ -51,6 +60,8 @@ void testCreatingSDKChannel() throws IOException, InterruptedException {
5160
.setHeader("Content-type", "application/json")
5261
);
5362

63+
// this service returns the platform_issuer url to the SDK during bootstrapping. This
64+
// tells the SDK where to download the OIDC discovery document from (our test webserver!)
5465
WellKnownServiceGrpc.WellKnownServiceImplBase wellKnownService = new WellKnownServiceGrpc.WellKnownServiceImplBase() {
5566
@Override
5667
public void getWellKnownConfiguration(GetWellKnownConfigurationRequest request, StreamObserver<GetWellKnownConfigurationResponse> responseObserver) {
@@ -65,55 +76,76 @@ public void getWellKnownConfiguration(GetWellKnownConfigurationRequest request,
6576
}
6677
};
6778

68-
AtomicReference<String> authHeaderFromRequest = new AtomicReference<>(null);
69-
AtomicReference<String> dpopHeaderFromRequest = new AtomicReference<>(null);
79+
// remember the auth headers that we received during GRPC calls to platform services
80+
AtomicReference<String> servicesAuthHeader = new AtomicReference<>(null);
81+
AtomicReference<String> servicesDPoPHeader = new AtomicReference<>(null);
7082

83+
// remember the auth headers that we received during GRPC calls to KAS
84+
AtomicReference<String> kasAuthHeader = new AtomicReference<>(null);
85+
AtomicReference<String> kasDPoPHeader = new AtomicReference<>(null);
7186
// we use the server in two different ways. the first time we use it to actually return
7287
// issuer for bootstrapping. the second time we use the interception functionality in order
7388
// to make sure that we are including a DPoP proof and an auth header
74-
int randomPort;
75-
try (ServerSocket socket = new ServerSocket(0)) {
76-
randomPort = socket.getLocalPort();
77-
}
78-
wellknownServer = ServerBuilder
79-
.forPort(randomPort)
89+
platformServicesServer = ServerBuilder
90+
.forPort(getRandomPort())
8091
.directExecutor()
8192
.addService(wellKnownService)
93+
.addService(new NamespaceServiceGrpc.NamespaceServiceImplBase() {})
94+
.intercept(new ServerInterceptor() {
95+
@Override
96+
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
97+
servicesAuthHeader.set(headers.get(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER)));
98+
servicesDPoPHeader.set(headers.get(Metadata.Key.of("DPoP", Metadata.ASCII_STRING_MARSHALLER)));
99+
return next.startCall(call, headers);
100+
}
101+
})
102+
.build()
103+
.start();
104+
105+
106+
kasServer = ServerBuilder
107+
.forPort(getRandomPort())
108+
.directExecutor()
109+
.addService(new AccessServiceGrpc.AccessServiceImplBase() {
110+
@Override
111+
public void rewrap(RewrapRequest request, StreamObserver<RewrapResponse> responseObserver) {
112+
responseObserver.onNext(RewrapResponse.getDefaultInstance());
113+
responseObserver.onCompleted();
114+
}
115+
})
82116
.intercept(new ServerInterceptor() {
83117
@Override
84118
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
85-
authHeaderFromRequest.set(headers.get(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER)));
86-
dpopHeaderFromRequest.set(headers.get(Metadata.Key.of("DPoP", Metadata.ASCII_STRING_MARSHALLER)));
119+
kasAuthHeader.set(headers.get(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER)));
120+
kasDPoPHeader.set(headers.get(Metadata.Key.of("DPoP", Metadata.ASCII_STRING_MARSHALLER)));
87121
return next.startCall(call, headers);
88122
}
89123
})
90124
.build()
91125
.start();
92126

93-
ManagedChannel channel = SDKBuilder
127+
SDK.Services services = SDKBuilder
94128
.newBuilder()
95129
.clientSecret("client-id", "client-secret")
96-
.platformEndpoint("localhost:" + wellknownServer.getPort())
130+
.platformEndpoint("localhost:" + platformServicesServer.getPort())
97131
.useInsecurePlaintextConnection(true)
98-
.buildChannel();
99-
100-
assertThat(channel).isNotNull();
101-
assertThat(channel.getState(false)).isEqualTo(ConnectivityState.IDLE);
132+
.buildServices();
102133

103-
var wellKnownStub = WellKnownServiceGrpc.newBlockingStub(channel);
134+
assertThat(services).isNotNull();
104135

105136
httpServer.enqueue(new MockResponse()
106137
.setBody("{\"access_token\": \"hereisthetoken\", \"token_type\": \"Bearer\"}")
107138
.setHeader("Content-Type", "application/json"));
108139

109-
var ignored = wellKnownStub.getWellKnownConfiguration(GetWellKnownConfigurationRequest.getDefaultInstance());
110-
channel.shutdownNow();
140+
var ignored = services.namespaces().getNamespace(GetNamespaceRequest.getDefaultInstance());
111141

112142
// we've now made two requests. one to get the bootstrapping info and one
113143
// call that should activate the token fetching logic
114144
assertThat(httpServer.getRequestCount()).isEqualTo(2);
115145

116146
httpServer.takeRequest();
147+
148+
// validate that we made a reasonable request to our fake IdP to get an access token
117149
var accessTokenRequest = httpServer.takeRequest();
118150
assertThat(accessTokenRequest).isNotNull();
119151
var authHeader = accessTokenRequest.getHeader("Authorization");
@@ -124,16 +156,35 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re
124156
var usernameAndPassword = new String(Base64.getDecoder().decode(authHeaderParts[1]), StandardCharsets.UTF_8);
125157
assertThat(usernameAndPassword).isEqualTo("client-id:client-secret");
126158

127-
assertThat(dpopHeaderFromRequest.get()).isNotNull();
128-
assertThat(authHeaderFromRequest.get()).isEqualTo("DPoP hereisthetoken");
159+
// validate that during the request to the namespace service we supplied a valid token
160+
assertThat(servicesDPoPHeader.get()).isNotNull();
161+
assertThat(servicesAuthHeader.get()).isEqualTo("DPoP hereisthetoken");
129162

130163
var body = new String(accessTokenRequest.getBody().readByteArray(), StandardCharsets.UTF_8);
131164
assertThat(body).contains("grant_type=client_credentials");
132165

166+
// now call KAS _on a different server_ and make sure that the interceptors provide us with auth tokens
167+
int kasPort = kasServer.getPort();
168+
SDK.KASInfo kasInfo = () -> "localhost:" + kasPort;
169+
services.kas().unwrap(kasInfo, new SDK.Policy() {});
170+
171+
assertThat(kasDPoPHeader.get()).isNotNull();
172+
assertThat(kasAuthHeader.get()).isEqualTo("DPoP hereisthetoken");
133173
} finally {
134-
if (wellknownServer != null) {
135-
wellknownServer.shutdownNow();
174+
if (platformServicesServer != null) {
175+
platformServicesServer.shutdownNow();
176+
}
177+
if (kasServer != null) {
178+
kasServer.shutdownNow();
136179
}
137180
}
138181
}
182+
183+
private static int getRandomPort() throws IOException {
184+
int randomPort;
185+
try (ServerSocket socket = new ServerSocket(0)) {
186+
randomPort = socket.getLocalPort();
187+
}
188+
return randomPort;
189+
}
139190
}

0 commit comments

Comments
 (0)