Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ inThisBuild(
)

val scala213 = "2.13.16"
val scala3 = "3.3.5"
val scala3 = "3.3.6"
val jdkVersion = 11
val allScalaVersions = List(scala213, scala3)
val jvmScalaVersions = allScalaVersions
Expand All @@ -40,7 +40,7 @@ val commonSettings = Seq(
case Some((2, _)) => Seq(s"-target:jvm-$jdkVersion")
case _ => Seq(s"-java-output-version:$jdkVersion")
}
}
},
)

val commonJvmSettings = Seq(
Expand Down Expand Up @@ -136,6 +136,23 @@ val smithy4s = projectMatrix
buildTimeProtocolDependency
)

val smithy4sTests = projectMatrix
.in(file("modules") / "smithy4sTests")
.jvmPlatform(jvmScalaVersions, commonJvmSettings)
.jsPlatform(jsScalaVersions)
.nativePlatform(Seq(scala3))
.disablePlugins(AssemblyPlugin)
.enablePlugins(Smithy4sCodegenPlugin)
.dependsOn(smithy4s, fs2 % Test)
.settings(
commonSettings,
publish / skip := true,
libraryDependencies ++= Seq(
"io.circe" %%% "circe-generic" % "0.14.7"
),
buildTimeProtocolDependency
)

val exampleServer = projectMatrix
.in(file("modules") / "examples/server")
.jvmPlatform(List(scala213), commonJvmSettings)
Expand Down Expand Up @@ -235,6 +252,7 @@ val root = project
exampleClient,
smithy,
smithy4s,
smithy4sTests,
exampleSmithyShared,
exampleSmithyServer,
exampleSmithyClient
Expand Down
12 changes: 7 additions & 5 deletions modules/core/src/main/scala/jsonrpclib/Channel.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
package jsonrpclib

import io.circe.Codec
import jsonrpclib.ErrorCodec.errorPayloadCodec
import io.circe.Encoder
import io.circe.Decoder

trait Channel[F[_]] {
def mountEndpoint(endpoint: Endpoint[F]): F[Unit]
def unmountEndpoint(method: String): F[Unit]

def notificationStub[In: Codec](method: String): In => F[Unit]
def simpleStub[In: Codec, Out: Codec](method: String): In => F[Out]
def stub[In: Codec, Err: ErrorCodec, Out: Codec](method: String): In => F[Either[Err, Out]]
def notificationStub[In: Encoder](method: String): In => F[Unit]
def simpleStub[In: Encoder, Out: Decoder](method: String): In => F[Out]
def stub[In: Encoder, Err: ErrorDecoder, Out: Decoder](method: String): In => F[Either[Err, Out]]
def stub[In, Err, Out](template: StubTemplate[In, Err, Out]): In => F[Either[Err, Out]]
}

Expand All @@ -27,7 +29,7 @@ object Channel {
(in: In) => F.doFlatMap(stub(in))(unit => F.doPure(Right(unit)))
}

final def simpleStub[In: Codec, Out: Codec](method: String): In => F[Out] = {
final def simpleStub[In: Encoder, Out: Decoder](method: String): In => F[Out] = {
val s = stub[In, ErrorPayload, Out](method)
(in: In) =>
F.doFlatMap(s(in)) {
Expand Down
43 changes: 27 additions & 16 deletions modules/core/src/main/scala/jsonrpclib/Endpoint.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package jsonrpclib

import io.circe.Codec
import jsonrpclib.ErrorCodec.errorPayloadCodec
import io.circe.Decoder
import io.circe.Encoder

sealed trait Endpoint[F[_]] {
def method: String

def mapK[G[_]](f: PolyFunction[F, G]): Endpoint[G]
}

object Endpoint {
Expand All @@ -16,44 +20,51 @@ object Endpoint {
class PartiallyAppliedEndpoint[F[_]](method: MethodPattern) {
def apply[In, Err, Out](
run: In => F[Either[Err, Out]]
)(implicit inCodec: Codec[In], errCodec: ErrorCodec[Err], outCodec: Codec[Out]): Endpoint[F] =
RequestResponseEndpoint(method, (_: InputMessage, in: In) => run(in), inCodec, errCodec, outCodec)
)(implicit inCodec: Decoder[In], errEncoder: ErrorEncoder[Err], outCodec: Encoder[Out]): Endpoint[F] =
RequestResponseEndpoint(method, (_: InputMessage, in: In) => run(in), inCodec, errEncoder, outCodec)

def full[In, Err, Out](
run: (InputMessage, In) => F[Either[Err, Out]]
)(implicit inCodec: Codec[In], errCodec: ErrorCodec[Err], outCodec: Codec[Out]): Endpoint[F] =
RequestResponseEndpoint(method, run, inCodec, errCodec, outCodec)
)(implicit inCodec: Decoder[In], errEncoder: ErrorEncoder[Err], outCodec: Encoder[Out]): Endpoint[F] =
RequestResponseEndpoint(method, run, inCodec, errEncoder, outCodec)

def simple[In, Out](
run: In => F[Out]
)(implicit F: Monadic[F], inCodec: Codec[In], outCodec: Codec[Out]) =
)(implicit F: Monadic[F], inCodec: Decoder[In], outCodec: Encoder[Out]) =
apply[In, ErrorPayload, Out](in =>
F.doFlatMap(F.doAttempt(run(in))) {
case Left(error) => F.doPure(Left(ErrorPayload(0, error.getMessage(), None)))
case Right(value) => F.doPure(Right(value))
}
)

def notification[In](run: In => F[Unit])(implicit inCodec: Codec[In]): Endpoint[F] =
def notification[In](run: In => F[Unit])(implicit inCodec: Decoder[In]): Endpoint[F] =
NotificationEndpoint(method, (_: InputMessage, in: In) => run(in), inCodec)

def notificationFull[In](run: (InputMessage, In) => F[Unit])(implicit inCodec: Codec[In]): Endpoint[F] =
def notificationFull[In](run: (InputMessage, In) => F[Unit])(implicit inCodec: Decoder[In]): Endpoint[F] =
NotificationEndpoint(method, run, inCodec)

}

final case class NotificationEndpoint[F[_], In](
private[jsonrpclib] final case class NotificationEndpoint[F[_], In](
method: MethodPattern,
run: (InputMessage, In) => F[Unit],
inCodec: Codec[In]
) extends Endpoint[F]
inCodec: Decoder[In]
) extends Endpoint[F] {

def mapK[G[_]](f: PolyFunction[F, G]): Endpoint[G] =
NotificationEndpoint[G, In](method, (msg, in) => f(run(msg, in)), inCodec)
}

final case class RequestResponseEndpoint[F[_], In, Err, Out](
private[jsonrpclib] final case class RequestResponseEndpoint[F[_], In, Err, Out](
method: Method,
run: (InputMessage, In) => F[Either[Err, Out]],
inCodec: Codec[In],
errCodec: ErrorCodec[Err],
outCodec: Codec[Out]
) extends Endpoint[F]
inCodec: Decoder[In],
errEncoder: ErrorEncoder[Err],
outCodec: Encoder[Out]
) extends Endpoint[F] {

def mapK[G[_]](f: PolyFunction[F, G]): Endpoint[G] =
RequestResponseEndpoint[G, In, Err, Out](method, (msg, in) => f(run(msg, in)), inCodec, errEncoder, outCodec)
}
}
9 changes: 6 additions & 3 deletions modules/core/src/main/scala/jsonrpclib/ErrorCodec.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package jsonrpclib

trait ErrorCodec[E] {

trait ErrorEncoder[E] {
def encode(a: E): ErrorPayload
def decode(error: ErrorPayload): Either[ProtocolError, E]
}

trait ErrorDecoder[E] {
def decode(error: ErrorPayload): Either[ProtocolError, E]
}

trait ErrorCodec[E] extends ErrorDecoder[E] with ErrorEncoder[E]

object ErrorCodec {
implicit val errorPayloadCodec: ErrorCodec[ErrorPayload] = new ErrorCodec[ErrorPayload] {
def encode(a: ErrorPayload): ErrorPayload = a
Expand Down
15 changes: 15 additions & 0 deletions modules/core/src/main/scala/jsonrpclib/Monadic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,19 @@ object Monadic {

override def doMap[A, B](fa: Future[A])(f: A => B): Future[B] = fa.map(f)
}

object syntax {
implicit class MonadicOps[F[_], A](fa: F[A]) {
def flatMap[B](f: A => F[B])(implicit m: Monadic[F]): F[B] = m.doFlatMap(fa)(f)
def map[B](f: A => B)(implicit m: Monadic[F]): F[B] = m.doMap(fa)(f)
def attempt[B](implicit m: Monadic[F]): F[Either[Throwable, A]] = m.doAttempt(fa)
def void(implicit m: Monadic[F]): F[Unit] = m.doVoid(fa)
}
implicit class MonadicOpsPure[A](a: A) {
def pure[F[_]](implicit m: Monadic[F]): F[A] = m.doPure(a)
}
implicit class MonadicOpsThrowable(t: Throwable) {
def raiseError[F[_], A](implicit m: Monadic[F]): F[A] = m.doRaiseError(t)
}
}
}
5 changes: 5 additions & 0 deletions modules/core/src/main/scala/jsonrpclib/PolyFunction.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package jsonrpclib

trait PolyFunction[F[_], G[_]] { self =>
def apply[A0](fa: => F[A0]): G[A0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

smort 👍

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.Promise
import scala.util.Try
import io.circe.Codec
import io.circe.Encoder

abstract class FutureBasedChannel(endpoints: List[Endpoint[Future]])(implicit ec: ExecutionContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ import jsonrpclib.Endpoint.RequestResponseEndpoint
import jsonrpclib.OutputMessage.ErrorMessage
import jsonrpclib.OutputMessage.ResponseMessage
import scala.util.Try
import io.circe.Codec
import io.circe.HCursor
import io.circe.Encoder
import io.circe.Decoder

private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F]) extends Channel.MonadicChannel[F] {

Expand All @@ -22,21 +23,21 @@ private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F
protected def storePendingCall(callId: CallId, handle: OutputMessage => F[Unit]): F[Unit]
protected def removePendingCall(callId: CallId): F[Option[OutputMessage => F[Unit]]]

def notificationStub[In](method: String)(implicit inCodec: Codec[In]): In => F[Unit] = { (input: In) =>
def notificationStub[In](method: String)(implicit inCodec: Encoder[In]): In => F[Unit] = { (input: In) =>
val encoded = inCodec(input)
val message = InputMessage.NotificationMessage(method, Some(Payload(encoded)))
sendMessage(message)
}

def stub[In, Err, Out](
method: String
)(implicit inCodec: Codec[In], errCodec: ErrorCodec[Err], outCodec: Codec[Out]): In => F[Either[Err, Out]] = {
)(implicit inCodec: Encoder[In], errDecoder: ErrorDecoder[Err], outCodec: Decoder[Out]): In => F[Either[Err, Out]] = {
(input: In) =>
val encoded = inCodec(input)
doFlatMap(nextCallId()) { callId =>
val message = InputMessage.RequestMessage(method, callId, Some(Payload(encoded)))
doFlatMap(createPromise[Either[Err, Out]](callId)) { case (fulfill, future) =>
val pc = createPendingCall(errCodec, outCodec, fulfill)
val pc = createPendingCall(errDecoder, outCodec, fulfill)
doFlatMap(storePendingCall(callId, pc))(_ => doFlatMap(sendMessage(message))(_ => future()))
}
}
Expand Down Expand Up @@ -80,13 +81,17 @@ private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F
case (InputMessage.RequestMessage(_, callId, Some(params)), ep: RequestResponseEndpoint[F, in, err, out]) =>
ep.inCodec(HCursor.fromJson(params.data)) match {
case Right(value) =>
doFlatMap(ep.run(input, value)) {
case Right(data) =>
doFlatMap(doAttempt(ep.run(input, value))) {
case Right(Right(data)) =>
val responseData = ep.outCodec(data)
sendMessage(OutputMessage.ResponseMessage(callId, Payload(responseData)))
case Left(error) =>
val errorPayload = ep.errCodec.encode(error)
case Right(Left(error)) =>
val errorPayload = ep.errEncoder.encode(error)
sendMessage(OutputMessage.ErrorMessage(callId, errorPayload))
case Left(err) =>
sendMessage(
OutputMessage.ErrorMessage(callId, ErrorPayload(0, s"ServerInternalError: ${err.getMessage}", None))
)
}
case Left(pError) =>
sendProtocolError(callId, ProtocolError.ParseError(pError.getMessage))
Expand All @@ -111,13 +116,13 @@ private[jsonrpclib] abstract class MessageDispatcher[F[_]](implicit F: Monadic[F
}

private def createPendingCall[Err, Out](
errCodec: ErrorCodec[Err],
outCodec: Codec[Out],
errDecoder: ErrorDecoder[Err],
outCodec: Decoder[Out],
fulfill: Try[Either[Err, Out]] => F[Unit]
): OutputMessage => F[Unit] = { (message: OutputMessage) =>
message match {
case ErrorMessage(_, errorPayload) =>
errCodec.decode(errorPayload) match {
errDecoder.decode(errorPayload) match {
case Left(_) => fulfill(scala.util.Failure(errorPayload))
case Right(value) => fulfill(scala.util.Success(Left(value)))
}
Expand Down
1 change: 0 additions & 1 deletion modules/core/src/main/scala/jsonrpclib/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@ package object jsonrpclib {

type ErrorCode = Int
type ErrorMessage = String

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import smithy4s.codecs.PayloadPath
import smithy4s.Document.{Decoder => _, _}
import io.circe._

private[jsonrpclib] object CirceJson {
object CirceJsonCodec {

def fromSchema[A](implicit schema: Schema[A]): Codec[A] = Codec.from(
c => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ private class ClientStub[Alg[_[_, _, _, _, _]], F[_]: Monadic](val service: Serv
endpointSpec: EndpointSpec
): I => F[O] = {

implicit val inputCodec: Codec[I] = CirceJson.fromSchema(smithy4sEndpoint.input)
implicit val outputCodec: Codec[O] = CirceJson.fromSchema(smithy4sEndpoint.output)
implicit val inputCodec: Codec[I] = CirceJsonCodec.fromSchema(smithy4sEndpoint.input)
implicit val outputCodec: Codec[O] = CirceJsonCodec.fromSchema(smithy4sEndpoint.output)

endpointSpec match {
case EndpointSpec.Notification(methodName) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@ import smithy4s.Service
import smithy4s.kinds.FunctorAlgebra
import smithy4s.kinds.FunctorInterpreter
import jsonrpclib.Monadic
import jsonrpclib.Payload
import jsonrpclib.ErrorPayload
import io.circe.Codec
import jsonrpclib.Monadic.syntax._
import jsonrpclib.ErrorEncoder
import smithy4s.schema.ErrorSchema

object ServerEndpoints {

Expand All @@ -24,27 +29,53 @@ object ServerEndpoints {
}
}

// TODO : codify errors at smithy level and handle them.
def jsonRPCEndpoint[F[_]: Monadic, Op[_, _, _, _, _], I, E, O, SI, SO](
smithy4sEndpoint: Smithy4sEndpoint[Op, I, E, O, SI, SO],
endpointSpec: EndpointSpec,
impl: FunctorInterpreter[Op, F]
): Endpoint[F] = {

implicit val inputCodec: Codec[I] = CirceJson.fromSchema(smithy4sEndpoint.input)
implicit val outputCodec: Codec[O] = CirceJson.fromSchema(smithy4sEndpoint.output)
implicit val inputCodec: Codec[I] = CirceJsonCodec.fromSchema(smithy4sEndpoint.input)
implicit val outputCodec: Codec[O] = CirceJsonCodec.fromSchema(smithy4sEndpoint.output)

def errorResponse(throwable: Throwable): F[E] = throwable match {
case smithy4sEndpoint.Error((_, e)) => e.pure
case e: Throwable => e.raiseError
}

endpointSpec match {
case EndpointSpec.Notification(methodName) =>
Endpoint[F](methodName).notification { (input: I) =>
val op = smithy4sEndpoint.wrap(input)
Monadic[F].doVoid(impl(op))
impl(op).void
}
case EndpointSpec.Request(methodName) =>
Endpoint[F](methodName).simple { (input: I) =>
val op = smithy4sEndpoint.wrap(input)
impl(op)
smithy4sEndpoint.error match {
case None =>
Endpoint[F](methodName).simple[I, O] { (input: I) =>
val op = smithy4sEndpoint.wrap(input)
impl(op)
}
case Some(errorSchema) =>
implicit val errorCodec: ErrorEncoder[E] = errorCodecFromSchema(errorSchema)
Endpoint[F](methodName).apply[I, E, O] { (input: I) =>
val op = smithy4sEndpoint.wrap(input)
impl(op).attempt.flatMap {
case Left(err) => errorResponse(err).map(r => Left(r): Either[E, O])
case Right(success) => (Right(success): Either[E, O]).pure
}
}
}
}
}

private def errorCodecFromSchema[A](s: ErrorSchema[A]): ErrorEncoder[A] = {
val circeCodec = CirceJsonCodec.fromSchema(s.schema)
(a: A) =>
ErrorPayload(
0,
Option(s.unliftError(a).getMessage()).getOrElse("JSONRPC-smithy4s application error"),
Some(Payload(circeCodec.apply(a)))
)
}
}
Loading