Skip to content

Commit 3cab2b8

Browse files
committed
Refactor annotation-based API
Generate code based on the new subparser API
1 parent f902c2b commit 3cab2b8

File tree

40 files changed

+1986
-1209
lines changed

40 files changed

+1986
-1209
lines changed

argparse/sandbox/src/example.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// A sandbox to experiment with argarse
2+
3+
object api
4+
extends argparse.core.ParsersApi
5+
with argparse.core.MacroApi
6+
with argparse.core.TypesApi
7+
with argparse.core.ReadersApi:
8+
9+
override def defaultHelpFlags = Seq("-h")
10+
11+
object app:
12+
13+
/**
14+
* Hello world
15+
*
16+
* @param base the base value
17+
*/
18+
@api.command()
19+
class wrapper(base: Int):
20+
@api.command()
21+
def add(x: Int) = println(base + x)
22+
23+
/** A nested command */
24+
@api.command()
25+
def nested(y: Int = 2) = foo(y)
26+
class foo(y: Int):
27+
@api.command()
28+
def ok() = println("ok")
29+
30+
def main(args: Array[String]): Unit = argparse.main(this, args)
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
package argparse
2+
3+
import argparse.core.DocComment
4+
import argparse.core.TextUtils
5+
6+
case class Command[A](
7+
name: String,
8+
makeParser: (() => A) => core.ParsersApi#ArgumentParser
9+
):
10+
type Container = A
11+
12+
inline def main[Container](instance: Container, args: Iterable[String]): Unit = ${
13+
CommandMacros.mainImpl[Container]('instance, 'args)
14+
}
15+
16+
object Command:
17+
18+
inline def find[Container]: Seq[Command[Container]] = ${
19+
CommandMacros.findImpl[Container]
20+
}
21+
22+
object CommandMacros:
23+
import quoted.Expr
24+
import quoted.Quotes
25+
import quoted.Type
26+
27+
def mainImpl[Container: Type](using qctx: Quotes)(instance: Expr[Container], args: Expr[Iterable[String]]) =
28+
import qctx.reflect.*
29+
findAllImpl[Container] match
30+
case Nil =>
31+
report.error(s"No main method found in ${TypeRepr.of[Container].show}. The container object must contain exactly one method annotated with @command")
32+
'{???}
33+
case head :: Nil =>
34+
'{
35+
val parser = $head.makeParser(() => $instance)
36+
parser.parseOrExit($args)
37+
}
38+
case list =>
39+
report.error(s"Too many main methods found in ${TypeRepr.of[Container].show}. The container object must contain exactly one method annotated with @command")
40+
'{???}
41+
42+
def findImpl[Container: Type](using qctx: Quotes) =
43+
Expr.ofList(findAllImpl[Container])
44+
45+
def findAllImpl[Container: Type](using qctx: Quotes): List[Expr[Command[Container]]] =
46+
import qctx.reflect.*
47+
val CommandAnnot = TypeRepr.of[core.MacroApi#command]
48+
49+
val methods = TypeRepr.of[Container].typeSymbol.memberMethods
50+
val classes = TypeRepr.of[Container].typeSymbol.memberTypes.filter(_.isClassDef)
51+
52+
for
53+
sym <- (methods ++ classes)
54+
annots = sym.annotations
55+
annot = annots.find(_.tpe <:< CommandAnnot)
56+
if annot.isDefined
57+
yield
58+
val TypeRef(apiType: TermRef, _) = annot.get.tpe
59+
val api = Ref.term(apiType)
60+
61+
val method =
62+
if sym.isClassDef then
63+
sym.primaryConstructor
64+
else
65+
sym
66+
67+
val doc = DocComment.extract(sym.docstring.getOrElse(""))
68+
69+
apiType.asType match
70+
case '[t] if TypeRepr.of[t] <:< TypeRepr.of[core.MacroApi] =>
71+
makeCommand[Container, core.MacroApi](api.asExprOf[core.MacroApi], method, sym.name, doc)
72+
case '[t] =>
73+
report.error(s"wrong API ${Type.show[t]}")
74+
'{???}
75+
76+
def getDefaultParams(using qctx: Quotes)(instance: qctx.reflect.Term, method: qctx.reflect.Symbol): Map[qctx.reflect.Symbol, qctx.reflect.Term] =
77+
import qctx.reflect.*
78+
79+
val pairs = for
80+
(param, idx) <- method.paramSymss.flatten.zipWithIndex
81+
if (param.flags.is(Flags.HasDefault))
82+
yield {
83+
val tree = if method.isClassConstructor then
84+
val defaultName = s"$$lessinit$$greater$$default$$${idx + 1}"
85+
Select(
86+
Select(instance, method.owner.companionModule),
87+
method.owner.companionModule.memberMethod(defaultName).head
88+
)
89+
else
90+
val defaultName = s"${method.name}$$default$$${idx + 1}"
91+
Select(instance, method.owner.memberMethod(defaultName).head)
92+
param -> tree
93+
}
94+
pairs.toMap
95+
96+
// term.apply(List(argss(0)(0),...,argss(0)(N)), ..., List(argss(M)(0),...,argss(M)(N)))
97+
def call(using qctx: Quotes)(
98+
term: qctx.reflect.Term,
99+
paramss: List[List[qctx.reflect.TypeRepr]],
100+
argss: Expr[Seq[Seq[?]]]
101+
): qctx.reflect.Term =
102+
import qctx.reflect.*
103+
104+
val accesses =
105+
for i <- paramss.indices.toList yield
106+
for j <- paramss(i).indices.toList yield
107+
paramss(i)(j).asType match
108+
case '[t] =>
109+
'{$argss(${Expr(i)})(${Expr(j)}).asInstanceOf[t]}.asTerm
110+
111+
val application = accesses.foldLeft(term)((lhs, args) => Apply(lhs, args))
112+
application
113+
114+
def makeCommand[Container: Type, Api <: core.MacroApi: Type](using qctx: Quotes)(
115+
api: Expr[Api],
116+
method: qctx.reflect.Symbol,
117+
name: String, // name is separate because method name is not always representative (e.g. if method is class constructor)
118+
doc: DocComment
119+
): Expr[Command[Container]] =
120+
import qctx.reflect.*
121+
122+
val rtpe = method.tree.asInstanceOf[DefDef].returnTpt.tpe
123+
val ptpes = method.paramSymss.map(_.map(_.tree.asInstanceOf[ValDef].tpt.tpe))
124+
val inner = rtpe.asType match
125+
case '[t] => findAllImpl[t]
126+
127+
val makeParser = '{
128+
(instance: () => Container) =>
129+
val parser = $api.ArgumentParser(description = ${Expr(doc.paragraphs.mkString("\n"))})
130+
131+
val args: Seq[Seq[() => ?]] = ${
132+
val defaults = getDefaultParams(using qctx)('{instance()}.asTerm, method)
133+
134+
val accessors =
135+
for paramList <- method.paramSymss yield
136+
val ls = for param <- paramList yield
137+
val argAnnot: Expr[argparse.arg] = param.getAnnotation(TypeRepr.of[argparse.arg].typeSymbol) match
138+
case Some(a) => a.asExprOf[argparse.arg]
139+
case None => '{argparse.arg()} // use default arg() values
140+
// TODO: replace with `param.termRef.widenTermRefByName` when upgrading scala version
141+
val paramTpe = param.tree.asInstanceOf[ValDef].tpt.tpe
142+
143+
def summonReader(tpe: TypeRepr): Term =
144+
val readerType =
145+
TypeSelect(
146+
api.asTerm,
147+
"Reader"
148+
).tpe.appliedTo(List(tpe))
149+
Implicits.search(readerType) match
150+
case iss: ImplicitSearchSuccess => iss.tree
151+
case other =>
152+
report.error( s"No ${readerType.show} available for parameter ${param.name}.", param.pos.get)
153+
'{???}.asTerm
154+
155+
paramTpe match
156+
case t if t =:= TermRef(api.asTerm.tpe, "ArgumentParser") =>
157+
'{() => parser}
158+
case t if t <:< TypeRepr.of[Iterable[?]] =>
159+
val AppliedType(_, List(inner)) = paramTpe.dealias
160+
val reader = summonReader(inner)
161+
defaults.get(param) match
162+
case Some(default) => // --named repeated
163+
'{
164+
val p = $api
165+
val arg = parser.asInstanceOf[p.ArgumentParser].repeatedParam[Any](
166+
name =
167+
$argAnnot.name match
168+
case null => TextUtils.kebabify(${Expr(s"--${param.name}")})
169+
case other => other,
170+
aliases = $argAnnot.aliases,
171+
help = ${Expr(doc.params.getOrElse(param.name, ""))},
172+
flag = ${Expr(paramTpe =:= TypeRepr.of[Boolean])},
173+
endOfNamed = false,
174+
interactiveCompleter = $argAnnot.interactiveCompleter,
175+
standaloneCompleter = $argAnnot.standaloneCompleter
176+
)(using ${reader.asExpr}.asInstanceOf[p.Reader[Any]])
177+
() => arg.value
178+
}
179+
case None => // positional repeated
180+
'{
181+
val p = $api
182+
val arg = parser.asInstanceOf[p.ArgumentParser].repeatedParam[Any](
183+
name =
184+
$argAnnot.name match
185+
case null => TextUtils.kebabify(${Expr(param.name)})
186+
case other => other,
187+
aliases = $argAnnot.aliases,
188+
help = ${Expr(doc.params.getOrElse(param.name, ""))},
189+
flag = false,
190+
endOfNamed = false,
191+
interactiveCompleter = $argAnnot.interactiveCompleter,
192+
standaloneCompleter = $argAnnot.standaloneCompleter
193+
)(using ${reader.asExpr}.asInstanceOf[p.Reader[Any]])
194+
() => arg.value
195+
}
196+
case t =>
197+
val reader = summonReader(t)
198+
199+
defaults.get(param) match
200+
case Some(default) => // --named
201+
'{
202+
val p = $api
203+
val arg = parser.asInstanceOf[p.ArgumentParser].singleParam[Any](
204+
name =
205+
$argAnnot.name match
206+
case null => TextUtils.kebabify(${Expr(s"--${param.name}")})
207+
case other => other,
208+
default = Some(() => ${default.asExpr}),
209+
env = Option($argAnnot.env),
210+
aliases = $argAnnot.aliases,
211+
help = ${Expr(doc.params.getOrElse(param.name, ""))},
212+
flag = ${Expr(paramTpe =:= TypeRepr.of[Boolean])},
213+
endOfNamed = false,
214+
interactiveCompleter = Option($argAnnot.interactiveCompleter),
215+
standaloneCompleter = Option($argAnnot.standaloneCompleter),
216+
argName = None
217+
)(using ${reader.asExpr}.asInstanceOf[p.Reader[Any]])
218+
() => arg.value
219+
}
220+
case None => // positional
221+
'{
222+
val p = $api
223+
val arg = parser.asInstanceOf[p.ArgumentParser].singleParam[Any](
224+
name =
225+
$argAnnot.name match
226+
case null => TextUtils.kebabify(${Expr(param.name)})
227+
case other => other,
228+
default = None,
229+
env = Option($argAnnot.env),
230+
aliases = $argAnnot.aliases,
231+
help = ${Expr(doc.params.getOrElse(param.name, ""))},
232+
flag = false,
233+
endOfNamed = false,
234+
interactiveCompleter = Option($argAnnot.interactiveCompleter),
235+
standaloneCompleter = Option($argAnnot.standaloneCompleter),
236+
argName = None
237+
)(using ${reader.asExpr}.asInstanceOf[p.Reader[Any]])
238+
() => arg.value
239+
}
240+
Expr.ofSeq(ls)
241+
end for
242+
Expr.ofSeq(accessors)
243+
}
244+
245+
def callOrInstantiate() =
246+
val outer = instance()
247+
val results = args.map(_.map(_()))
248+
${
249+
if method.isClassConstructor then
250+
call(using qctx)(
251+
New(TypeSelect('{outer}.asTerm, rtpe.typeSymbol.name)).select(method),
252+
ptpes,
253+
'results
254+
).asExpr
255+
else
256+
call(using qctx)(
257+
Select('{outer}.asTerm, method),
258+
ptpes,
259+
'results
260+
).asExpr
261+
}
262+
263+
${
264+
if inner.isEmpty then
265+
'{
266+
parser.action{callOrInstantiate()}
267+
parser
268+
}
269+
else
270+
'{
271+
val commands = ${Expr.ofList(inner)}
272+
for cmd <- commands do
273+
parser.addSubparser(
274+
cmd.name,
275+
cmd.makeParser(() => callOrInstantiate().asInstanceOf[cmd.Container])
276+
)
277+
parser
278+
}
279+
}
280+
}
281+
'{
282+
Command(
283+
${Expr(TextUtils.kebabify(name))},
284+
$makeParser
285+
)
286+
}
287+
288+
end makeCommand
Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,24 @@
11
package argparse
22

3+
/** Annotate a method parameter with this annotation to override some aspects of
4+
* macro-generated code.
5+
*
6+
* In general, `null` means to let the macro generate code, and is usually the
7+
* default.
8+
*
9+
* @param name Override the name of the parameter. Note that the name will be
10+
* used as-is. In particular, this means that you need to specify leading
11+
* dashes for named parameters.
12+
*
13+
* @param aliases Set aliases for the parameter.
14+
*
15+
* @param aliases Set the environment variable from which this parameter may be
16+
* read if not specified on the command line.
17+
*/
318
case class arg(
19+
name: String = null,
420
aliases: Seq[String] = Seq(),
5-
env: String = null
21+
env: String = null,
22+
interactiveCompleter: String => Seq[String] = null,
23+
standaloneCompleter: BashCompleter = null
624
) extends annotation.StaticAnnotation

0 commit comments

Comments
 (0)