From 592de50daba9d2d975b25f33081f627d8542d378 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Tue, 22 Oct 2024 22:18:57 -0400 Subject: [PATCH 1/2] Elide more type checks --- compiler.py | 54 +++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 42 insertions(+), 12 deletions(-) diff --git a/compiler.py b/compiler.py index fab688b5..c4b1dd14 100644 --- a/compiler.py +++ b/compiler.py @@ -30,6 +30,9 @@ type_of, IntType, StringType, + TyEmptyRow, + TyRow, + row_flatten, parse, # needed for /compilerepl tokenize, # needed for /compilerepl ) @@ -131,14 +134,31 @@ def _guard(self, cond: str, msg: Optional[str] = None) -> None: self._emit("abort();") self._emit("}") + def _is_int(self, exp: Object) -> bool: + return type_of(exp) == IntType + + def _is_list(self, exp: Object) -> bool: + ty = type_of(exp) + return isinstance(ty, TyCon) and ty.name == "list" + + def _is_record(self, exp: Object) -> bool: + return isinstance(type_of(exp), TyRow) or isinstance(type_of(exp), TyEmptyRow) + def _guard_int(self, exp: Object, c_name: str) -> None: - if type_of(exp) != IntType: + if not self._is_int(exp): self._guard(f"is_num({c_name})") def _guard_str(self, exp: Object, c_name: str) -> None: if type_of(exp) != StringType: self._guard(f"is_string({c_name})") + def _guaranteed_has_field(self, exp: Object, name: str) -> bool: + ty = type_of(exp) + if not isinstance(ty, TyRow): + return False + fields, _ = row_flatten(ty) + return name in fields + def _mktemp(self, exp: str) -> str: temp = self.gensym() return self._handle(temp, exp) @@ -193,10 +213,12 @@ def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> En # TODO(max): Give `arg` an AST node so we can track its inferred type # and make use of that in pattern matching if isinstance(pattern, Int): - self._emit(f"if (!is_num_equal_word({arg}, {pattern.value})) {{ goto {fallthrough}; }}") + if not self._is_int(pattern): + self._emit(f"if (!is_num_equal_word({arg}, {pattern.value})) {{ goto {fallthrough}; }}") return {} if isinstance(pattern, Hole): - self._emit(f"if (!is_hole({arg})) {{ goto {fallthrough}; }}") + if not self._is_hole(pattern): + self._emit(f"if (!is_hole({arg})) {{ goto {fallthrough}; }}") return {} if isinstance(pattern, Variant): self.variant_tag(pattern.tag) # register it for the big enum @@ -205,7 +227,8 @@ def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> En # necessary; the non-Hole case would work just fine. self._emit(f"if ({arg} != mk_immediate_variant(Tag_{pattern.tag})) {{ goto {fallthrough}; }}") return {} - self._emit(f"if (!is_variant({arg})) {{ goto {fallthrough}; }}") + if not self._is_variant(pattern): + self._emit(f"if (!is_variant({arg})) {{ goto {fallthrough}; }}") self._emit(f"if (variant_tag({arg}) != Tag_{pattern.tag}) {{ goto {fallthrough}; }}") return self.try_match(env, self._mktemp(f"variant_value({arg})"), pattern.value, fallthrough) @@ -214,7 +237,8 @@ def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> En if len(value) < 8: self._emit(f"if ({arg} != mksmallstring({json.dumps(value)}, {len(value)})) {{ goto {fallthrough}; }}") return {} - self._emit(f"if (!is_string({arg})) {{ goto {fallthrough}; }}") + if not self._is_string(pattern): + self._emit(f"if (!is_string({arg})) {{ goto {fallthrough}; }}") self._emit( f"if (!string_equal_cstr_len({arg}, {json.dumps(value)}, {len(value)})) {{ goto {fallthrough}; }}" ) @@ -222,7 +246,8 @@ def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> En if isinstance(pattern, Var): return {pattern.name: arg} if isinstance(pattern, List): - self._emit(f"if (!is_list({arg})) {{ goto {fallthrough}; }}") + if not self._is_list(pattern): + self._emit(f"if (!is_list({arg})) {{ goto {fallthrough}; }}") updates = {} the_list = arg use_spread = False @@ -242,7 +267,8 @@ def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> En self._emit(f"if (!is_empty_list({the_list})) {{ goto {fallthrough}; }}") return updates if isinstance(pattern, Record): - self._emit(f"if (!is_record({arg})) {{ goto {fallthrough}; }}") + if not self._is_record(pattern): + self._emit(f"if (!is_record({arg})) {{ goto {fallthrough}; }}") updates = {} use_spread = False for key, pattern_value in pattern.data.items(): @@ -253,9 +279,11 @@ def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> En break key_idx = self.record_key(key) record_value = self._mktemp(f"record_get({arg}, {key_idx})") - # TODO(max): If the key is present in the type, don't emit this - # check - self._emit(f"if ({record_value} == NULL) {{ goto {fallthrough}; }}") + # TODO(max): Figure out another way to do this. It's a bit of a + # hack to check the pattern type *even though* it's supposed to + # be unified with the arg type + if not self._guaranteed_has_field(pattern, key): + self._emit(f"if ({record_value} == NULL) {{ goto {fallthrough}; }}") updates.update(self.try_match(env, record_value, pattern_value, fallthrough)) if not use_spread: self._emit(f"if (record_num_fields({arg}) != {len(pattern.data)}) {{ goto {fallthrough}; }}") @@ -439,9 +467,11 @@ def compile(self, env: Env, exp: Object) -> str: record = self.compile(env, exp.obj) key_idx = self.record_key(exp.at.name) # Check if the record is a record - self._guard(f"is_record({record})", "not a record") + if not self._is_record(exp.obj): + self._guard(f"is_record({record})", "not a record") value = self._mktemp(f"record_get({record}, {key_idx})") - self._guard(f"{value} != NULL", f"missing key {exp.at.name!s}") + if not self._guaranteed_has_field(exp.obj, exp.at.name): + self._guard(f"{value} != NULL", f"missing key {exp.at.name!s}") return value if isinstance(exp, Function): # Anonymous function From e118afdb2b9edb8b349fc957649f78d16969e2ee Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Tue, 22 Oct 2024 22:28:16 -0400 Subject: [PATCH 2/2] . --- compiler.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/compiler.py b/compiler.py index c4b1dd14..0a9f61c2 100644 --- a/compiler.py +++ b/compiler.py @@ -32,6 +32,7 @@ StringType, TyEmptyRow, TyRow, + TyCon, row_flatten, parse, # needed for /compilerepl tokenize, # needed for /compilerepl @@ -141,6 +142,10 @@ def _is_list(self, exp: Object) -> bool: ty = type_of(exp) return isinstance(ty, TyCon) and ty.name == "list" + def _is_hole(self, exp: Object) -> bool: + ty = type_of(exp) + return isinstance(ty, TyCon) and ty.name == "hole" + def _is_record(self, exp: Object) -> bool: return isinstance(type_of(exp), TyRow) or isinstance(type_of(exp), TyEmptyRow) @@ -227,8 +232,8 @@ def try_match(self, env: Env, arg: str, pattern: Object, fallthrough: str) -> En # necessary; the non-Hole case would work just fine. self._emit(f"if ({arg} != mk_immediate_variant(Tag_{pattern.tag})) {{ goto {fallthrough}; }}") return {} - if not self._is_variant(pattern): - self._emit(f"if (!is_variant({arg})) {{ goto {fallthrough}; }}") + # TODO(max): Check if it's a variant + self._emit(f"if (!is_variant({arg})) {{ goto {fallthrough}; }}") self._emit(f"if (variant_tag({arg}) != Tag_{pattern.tag}) {{ goto {fallthrough}; }}") return self.try_match(env, self._mktemp(f"variant_value({arg})"), pattern.value, fallthrough)