diff --git a/build.sbt b/build.sbt index e915b7e..7e8852b 100644 --- a/build.sbt +++ b/build.sbt @@ -16,7 +16,7 @@ inThisBuild( ) val scala213 = "2.13.16" -val scala3 = "3.3.5" +val scala3 = "3.3.6" val jdkVersion = 11 val allScalaVersions = List(scala213, scala3) val jvmScalaVersions = allScalaVersions @@ -40,7 +40,7 @@ val commonSettings = Seq( case Some((2, _)) => Seq(s"-target:jvm-$jdkVersion") case _ => Seq(s"-java-output-version:$jdkVersion") } - } + }, ) val commonJvmSettings = Seq( @@ -136,6 +136,23 @@ val smithy4s = projectMatrix buildTimeProtocolDependency ) +val smithy4sTests = projectMatrix + .in(file("modules") / "smithy4sTests") + .jvmPlatform(jvmScalaVersions, commonJvmSettings) + .jsPlatform(jsScalaVersions) + .nativePlatform(Seq(scala3)) + .disablePlugins(AssemblyPlugin) + .enablePlugins(Smithy4sCodegenPlugin) + .dependsOn(smithy4s, fs2 % Test) + .settings( + commonSettings, + publish / skip := true, + libraryDependencies ++= Seq( + "io.circe" %%% "circe-generic" % "0.14.7" + ), + buildTimeProtocolDependency + ) + val exampleServer = projectMatrix .in(file("modules") / "examples/server") .jvmPlatform(List(scala213), commonJvmSettings) @@ -235,6 +252,7 @@ val root = project exampleClient, smithy, smithy4s, + smithy4sTests, exampleSmithyShared, exampleSmithyServer, exampleSmithyClient diff --git a/modules/core/src/main/scala/jsonrpclib/Channel.scala b/modules/core/src/main/scala/jsonrpclib/Channel.scala index ba533e4..24d9d74 100644 --- a/modules/core/src/main/scala/jsonrpclib/Channel.scala +++ b/modules/core/src/main/scala/jsonrpclib/Channel.scala @@ -1,14 +1,16 @@ package jsonrpclib -import io.circe.Codec +import jsonrpclib.ErrorCodec.errorPayloadCodec +import io.circe.Encoder +import io.circe.Decoder trait Channel[F[_]] { def mountEndpoint(endpoint: Endpoint[F]): F[Unit] def unmountEndpoint(method: String): F[Unit] - def notificationStub[In: Codec](method: String): In => F[Unit] - def simpleStub[In: Codec, Out: Codec](method: String): In => F[Out] - def stub[In: Codec, Err: ErrorCodec, Out: Codec](method: String): In => F[Either[Err, Out]] + def notificationStub[In: Encoder](method: String): In => F[Unit] + def simpleStub[In: Encoder, Out: Decoder](method: String): In => F[Out] + def stub[In: Encoder, Err: ErrorDecoder, Out: Decoder](method: String): In => F[Either[Err, Out]] def stub[In, Err, Out](template: StubTemplate[In, Err, Out]): In => F[Either[Err, Out]] } @@ -27,7 +29,7 @@ object Channel { (in: In) => F.doFlatMap(stub(in))(unit => F.doPure(Right(unit))) } - final def simpleStub[In: Codec, Out: Codec](method: String): In => F[Out] = { + final def simpleStub[In: Encoder, Out: Decoder](method: String): In => F[Out] = { val s = stub[In, ErrorPayload, Out](method) (in: In) => F.doFlatMap(s(in)) { diff --git a/modules/core/src/main/scala/jsonrpclib/Endpoint.scala b/modules/core/src/main/scala/jsonrpclib/Endpoint.scala index 1d7197a..d9d0910 100644 --- a/modules/core/src/main/scala/jsonrpclib/Endpoint.scala +++ b/modules/core/src/main/scala/jsonrpclib/Endpoint.scala @@ -1,9 +1,13 @@ package jsonrpclib -import io.circe.Codec +import jsonrpclib.ErrorCodec.errorPayloadCodec +import io.circe.Decoder +import io.circe.Encoder sealed trait Endpoint[F[_]] { def method: String + + def mapK[G[_]](f: PolyFunction[F, G]): Endpoint[G] } object Endpoint { @@ -16,17 +20,17 @@ object Endpoint { class PartiallyAppliedEndpoint[F[_]](method: MethodPattern) { def apply[In, Err, Out]( run: In => F[Either[Err, Out]] - )(implicit inCodec: Codec[In], errCodec: ErrorCodec[Err], outCodec: Codec[Out]): Endpoint[F] = - RequestResponseEndpoint(method, (_: InputMessage, in: In) => run(in), inCodec, errCodec, outCodec) + )(implicit inCodec: Decoder[In], errEncoder: ErrorEncoder[Err], outCodec: Encoder[Out]): Endpoint[F] = + RequestResponseEndpoint(method, (_: InputMessage, in: In) => run(in), inCodec, errEncoder, outCodec) def full[In, Err, Out]( run: (InputMessage, In) => F[Either[Err, Out]] - )(implicit inCodec: Codec[In], errCodec: ErrorCodec[Err], outCodec: Codec[Out]): Endpoint[F] = - RequestResponseEndpoint(method, run, inCodec, errCodec, outCodec) + )(implicit inCodec: Decoder[In], errEncoder: ErrorEncoder[Err], outCodec: Encoder[Out]): Endpoint[F] = + RequestResponseEndpoint(method, run, inCodec, errEncoder, outCodec) def simple[In, Out]( run: In => F[Out] - )(implicit F: Monadic[F], inCodec: Codec[In], outCodec: Codec[Out]) = + )(implicit F: Monadic[F], inCodec: Decoder[In], outCodec: Encoder[Out]) = apply[In, ErrorPayload, Out](in => F.doFlatMap(F.doAttempt(run(in))) { case Left(error) => F.doPure(Left(ErrorPayload(0, error.getMessage(), None))) @@ -34,26 +38,33 @@ object Endpoint { } ) - def notification[In](run: In => F[Unit])(implicit inCodec: Codec[In]): Endpoint[F] = + def notification[In](run: In => F[Unit])(implicit inCodec: Decoder[In]): Endpoint[F] = NotificationEndpoint(method, (_: InputMessage, in: In) => run(in), inCodec) - def notificationFull[In](run: (InputMessage, In) => F[Unit])(implicit inCodec: Codec[In]): Endpoint[F] = + def notificationFull[In](run: (InputMessage, In) => F[Unit])(implicit inCodec: Decoder[In]): Endpoint[F] = NotificationEndpoint(method, run, inCodec) } - final case class NotificationEndpoint[F[_], In]( + private[jsonrpclib] final case class NotificationEndpoint[F[_], In]( method: MethodPattern, run: (InputMessage, In) => F[Unit], - inCodec: Codec[In] - ) extends Endpoint[F] + inCodec: Decoder[In] + ) extends Endpoint[F] { + + def mapK[G[_]](f: PolyFunction[F, G]): Endpoint[G] = + NotificationEndpoint[G, In](method, (msg, in) => f(run(msg, in)), inCodec) + } - final case class RequestResponseEndpoint[F[_], In, Err, Out]( + private[jsonrpclib] final case class RequestResponseEndpoint[F[_], In, Err, Out]( method: Method, run: (InputMessage, In) => F[Either[Err, Out]], - inCodec: Codec[In], - errCodec: ErrorCodec[Err], - outCodec: Codec[Out] - ) extends Endpoint[F] + inCodec: Decoder[In], + errEncoder: ErrorEncoder[Err], + outCodec: Encoder[Out] + ) extends Endpoint[F] { + def mapK[G[_]](f: PolyFunction[F, G]): Endpoint[G] = + RequestResponseEndpoint[G, In, Err, Out](method, (msg, in) => f(run(msg, in)), inCodec, errEncoder, outCodec) + } } diff --git a/modules/core/src/main/scala/jsonrpclib/ErrorCodec.scala b/modules/core/src/main/scala/jsonrpclib/ErrorCodec.scala index f150c1f..8af58e9 100644 --- a/modules/core/src/main/scala/jsonrpclib/ErrorCodec.scala +++ b/modules/core/src/main/scala/jsonrpclib/ErrorCodec.scala @@ -1,12 +1,15 @@ package jsonrpclib -trait ErrorCodec[E] { - +trait ErrorEncoder[E] { def encode(a: E): ErrorPayload - def decode(error: ErrorPayload): Either[ProtocolError, E] +} +trait ErrorDecoder[E] { + def decode(error: ErrorPayload): Either[ProtocolError, E] } +trait ErrorCodec[E] extends ErrorDecoder[E] with ErrorEncoder[E] + object ErrorCodec { implicit val errorPayloadCodec: ErrorCodec[ErrorPayload] = new ErrorCodec[ErrorPayload] { def encode(a: ErrorPayload): ErrorPayload = a diff --git a/modules/core/src/main/scala/jsonrpclib/Monadic.scala b/modules/core/src/main/scala/jsonrpclib/Monadic.scala index 0d5a7f0..a42aaa7 100644 --- a/modules/core/src/main/scala/jsonrpclib/Monadic.scala +++ b/modules/core/src/main/scala/jsonrpclib/Monadic.scala @@ -26,4 +26,19 @@ object Monadic { override def doMap[A, B](fa: Future[A])(f: A => B): Future[B] = fa.map(f) } + + object syntax { + implicit class MonadicOps[F[_], A](fa: F[A]) { + def flatMap[B](f: A => F[B])(implicit m: Monadic[F]): F[B] = m.doFlatMap(fa)(f) + def map[B](f: A => B)(implicit m: Monadic[F]): F[B] = m.doMap(fa)(f) + def attempt[B](implicit m: Monadic[F]): F[Either[Throwable, A]] = m.doAttempt(fa) + def void(implicit m: Monadic[F]): F[Unit] = m.doVoid(fa) + } + implicit class MonadicOpsPure[A](a: A) { + def pure[F[_]](implicit m: Monadic[F]): F[A] = m.doPure(a) + } + implicit class MonadicOpsThrowable(t: Throwable) { + def raiseError[F[_], A](implicit m: Monadic[F]): F[A] = m.doRaiseError(t) + } + } } diff --git a/modules/core/src/main/scala/jsonrpclib/PolyFunction.scala b/modules/core/src/main/scala/jsonrpclib/PolyFunction.scala new file mode 100644 index 0000000..3942a26 --- /dev/null +++ b/modules/core/src/main/scala/jsonrpclib/PolyFunction.scala @@ -0,0 +1,5 @@ +package jsonrpclib + +trait PolyFunction[F[_], G[_]] { self => + def apply[A0](fa: => F[A0]): G[A0] +} diff --git a/modules/core/src/main/scala/jsonrpclib/internals/FutureBaseChannel.scala b/modules/core/src/main/scala/jsonrpclib/internals/FutureBaseChannel.scala index 0dd6c15..3b50fd5 100644 --- a/modules/core/src/main/scala/jsonrpclib/internals/FutureBaseChannel.scala +++ b/modules/core/src/main/scala/jsonrpclib/internals/FutureBaseChannel.scala @@ -7,7 +7,6 @@ import scala.concurrent.ExecutionContext import scala.concurrent.Future import scala.concurrent.Promise import scala.util.Try -import io.circe.Codec import io.circe.Encoder abstract class FutureBasedChannel(endpoints: List[Endpoint[Future]])(implicit ec: ExecutionContext) diff --git a/modules/core/src/main/scala/jsonrpclib/internals/MessageDispatcher.scala b/modules/core/src/main/scala/jsonrpclib/internals/MessageDispatcher.scala index f64a12d..5950890 100644 --- a/modules/core/src/main/scala/jsonrpclib/internals/MessageDispatcher.scala +++ b/modules/core/src/main/scala/jsonrpclib/internals/MessageDispatcher.scala @@ -6,8 +6,9 @@ import jsonrpclib.Endpoint.RequestResponseEndpoint import jsonrpclib.OutputMessage.ErrorMessage import jsonrpclib.OutputMessage.ResponseMessage import scala.util.Try -import io.circe.Codec import io.circe.HCursor +import io.circe.Encoder +import io.circe.Decoder private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F]) extends Channel.MonadicChannel[F] { @@ -22,7 +23,7 @@ private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F protected def storePendingCall(callId: CallId, handle: OutputMessage => F[Unit]): F[Unit] protected def removePendingCall(callId: CallId): F[Option[OutputMessage => F[Unit]]] - def notificationStub[In](method: String)(implicit inCodec: Codec[In]): In => F[Unit] = { (input: In) => + def notificationStub[In](method: String)(implicit inCodec: Encoder[In]): In => F[Unit] = { (input: In) => val encoded = inCodec(input) val message = InputMessage.NotificationMessage(method, Some(Payload(encoded))) sendMessage(message) @@ -30,13 +31,13 @@ private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F def stub[In, Err, Out]( method: String - )(implicit inCodec: Codec[In], errCodec: ErrorCodec[Err], outCodec: Codec[Out]): In => F[Either[Err, Out]] = { + )(implicit inCodec: Encoder[In], errDecoder: ErrorDecoder[Err], outCodec: Decoder[Out]): In => F[Either[Err, Out]] = { (input: In) => val encoded = inCodec(input) doFlatMap(nextCallId()) { callId => val message = InputMessage.RequestMessage(method, callId, Some(Payload(encoded))) doFlatMap(createPromise[Either[Err, Out]](callId)) { case (fulfill, future) => - val pc = createPendingCall(errCodec, outCodec, fulfill) + val pc = createPendingCall(errDecoder, outCodec, fulfill) doFlatMap(storePendingCall(callId, pc))(_ => doFlatMap(sendMessage(message))(_ => future())) } } @@ -80,13 +81,17 @@ private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F case (InputMessage.RequestMessage(_, callId, Some(params)), ep: RequestResponseEndpoint[F, in, err, out]) => ep.inCodec(HCursor.fromJson(params.data)) match { case Right(value) => - doFlatMap(ep.run(input, value)) { - case Right(data) => + doFlatMap(doAttempt(ep.run(input, value))) { + case Right(Right(data)) => val responseData = ep.outCodec(data) sendMessage(OutputMessage.ResponseMessage(callId, Payload(responseData))) - case Left(error) => - val errorPayload = ep.errCodec.encode(error) + case Right(Left(error)) => + val errorPayload = ep.errEncoder.encode(error) sendMessage(OutputMessage.ErrorMessage(callId, errorPayload)) + case Left(err) => + sendMessage( + OutputMessage.ErrorMessage(callId, ErrorPayload(0, s"ServerInternalError: ${err.getMessage}", None)) + ) } case Left(pError) => sendProtocolError(callId, ProtocolError.ParseError(pError.getMessage)) @@ -111,13 +116,13 @@ private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F } private def createPendingCall[Err, Out]( - errCodec: ErrorCodec[Err], - outCodec: Codec[Out], + errDecoder: ErrorDecoder[Err], + outCodec: Decoder[Out], fulfill: Try[Either[Err, Out]] => F[Unit] ): OutputMessage => F[Unit] = { (message: OutputMessage) => message match { case ErrorMessage(_, errorPayload) => - errCodec.decode(errorPayload) match { + errDecoder.decode(errorPayload) match { case Left(_) => fulfill(scala.util.Failure(errorPayload)) case Right(value) => fulfill(scala.util.Success(Left(value))) } diff --git a/modules/core/src/main/scala/jsonrpclib/package.scala b/modules/core/src/main/scala/jsonrpclib/package.scala index 9093575..5c0f070 100644 --- a/modules/core/src/main/scala/jsonrpclib/package.scala +++ b/modules/core/src/main/scala/jsonrpclib/package.scala @@ -2,5 +2,4 @@ package object jsonrpclib { type ErrorCode = Int type ErrorMessage = String - } diff --git a/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/CirceJson.scala b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/CirceJsonCodec.scala similarity index 97% rename from modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/CirceJson.scala rename to modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/CirceJsonCodec.scala index 38dd09a..0c242b6 100644 --- a/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/CirceJson.scala +++ b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/CirceJsonCodec.scala @@ -7,7 +7,7 @@ import smithy4s.codecs.PayloadPath import smithy4s.Document.{Decoder => _, _} import io.circe._ -private[jsonrpclib] object CirceJson { +object CirceJsonCodec { def fromSchema[A](implicit schema: Schema[A]): Codec[A] = Codec.from( c => { diff --git a/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/ClientStub.scala b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/ClientStub.scala index f947323..162e1b3 100644 --- a/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/ClientStub.scala +++ b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/ClientStub.scala @@ -34,8 +34,8 @@ private class ClientStub[Alg[_[_, _, _, _, _]], F[_]: Monadic](val service: Serv endpointSpec: EndpointSpec ): I => F[O] = { - implicit val inputCodec: Codec[I] = CirceJson.fromSchema(smithy4sEndpoint.input) - implicit val outputCodec: Codec[O] = CirceJson.fromSchema(smithy4sEndpoint.output) + implicit val inputCodec: Codec[I] = CirceJsonCodec.fromSchema(smithy4sEndpoint.input) + implicit val outputCodec: Codec[O] = CirceJsonCodec.fromSchema(smithy4sEndpoint.output) endpointSpec match { case EndpointSpec.Notification(methodName) => diff --git a/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/ServerEndpoints.scala b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/ServerEndpoints.scala index 9e8971d..d773b3b 100644 --- a/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/ServerEndpoints.scala +++ b/modules/smithy4s/src/main/scala/jsonrpclib/smithy4sinterop/ServerEndpoints.scala @@ -6,7 +6,12 @@ import smithy4s.Service import smithy4s.kinds.FunctorAlgebra import smithy4s.kinds.FunctorInterpreter import jsonrpclib.Monadic +import jsonrpclib.Payload +import jsonrpclib.ErrorPayload import io.circe.Codec +import jsonrpclib.Monadic.syntax._ +import jsonrpclib.ErrorEncoder +import smithy4s.schema.ErrorSchema object ServerEndpoints { @@ -24,27 +29,53 @@ object ServerEndpoints { } } - // TODO : codify errors at smithy level and handle them. def jsonRPCEndpoint[F[_]: Monadic, Op[_, _, _, _, _], I, E, O, SI, SO]( smithy4sEndpoint: Smithy4sEndpoint[Op, I, E, O, SI, SO], endpointSpec: EndpointSpec, impl: FunctorInterpreter[Op, F] ): Endpoint[F] = { - implicit val inputCodec: Codec[I] = CirceJson.fromSchema(smithy4sEndpoint.input) - implicit val outputCodec: Codec[O] = CirceJson.fromSchema(smithy4sEndpoint.output) + implicit val inputCodec: Codec[I] = CirceJsonCodec.fromSchema(smithy4sEndpoint.input) + implicit val outputCodec: Codec[O] = CirceJsonCodec.fromSchema(smithy4sEndpoint.output) + + def errorResponse(throwable: Throwable): F[E] = throwable match { + case smithy4sEndpoint.Error((_, e)) => e.pure + case e: Throwable => e.raiseError + } endpointSpec match { case EndpointSpec.Notification(methodName) => Endpoint[F](methodName).notification { (input: I) => val op = smithy4sEndpoint.wrap(input) - Monadic[F].doVoid(impl(op)) + impl(op).void } case EndpointSpec.Request(methodName) => - Endpoint[F](methodName).simple { (input: I) => - val op = smithy4sEndpoint.wrap(input) - impl(op) + smithy4sEndpoint.error match { + case None => + Endpoint[F](methodName).simple[I, O] { (input: I) => + val op = smithy4sEndpoint.wrap(input) + impl(op) + } + case Some(errorSchema) => + implicit val errorCodec: ErrorEncoder[E] = errorCodecFromSchema(errorSchema) + Endpoint[F](methodName).apply[I, E, O] { (input: I) => + val op = smithy4sEndpoint.wrap(input) + impl(op).attempt.flatMap { + case Left(err) => errorResponse(err).map(r => Left(r): Either[E, O]) + case Right(success) => (Right(success): Either[E, O]).pure + } + } } } } + + private def errorCodecFromSchema[A](s: ErrorSchema[A]): ErrorEncoder[A] = { + val circeCodec = CirceJsonCodec.fromSchema(s.schema) + (a: A) => + ErrorPayload( + 0, + Option(s.unliftError(a).getMessage()).getOrElse("JSONRPC-smithy4s application error"), + Some(Payload(circeCodec.apply(a))) + ) + } } diff --git a/modules/smithy4sTests/src/main/smithy/spec.smithy b/modules/smithy4sTests/src/main/smithy/spec.smithy new file mode 100644 index 0000000..12ac63a --- /dev/null +++ b/modules/smithy4sTests/src/main/smithy/spec.smithy @@ -0,0 +1,52 @@ +$version: "2.0" + +namespace test + +use jsonrpclib#jsonNotification +use jsonrpclib#jsonRPC +use jsonrpclib#jsonRequest + +@jsonRPC +service TestServer { + operations: [Greet, Ping] +} + +@jsonRPC +service TestClient { + operations: [Pong] +} + +@jsonRequest("greet") +operation Greet { + input := { + @required + name: String + } + output := { + @required + message: String + } + errors: [NotWelcomeError] +} + +@error("client") +structure NotWelcomeError { + @required + msg: String +} + +@jsonNotification("ping") +operation Ping { + input := { + @required + ping: String + } +} + +@jsonNotification("pong") +operation Pong { + input := { + @required + pong: String + } +} diff --git a/modules/smithy4sTests/src/test/scala/jsonrpclib/smithy4sinterop/TestClientSpec.scala b/modules/smithy4sTests/src/test/scala/jsonrpclib/smithy4sinterop/TestClientSpec.scala new file mode 100644 index 0000000..0e911d4 --- /dev/null +++ b/modules/smithy4sTests/src/test/scala/jsonrpclib/smithy4sinterop/TestClientSpec.scala @@ -0,0 +1,68 @@ +package jsonrpclib.smithy4sinterop + +import cats.effect.IO +import fs2.Stream +import jsonrpclib._ +import test.TestServer +import weaver._ +import cats.syntax.all._ + +import scala.concurrent.duration._ +import jsonrpclib.fs2._ +import test.GreetOutput +import io.circe.Encoder +import test.GreetInput +import io.circe.Decoder +import test.PingInput +import _root_.fs2.concurrent.SignallingRef + +object TestClientSpec extends SimpleIOSuite { + def testRes(name: TestName)(run: Stream[IO, Expectations]): Unit = + test(name)(run.compile.lastOrError.timeout(10.second)) + + type ClientSideChannel = FS2Channel[IO] + def setup(endpoints: Endpoint[IO]*) = setupAux(endpoints, None) + def setup(cancelTemplate: CancelTemplate, endpoints: Endpoint[IO]*) = setupAux(endpoints, Some(cancelTemplate)) + def setupAux(endpoints: Seq[Endpoint[IO]], cancelTemplate: Option[CancelTemplate]): Stream[IO, ClientSideChannel] = { + for { + serverSideChannel <- FS2Channel.stream[IO](cancelTemplate = cancelTemplate) + clientSideChannel <- FS2Channel.stream[IO](cancelTemplate = cancelTemplate) + _ <- serverSideChannel.withEndpointsStream(endpoints) + _ <- Stream(()) + .concurrently(clientSideChannel.output.through(serverSideChannel.input)) + .concurrently(serverSideChannel.output.through(clientSideChannel.input)) + } yield { + clientSideChannel + } + } + + testRes("Round trip") { + implicit val greetInputDecoder: Decoder[GreetInput] = CirceJsonCodec.fromSchema + implicit val greetOutputEncoder: Encoder[GreetOutput] = CirceJsonCodec.fromSchema + val endpoint: Endpoint[IO] = + Endpoint[IO]("greet").simple[GreetInput, GreetOutput](in => IO(GreetOutput(s"Hello ${in.name}"))) + + for { + clientSideChannel <- setup(endpoint) + clientStub = ClientStub(TestServer, clientSideChannel) + result <- clientStub.greet("Bob").toStream + } yield { + expect.same(result.message, "Hello Bob") + } + } + + testRes("Sending notification") { + implicit val pingInputDecoder: Decoder[PingInput] = CirceJsonCodec.fromSchema + + for { + ref <- SignallingRef[IO, Option[PingInput]](none).toStream + endpoint: Endpoint[IO] = Endpoint[IO]("ping").notification[PingInput](p => ref.set(p.some)) + clientSideChannel <- setup(endpoint) + clientStub = ClientStub(TestServer, clientSideChannel) + _ <- clientStub.ping("hello").toStream + result <- ref.discrete.dropWhile(_.isEmpty).take(1) + } yield { + expect.same(result, Some(PingInput("hello"))) + } + } +} diff --git a/modules/smithy4sTests/src/test/scala/jsonrpclib/smithy4sinterop/TestServerSpec.scala b/modules/smithy4sTests/src/test/scala/jsonrpclib/smithy4sinterop/TestServerSpec.scala new file mode 100644 index 0000000..be66690 --- /dev/null +++ b/modules/smithy4sTests/src/test/scala/jsonrpclib/smithy4sinterop/TestServerSpec.scala @@ -0,0 +1,181 @@ +package jsonrpclib.smithy4sinterop + +import cats.effect.IO +import fs2.Stream +import test.TestServer +import test.TestClient +import weaver._ +import smithy4s.kinds.FunctorAlgebra +import cats.syntax.all._ + +import scala.concurrent.duration._ +import jsonrpclib.fs2._ +import test.GreetOutput +import io.circe.Encoder +import test.GreetInput +import test.NotWelcomeError +import io.circe.Decoder +import smithy4s.Service +import jsonrpclib.Monadic +import test.PingInput +import fs2.concurrent.SignallingRef +import test.TestServerOperation +import test.TestServerOperation.GreetError +import jsonrpclib.ErrorPayload +import jsonrpclib.Payload + +object TestServerSpec extends SimpleIOSuite { + def testRes(name: TestName)(run: Stream[IO, Expectations]): Unit = + test(name)(run.compile.lastOrError.timeout(10.second)) + + type ClientSideChannel = FS2Channel[IO] + + class ServerImpl(client: TestClient[IO]) extends TestServer[IO] { + def greet(name: String): IO[GreetOutput] = IO.pure(GreetOutput(s"Hello $name")) + + def ping(ping: String): IO[Unit] = { + client.pong(s"Returned to sender: $ping") + } + } + + class Client(ref: SignallingRef[IO, Option[String]]) extends TestClient[IO] { + def pong(pong: String): IO[Unit] = ref.set(Some(pong)) + } + + trait AlgebraWrapper { + type Alg[_[_, _, _, _, _]] + + def algebra: FunctorAlgebra[Alg, IO] + def service: Service[Alg] + } + + object AlgebraWrapper { + def apply[A[_[_, _, _, _, _]]](alg: FunctorAlgebra[A, IO])(implicit srv: Service[A]): AlgebraWrapper = + new AlgebraWrapper { + type Alg[F[_, _, _, _, _]] = A[F] + + val algebra = alg + val service = srv + } + } + + def setup(mkServer: FS2Channel[IO] => AlgebraWrapper) = + setupAux(None, mkServer.andThen(Seq(_)), _ => Seq.empty) + + def setup(mkServer: FS2Channel[IO] => AlgebraWrapper, mkClient: FS2Channel[IO] => AlgebraWrapper) = + setupAux(None, mkServer.andThen(Seq(_)), mkClient.andThen(Seq(_))) + + def setup[Alg[_[_, _, _, _, _]]]( + cancelTemplate: CancelTemplate, + mkServer: FS2Channel[IO] => Seq[AlgebraWrapper], + mkClient: FS2Channel[IO] => Seq[AlgebraWrapper] + ) = setupAux(Some(cancelTemplate), mkServer, mkClient) + + def setupAux[Alg[_[_, _, _, _, _]]]( + cancelTemplate: Option[CancelTemplate], + mkServer: FS2Channel[IO] => Seq[AlgebraWrapper], + mkClient: FS2Channel[IO] => Seq[AlgebraWrapper] + ): Stream[IO, ClientSideChannel] = { + for { + serverSideChannel <- FS2Channel.stream[IO](cancelTemplate = cancelTemplate) + clientSideChannel <- FS2Channel.stream[IO](cancelTemplate = cancelTemplate) + serverChannelWithEndpoints <- serverSideChannel.withEndpointsStream(mkServer(serverSideChannel).flatMap { p => + ServerEndpoints(p.algebra)(p.service, Monadic[IO]) + }) + clientChannelWithEndpoints <- clientSideChannel.withEndpointsStream(mkClient(clientSideChannel).flatMap { p => + ServerEndpoints(p.algebra)(p.service, Monadic[IO]) + }) + _ <- Stream(()) + .concurrently(clientChannelWithEndpoints.output.through(serverChannelWithEndpoints.input)) + .concurrently(serverChannelWithEndpoints.output.through(clientChannelWithEndpoints.input)) + } yield { + clientSideChannel + } + } + + testRes("Round trip") { + implicit val greetInputEncoder: Encoder[GreetInput] = CirceJsonCodec.fromSchema + implicit val greetOutputDecoder: Decoder[GreetOutput] = CirceJsonCodec.fromSchema + + for { + clientSideChannel <- setup(channel => { + val testClient = ClientStub(TestClient, channel) + AlgebraWrapper(new ServerImpl(testClient)) + }) + remoteFunction = clientSideChannel.simpleStub[GreetInput, GreetOutput]("greet") + result <- remoteFunction(GreetInput("Bob")).toStream + } yield { + expect.same(result.message, "Hello Bob") + } + } + + testRes("notification both ways") { + implicit val greetInputEncoder: Encoder[PingInput] = CirceJsonCodec.fromSchema + + for { + ref <- SignallingRef[IO, Option[String]](none).toStream + clientSideChannel <- setup( + channel => { + val testClient = ClientStub(TestClient, channel) + AlgebraWrapper(new ServerImpl(testClient)) + }, + _ => AlgebraWrapper(new Client(ref)) + ) + remoteFunction = clientSideChannel.notificationStub[PingInput]("ping") + _ <- remoteFunction(PingInput("hi server")).toStream + result <- ref.discrete.dropWhile(_.isEmpty).take(1) + } yield { + expect.same(result, "Returned to sender: hi server".some) + } + } + + testRes("server returns known error") { + implicit val greetInputEncoder: Encoder[GreetInput] = CirceJsonCodec.fromSchema + implicit val greetOutputDecoder: Decoder[GreetOutput] = CirceJsonCodec.fromSchema + implicit val greetErrorEncoder: Encoder[TestServerOperation.GreetError] = CirceJsonCodec.fromSchema + + for { + clientSideChannel <- setup(_ => { + AlgebraWrapper(new TestServer[IO] { + override def greet(name: String): IO[GreetOutput] = IO.raiseError(NotWelcomeError(s"$name is not welcome")) + + override def ping(ping: String): IO[Unit] = ??? + }) + }) + remoteFunction = clientSideChannel.simpleStub[GreetInput, GreetOutput]("greet") + result <- remoteFunction(GreetInput("Alice")).attempt.toStream + } yield { + matches(result) { case Left(t: ErrorPayload) => + expect.same(t.code, 0) && + expect.same(t.message, "test.NotWelcomeError(Alice is not welcome)") && + expect.same( + t.data, + Payload(greetErrorEncoder.apply(GreetError.notWelcomeError(NotWelcomeError(s"Alice is not welcome")))).some + ) + } + } + } + + testRes("server returns unknown error") { + implicit val greetInputEncoder: Encoder[GreetInput] = CirceJsonCodec.fromSchema + implicit val greetOutputDecoder: Decoder[GreetOutput] = CirceJsonCodec.fromSchema + + for { + clientSideChannel <- setup(_ => { + AlgebraWrapper(new TestServer[IO] { + override def greet(name: String): IO[GreetOutput] = IO.raiseError(new RuntimeException("some other error")) + + override def ping(ping: String): IO[Unit] = ??? + }) + }) + remoteFunction = clientSideChannel.simpleStub[GreetInput, GreetOutput]("greet") + result <- remoteFunction(GreetInput("Alice")).attempt.toStream + } yield { + matches(result) { case Left(t: ErrorPayload) => + expect.same(t.code, 0) && + expect.same(t.message, "ServerInternalError: some other error") && + expect.same(t.data, none) + } + } + } +}