diff --git a/pyproject.toml b/pyproject.toml index bad7dee..8783671 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,8 @@ dependencies = [ "rstr", "exrex", "joblib", + #"automata-lib", + "automata-lib @ git+https://github.com/StefanosChaliasos/automata.git@add-support-for-charclass" ] [project.optional-dependencies] @@ -33,9 +35,8 @@ dev = [ [tool.ruff] line-length = 88 -target-version = "py38" -lint.select = ["E", "F", "W"] -lint.extend-select = ["I"] +target-version = "py312" +lint.select = ["E", "F", "W", "I"] lint.ignore = ["F401", "E501"] [tool.ruff.format] diff --git a/scripts/lint_and_tests.py b/scripts/lint_and_tests.py index 7154586..6e86cf9 100644 --- a/scripts/lint_and_tests.py +++ b/scripts/lint_and_tests.py @@ -23,7 +23,7 @@ def run_formatter() -> int: def run_tests() -> int: """Placeholder for future test execution.""" - print("Running Tests (Placeholder)...") + print("Running Tests...") return run_command("pytest") diff --git a/src/zkregex_fuzzer/chars.py b/src/zkregex_fuzzer/chars.py new file mode 100644 index 0000000..fe097a8 --- /dev/null +++ b/src/zkregex_fuzzer/chars.py @@ -0,0 +1,15 @@ +import string + + +def create_range(start_char: str, end_char: str) -> set[str]: + """ + Create a set of characters from start_char to end_char. + """ + return {chr(i) for i in range(ord(start_char), ord(end_char) + 1)} + + +LATIN_EXT_CHARS = create_range("¡", "ƿ") +GREEK_CHARS = create_range("Ͱ", "Ͽ") +CYRILLIC_CHARS = create_range("Ѐ", "ӿ") +ASCII_CHARS = set(string.printable) +ALL_CHARS = ASCII_CHARS.union(LATIN_EXT_CHARS).union(GREEK_CHARS).union(CYRILLIC_CHARS) diff --git a/src/zkregex_fuzzer/cli.py b/src/zkregex_fuzzer/cli.py index 6cdc66c..966e4c6 100644 --- a/src/zkregex_fuzzer/cli.py +++ b/src/zkregex_fuzzer/cli.py @@ -9,7 +9,7 @@ from pathlib import Path from zkregex_fuzzer.configs import GENERATORS, TARGETS, VALID_INPUT_GENERATORS -from zkregex_fuzzer.fuzzer import fuzz_with_database, fuzz_with_grammar +from zkregex_fuzzer.fuzzer import fuzz_with_database, fuzz_with_dfa, fuzz_with_grammar from zkregex_fuzzer.grammar import REGEX_GRAMMAR from zkregex_fuzzer.harness import HarnessStatus from zkregex_fuzzer.logger import logger @@ -240,6 +240,14 @@ def do_fuzz(args): inputs_num=args.inputs_num, kwargs=kwargs, ) + elif args.fuzzer == "dfa": + fuzz_with_dfa( + target_implementation=args.target, + oracle_params=(args.oracle == "valid", args.valid_input_generator), + regex_num=args.regex_num, + inputs_num=args.inputs_num, + kwargs=kwargs, + ) def do_reproduce(args): diff --git a/src/zkregex_fuzzer/configs.py b/src/zkregex_fuzzer/configs.py index e38adb6..bf9e27a 100644 --- a/src/zkregex_fuzzer/configs.py +++ b/src/zkregex_fuzzer/configs.py @@ -1,7 +1,16 @@ from zkregex_fuzzer.grammar import REGEX_GRAMMAR -from zkregex_fuzzer.regexgen import DatabaseRegexGenerator, GrammarRegexGenerator +from zkregex_fuzzer.regexgen import ( + DatabaseRegexGenerator, + DFARegexGenerator, + GrammarRegexGenerator, +) from zkregex_fuzzer.runner import CircomRunner, NoirRunner, PythonReRunner -from zkregex_fuzzer.vinpgen import ExrexGenerator, GrammarBasedGenerator, RstrGenerator +from zkregex_fuzzer.vinpgen import ( + DFAWalkerGenerator, + ExrexGenerator, + GrammarBasedGenerator, + RstrGenerator, +) TARGETS = { "circom": CircomRunner, @@ -17,9 +26,11 @@ "grammar": GrammarBasedGenerator, "rstr": RstrGenerator, "exrex": ExrexGenerator, + "dfa": DFAWalkerGenerator, } GENERATORS = { "grammar": GrammarRegexGenerator, "database": DatabaseRegexGenerator, + "dfa": DFARegexGenerator, } diff --git a/src/zkregex_fuzzer/dfa.py b/src/zkregex_fuzzer/dfa.py new file mode 100644 index 0000000..be7559d --- /dev/null +++ b/src/zkregex_fuzzer/dfa.py @@ -0,0 +1,497 @@ +""" +dfa + +A number of functions for working with DFAs. +""" + +import random +import re +import string +from typing import Dict, Optional, Set + +from automata.fa.dfa import DFA +from automata.fa.gnfa import GNFA +from automata.fa.nfa import NFA + +from zkregex_fuzzer.chars import ASCII_CHARS + + +def get_supported_symbols() -> set[str]: + """ + Get the set of symbols that are supported by the regex engine. + """ + # TODO make this configurable + # Symbols should include at least all ASCII characters + return ASCII_CHARS + + +def regex_to_dfa(regex: str) -> DFA: + """ + Convert a regex to a DFA. + """ + symbols = get_supported_symbols() + regex = unwrap_regex(regex) + + try: + nfa = NFA.from_regex(regex, input_symbols=symbols) + except Exception as e: + raise ValueError(f"Failed to parse '{regex}' into an automaton: {e}") + try: + return DFA.from_nfa(nfa, minify=True) + except Exception as e: + raise ValueError(f"Failed to convert NFA to DFA: {e}") + + +def has_multiple_accepting_states_regex(regex: str) -> bool: + """ + Returns True if converting the given regex to a DFA yields + multiple accepting (final) states. Returns False otherwise. + + NOTE: + - Only handles a subset of regex syntax recognized by automata-lib. + - For advanced Python regex features, a custom NFA builder is needed. + """ + dfa = regex_to_dfa(regex) + num_final_states = len(dfa.final_states) + + return num_final_states > 1 + + +def has_one_accepting_state_regex(regex: str) -> bool: + """ + Returns True if converting the given regex to a DFA yields + exactly one accepting (final) state. Returns False otherwise. + """ + dfa = regex_to_dfa(regex) + return len(dfa.final_states) == 1 + + +def unwrap_regex(regex: str) -> str: + """ + Unwrap a regex by removing the start and end anchors. + """ + if regex.startswith("^"): + regex = regex[1:] + # There are also some more cases with "starting" "^" + elif regex.startswith("(|^)"): + regex = regex[4:] + # Cases like '(\r\n|^)...', '(\r|^)...', '(\n|^)...' + elif bool(re.match(r"^\([\r\n]*\|\^\).*", regex)): + regex = regex[regex.find("^") + 2 :] + elif bool(re.match(r"^\([\\r\\n]*\|\^\).*", regex)): + regex = regex[regex.find("^") + 2 :] + if regex.endswith("$"): + regex = regex[:-1] + return regex.replace("\n", r"\n").replace("\r", r"\r") + + +def wrapped_has_one_accepting_state_regex(regex: str) -> bool: + """ + Returns True if converting the given regex to a DFA yields + exactly one accepting (final) state. Returns False otherwise. + + NOTE: + - As the automata-lib does not support starting with '^' and ending with '$', + we just remove them from the regex and check if the rest of the regex has one accepting state. + """ + return has_one_accepting_state_regex(unwrap_regex(regex)) + + +def has_multiple_accepting_states_dfa(dfa: DFA) -> bool: + """ + Returns True if the given DFA has multiple accepting (final) states. + Returns False otherwise. + """ + return len(dfa.final_states) > 1 + + +def transform_dfa_to_regex(dfa: DFA) -> str: + """ + Convert a DFA to a regular expression. + """ + # Convert the DFA to an equivalent GNFA + gnfa = GNFA.from_dfa(dfa) + # Use state elimination to get a regular expression + regex = gnfa.to_regex() + return regex + + +def _get_alphabet( + use_unicode: bool, num_states: int, min_size: int = 2, max_size: int = 10 +) -> Set[str]: + """ + Generate a random alphabet for a DFA. + """ + alphabet_size = random.randint(min_size, max_size) + if use_unicode: + alphabet = set() + while len(alphabet) < alphabet_size: + codepoint = random.randint(0, 0x10FFFF) + try: + char = chr(codepoint) + except ValueError: + continue # skip invalid code points (if any) + alphabet.add(char) + else: + # Restricted character set: letters, digits, punctuation, whitespace + allowed_pool = ( + string.ascii_letters + + string.digits + + string.punctuation + + string.whitespace + ) + alphabet = set(random.sample(allowed_pool, alphabet_size)) + return alphabet + + +def generate_random_dfa( + max_depth: int = 5, + use_unicode: bool = False, + single_final_state: bool = False, +) -> DFA: + """ + Generate a random DFA with a given seed for reproducibility. + + Randomly incorporates regex features like character classes, repetition, + and fixed string prefixes/suffixes. + + Parameters: + max_depth: Maximum number of states in the DFA + use_unicode: Whether to use Unicode characters in the alphabet + single_final_state: Whether to generate a DFA with exactly one final state + + TODO: + - Add regex features + - Add support for more complex regex features + - Add support for more complex DFA structures + """ + # Original implementation for generating a DFA directly + num_states = random.randint(1, max_depth) + + # Define state names (q0, q1, ..., qN) and the initial state + states = {f"q{i}" for i in range(num_states)} + initial_state = "q0" + + # Determine final state(s) + if single_final_state: + final_state = random.choice(list(states)) + final_states = {final_state} + else: + # One or more final states (randomly chosen subset of states) + num_finals = random.randint(1, num_states) # at least one final + final_states = set(random.sample(list(states), num_finals)) + + alphabet = _get_alphabet(use_unicode, num_states) + + # Construct transitions: for each state and each symbol, choose a random next state + transitions: Dict[str, Dict[str, str]] = {} + for state in states: + transitions[state] = {} + for sym in alphabet: + transitions[state][sym] = random.choice(list(states)) + + # Ensure at least one self-loop (cycle) + loop_exists = any( + state == dest for state in states for dest in transitions[state].values() + ) + if not loop_exists: + # Add a self-loop on a random state with a random symbol + some_state = random.choice(list(states)) + some_symbol = random.choice(list(alphabet)) + transitions[some_state][some_symbol] = some_state + + # Ensure at least one branching point (one state with two different outgoing targets) + if len(alphabet) >= 2: + branching_exists = any(len(set(transitions[s].values())) >= 2 for s in states) + if not branching_exists: + # Force branching on the initial state (as an example) + sym_list = list(alphabet) + # Make sure we have at least two symbols to create a branch + if len(sym_list) >= 2: + sym1, sym2 = sym_list[0], sym_list[1] + # Assign different targets for sym1 and sym2 from the initial state + if transitions[initial_state][sym1] == transitions[initial_state][sym2]: + # Pick a different state for sym2 if both symbols currently go to the same target + possible_targets = list(states - {transitions[initial_state][sym1]}) + if possible_targets: + transitions[initial_state][sym2] = random.choice( + possible_targets + ) + # (If no possible_targets, it means only one state exists, handled by loop above) + + # Introduce an "optional" path (allow skipping or taking a symbol): + # We do this by creating an alternate route to a final state. + if single_final_state and len(states) > 1: + # For a single final state, ensure multiple paths (direct & indirect) to it + final_state = next(iter(final_states)) # the only final state + # If initial state doesn't already have a direct transition to final, add one + if final_state not in transitions[initial_state].values(): + sym = random.choice(list(alphabet)) + transitions[initial_state][sym] = final_state + # Also ensure an indirect path: find a symbol from initial that goes to an intermediate state + intermediate_symbols = [ + sym + for sym, dest in transitions[initial_state].items() + if dest != final_state + ] + if intermediate_symbols: + sym = intermediate_symbols[0] + intermediate_state = transitions[initial_state][sym] + # Link the intermediate state to the final state on some symbol (if not already final) + if intermediate_state != final_state: + sym2 = random.choice(list(alphabet)) + transitions[intermediate_state][sym2] = final_state + elif not single_final_state: + # If multiple finals are allowed, we can treat the start state as an optional accepting state + # (Accept empty string or early termination) + if initial_state not in final_states: + final_states.add(initial_state) + + # Construct the DFA with the generated components + dfa = DFA( + states=states, + input_symbols=alphabet, + transitions=transitions, + initial_state=initial_state, + final_states=final_states, + ) + + # Minimize the DFA for a simpler equivalent automaton + try: + # If automata-lib provides a direct minification method + dfa = dfa.minify() + except AttributeError: + # Fallback: convert to NFA and use DFA.from_nfa for minimization + nfa_transitions: Dict[str, Dict[str, Set[str]]] = {} + for state, trans in transitions.items(): + # Each DFA transition becomes a singleton set in the NFA transition + nfa_transitions[state] = {sym: {dest} for sym, dest in trans.items()} + nfa = NFA( + states=states, + input_symbols=alphabet, + transitions=nfa_transitions, + initial_state=initial_state, + final_states=final_states, + ) + # Convert NFA to DFA with minimization + dfa = DFA.from_nfa(nfa, minify=True) + + return dfa + + +def transform_dfa_to_single_final_state(dfa: DFA) -> DFA: + """ + Convert a DFA with multiple final states to one with a single final state. + + This implementation follows a principled automata theory approach: + 1. Add a new final state + 2. Redirect transitions from original final states to this new state + 3. Make the new final state the only accepting state + 4. Ensure the DFA is complete + + Returns: + A new DFA with exactly one final state + """ + # If the DFA already has a single final state, return it as-is + if len(dfa.final_states) == 1: + return dfa + + # Create mutable copies of the DFA's components + states = set(dfa.states) + alphabet = set(dfa.input_symbols) + transitions = {} + for state in states: + transitions[state] = {} + for symbol in alphabet: + if state in dfa.transitions and symbol in dfa.transitions[state]: + transitions[state][symbol] = dfa.transitions[state][symbol] + + initial_state = dfa.initial_state + original_finals = set(dfa.final_states) + + # Step 1: Add a new single final state + new_final = max(states) + 1 + states.add(new_final) + transitions[new_final] = {} + + # Step 2: Redirect transitions from all existing final states to the new final state + for final_state in original_finals: + for symbol in alphabet: + if symbol in transitions[final_state]: + transitions[final_state][symbol] = new_final + if len(transitions[final_state]) == 0: + transitions[final_state][list(alphabet)[0]] = new_final + + # Step 4: Create the transformed DFA with single final state + new_dfa = DFA( + states=states, + input_symbols=alphabet, + transitions=transitions, + initial_state=initial_state, + final_states={new_final}, + allow_partial=True, + ) + # Step 5: Minimize the DFA to merge equivalent states + # The automata-lib library has a built-in minify method + try: + minimized_dfa = new_dfa.minify() + # check if we can transform the minimized dfa to a regex + regex = transform_dfa_to_regex(minimized_dfa) + if not regex: + raise Exception("Failed to transform minimized DFA to regex") + return minimized_dfa + except Exception as e: + raise Exception(f"DFA minimization failed: {e}") + + +def dfa_string_matching( + regex: str, + wanted_length: int = 50, + direct_match: bool = True, +) -> str: + """ + Convert `regex` to a DFA using automata-lib, then randomly generate a string + that the DFA accepts. Returns a string that the DFA accepts. + + Parameters: + regex: The regular expression to match + wanted_length: The desired length of the generated string + direct_match: If True, only follow paths that lead to accepting states + """ + regex = unwrap_regex(regex) + # Some hard limited length that we can't exceed + # TODO make this configurable + max_length = 500 + # Convert regex to NFA + nfa = NFA.from_regex(regex, input_symbols=get_supported_symbols()) + + # Start with the initial state and an empty string + current_states = nfa._get_lambda_closures()[nfa.initial_state] + result = "" + + # If we start in a final state and regex allows empty string, we might return empty + if not current_states.isdisjoint(nfa.final_states) and random.random() < 0.2: + return "" + + # If direct_match is True, precompute which states can reach a final state + reachable_to_final = None + if direct_match: + # Compute states that can reach a final state (reverse BFS) + reachable_to_final = set() + queue = list(nfa.final_states) + visited = set(queue) + + # Build reverse transition graph + reverse_transitions = {} + for state in nfa.states: + reverse_transitions[state] = [] + + for state in nfa.states: + if state in nfa.transitions: + for symbol, next_states in nfa.transitions[state].items(): + for next_state in next_states: + reverse_transitions[next_state].append((state, symbol)) + + # Do BFS from final states + while queue: + state = queue.pop(0) + reachable_to_final.add(state) + + for prev_state, _ in reverse_transitions[state]: + if prev_state not in visited: + visited.add(prev_state) + queue.append(prev_state) + + # Maximum number of attempts to find an accepting path + max_attempts = 5 + for attempt in range(max_attempts): + current_states = nfa._get_lambda_closures()[nfa.initial_state] + result = "" + + # Try to build a matching string by traversing the NFA + for _ in range(max_length): + # Get all possible transitions from current states + possible_moves = [] + for state in current_states: + if state in nfa.transitions: + for symbol, next_states in nfa.transitions[state].items(): + if symbol: # Skip lambda transitions + for next_state in next_states: + # If direct_match is True, only consider moves that can reach a final state + if not direct_match or next_state in reachable_to_final: + possible_moves.append((symbol, next_state)) + + # No more possible moves + if not possible_moves: + break + + # Choose moves with a bias toward making progress + # For longer patterns, we want to avoid getting stuck in loops + if len(possible_moves) > 1 and len(result) > wanted_length * 0.7: + # In later stages, prioritize moves that might lead to acceptance faster + # We'll do this by favoring transitions to states closer to final states + + # Group possible moves by their target state + moves_by_state = {} + for symbol, next_state in possible_moves: + if next_state not in moves_by_state: + moves_by_state[next_state] = [] + moves_by_state[next_state].append(symbol) + + # If we're in a state we've seen before, try to avoid it + # Convert states to string representation for hashing + current_state_str = "".join(str(s) for s in sorted(current_states)) + if hasattr(dfa_string_matching, "seen_states"): + if current_state_str in dfa_string_matching.seen_states: + # Try to choose a different path than before + dfa_string_matching.seen_states[current_state_str] += 1 + else: + dfa_string_matching.seen_states[current_state_str] = 1 + else: + dfa_string_matching.seen_states = {current_state_str: 1} + + # Bias towards less-visited transitions + weights = [] + for symbol, next_state in possible_moves: + next_state_str = "".join( + str(s) for s in sorted(nfa._get_lambda_closures()[next_state]) + ) + visits = dfa_string_matching.seen_states.get(next_state_str, 0) + # Weight inversely to number of visits (add 1 to avoid division by zero) + weights.append(1.0 / (visits + 1)) + + # Normalize weights + total = sum(weights) + if total > 0: + weights = [w / total for w in weights] + symbol, next_state = random.choices( + possible_moves, weights=weights, k=1 + )[0] + else: + symbol, next_state = random.choice(possible_moves) + else: + # Standard random choice for early parts of the pattern + symbol, next_state = random.choice(possible_moves) + + result += symbol + + # Update current states with the chosen move and its lambda closure + current_states = nfa._get_lambda_closures()[next_state] + + # If we're in a final state, we might choose to stop + if not current_states.isdisjoint(nfa.final_states): + if random.random() < 0.3: + break + # If we have reached the wanted length, we're more likely to stop + if len(result) >= wanted_length and random.random() < 0.9: + break + + # Check if our string is accepted by the NFA + if nfa.accepts_input(result): + return result + + # If we failed, we'll try again with a clean slate + if hasattr(dfa_string_matching, "seen_states"): + delattr(dfa_string_matching, "seen_states") + + raise ValueError(f"Failed to generate a string that the NFA accepts: {regex}") diff --git a/src/zkregex_fuzzer/fuzzer.py b/src/zkregex_fuzzer/fuzzer.py index cc0d4a4..ba442d4 100644 --- a/src/zkregex_fuzzer/fuzzer.py +++ b/src/zkregex_fuzzer/fuzzer.py @@ -8,7 +8,11 @@ from zkregex_fuzzer.configs import GRAMMARS, TARGETS, VALID_INPUT_GENERATORS from zkregex_fuzzer.harness import HarnessResult, HarnessStatus, harness from zkregex_fuzzer.logger import dynamic_filter, logger -from zkregex_fuzzer.regexgen import DatabaseRegexGenerator, GrammarRegexGenerator +from zkregex_fuzzer.regexgen import ( + DatabaseRegexGenerator, + DFARegexGenerator, + GrammarRegexGenerator, +) from zkregex_fuzzer.runner import PythonReRunner from zkregex_fuzzer.runner.base_runner import Runner from zkregex_fuzzer.transformers import regex_to_grammar @@ -56,6 +60,25 @@ def fuzz_with_database( fuzz_with_regexes(regexes, inputs_num, target_runner, oracle_params, kwargs) +def fuzz_with_dfa( + target_implementation: str, + oracle_params: tuple[bool, str], + regex_num: int, + inputs_num: int, + kwargs: dict, +): + """ + Fuzz test with DFA. + """ + target_runner = TARGETS[target_implementation] + + regex_generator = DFARegexGenerator() + regexes = regex_generator.generate_many(regex_num) + logger.info(f"Generated {len(regexes)} regexes.") + + fuzz_with_regexes(regexes, inputs_num, target_runner, oracle_params, kwargs) + + def fuzz_with_regexes( regexes: list[str], inputs_num: int, diff --git a/src/zkregex_fuzzer/regexgen.py b/src/zkregex_fuzzer/regexgen.py index f4c8394..b5c1161 100644 --- a/src/zkregex_fuzzer/regexgen.py +++ b/src/zkregex_fuzzer/regexgen.py @@ -22,6 +22,11 @@ from fuzzingbook.Grammars import Grammar +from zkregex_fuzzer.dfa import ( + generate_random_dfa, + regex_to_dfa, + transform_dfa_to_regex, +) from zkregex_fuzzer.logger import logger from zkregex_fuzzer.utils import ( check_zkregex_rules_basic, @@ -50,7 +55,9 @@ def generate(self) -> str: regex = self.generate_unsafe() if not is_valid_regex(regex): continue - if not check_zkregex_rules_basic(regex): + correct, accepting_state_check = check_zkregex_rules_basic(regex) + if not correct: + # TODO: We should try to fix the regex if it has multiple accepting states continue logger.debug(f"Generated regex: {regex}") return regex @@ -144,3 +151,33 @@ def generate_many(self, num): break return result + + +class DFARegexGenerator(RegexGenerator): + """ + Generate regexes using a DFA. + """ + + def __init__( + self, + max_depth: int = 5, + use_unicode: bool = False, + single_final_state: bool = True, + ): + self.max_depth = max_depth + self.use_unicode = use_unicode + self.single_final_state = single_final_state + + def generate_unsafe(self) -> str: + """ + Generate a regex using a DFA. + """ + while True: + try: + dfa = generate_random_dfa( + self.max_depth, self.use_unicode, self.single_final_state + ) + return transform_dfa_to_regex(dfa) + except Exception as e: + logger.debug(f"Error generating DFA: {e}") + continue diff --git a/src/zkregex_fuzzer/utils.py b/src/zkregex_fuzzer/utils.py index 90b56c0..01fbdb1 100644 --- a/src/zkregex_fuzzer/utils.py +++ b/src/zkregex_fuzzer/utils.py @@ -8,6 +8,8 @@ from fuzzingbook.Grammars import Grammar, simple_grammar_fuzzer +from zkregex_fuzzer.dfa import wrapped_has_one_accepting_state_regex + def is_valid_regex(regex: str) -> bool: """ @@ -20,65 +22,90 @@ def is_valid_regex(regex: str) -> bool: return False -def check_zkregex_rules_basic(regex: str) -> bool: +def has_lazy_quantifier(pattern: str) -> bool: """ - Check partial zk-regex constraints with a text-based approach: - 1) Must end with '$' - 2) If '^' is present, it is either at index 0 or in substring '(|^)' - 3) No lazy quantifiers like '*?' or '+?' or '??' or '{m,n}?' - Returns True if all checks pass, False otherwise. - - TODO: DFA Checks -- code that actually compiles the regex to an automaton and verifies: - - No loop from initial state back to itself (i.e. no .*-like or equivalent) - - Only one accepting state + Returns True if `pattern` contains any lazy quantifiers (i.e., *?, +?, ??, or {m,n}?), + False otherwise. + + This is a naive textual check and doesn't handle escaping inside character classes or + more advanced regex syntax. For most simple usage, however, it suffices. """ + # Regex to search for the typical lazy quantifier patterns: + # *? +? ?? {m,n}? + # We'll assume m,n are simple digit sets, e.g. {2,5} + lazy_check = re.compile(r"(\*\?)|(\+\?)|(\?\?)|\{\d+(,\d+)?\}\?") - # 1) Must end with '$' (if it present) - if "$" in regex and not regex.endswith("$"): - return False + match = lazy_check.search(pattern) + return bool(match) - # 2) '^' must be at start or in '(|^)' - # We'll allow no '^' at all. If it appears, check positions. - # We'll define a function to find all occurrences of '^'. - allowed_positions = set() - # If the string starts with '^', that’s allowed - if len(regex) > 0 and regex[0] == "^": - allowed_positions.add(0) - - # If the string contains '|^', that means '^' is at position (idx+1) - idx = 0 - while True: - idx = regex.find("|^", idx) - if idx == -1: - break - # '^' occurs at (idx + 1) - allowed_positions.add(idx + 1) - idx += 2 # skip past - - # If the string contains '[^]', that means '^' is at position (idx+1) - idx = 0 - while True: - idx = regex.find("[^", idx) - if idx == -1: - break - # '^' occurs at (idx + 1) - allowed_positions.add(idx + 1) - idx += 2 # skip past - - # Now see if there's any '^' outside those allowed positions - for match in re.finditer(r"\^", regex): - pos = match.start() - if pos not in allowed_positions: + +def correct_carret_position(regex: str) -> bool: + """ + Correct positions are: + - At the start of the regex + - In a capturing group that is at the start of the regex + - In a negated character class + Returns True if the '^' is in the correct position, False otherwise. + + This is a naive textual check and doesn't handle escaping inside character classes or + more advanced regex syntax. For most simple usage, however, it suffices. + """ + # Find all occurrences of '^' that are not escaped + caret_positions = [match.start() for match in re.finditer(r"(? 1: + continue + # Let's check if the '^' is in a group that is at the start of the regex + # and before '^' there is a '|' and before '|' there is either nothing or \r or \n until + # the beginning of the group + if ( + regex[pos - 1] == "|" + and regex[pos + 1] == ")" + and regex[0] == "(" + and bool(re.match(r"^\s*", regex[1 : pos - 1])) + ): + status = True + continue + # Let's check if the '^' is in a negated character class + if regex[pos - 1] == "[": + status = True + continue + if status is False: return False + return status - # 3) Check no lazy quantifiers like *?, +?, ??, or {m,n}? - # We do a simple regex search for them: - # Patterns we search for: (*?), (+?), (??), ({\d+(,\d+)?}\?) - lazy_pattern = re.compile(r"(\*\?|\+\?|\?\?|\{\d+(,\d+)?\}\?)") - if lazy_pattern.search(regex): - return False - return True +def check_zkregex_rules_basic(regex: str) -> tuple[bool, bool]: + """ + Check partial zk-regex constraints with a text-based approach: + 1) If '^' is present, it is either at index 0 or in substring '(|^)' or in (\r\n|^) or in substring '[^...]' + 2) No lazy quantifiers like '*?' or '+?' or '??' or '{m,n}?' + 3) Check that the regex has exactly one accepting state + Returns True if all checks pass, False otherwise. Also return the status of the accepting state check. + Returns (True, True) if all checks pass, (False, True) if the regex is invalid, (False, False) if the regex has multiple accepting states. + """ + # 1) If '^' is present, it is either at index 0 or in substring '(|^)' or in (\r\n|^) or in substring '[^...]' + if not correct_carret_position(regex): + return False, True # we return True as we haven't performed the DFA check + + # 2) Check no lazy quantifiers like *?, +?, ??, or {m,n}? + if has_lazy_quantifier(regex): + return False, True # we return True as we haven't performed the DFA check + + # 3) Check that the regex has exactly one accepting state + if not wrapped_has_one_accepting_state_regex(regex): + return False, False + + return True, True def check_if_string_is_valid(regex: str, string: str) -> bool: diff --git a/src/zkregex_fuzzer/vinpgen.py b/src/zkregex_fuzzer/vinpgen.py index 296e9b3..f9ff826 100644 --- a/src/zkregex_fuzzer/vinpgen.py +++ b/src/zkregex_fuzzer/vinpgen.py @@ -14,6 +14,7 @@ import exrex import rstr +from zkregex_fuzzer.dfa import dfa_string_matching from zkregex_fuzzer.logger import logger from zkregex_fuzzer.transformers import regex_to_grammar from zkregex_fuzzer.utils import check_if_string_is_valid, grammar_fuzzer, pretty_regex @@ -130,3 +131,16 @@ def __init__(self, regex: str): def generate_unsafe(self) -> str: return exrex.getone(self.regex) + + +class DFAWalkerGenerator(ValidInputGenerator): + """ + Generate valid inputs for a regex using a DFA walker. + """ + + def __init__(self, regex: str): + super().__init__(regex) + + def generate_unsafe(self) -> str: + inp = dfa_string_matching(self.regex) + return inp diff --git a/tests/test_dfa.py b/tests/test_dfa.py new file mode 100644 index 0000000..39c3809 --- /dev/null +++ b/tests/test_dfa.py @@ -0,0 +1,133 @@ +# ruff: noqa: I001 +from automata.regex.regex import isequal +from zkregex_fuzzer.dfa import ( + dfa_string_matching, + generate_random_dfa, + has_multiple_accepting_states_regex, + regex_to_dfa, + transform_dfa_to_regex, + transform_dfa_to_single_final_state, +) +import re + +regex_with_multiple_accepting_states = [ + r"(ab|aba)", + r"(ab|aba)*", + r"(hello|hell)", + r"b(aa|aaa)", + r"(cat|cats)", + r"(xy|xyx)", + r"(a|ab|abc)", + r"(1|12)", +] +regex_without_multiple_accepting_states = [ + r"(a|b)*", + r"abc", + r"(abc|def|ghi)", + r"(abc)*", + r"(hello)", + r"(ab)*", + r"(a|b|c)*", + r"((a|b|c)*abc)", + r"[a-zA-Z]+", + r"[0-9]+", + r"(abc|abcd|abcde)f", + r"(hello|helloo|hellooo)(foo|foob|fooba)?bar", + r"(foo|foob|fooba)?bar", + r"(abc|def)(gh|jk)(lm|nop)", +] +single_solution_regexes = [ + r"abc", + r"(hello)", +] +zkemail_regexes = [ + r">[^<>]+<.*", + r"to:[^\r\n]+\r\n", + r"subject:[^\r\n]+\r\n", + r"[A-Za-z0-9!#$%&'*+=?\-\^_`{|}~./]+@[A-Za-z0-9.\-@]+", + r"dkim-signature:([a-z]+=[^;]+; )+bh=[a-zA-Z0-9+/=]+;", + r"[A-Za-z0-9!#$%&'*+=?\-\^_`{|}~./@]+@[A-Za-z0-9.\-]+", + r"from:[^\r\n]+\r\n", + r"dkim-signature:([a-z]+=[^;]+; )+t=[0-9]+;", + r"message-id:<[A-Za-z0-9=@\.\+_-]+>\r\n", +] + + +def test_has_multiple_accepting_states_regex_without_multiple(): + for regex in regex_without_multiple_accepting_states: + assert not has_multiple_accepting_states_regex(regex) + + +def test_has_multiple_accepting_states_regex_with_multiple(): + for regex in regex_with_multiple_accepting_states: + assert has_multiple_accepting_states_regex(regex) + + +def test_transform_dfa_to_regex(): + regexes = [ + r"(ab|aba)", + r"(ab|aba)*", + r"(hello|hell)", + ] + for regex in regexes: + dfa = regex_to_dfa(regex) + transformed_regex = transform_dfa_to_regex(dfa) + assert isequal(regex, transformed_regex) + + +def test_transform_dfa_to_regex_with_multiple_accepting_states(): + for regex in regex_with_multiple_accepting_states: + dfa = regex_to_dfa(regex) + transformed_dfa = transform_dfa_to_single_final_state(dfa) + assert len(transformed_dfa.final_states) == 1 + transformed_regex = transform_dfa_to_regex(transformed_dfa) + new_dfa = regex_to_dfa(transformed_regex) + assert len(new_dfa.final_states) == 1 + + +def test_generate_dfa(): + while True: + try: + dfa_with_final = generate_random_dfa( + max_depth=10, use_unicode=False, single_final_state=True + ) + regex_with_final = transform_dfa_to_regex(dfa_with_final) + dfa_from_regex_with_final = regex_to_dfa(regex_with_final) + break + except Exception: + continue + assert len(dfa_with_final.final_states) == 1 + assert len(dfa_from_regex_with_final.final_states) == 1 + + while True: + try: + dfa_without_final = generate_random_dfa( + max_depth=10, use_unicode=False, single_final_state=False + ) + regex_without_final = transform_dfa_to_regex(dfa_without_final) + dfa_from_regex_without_final = regex_to_dfa(regex_without_final) + break + except Exception: + continue + assert len(dfa_without_final.final_states) >= 1 + assert len(dfa_from_regex_without_final.final_states) >= 1 + + +def test_dfa_string_matching(): + for regex in regex_without_multiple_accepting_states: + string = dfa_string_matching(regex) + assert string is not None + for _ in range(5): + string2 = dfa_string_matching(regex) + if string != string2: + break + if regex not in single_solution_regexes: + assert string != string2 + + +def test_dfa_string_matching_zkemail(): + for regex in zkemail_regexes: + string = dfa_string_matching(regex) + assert string is not None + # we also need to check against re module + assert re.match(regex, string) is not None diff --git a/tests/test_utils.py b/tests/test_utils.py index 9933b4c..29ffff3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,9 @@ -from zkregex_fuzzer.utils import is_valid_regex +from zkregex_fuzzer.utils import ( + check_zkregex_rules_basic, + correct_carret_position, + has_lazy_quantifier, + is_valid_regex, +) def test_valid_regex(): @@ -26,3 +31,129 @@ def test_invalid_regex(): ] for pattern in invalid_patterns: assert not is_valid_regex(pattern), f"Expected {pattern} to be invalid" + + +def test_has_lazy_quantifier(): + """Test that has_lazy_quantifier returns True for patterns with lazy quantifiers.""" + patterns = [ + (r"ab*c", False), + (r"a+?", True), + (r"(abc){2,5}?", True), + (r"xyz", False), + (r"[a-z]*", False), + (r".+?", True), + ] + for pattern, expected in patterns: + assert has_lazy_quantifier(pattern) == expected, ( + f"Expected {pattern} to have lazy quantifier: {expected}" + ) + + +def test_correct_carret_position(): + """ + Test the correct_carret_position function with various corner cases. + """ + # Test cases with expected results + test_cases = [ + # Basic cases + (r"^abc", True), # Start of regex + (r"abc", True), # No caret + (r"abc^", False), # Invalid position at end + # Capturing group cases + (r"(^abc)", False), # Start of capturing group + (r"(|^)", True), # Alternative with caret + (r"(abc|^def)", False), # Caret in middle of alternative + (r"(|^)", True), # Simple alternative with caret + (r"(\n|^)", True), # Newline alternative + (r"abc(\n|^)", False), # Not at start of regex + (r"(\r|^)", True), # Carriage return alternative + (r"(\r\n|^)", True), # CRLF alternative + (r"(\n\r|^)", True), # CRLF alternative + (r"( |^)", True), # Spaces before alternative + # Character class cases + (r"[^abc]", True), # Simple negated character class + (r"abc[^xyz]def", True), # Negated character class in middle + (r"[abc^]", False), # Caret not at start of character class + (r"[[^]]", True), # Nested character class + (r"[^]", True), # Empty negated character class + # Multiple caret cases + (r"^abc[^xyz]", True), # Valid multiple carets + (r"^abc^", False), # Invalid multiple carets + (r"[^abc][^xyz]", True), # Multiple negated character classes + # Edge cases + (r"", True), # Empty string + (r"^", True), # Just caret + (r"[]^]", False), # Invalid character class + (r"(^)|^", False), # Multiple start anchors + (r"(^abc|^def)", False), # Multiple start anchors in group + # Complex cases + (r"(|^)abc[^xyz]123", True), # Combination of valid cases + (r"^abc[^xyz](|^)def", False), # Invalid multiple start anchors + (r"[^abc]^[^xyz]", False), # Invalid caret between character classes + (r"( \r\n |^)abc", True), # Complex whitespace before alternative + # Escaped caret cases + (r"abc\^", True), + (r"abc\^def", True), + ] + for regex, expected in test_cases: + assert correct_carret_position(regex) == expected, ( + f"Expected {regex} to have correct caret position: {expected}" + ) + + +def test_check_zkregex_rules_basic(): + """ + Test the check_zkregex_rules_basic function with various test cases. + """ + # Test cases with expected results + test_cases = [ + # 1. Dollar sign tests + (r"abc$", (True, True)), # Valid dollar sign at end, + (r"abc$def", (True, True)), # Valid dollar sign in middle + (r"abc", (True, True)), # No dollar sign + (r"$abc", (True, True)), # Dollar sign at start + # 2. Caret position tests + (r"^abc", (True, True)), # Valid caret at start + (r"(|^)abc", (True, True)), # Valid caret in alternative + (r"(\r\n|^)abc", (True, True)), # Valid caret with CRLF alternative + (r"[^abc]", (True, True)), # Valid caret in character class + (r"abc^", (False, True)), # Invalid caret at end + (r"abc^def", (False, True)), # Invalid caret in middle + # 3. Lazy quantifier tests + (r"abc*", (True, True)), # Valid greedy quantifier + (r"abc*?", (False, True)), # Invalid lazy star quantifier + (r"abc+?", (False, True)), # Invalid lazy plus quantifier + (r"abc??", (False, True)), # Invalid lazy question mark quantifier + (r"abc{1,2}?", (False, True)), # Invalid lazy range quantifier + # 4. Combined valid cases + (r"^abc$", (True, True)), # Valid start and end anchors + (r"(|^)abc$", (True, True)), # Valid alternative and end anchor + (r"[^abc].*$", (True, True)), # Valid character class and end anchor + # 5. Combined invalid cases + (r"^abc$def", (True, True)), # Valid dollar position with caret + (r"abc^def$", (False, True)), # Invalid caret with dollar + (r"[^abc]*?$", (False, True)), # Invalid lazy quantifier with valid anchors + # 6. Complex cases + (r"(|^)abc[^xyz]*$", (True, True)), # Complex valid regex + (r"^abc[^xyz]+def$", (True, True)), # Complex valid regex with quantifiers + ( + r"(|^)abc*?[^xyz]$", + (False, True), + ), # Complex invalid regex with lazy quantifier + (r"[a-zA-Z0-9._%+-]+", (True, True)), + (r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$", (True, True)), + # 7. The common regexes from zkemail + (r">[^<>]+<.*", (True, True)), + (r"(\r\n|^)to:[^\r\n]+\r\n", (True, True)), + (r"(\r\n|^)subject:[^\r\n]+\r\n", (True, True)), + (r"[A-Za-z0-9!#$%&'*+=?\-\^_`{|}~.\/]+@[A-Za-z0-9.\-@]+", (True, True)), + (r"[A-Za-z0-9!#$%&'*+=?\-\^_`{|}~.\/@]+@[A-Za-z0-9.\-]+", (True, True)), + (r"(\r\n|^)from:[^\r\n]+\r\n", (True, True)), + (r"(\r\n|^)dkim-signature:([a-z]+=[^;]+; )+bh=[a-zA-Z0-9+/=]+;", (True, True)), + (r"(\r\n|^)dkim-signature:([a-z]+=[^;]+; )+t=[0-9]+;", (True, True)), + (r"(\r\n|^)message-id:<[A-Za-z0-9=@\.\+_-]+>\r\n", (True, True)), + ] + for regex, expected in test_cases: + assert check_zkregex_rules_basic(regex) == expected, ( + f"Expected {regex} to have correct zk-regex rules: {expected}" + )