1+ package wasm .miniwasm
2+
3+ import wasm .ast ._
4+ import wasm .memory ._
5+
6+ import scala .collection .mutable .ArrayBuffer
7+ import scala .collection .mutable .HashMap
8+ import Console .{GREEN , RED , RESET , YELLOW_B , UNDERLINED }
9+
10+ case class Trap () extends Exception
11+
12+ case class ModuleInstance (
13+ defs : List [Definition ],
14+ types : List [FuncLikeType ],
15+ tags : List [FuncType ],
16+ funcs : HashMap [Int , Callable ],
17+ memory : List [RTMemory ] = List (RTMemory ()),
18+ globals : List [RTGlobal ] = List (),
19+ exports : List [Export ] = List ()
20+ )
21+
22+ case class Frame (locals : ArrayBuffer [Value ])
23+
24+ object ModuleInstance {
25+ def apply (module : Module ): ModuleInstance = {
26+ val types = module.definitions
27+ .collect({
28+ case TypeDef (_, ft) => ft
29+ })
30+ .toList
31+ val tags = module.definitions
32+ .collect({
33+ case Tag (id, ty) => ty
34+ })
35+ .toList
36+
37+ val funcs = module.definitions
38+ .collect({
39+ case FuncDef (_, fndef @ FuncBodyDef (_, _, _, _)) => fndef
40+ })
41+ .toList
42+
43+ val globals = module.definitions
44+ .collect({
45+ case Global (_, GlobalValue (ty, e)) =>
46+ (e.head) match {
47+ case Const (c) => RTGlobal (ty, c)
48+ // Q: What is the default behavior if case in non-exhaustive
49+ case _ => ???
50+ }
51+ })
52+ .toList
53+
54+ // TODO: correct the behavior for memory
55+ val memory = module.definitions
56+ .collect({
57+ case Memory (id, MemoryType (min, max_opt)) =>
58+ RTMemory (min, max_opt)
59+ })
60+ .toList
61+
62+ val exports = module.definitions
63+ .collect({
64+ case e @ Export (_, ExportFunc (_)) => e
65+ })
66+ .toList
67+
68+ ModuleInstance (module.definitions, types, tags, module.funcEnv, memory, globals, exports)
69+ }
70+ }
71+
72+ def evalBinOp (op : BinOp , lhs : Value , rhs : Value ): Value = op match
73+ case Add (_) =>
74+ (lhs, rhs) match
75+ case (I32V (v1), I32V (v2)) => I32V (v1 + v2)
76+ case (I64V (v1), I64V (v2)) => I64V (v1 + v2)
77+ case (F32V (v1), F32V (v2)) => F32V (v1 + v2)
78+ case (F64V (v1), F64V (v2)) => F64V (v1 + v2)
79+ case _ => throw new Exception (" Invalid types" )
80+ case Mul (_) =>
81+ (lhs, rhs) match
82+ case (I32V (v1), I32V (v2)) => I32V (v1 * v2)
83+ case (I64V (v1), I64V (v2)) => I64V (v1 * v2)
84+ case _ => throw new Exception (" Invalid types" )
85+ case Sub (_) =>
86+ (lhs, rhs) match
87+ case (I32V (v1), I32V (v2)) => I32V (v1 - v2)
88+ case (I64V (v1), I64V (v2)) => I64V (v1 - v2)
89+ case _ => throw new Exception (" Invalid types" )
90+ case Shl (_) =>
91+ (lhs, rhs) match
92+ case (I32V (v1), I32V (v2)) => I32V (v1 << v2)
93+ case (I64V (v1), I64V (v2)) => I64V (v1 << v2)
94+ case _ => throw new Exception (" Invalid types" )
95+ case ShrU (_) =>
96+ (lhs, rhs) match
97+ case (I32V (v1), I32V (v2)) => I32V (v1 >>> v2)
98+ case (I64V (v1), I64V (v2)) => I64V (v1 >>> v2)
99+ case _ => throw new Exception (" Invalid types" )
100+ case And (_) =>
101+ (lhs, rhs) match
102+ case (I32V (v1), I32V (v2)) => I32V (v1 & v2)
103+ case (I64V (v1), I64V (v2)) => I64V (v1 & v2)
104+ case _ => throw new Exception (" Invalid types" )
105+ case _ => ???
106+
107+ def evalUnaryOp (op : UnaryOp , value : Value ) = op match
108+ case Clz (_) =>
109+ value match
110+ case I32V (v) => I32V (Integer .numberOfLeadingZeros(v))
111+ case I64V (v) => I64V (java.lang.Long .numberOfLeadingZeros(v))
112+ case _ => throw new Exception (" Invalid types" )
113+ case Ctz (_) =>
114+ value match
115+ case I32V (v) => I32V (Integer .numberOfTrailingZeros(v))
116+ case I64V (v) => I64V (java.lang.Long .numberOfTrailingZeros(v))
117+ case _ => throw new Exception (" Invalid types" )
118+ case Popcnt (_) =>
119+ value match
120+ case I32V (v) => I32V (Integer .bitCount(v))
121+ case I64V (v) => I64V (java.lang.Long .bitCount(v))
122+ case _ => throw new Exception (" Invalid types" )
123+ case _ => ???
124+
125+ def evalRelOp (op : RelOp , lhs : Value , rhs : Value ) = op match
126+ case Eq (_) =>
127+ (lhs, rhs) match
128+ case (I32V (v1), I32V (v2)) => I32V (if (v1 == v2) 1 else 0 )
129+ case (I64V (v1), I64V (v2)) => I32V (if (v1 == v2) 1 else 0 )
130+ case _ => throw new Exception (" Invalid types" )
131+ case Ne (_) =>
132+ (lhs, rhs) match
133+ case (I32V (v1), I32V (v2)) => I32V (if (v1 != v2) 1 else 0 )
134+ case (I64V (v1), I64V (v2)) => I32V (if (v1 != v2) 1 else 0 )
135+ case _ => throw new Exception (" Invalid types" )
136+ case LtS (_) =>
137+ (lhs, rhs) match
138+ case (I32V (v1), I32V (v2)) => I32V (if (v1 < v2) 1 else 0 )
139+ case (I64V (v1), I64V (v2)) => I32V (if (v1 < v2) 1 else 0 )
140+ case _ => throw new Exception (" Invalid types" )
141+ case LtU (_) =>
142+ (lhs, rhs) match
143+ case (I32V (v1), I32V (v2)) =>
144+ I32V (if (Integer .compareUnsigned(v1, v2) < 0 ) 1 else 0 )
145+ case (I64V (v1), I64V (v2)) =>
146+ I32V (if (java.lang.Long .compareUnsigned(v1, v2) < 0 ) 1 else 0 )
147+ case _ => throw new Exception (" Invalid types" )
148+ case GtS (_) =>
149+ (lhs, rhs) match
150+ case (I32V (v1), I32V (v2)) => I32V (if (v1 > v2) 1 else 0 )
151+ case (I64V (v1), I64V (v2)) => I32V (if (v1 > v2) 1 else 0 )
152+ case _ => throw new Exception (" Invalid types" )
153+ case GtU (_) =>
154+ (lhs, rhs) match
155+ case (I32V (v1), I32V (v2)) =>
156+ I32V (if (Integer .compareUnsigned(v1, v2) > 0 ) 1 else 0 )
157+ case (I64V (v1), I64V (v2)) =>
158+ I32V (if (java.lang.Long .compareUnsigned(v1, v2) > 0 ) 1 else 0 )
159+ case _ => throw new Exception (" Invalid types" )
160+ case LeS (_) =>
161+ (lhs, rhs) match
162+ case (I32V (v1), I32V (v2)) => I32V (if (v1 <= v2) 1 else 0 )
163+ case (I64V (v1), I64V (v2)) => I32V (if (v1 <= v2) 1 else 0 )
164+ case _ => throw new Exception (" Invalid types" )
165+ case LeU (_) =>
166+ (lhs, rhs) match
167+ case (I32V (v1), I32V (v2)) =>
168+ I32V (if (Integer .compareUnsigned(v1, v2) <= 0 ) 1 else 0 )
169+ case (I64V (v1), I64V (v2)) =>
170+ I32V (if (java.lang.Long .compareUnsigned(v1, v2) <= 0 ) 1 else 0 )
171+ case _ => throw new Exception (" Invalid types" )
172+ case GeS (_) =>
173+ (lhs, rhs) match
174+ case (I32V (v1), I32V (v2)) => I32V (if (v1 >= v2) 1 else 0 )
175+ case (I64V (v1), I64V (v2)) => I32V (if (v1 >= v2) 1 else 0 )
176+ case _ => throw new Exception (" Invalid types" )
177+ case GeU (_) =>
178+ (lhs, rhs) match
179+ case (I32V (v1), I32V (v2)) =>
180+ I32V (if (Integer .compareUnsigned(v1, v2) >= 0 ) 1 else 0 )
181+ case (I64V (v1), I64V (v2)) =>
182+ I32V (if (java.lang.Long .compareUnsigned(v1, v2) >= 0 ) 1 else 0 )
183+ case _ => throw new Exception (" Invalid types" )
184+
185+ def evalTestOp (op : TestOp , value : Value ) = op match
186+ case Eqz (_) =>
187+ value match
188+ case I32V (v) => I32V (if (v == 0 ) 1 else 0 )
189+ case I64V (v) => I32V (if (v == 0 ) 1 else 0 )
190+ case _ => throw new Exception (" Invalid types" )
191+
192+ def memOutOfBound (module : ModuleInstance , memoryIndex : Int , offset : Int , size : Int ) = {
193+ val memory = module.memory(memoryIndex)
194+ offset + size > memory.size
195+ }
196+
197+ def zero (t : ValueType ): Value = t match
198+ case NumType (kind) =>
199+ kind match
200+ case I32Type => I32V (0 )
201+ case I64Type => I64V (0 )
202+ case F32Type => F32V (0 )
203+ case F64Type => F64V (0 )
204+ case VecType (kind) => ???
205+ case RefType (kind) => RefNullV (kind)
206+
207+ def getFuncType (ty : BlockType ): FuncType =
208+ ty match
209+ case VarBlockType (_, None ) => ??? // TODO: fill this branch until we handle type index correctly
210+ case VarBlockType (_, Some (tipe)) => tipe
211+ case ValBlockType (Some (tipe)) => FuncType (List (), List (), List (tipe))
212+ case ValBlockType (None ) => FuncType (List (), List (), List ())
213+
214+ def extractMainInstrs (module : ModuleInstance , main : Option [String ]): List [Instr ] =
215+ main match
216+ case Some (func_name) =>
217+ module.defs.flatMap({
218+ case Export (`func_name`, ExportFunc (fid)) =>
219+ System .err.println(s " Entering function $main" )
220+ module.funcs(fid) match
221+ case FuncDef (_, FuncBodyDef (_, _, locals, body)) => body
222+ case _ => throw new Exception (" Entry function has no concrete body" )
223+ case _ => List ()
224+ })
225+ case None =>
226+ module.defs.flatMap({
227+ case Start (id) =>
228+ System .err.println(s " Entering unnamed function $id" )
229+ module.funcs(id) match
230+ case FuncDef (_, FuncBodyDef (_, _, locals, body)) => body
231+ case _ => throw new Exception (" Entry function has no concrete body" )
232+ case _ => List ()
233+ })
234+
235+ def extractLocals (module : ModuleInstance , main : Option [String ]): List [ValueType ] =
236+ main match
237+ case Some (func_name) =>
238+ module.defs.flatMap({
239+ case Export (`func_name`, ExportFunc (fid)) =>
240+ System .err.println(s " Entering function $main" )
241+ module.funcs(fid) match
242+ case FuncDef (_, FuncBodyDef (_, _, locals, _)) => locals
243+ case _ => throw new Exception (" Entry function has no concrete body" )
244+ case _ => List ()
245+ })
246+ case None =>
247+ module.defs.flatMap({
248+ case Start (id) =>
249+ System .err.println(s " Entering unnamed function $id" )
250+ module.funcs(id) match
251+ case FuncDef (_, FuncBodyDef (_, _, locals, body)) => locals
252+ case _ => throw new Exception (" Entry function has no concrete body" )
253+ case _ => List ()
254+ })
0 commit comments