Skip to content

Commit 58f403a

Browse files
authored
fix indentations and optional enum types (#16)
* fix indentations and optional enum types
1 parent a394056 commit 58f403a

File tree

4 files changed

+38783
-31075
lines changed

4 files changed

+38783
-31075
lines changed

build.sbt

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -229,22 +229,7 @@ def codegenTask(
229229
}
230230

231231
val files = listFilesRec(List(outDir), Nil)
232-
233-
// formatting (may need to find another way...)
234-
val fmtCmd = s"scala-cli fmt --scalafmt-conf=./.scalafmt.conf ${outDir.absolutePath}"
235-
logger.info(s"Formatting with '$fmtCmd'")
236-
val fmtErrs = scala.collection.mutable.ListBuffer.empty[String]
237-
fmtCmd ! ProcessLogger(
238-
_ => (),
239-
e => fmtErrs += e
240-
) match {
241-
case 0 => ()
242-
case c => throw new InterruptedException(s"Failure on code formatting: ${errs.mkString("\n")}")
243-
}
244-
245232
IO.delete(outDir / ".scala-build")
246-
logger.success(s"Formatting sources in $outPathRel done")
247-
248233
files
249234
}
250235
}

modules/core/shared/src/main/scala/codegen.scala

Lines changed: 34 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,6 @@ def generateBySpec(
255255
pkg = schemasPkg,
256256
objName = commonCodecsObj,
257257
jsonCodec = config.jsonCodec,
258-
dialect = config.dialect,
259258
hasProps = p => specs.hasProps(p),
260259
arrType = config.arrayType
261260
) match
@@ -301,14 +300,11 @@ def generateBySpec(
301300

302301
val scalaKeyWords = Set("type", "import", "val", "object", "enum", "export")
303302

303+
def toScalaTypeName(n: String): String = toScalaName(n.capitalize)
304+
304305
def toScalaName(n: String): String =
305-
n match
306-
// to fix a compiler warning like
307-
// import looks like a language import, but refers to something else: object language in object GoogleCloudAiplatformV1ExecutableCode
308-
case "language" => "Language"
309-
case _ =>
310-
if scalaKeyWords.contains(n) then s"`$n`"
311-
else n.replaceAll("[^a-zA-Z0-9_]", "")
306+
if scalaKeyWords.contains(n) then s"`$n`"
307+
else n.replaceAll("[^a-zA-Z0-9_]", "")
312308

313309
def resourceCode(
314310
rootPkg: String,
@@ -368,14 +364,14 @@ def resourceCode(
368364
case None => None
369365

370366
val (requiredParams, optParams) = method.scalaParameters.partition(_._2.required)
371-
val params =
372-
requiredParams.map((n, t) => s"${toComment(t.description)}$n: ${t.scalaType(arrType)}") :::
373-
req.toList.map(r => s"request: ${r.scalaType(arrType)}") :::
374-
uploadProtocol.toList.map((typ, default) => s"uploadProtocol: $typ = \"$default\"") :::
375-
optParams.map((n, t) => s"${toComment(t.description)}$n: ${t.scalaType(arrType)} = None") :::
367+
def params(indent: String) =
368+
requiredParams.map((n, t) => s"${toComment(t.description)}$indent$n: ${t.scalaType(arrType)}") :::
369+
req.toList.map(r => s"${indent}request: ${r.scalaType(arrType)}") :::
370+
uploadProtocol.toList.map((typ, default) => s"${indent}uploadProtocol: $typ = \"$default\"") :::
371+
optParams.map((n, t) => s"${toComment(t.description)}$indent$n: ${t.scalaType(arrType)} = None") :::
376372
List(
377-
s"endpointUrl: Uri = $rootPkg.baseUrl",
378-
"commonQueryParams: QueryParameters = " + ((
373+
s"${indent}endpointUrl: Uri = $rootPkg.baseUrl",
374+
s"${indent}commonQueryParams: QueryParameters = " + ((
379375
method.mediaUploads,
380376
commonQueryParams.collectFirst { case ("uploadType", Parameter(_, _, e: SchemaType.Enum, _, _)) => e }
381377
) match {
@@ -431,10 +427,10 @@ def resourceCode(
431427
)
432428
case _ => (responseType("String"), ".response(asEmptyResponse)")
433429

434-
s"""|def ${toScalaName(k)}(\n${params.mkString(",\n")}): $resType = {$queryParams
435-
| $setReqUri
436-
| resourceRequest.${method.httpMethod.toLowerCase()}(requestUri.addParams(params))$body$mapResponse
437-
|}""".stripMargin
430+
s"""| def ${toScalaName(k)}(\n${params(indent = " ").mkString(",\n")}\n ): $resType = {$queryParams
431+
| $setReqUri
432+
| resourceRequest.${method.httpMethod.toLowerCase()}(requestUri.addParams(params))$body$mapResponse
433+
| }""".stripMargin
438434
}
439435
.mkString("\n", "\n\n", "\n") +
440436
"}"
@@ -467,16 +463,16 @@ def schemasCode(
467463
if jsonCodec == JsonCodec.Jsoniter then
468464
enums
469465
.map((k, e) =>
470-
s"enum ${toScalaName(k)} {\n${e.values.map(v => s"${toComment(Some(v.enumDescription))} case ${toScalaName(v.value)}").mkString("\n ")}}\n"
466+
s" enum ${toScalaTypeName(k)} {\n${e.values.map(v => s"${toComment(Some(v.enumDescription), " ")} case ${toScalaName(v.value)}").mkString("\n ")}\n }\n"
471467
)
472468
.mkString("\n")
473469
else "",
474470
jsonCodec match
475471
case JsonCodec.ZioJson =>
476-
s"${implicitVal(dialect)} jsonCodec: JsonCodec[$objName] = JsonCodec.derived[$objName]"
472+
s" given jsonCodec: JsonCodec[$objName] = JsonCodec.derived[$objName]"
477473
case JsonCodec.Jsoniter =>
478-
s"""|${implicitVal(dialect)} jsonCodec: JsonValueCodec[$objName] =
479-
| JsonCodecMaker.make(CodecMakerConfig.withAllowRecursiveTypes(true).withDiscriminatorFieldName(None))""".stripMargin,
474+
s"""| given jsonCodec: JsonValueCodec[$objName] =
475+
| JsonCodecMaker.make(CodecMakerConfig.withAllowRecursiveTypes(true).withDiscriminatorFieldName(None))""".stripMargin,
480476
"}"
481477
).mkString("\n")
482478

@@ -488,13 +484,13 @@ def schemasCode(
488484
.map { (n, t) =>
489485
val enumType =
490486
if jsonCodec == JsonCodec.ZioJson then SchemaType.EnumType.Literal
491-
else SchemaType.EnumType.Nominal(s"$scalaName.$n")
492-
s"${toComment(t.withTypeDescription)}$n: ${
487+
else SchemaType.EnumType.Nominal(s"$scalaName.${toScalaTypeName(n)}")
488+
s"${toComment(t.withTypeDescription)} $n: ${
493489
(if (t.optional) s"${t.scalaType(arrType, enumType)} = None" else t.scalaType(arrType, enumType))
494490
}"
495491
}
496492
.mkString("", ",\n", "")}
497-
|) {\n${`def toJsonString`(scalaName)}\n}\n
493+
|) {\n ${`def toJsonString`(scalaName)}\n}\n
498494
|
499495
|${jsonDecoder(scalaName)}
500496
|""".stripMargin
@@ -522,7 +518,6 @@ def commonSchemaCodecs(
522518
pkg: String,
523519
objName: String,
524520
jsonCodec: JsonCodec,
525-
dialect: Dialect,
526521
hasProps: SchemaPath => Boolean,
527522
arrType: ArrayType
528523
): Option[String] = {
@@ -534,7 +529,7 @@ def commonSchemaCodecs(
534529
.collect { case (k, Property(_, SchemaType.Array(typ, _), _)) =>
535530
val enumType =
536531
if jsonCodec == JsonCodec.ZioJson then SchemaType.EnumType.Literal
537-
else SchemaType.EnumType.Nominal(s"${sk.lastOption.getOrElse("")}.$k")
532+
else SchemaType.EnumType.Nominal(s"${sk.lastOption.getOrElse("")}.${toScalaTypeName(k)}")
538533
typ.scalaType(arrType, enumType)
539534
}
540535
)
@@ -552,17 +547,17 @@ def commonSchemaCodecs(
552547
s"object $objName {",
553548
props
554549
.map { t =>
555-
val prefix = implicitVal(dialect) + " " + toScalaName(t + "ChunkCodec")
550+
val prefix = " given " + toScalaName(t + "ChunkCodec")
556551
s"""|${prefix}: JsonValueCodec[Chunk[$t]] = new JsonValueCodec[Chunk[$t]] {
557-
| val arrCodec: JsonValueCodec[Array[$t]] = JsonCodecMaker.make
552+
| val arrCodec: JsonValueCodec[Array[$t]] = JsonCodecMaker.make
558553
|
559-
| override val nullValue: Chunk[$t] = Chunk.empty
554+
| override val nullValue: Chunk[$t] = Chunk.empty
560555
|
561-
| override def decodeValue(in: JsonReader, default: Chunk[$t]): Chunk[$t] =
562-
| Chunk.fromArray(arrCodec.decodeValue(in, default.toArray))
556+
| override def decodeValue(in: JsonReader, default: Chunk[$t]): Chunk[$t] =
557+
| Chunk.fromArray(arrCodec.decodeValue(in, default.toArray))
563558
|
564-
| override def encodeValue(x: Chunk[$t], out: JsonWriter): Unit =
565-
| arrCodec.encodeValue(x.toArray, out)
559+
| override def encodeValue(x: Chunk[$t], out: JsonWriter): Unit =
560+
| arrCodec.encodeValue(x.toArray, out)
566561
|}""".stripMargin
567562
}
568563
.mkString("\n\n"),
@@ -691,7 +686,7 @@ case class Parameter(
691686
required: Boolean = false,
692687
pattern: Option[String] = None
693688
) {
694-
def scalaType(arrType: ArrayType) = typ.withOptional(!required).scalaType(arrType)
689+
def scalaType(arrType: ArrayType): String = typ.withOptional(!required).scalaType(arrType)
695690
}
696691

697692
object Parameter:
@@ -810,7 +805,7 @@ object SchemaType:
810805
read[List[String]](e)
811806
.zip(read[List[String]](o("enumDescriptions")))
812807
.map((v, vd) => EnumValue(value = v, enumDescription = vd)),
813-
false
808+
optional
814809
)
815810
case _ =>
816811
o.value.get("additionalProperties").map(_.obj) match
@@ -842,7 +837,7 @@ object SchemaPath:
842837

843838
extension (s: SchemaPath)
844839
def scalaName: String =
845-
s.filter(!Set("items", "properties").contains(_)).map(_.capitalize).mkString
840+
s.filter(!Set("items", "properties").contains(_)).map(toScalaTypeName(_)).mkString
846841

847842
def add(nested: String): SchemaPath = s.appended(nested)
848843
def hasNested: Boolean = s.size > 1
@@ -927,7 +922,7 @@ object ResourcePath:
927922
def apply(pp: Vector[String], p: String): ResourcePath = pp :+ p
928923
extension (r: ResourcePath)
929924
def add(p: String): ResourcePath = r :+ p
930-
def scalaName: String = r.last.capitalize
925+
def scalaName: String = toScalaTypeName(r.last)
931926
def pkgPath: Vector[String] = r.dropRight(1).map(camelToSnakeCase)
932927
def pkgName(base: String): String = s"$base${if pkgPath.nonEmpty then pkgPath.mkString(".", ".", "") else ""}"
933928
def dirPath(base: Path): Path = base / pkgPath
@@ -967,9 +962,6 @@ def camelToSnakeCase(camelCase: String): String = {
967962
camelCaseRegex.replaceAllIn(camelCase, matched => "_" + matched.group(0).toLowerCase)
968963
}
969964

970-
def implicitVal(dialect: Dialect) = dialect match
971-
case Dialect.Scala3 => "given"
972-
973965
// comment splitted into multipl lines
974966
private def toComment(content: Iterable[String], indent: String = " "): String =
975967
if content.isEmpty then ""

0 commit comments

Comments
 (0)