From 53980774c327675e886179c0a2c140744dcf9b95 Mon Sep 17 00:00:00 2001 From: Jonas Maier Date: Sat, 6 Jun 2026 12:15:52 +0200 Subject: special cased regex for performance --- src/regex/bc.rs | 752 ++++++++++++++++++++++++++++++++++++++++++++++++ src/regex/byte_range.rs | 204 +++++++++++++ src/regex/dfa.rs | 381 ++++++++++++++++++++++++ src/regex/enfa.rs | 741 +++++++++++++++++++++++++++++++++++++++++++++++ src/regex/mod.rs | 344 ++++++++++++++++++++++ src/regex/simple.rs | 125 ++++++++ 6 files changed, 2547 insertions(+) create mode 100644 src/regex/bc.rs create mode 100644 src/regex/byte_range.rs create mode 100644 src/regex/dfa.rs create mode 100644 src/regex/enfa.rs create mode 100644 src/regex/mod.rs create mode 100644 src/regex/simple.rs (limited to 'src/regex') diff --git a/src/regex/bc.rs b/src/regex/bc.rs new file mode 100644 index 0000000..4a79485 --- /dev/null +++ b/src/regex/bc.rs @@ -0,0 +1,752 @@ +use std::collections::{HashMap, VecDeque}; + +use super::{ + CharacterClass, GreedyBehavior, LookDirection, LookPolarity, Match, Pattern, RegexEngine, + byte_range::ByteRange, +}; +use crate::bitset::BitSet; + +trait Flavor: Clone { + type CustomInstr: Copy + Clone + std::fmt::Debug; + type ThreadData: Clone; + type StepData<'a, 'b> + where + 'b: 'a; + + fn accepts<'a, 'b>( + thread: &mut Thread, + instr: Self::CustomInstr, + sd: &mut Self::StepData<'a, 'b>, + ) -> bool; + + fn save(x: u32) -> Option; +} + +#[derive(Copy, Clone, Debug)] +struct MainFlavor; +impl Flavor for MainFlavor { + type CustomInstr = MainInstr; + type ThreadData = Box<[Option]>; + type StepData<'a, 'b> + = (usize, &'a BitSet, &'a mut LookaheadVM<'b>) + where + 'b: 'a; + + fn accepts<'a, 'b>( + thread: &mut Thread, + instr: Self::CustomInstr, + data: &mut Self::StepData<'a, 'b>, + ) -> bool { + match instr { + MainInstr::Save(reg) => { + thread.data[reg as usize] = Some(data.0); + true + } + MainInstr::Join(assertion) => { + let should_match = assertion.pol == LookPolarity::Positive; + let state = assertion.target as usize; + let is_matching = match assertion.dir { + LookDirection::Ahead => data.2.get_state(data.0, state), + LookDirection::Behind => data.1.get(state), + }; + is_matching == should_match + } + } + } + + fn save(x: u32) -> Option { + Some(MainInstr::Save(x)) + } +} + +#[derive(Copy, Clone, Debug)] +enum Nothing {} + +#[derive(Copy, Clone, Debug)] +struct AssertionFlavor; +impl Flavor for AssertionFlavor { + type CustomInstr = Nothing; + type ThreadData = (); + type StepData<'a, 'b> + = () + where + 'b: 'a; + + fn accepts(_thread: &mut Thread, instr: Self::CustomInstr, _sd: &mut ()) -> bool { + match instr {} + } + + fn save(_: u32) -> Option { + None + } +} + +type JumpTarget = u32; +type Register = u32; + +#[derive(Copy, Clone, Debug)] +struct Assertion { + target: JumpTarget, + dir: LookDirection, + pol: LookPolarity, +} + +#[derive(Copy, Clone, Debug)] +enum Instr { + Class(CharacterClass), + Consume(ByteRange), + Jump(JumpTarget), + Fork(JumpTarget, JumpTarget), + Custom(F::CustomInstr), +} + +#[derive(Copy, Clone, Debug)] +enum MainInstr { + Save(Register), + Join(Assertion), +} + +#[derive(Clone)] +struct Thread { + pc: JumpTarget, + data: F::ThreadData, +} + +struct VM<'p, F: Flavor> { + instr: &'p [Instr], + passive_threads: VecDeque>, + active_threads: VecDeque>, + hot: BitSet, + warm: BitSet, +} + +impl<'p, F: Flavor> VM<'p, F> { + fn new(instr: &'p [Instr], starting_thread: Thread) -> Self { + Self { + instr, + passive_threads: vec![starting_thread].into(), + active_threads: VecDeque::new(), + hot: BitSet::new(instr.len()), + warm: BitSet::new(instr.len()), + } + } + + fn step_epsilon<'a>(&mut self, sd: &mut F::StepData<'a, 'p>) { + std::mem::swap(&mut self.active_threads, &mut self.passive_threads); + self.hot.set_all(false); + self.warm.set_all(false); + + macro_rules! add_thread { + ($t:expr) => {{ + let t = $t; + let bit = t.pc as usize; + if !self.warm.get(bit) { + self.warm.set(bit, true); + self.active_threads.push_front(t); + } + }}; + } + + while let Some(mut thread) = self.active_threads.pop_front() { + match self.instr[thread.pc as usize] { + Instr::Class(_) | Instr::Consume(_) => { + if !self.hot.get(thread.pc as usize) { + self.hot.set(thread.pc as usize, true); + self.passive_threads.push_back(thread); + } + } + Instr::Jump(j) => { + thread.pc = j; + add_thread!(thread); + } + Instr::Fork(a, b) => { + add_thread!(Thread { + pc: b, + data: thread.data.clone(), + }); + add_thread!(Thread { + pc: a, + data: thread.data.clone(), + }); + } + Instr::Custom(instr) => { + if F::accepts(&mut thread, instr, sd) { + thread.pc += 1; + add_thread!(thread); + } + } + } + } + } + + fn step_consume(&mut self, byte: u8) { + self.hot.set_all(false); + self.passive_threads + .retain_mut(|thread| match self.instr[thread.pc as usize] { + Instr::Class(class) => { + if class.matches(byte) { + thread.pc += 1; + self.hot.set(thread.pc as usize, true); + true + } else { + false + } + } + Instr::Consume(bytes) => { + if bytes.contains(byte) { + thread.pc += 1; + self.hot.set(thread.pc as usize, true); + true + } else { + false + } + } + _ => false, + }); + } +} + +struct LookaheadVM<'a> { + vm: VM<'a, AssertionFlavor>, + data: &'a [u8], + cached: bool, + cache_data: Vec, + loc_offset: usize, +} + +impl<'a> LookaheadVM<'a> { + fn new(vm: VM<'a, AssertionFlavor>, data: &'a [u8]) -> Self { + Self { + vm, + data, + cached: false, + cache_data: Vec::new(), + loc_offset: 0, + } + } + + fn get_state(&mut self, loc: usize, state: usize) -> bool { + if !self.cached { + assert!(self.cache_data.is_empty()); + assert_eq!(self.loc_offset, 0); + self.loc_offset = loc; + self.vm.step_epsilon(&mut ()); + self.cache_data.push(self.vm.hot.clone()); + for i in (loc..self.data.len()).rev() { + self.vm.step_consume(self.data[i]); + self.vm.step_epsilon(&mut ()); + self.cache_data.push(self.vm.hot.clone()); + } + self.cache_data.reverse(); + self.cached = true; + } + + assert!( + loc >= self.loc_offset, + "get_state must be called with non-decreasing arguments." + ); + self.cache_data[loc - self.loc_offset].get(state) + } +} + +struct VirtualMachine<'a> { + vm0: VM<'a, AssertionFlavor>, + vm1: VM<'a, MainFlavor>, + vm2: LookaheadVM<'a>, + accepting: &'a BitSet, +} + +impl<'a> VirtualMachine<'a> { + fn step_epsilon_1(&mut self, loc: usize) { + self.vm1 + .step_epsilon(&mut (loc, &self.vm0.hot, &mut self.vm2)); + } + + fn step_epsilon(&mut self, loc: usize) { + self.vm0.step_epsilon(&mut ()); + self.step_epsilon_1(loc); + } + + fn step_consume(&mut self, byte: u8) { + self.vm0.step_consume(byte); + self.vm1.step_consume(byte); + } + + fn step(&mut self, byte: u8, loc: usize) { + self.step_epsilon(loc); + self.step_consume(byte); + } + + fn extract_match(&self) -> Option { + self.vm1 + .passive_threads + .iter() + .filter(|t| self.accepting.get(t.pc as usize)) + .map(|t| { + let submatches: Vec<_> = t.data.windows(2).map(|x| Some(x[0]?..x[1]?)).collect(); + + Match { + submatches: submatches.into(), + } + }) + .next() + } +} + +fn fmt_instructions( + f: &mut std::fmt::Formatter<'_>, + label: &str, + insns: &[Instr], +) -> std::fmt::Result { + writeln!(f, "# {label}")?; + for (idx, ins) in insns.iter().enumerate() { + writeln!(f, "{idx}: {ins:?}")?; + } + Ok(()) +} + +pub struct BytecodeCompiledRegex { + instrs0: Box<[Instr]>, + instrs1: Box<[Instr]>, + instrs2: Box<[Instr]>, + no_lookbehind: bool, + submatch_count: usize, + accepting: BitSet, +} + +impl std::fmt::Debug for BytecodeCompiledRegex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fmt_instructions(f, "behind", &self.instrs0)?; + fmt_instructions(f, "ahead", &self.instrs2)?; + fmt_instructions(f, "main", &self.instrs1)?; + writeln!(f, "accepting: {:?}", self.accepting) + } +} + +impl BytecodeCompiledRegex { + pub fn re_match(&self, data: &[u8]) -> Option { + let vm0 = VM::new(&self.instrs0, Thread { pc: 0, data: () }); + let vm1 = VM::new( + &self.instrs1, + Thread { + pc: 0, + data: vec![None; 2 * self.submatch_count].into(), + }, + ); + let vm2 = VM::new(&self.instrs2, Thread { pc: 0, data: () }); + let vm2 = LookaheadVM::new(vm2, data); + let mut vm = VirtualMachine { + vm0, + vm1, + vm2, + accepting: &self.accepting, + }; + if self.no_lookbehind { + for (i, ch) in data.iter().cloned().enumerate() { + vm.step_epsilon_1(i); + vm.vm1.step_consume(ch); + } + vm.step_epsilon(data.len()); + } else { + for (i, ch) in data.iter().cloned().enumerate() { + vm.step(ch, i); + } + vm.step_epsilon(data.len()); + } + vm.extract_match() + } + + pub fn matches(&self, data: &[u8]) -> bool { + self.re_match(data).is_some() + } +} + +type AssertionHandler<'a, F> = + Box CompileResult>>; + +#[derive(Copy, Clone)] +struct CompiledSnippet { + begin: JumpTarget, + end: JumpTarget, +} + +struct Compiler<'a, F: Flavor> { + instrs: Vec>, + map: HashMap, + assertion_handler: AssertionHandler<'a, F>, + assertion_fork_base: usize, + submatch_count: usize, +} + +fn fork(repeat: usize, exit: usize, greedy: GreedyBehavior) -> Instr { + let repeat = repeat as JumpTarget; + let exit = exit as JumpTarget; + match greedy { + GreedyBehavior::Greedy => Instr::Fork(repeat, exit), + GreedyBehavior::NonGreedy => Instr::Fork(exit, repeat), + } +} + +impl<'a, F: Flavor> Compiler<'a, F> { + fn new( + assertion_handler: impl 'a + + FnMut(LookDirection, LookPolarity, Pattern) -> CompileResult>, + ) -> Self { + Self { + instrs: Vec::new(), + map: HashMap::new(), + assertion_handler: Box::new(assertion_handler), + assertion_fork_base: usize::MAX, + submatch_count: 0, + } + } + + fn rep_1_or_more(&mut self, pat: Pattern, greedy: GreedyBehavior) -> CompileResult { + let base = self.instrs.len(); + self.compile(pat)?; + let exit = self.instrs.len() + 1; + self.instrs.push(fork(base, exit, greedy)); + Ok(()) + } + + fn rep_0_or_1(&mut self, pat: Pattern, greedy: GreedyBehavior) -> CompileResult { + let base = self.instrs.len(); + self.instrs.push(Instr::Jump(u32::MAX)); + self.compile(pat)?; + self.instrs[base] = fork(base + 1, self.instrs.len(), greedy); + Ok(()) + } + + fn rep_any_amt(&mut self, pat: Pattern, greedy: GreedyBehavior) -> CompileResult { + let base = self.instrs.len(); + self.instrs.push(Instr::Jump(u32::MAX)); + self.compile(pat)?; + let fork_pos = self.instrs.len(); + let after = fork_pos + 1; + self.instrs.push(fork(base + 1, after, greedy)); + self.instrs[base] = Instr::Jump(fork_pos as JumpTarget); + Ok(()) + } + + fn compile(&mut self, pat: Pattern) -> CompileResult { + match pat { + Pattern::Byte(x) => self.instrs.push(Instr::Consume(ByteRange::new_single(x))), + Pattern::Range(a, b) => self.instrs.push(Instr::Consume(ByteRange::new_range(a, b))), + Pattern::CharacterClass(cc) => { + self.instrs.push(Instr::Class(cc)); + } + Pattern::Alt(patterns) => { + let branch_factor = patterns.len(); + assert!(branch_factor > 0); + + let base = self.instrs.len(); + + // placeholders to later place in forks + for _ in 0..patterns.len() - 1 { + self.instrs.push(Instr::Jump(u32::MAX)); + } + + let mut enter_pats = Vec::new(); + let mut leave_pats = Vec::new(); + for pat in patterns.into_iter() { + enter_pats.push(self.instrs.len()); + self.compile(pat)?; + leave_pats.push(self.instrs.len()); + + // placeholder to place in join + self.instrs.push(Instr::Jump(u32::MAX)); + } + + self.instrs.pop(); // remove last jump + let join_point = self.instrs.len(); + + // link forks + for i in 0..branch_factor - 1 { + let a = enter_pats[i]; + let b = if i == branch_factor - 2 { + enter_pats[i + 1] + } else { + base + i + 1 + }; + self.instrs[base + i] = Instr::Fork(a as JumpTarget, b as JumpTarget); + } + + // link joins + for i in 0..branch_factor - 1 { + self.instrs[leave_pats[i]] = Instr::Jump(join_point as JumpTarget); + } + } + Pattern::Concat(patterns) => { + for pat in patterns.into_iter() { + self.compile(pat)?; + } + } + Pattern::Rep(pat, 0, None, greed) => { + self.rep_any_amt(*pat, greed)?; + } + Pattern::Rep(pat, min, None, greed) => { + let pat = *pat; + for _ in 1..min { + self.compile(pat.clone())?; + } + self.rep_1_or_more(pat, greed)?; + } + Pattern::Rep(pat, min, Some(max), greed) => { + let pat = *pat; + let opt = max - min; + for _ in 0..min { + self.compile(pat.clone())?; + } + for _ in 0..opt { + self.rep_0_or_1(pat.clone(), greed)?; + } + } + Pattern::Assertion(look_direction, look_polarity, pattern) => { + let ins = (self.assertion_handler)(look_direction, look_polarity, *pattern)?; + self.instrs.push(ins); + } + Pattern::Nothing => {} + Pattern::Submatch(pat) => { + let i = self.submatch_count as u32 * 2; + self.submatch_count += 1; + if let Some(ins) = F::save(i) { + self.instrs.push(Instr::Custom(ins)); + } + self.compile(*pat)?; + if let Some(ins) = F::save(i + 1) { + self.instrs.push(Instr::Custom(ins)); + } + } + } + Ok(()) + } + + fn compile_and_memoize(&mut self, pat: Pattern) -> CompileResult { + if let Some(&jt) = self.map.get(&pat) { + return Ok(jt); + } + let begin = self.instrs.len() as JumpTarget; + self.compile(pat.clone())?; + let end = self.instrs.len() as JumpTarget; + self.instrs.push(Instr::Class(CharacterClass::Nothing)); + let bounds = CompiledSnippet { begin, end }; + self.map.insert(pat, bounds); + Ok(bounds) + } + + fn finalize_assertion_forks(&mut self) { + let fork_targets: Vec = self.map.values().map(|v| v.begin).collect(); + let fork_begin = self.instrs.len() as JumpTarget; + match fork_targets.len() { + 0 => { + self.instrs[self.assertion_fork_base] = Instr::Class(CharacterClass::Nothing); + } + 1 => { + self.instrs[self.assertion_fork_base] = Instr::Jump(fork_targets[0]); + } + 2 => { + self.instrs[self.assertion_fork_base] = + Instr::Fork(fork_targets[0], fork_targets[1]); + } + n => { + self.instrs[self.assertion_fork_base] = Instr::Fork(fork_targets[0], fork_begin); + for i in 1..n - 1 { + let fork = if i == n - 2 { + Instr::Fork(fork_targets[i], fork_targets[i + 1]) + } else { + Instr::Fork(fork_targets[i], self.instrs.len() as JumpTarget + 1) + }; + self.instrs.push(fork); + } + } + } + } +} + +fn assertion_compiler() -> Compiler<'static, AssertionFlavor> { + let mut c = Compiler::new(|_, _, _| Err(RegexCompilationError::NestedLookaroundNotSupported)); + c.rep_any_amt( + Pattern::CharacterClass(CharacterClass::Everything), + GreedyBehavior::NonGreedy, + ) + .expect("characterclass should always compile"); + c.assertion_fork_base = c.instrs.len(); + c.instrs.push(Instr::Jump(u32::MAX)); // in the end this gets replaced by a jump to a fork-list for all the assertions + c +} + +#[derive(Clone, Debug)] +pub enum RegexCompilationError { + NestedLookaroundNotSupported, +} + +pub type CompileResult = Result; + +impl TryFrom for BytecodeCompiledRegex { + type Error = RegexCompilationError; + + fn try_from(value: Pattern) -> Result { + let mut neg = assertion_compiler(); + let mut pos = assertion_compiler(); + let (final_state, instrs, submatch_count) = { + let mut main: Compiler = Compiler::new(|dir, pol, pat| { + let target = match dir { + LookDirection::Ahead => pos.compile_and_memoize(pat.reverse()), + LookDirection::Behind => neg.compile_and_memoize(pat), + }? + .end; + + Ok(Instr::Custom(MainInstr::Join(Assertion { + target, + dir, + pol, + }))) + }); + main.compile(value)?; + let end = main.instrs.len(); + main.instrs.push(Instr::Class(CharacterClass::Nothing)); + (end, main.instrs, main.submatch_count) + }; + neg.finalize_assertion_forks(); + pos.finalize_assertion_forks(); + + let mut accepting = BitSet::new(instrs.len()); + accepting.set(final_state, true); + + Ok(Self { + no_lookbehind: neg.map.is_empty(), + instrs0: neg.instrs.into(), + instrs1: instrs.into(), + instrs2: pos.instrs.into(), + accepting, + submatch_count, + }) + } +} + +impl RegexEngine for BytecodeCompiledRegex { + type CompileError = RegexCompilationError; + + fn compile(pat: Pattern) -> Result { + Self::try_from(pat) + } + + fn run(&self, input: &[u8]) -> Option { + self.re_match(input) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::parse::Parse; + fn regex(s: &str) -> BytecodeCompiledRegex { + let pat = Pattern::parse_from_bytes(s.as_bytes()).unwrap(); + let compiled = BytecodeCompiledRegex::try_from(pat).unwrap(); + compiled + } + + #[test] + fn print_compiled_vm() { + let compiled = regex("a?b?"); + println!("{compiled:#?}"); + assert_eq!(compiled.matches(b"ab"), true); + assert_eq!(compiled.matches(b"a"), true); + assert_eq!(compiled.matches(b"b"), true); + assert_eq!(compiled.matches(b""), true); + } + + #[test] + fn nongreedy_star() { + let re = regex("(ab*?)bb*"); + assert_eq!( + re.re_match(b"abbb").unwrap().submatches[0].clone().unwrap(), + 0..1 + ); + assert_eq!( + re.re_match(b"abbbbb").unwrap().submatches[0] + .clone() + .unwrap(), + 0..1 + ); + } + + #[test] + fn greedy_star() { + let re = regex("(ab*)bb*"); + assert_eq!( + re.re_match(b"abbb").unwrap().submatches[0].clone().unwrap(), + 0..3 + ); + assert_eq!( + re.re_match(b"abbbbb").unwrap().submatches[0] + .clone() + .unwrap(), + 0..5 + ); + } + + #[test] + fn nongreedy_plus() { + let re = regex("(ab+?)bb*"); + assert_eq!( + re.re_match(b"abbbb").unwrap().submatches[0] + .clone() + .unwrap(), + 0..2 + ); + assert_eq!( + re.re_match(b"abbbbb").unwrap().submatches[0] + .clone() + .unwrap(), + 0..2 + ); + } + + #[test] + fn greedy_plus() { + let re = regex("(ab+)bb*"); + assert_eq!( + re.re_match(b"abbb").unwrap().submatches[0].clone().unwrap(), + 0..3 + ); + assert_eq!( + re.re_match(b"abbbbb").unwrap().submatches[0] + .clone() + .unwrap(), + 0..5 + ); + } + + #[test] + fn nongreedy_qm() { + let re = regex("(ab??)bb*"); + assert_eq!( + re.re_match(b"abbbb").unwrap().submatches[0] + .clone() + .unwrap(), + 0..1 + ); + assert_eq!( + re.re_match(b"abbbbb").unwrap().submatches[0] + .clone() + .unwrap(), + 0..1 + ); + } + + #[test] + fn greedy_qm() { + let re = regex("(ab?)bb*"); + assert_eq!( + re.re_match(b"abbb").unwrap().submatches[0].clone().unwrap(), + 0..2 + ); + assert_eq!( + re.re_match(b"abbbbb").unwrap().submatches[0] + .clone() + .unwrap(), + 0..2 + ); + } +} diff --git a/src/regex/byte_range.rs b/src/regex/byte_range.rs new file mode 100644 index 0000000..b7642c1 --- /dev/null +++ b/src/regex/byte_range.rs @@ -0,0 +1,204 @@ +use std::ops::RangeInclusive; + +#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct ByteRange { + /// inclusive + from: u8, + /// inclusive + to: u8, +} + +impl From> for ByteRange { + fn from(value: RangeInclusive) -> Self { + Self::new_range(*value.start(), *value.end()) + } +} + +impl ByteRange { + pub fn new_range(from: u8, to: u8) -> Self { + assert!(from <= to, "{from} <= {to}"); + Self { from, to } + } + + pub fn new_single(c: u8) -> Self { + Self::new_range(c, c) + } + + pub fn contains(&self, c: u8) -> bool { + self.from <= c && c <= self.to + } + + pub fn overlaps(&self, other: Self) -> bool { + self.from.max(other.from) <= self.to.min(other.to) + } + + pub fn split_to_disjoint(ranges: Vec) -> Vec { + if ranges.is_empty() { + return vec![]; + } + + let mut points: Vec = Vec::new(); + for r in &ranges { + points.push(r.from); + if r.to != u8::MAX { + points.push(r.to + 1); + } + } + + points.sort_unstable(); + points.dedup(); + + let mut out = Vec::new(); + + for window in points.windows(2) { + let start = window[0]; + let end_exclusive = window[1]; + + if start >= end_exclusive { + continue; + } + + let mut active = false; + + for r in &ranges { + if r.from <= start && start <= r.to { + active = true; + break; + } + } + + if active { + out.push(ByteRange { + from: start, + to: end_exclusive - 1, + }); + } + } + + out + } +} + +#[test] +fn byterange_test() { + assert_eq!( + ByteRange::split_to_disjoint(vec![ + ByteRange::new_range(b'a', b'z'), + ByteRange::new_single(b'm') + ]), + vec![ + ByteRange::new_range(b'a', b'l'), + ByteRange::new_single(b'm'), + ByteRange::new_range(b'n', b'z'), + ] + ); +} + +impl std::fmt::Debug for ByteRange { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.from == self.to { + write!(f, "{}", [self.from].escape_ascii()) + } else { + write!( + f, + "{}-{}", + [self.from].escape_ascii(), + [self.to].escape_ascii() + ) + } + } +} + +#[cfg(test)] +mod non_overlapping_tests { + use std::ops::RangeInclusive; + + use super::ByteRange; + + fn middle(r: ByteRange) -> u8 { + let a = r.from as u8; + let b = r.to as u8; + (a + (b - a) / 2) as u8 + } + + fn prev(c: u8) -> u8 { + c - 1 + } + + fn next(c: u8) -> u8 { + c + 1 + } + + fn run(ranges: Vec>) { + let ranges1: Vec = ranges.into_iter().map(Into::into).collect(); + let ranges2 = ByteRange::split_to_disjoint(ranges1.clone()); + + let r1 = |c| ranges1.iter().any(|cr| cr.contains(c)); + let r2 = |c| ranges2.iter().any(|cr| cr.contains(c)); + + for &range in ranges1.iter() { + assert!(r1(range.from)); + assert!(r1(range.to)); + assert!(r1(middle(range))); + + assert!(r2(range.from)); + assert!(r2(range.to)); + assert!(r2(middle(range))); + + assert_eq!(r1(prev(range.from)), r2(prev(range.from))); + assert_eq!(r1(next(range.from)), r2(next(range.from))); + } + + for i in 0..ranges2.len() { + for j in 0..i { + assert!( + !ranges2[i].overlaps(ranges2[j]), + "{i} and {j} overlap: {:?}, {:?}", + ranges2[i], + ranges2[j] + ); + } + } + } + + #[test] + fn overlap_correct() { + assert!(ByteRange::new_range(b'a', b'g').overlaps(ByteRange::new_single(b'f'))); + assert!(!ByteRange::new_range(b'a', b'g').overlaps(ByteRange::new_single(b'h'))); + } + + #[test] + fn empty() { + run(vec![]); + } + + #[test] + fn singleton() { + run(vec![b'0'..=b'9']); + } + + #[test] + fn contained1() { + run(vec![b'0'..=b'9', b'5'..=b'6']); + } + + #[test] + fn contained2() { + run(vec![b'5'..=b'6', b'0'..=b'9']); + } + + #[test] + fn overlap2() { + run(vec![b'1'..=b'6', b'4'..=b'9']) + } + + #[test] + fn overlap3() { + run(vec![b'a'..=b'f', b'd'..=b'j', b'g'..=b'm']) + } + + #[test] + fn overlap4() { + run(vec![b'a'..=b'f', b'd'..=b'j', b'g'..=b'm', b'k'..=b'q']) + } +} diff --git a/src/regex/dfa.rs b/src/regex/dfa.rs new file mode 100644 index 0000000..c55d99d --- /dev/null +++ b/src/regex/dfa.rs @@ -0,0 +1,381 @@ +use core::fmt; +use std::collections::{BinaryHeap, HashMap, HashSet}; + +use crate::regex::{Match, RegexEngine, enfa::EnfaTranslationError}; + +use super::{ + Pattern, + byte_range::ByteRange, + enfa::{ENFA, MultiState}, +}; + +pub type StateId = usize; + +pub struct State { + trans: HashMap, + default_trans: StateId, + accept: bool, +} + +#[allow(clippy::upper_case_acronyms)] +pub struct DFA { + start: StateId, + states: Vec, +} + +impl fmt::Debug for DFA { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "DFA {{")?; + for (i, s) in self.states.iter().enumerate() { + if self.start == i { + write!(f, "-> {i}: ")?; + } else { + write!(f, " {i}: ")?; + } + + for (chr, to) in s.trans.iter() { + write!(f, "{chr:?} to {to}, ")?; + } + + write!(f, "dfl to {}", s.default_trans)?; + if s.accept { + write!(f, ", accept")?; + } + writeln!(f)?; + } + writeln!(f, "}}") + } +} + +impl From for DFA { + fn from(mut nfa: ENFA) -> Self { + nfa.remove_unreachable(); + + let mut multi_states = nfa.all_multi_states(); + multi_states.insert(nfa.void_multi_state()); + let mut len = 0; + let multi_to_dfa: HashMap = multi_states + .clone() + .into_iter() + .map(|ms| { + len += 1; + (ms, len - 1) + }) + .collect(); + + let void = multi_to_dfa[&nfa.void_multi_state()]; + + let mut states: Vec = (0..len) + .map(|_| State { + trans: HashMap::new(), + default_trans: void, + accept: false, + }) + .collect(); + + for ms in multi_states.iter() { + let i: usize = multi_to_dfa[ms]; + states[i].accept = ms.accept(); + for t in ms.possible_transitions() { + let k = multi_to_dfa[&ms.transition(t)]; + states[i].trans.insert(t, k); + } + } + + let mut this = Self { + start: multi_to_dfa[&nfa.start_multi_state()], + states, + }; + this.minify(); + this + } +} + +#[derive(Clone, Debug)] +pub enum DFACompileError { + SubmatchesNotSupported, + #[allow(unused)] + NFAError(EnfaTranslationError), +} + +impl RegexEngine for DFA { + type CompileError = DFACompileError; + + fn compile(pat: Pattern) -> Result { + match ENFA::try_from(pat) { + Ok(nfa) => { + if nfa.has_submatches { + Err(DFACompileError::SubmatchesNotSupported) + } else { + Ok(Self::from(nfa)) + } + } + Err(e) => Err(DFACompileError::NFAError(e)), + } + } + + fn run(&self, input: &[u8]) -> Option { + if self.matches(input) { + Some(Match::new_empty()) + } else { + None + } + } +} + +impl DFA { + pub fn matches(&self, x: &[u8]) -> bool { + let mut state = self.start; + 'next_byte: for &b in x.iter() { + for (range, &next_state) in self.states[state].trans.iter() { + if range.contains(b) { + state = next_state; + continue 'next_byte; + } + } + state = self.states[state].default_trans; + } + self.states[state].accept + } +} + +mod state_set { + #[derive(Hash, Clone, PartialEq, Eq)] + pub struct StateSet { + set: Vec, + } + + impl StateSet { + pub fn new(mut set: Vec) -> Self { + set.sort(); + set.dedup(); + Self { set } + } + + pub fn iter(&self) -> impl Iterator { + self.set.iter().cloned() + } + + pub fn intersection(&self, other: &Self) -> Self { + let a = &self.set; + let b = &other.set; + + let mut i = 0; + let mut j = 0; + let mut out = Vec::new(); + + while i < a.len() && j < b.len() { + match a[i].cmp(&b[j]) { + std::cmp::Ordering::Less => i += 1, + std::cmp::Ordering::Greater => j += 1, + std::cmp::Ordering::Equal => { + out.push(a[i]); + i += 1; + j += 1; + } + } + } + + Self::new(out) + } + + pub fn difference(&self, other: &Self) -> Self { + let a = &self.set; + let b = &other.set; + + let mut i = 0; + let mut j = 0; + let mut out = Vec::new(); + + while i < a.len() && j < b.len() { + match a[i].cmp(&b[j]) { + std::cmp::Ordering::Less => { + out.push(a[i]); + i += 1; + } + std::cmp::Ordering::Greater => { + j += 1; + } + std::cmp::Ordering::Equal => { + i += 1; + j += 1; + } + } + } + + out.extend_from_slice(&a[i..]); + Self::new(out) + } + + pub fn is_empty(&self) -> bool { + self.set.is_empty() + } + + pub fn len(&self) -> usize { + self.set.len() + } + + pub fn primary_state(&self) -> Option { + self.set.first().cloned() + } + } + + // custom implementation such that smaller sets come first in a BinaryHeap + impl Ord for StateSet { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + match self.set.len().cmp(&other.set.len()) { + std::cmp::Ordering::Equal => {} + other => return other.reverse(), + } + self.set.cmp(&other.set) + } + } + + impl PartialOrd for StateSet { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } +} + +use state_set::StateSet; + +trait GoesTo { + fn goes_to(&self, to: usize) -> bool; +} +impl GoesTo for (&State, ByteRange) { + fn goes_to(&self, target: usize) -> bool { + let from = self.0; + let ch = self.1; + for (c, &to) in from.trans.iter() { + if c.overlaps(ch) && to == target { + return true; + } + } + from.default_trans == target + } +} + +impl DFA { + fn states_where(&self, mut f: impl FnMut(&State) -> bool) -> StateSet { + let states = self + .states + .iter() + .enumerate() + .filter_map(|(i, s)| if f(s) { Some(i) } else { None }) + .collect(); + StateSet::new(states) + } + + /// https://en.wikipedia.org/wiki/DFA_minimization + fn hopcroft_minimization(&mut self) { + let mut partitions: HashSet = HashSet::new(); + partitions.insert(self.states_where(|s| s.accept)); + partitions.insert(self.states_where(|s| !s.accept)); + + let mut ranges: Vec = self + .states + .iter() + .flat_map(|s| s.trans.iter().map(|t| *t.0)) + .chain([ByteRange::new_range(u8::MIN, u8::MAX)]) + .collect(); + ranges.sort(); + ranges.dedup(); + let ranges = ByteRange::split_to_disjoint(ranges); + + let mut queue: BinaryHeap = partitions.iter().cloned().collect(); + let mut queue_set = partitions.clone(); + + while let Some(a) = queue.pop() { + if !queue_set.contains(&a) { + continue; + } + + for &c in ranges.iter() { + let x = self.states_where(|s| a.iter().any(|a| (s, c).goes_to(a))); + + let mut del_list = HashSet::new(); + let mut add_list = Vec::new(); + for y in partitions.iter() { + let i = x.intersection(y); + let d = y.difference(&x); + + if !i.is_empty() && !d.is_empty() { + del_list.insert(y.clone()); + add_list.push(i.clone()); + add_list.push(d.clone()); + + if queue_set.contains(y) { + queue_set.remove(y); + + queue.push(i.clone()); + queue_set.insert(i); + + queue.push(d.clone()); + queue_set.insert(d); + } else if i.len() < d.len() { + queue.push(i.clone()); + queue_set.insert(i); + } else { + queue.push(d.clone()); + queue_set.insert(d); + } + } + } + + partitions.retain(|i| !del_list.contains(i)); + for x in add_list { + partitions.insert(x); + } + } + } + + let mut replacement = vec![None; self.states.len()]; + for partition in partitions { + if let Some(x) = partition.primary_state() { + for state in partition.iter() { + assert!(replacement[state].is_none()); + replacement[state] = Some(x); + } + } + } + + // replacement indices in original index space + let replacement: Vec = replacement.into_iter().map(|x| x.unwrap()).collect(); + let is_alive = |idx: usize| replacement[idx] == idx; + + // compact index space + let mut compact: Vec = vec![usize::MAX; self.states.len()]; + let mut next = 0; + for i in 0..self.states.len() { + if is_alive(i) { + compact[i] = next; + next += 1; + } + } + + // remap everything and skip all no-longer-needed states + let remap = |idx: usize| compact[replacement[idx]]; + let mut new_states = Vec::with_capacity(next); + for i in 0..self.states.len() { + if is_alive(i) { + let s = &self.states[i]; + new_states.push(State { + trans: s.trans.iter().map(|(&ch, &to)| (ch, remap(to))).collect(), + default_trans: remap(s.default_trans), + accept: s.accept, + }); + } + } + self.states = new_states; + self.start = remap(self.start); + } + + pub fn minify(&mut self) { + for state in self.states.iter_mut() { + state.trans.retain(|_, to| *to != state.default_trans); + } + + self.hopcroft_minimization(); + } +} diff --git a/src/regex/enfa.rs b/src/regex/enfa.rs new file mode 100644 index 0000000..8392642 --- /dev/null +++ b/src/regex/enfa.rs @@ -0,0 +1,741 @@ +use std::{ + collections::HashSet, + hash::{DefaultHasher, Hash, Hasher}, +}; + +use super::{LookDirection, LookPolarity, Pattern, byte_range::ByteRange}; + +/// NFA with epsilon transitions +#[derive(Clone)] +#[allow(clippy::upper_case_acronyms)] +pub struct ENFA { + pub states: Vec, + pub has_submatches: bool, +} + +fn cartesian_product(x: Vec>) -> Vec> { + let mut result = vec![Vec::new()]; + + for xs in x { + let mut next = Vec::new(); + + for prefix in &result { + for item in &xs { + let mut v = prefix.clone(); + v.push(item.clone()); + next.push(v); + } + } + + result = next; + } + + result +} + +#[cfg(test)] +mod product_tests { + use super::cartesian_product; + + #[test] + fn basic_case() { + let x = vec![vec![1, 2], vec![10, 20]]; + + let out = cartesian_product(x); + + assert_eq!( + out, + vec![vec![1, 10], vec![1, 20], vec![2, 10], vec![2, 20],] + ); + } + + #[test] + fn single_dimension() { + let x = vec![vec![1, 2, 3]]; + + let out = cartesian_product(x); + + assert_eq!(out, vec![vec![1], vec![2], vec![3],]); + } + + #[test] + fn empty_outer_vector() { + let x: Vec> = vec![]; + + let out: Vec> = cartesian_product(x); + let r: Vec> = vec![vec![]]; + + // One empty combination. + assert_eq!(out, r); + } + + #[test] + fn empty_inner_vector() { + let x = vec![vec![1, 2], vec![], vec![3, 4]]; + + let out = cartesian_product(x); + + assert!(out.is_empty()); + } + + #[test] + fn output_size_matches_product() { + let x = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8, 9]]; + + let out = cartesian_product(x); + + assert_eq!(out.len(), 3 * 2 * 4); + } + + #[test] + fn every_output_has_correct_length() { + let x = vec![vec!['a', 'b'], vec!['x', 'y', 'z'], vec!['0']]; + + let out = cartesian_product(x); + + assert!(out.iter().all(|v| v.len() == 3)); + } + + #[test] + fn works_with_strings() { + let x = vec![ + vec!["a".to_string(), "b".to_string()], + vec!["x".to_string()], + ]; + + let out = cartesian_product(x); + + assert_eq!( + out, + vec![ + vec!["a".to_string(), "x".to_string()], + vec!["b".to_string(), "x".to_string()], + ] + ); + } +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] +struct Thread { + state: StateId, + positives: Vec, + negatives: Vec, +} + +impl Thread { + fn new_simple(state: StateId) -> Self { + Self { + state, + positives: Vec::new(), + negatives: Vec::new(), + } + } + + fn new( + enfa: &ENFA, + state: StateId, + mut positives: Vec, + mut negatives: Vec, + ) -> Option { + positives.sort(); + positives.dedup(); + positives.retain(|t| !t.accept(enfa)); + positives.shrink_to_fit(); + + negatives.sort(); + negatives.dedup(); + negatives.shrink_to_fit(); + + if negatives.iter().any(|t| t.accept(enfa)) { + return None; + } + + Some(Self { + state, + positives, + negatives, + }) + } + + fn accept(&self, enfa: &ENFA) -> bool { + let pos = self.positives.iter().all(|t| t.accept(enfa)); + let neg = self.negatives.iter().all(|t| !t.accept(enfa)); + let this = match enfa.states[self.state].accept { + Acceptance::Accept => true, + Acceptance::Assertion => true, + Acceptance::NotYet => false, + }; + pos && neg && this + } + + fn step_epsilon0( + self, + enfa: &ENFA, + ret: &mut impl FnMut(Thread), + visited: &mut [bool], + new_assertions: Vec, + ) { + if visited[self.state] { + return; + } + visited[self.state] = true; + + for t in enfa.states[self.state].trans.iter() { + if t.consumes.is_some() { + continue; + } + + let mut new_assertions = new_assertions.clone(); + if let Some(assertion) = enfa.states[self.state].assert.as_ref() { + new_assertions.push(assertion.clone()); + } + + Self { + state: t.to, + ..self.clone() + } + .step_epsilon0(enfa, ret, visited, new_assertions); + } + + let Self { + state, + positives, + negatives, + } = self; + let mut p = Vec::new(); + let mut n = Vec::new(); + for assertion in new_assertions { + let threads = Self::new_simple(assertion.to).step_epsilon(enfa); + let vec = match assertion.polarity { + LookPolarity::Positive => &mut p, + LookPolarity::Negative => &mut n, + }; + vec.push(threads); + } + let p = cartesian_product(p); + for mut p in p { + let mut positives = positives.clone(); + let mut negatives = negatives.clone(); + positives.append(&mut p); + negatives.append(&mut n.iter().flatten().cloned().collect()); + if let Some(thread) = Self::new(enfa, state, positives, negatives) { + ret(thread); + } + } + } + + fn step_epsilon(self, enfa: &ENFA) -> Vec { + let mut vec = Vec::new(); + self.step_epsilon0( + enfa, + &mut |t| vec.push(t), + &mut vec![false; enfa.states.len()], + Vec::new(), + ); + vec + } + + fn step0(self, enfa: &ENFA, input: ByteRange, ret: &mut impl FnMut(Thread)) { + let positives = self + .positives + .clone() + .into_iter() + .map(|t| t.step(enfa, input)) + .collect(); + let negatives: Vec<_> = self + .negatives + .into_iter() + .flat_map(|t| t.step(enfa, input)) + .collect(); + let positives = cartesian_product(positives); + let next_states: Vec = enfa.states[self.state] + .trans + .iter() + .filter_map(|t| { + if let Some(ch) = t.consumes + && ch.overlaps(input) + { + Some(t.to) + } else { + None + } + }) + .collect(); + + for s in next_states { + for p in positives.clone() { + if let Some(thread) = Self::new(enfa, s, p.clone(), negatives.clone()) { + thread.step_epsilon0( + enfa, + ret, + &mut vec![false; enfa.states.len()], + Vec::new(), + ); + } + } + } + } + + fn step(self, enfa: &ENFA, input: ByteRange) -> Vec { + let mut vec = Vec::new(); + self.step0(enfa, input, &mut |x| vec.push(x)); + vec + } + + fn possible_transitions(&self, enfa: &ENFA, f: &mut impl FnMut(ByteRange)) { + for t in enfa.states[self.state].trans.iter() { + if let Some(ch) = t.consumes { + f(ch); + } + } + for t in self.positives.iter() { + t.possible_transitions(enfa, f); + } + for t in self.negatives.iter() { + t.possible_transitions(enfa, f); + } + } +} + +#[derive(Clone)] +pub struct MultiState<'a> { + nfa: &'a ENFA, + threads: Vec, + accept: bool, + hash: u64, +} + +impl<'a> PartialEq for MultiState<'a> { + fn eq(&self, other: &Self) -> bool { + (self.nfa as *const ENFA as u64) == (other.nfa as *const ENFA as u64) + && self.threads == other.threads + && self.accept == other.accept + && self.hash == other.hash + } +} +impl<'a> Eq for MultiState<'a> {} + +impl<'a> MultiState<'a> { + fn new(nfa: &'a ENFA, mut threads: Vec) -> Self { + threads.sort(); + threads.dedup(); + threads.shrink_to_fit(); + + let accept = threads.iter().any(|t| t.accept(nfa)); + let mut hasher = DefaultHasher::new(); + threads.hash(&mut hasher); + let hash = hasher.finish(); + + Self { + nfa, + threads, + accept, + hash, + } + } + + /// all the chars that will make an interesting transition + pub fn possible_transitions(&self) -> Vec { + let mut vec = Vec::new(); + for t in self.threads.iter() { + t.possible_transitions(self.nfa, &mut |x| vec.push(x)); + } + vec = ByteRange::split_to_disjoint(vec); + vec.sort(); + vec.dedup(); + vec.shrink_to_fit(); + vec + } + + pub fn transition(&self, ch: ByteRange) -> Self { + let new_states = self + .threads + .iter() + .flat_map(|t| t.clone().step(self.nfa, ch)) + .collect(); + + Self::new(self.nfa, new_states) + } + + pub fn accept(&self) -> bool { + self.accept + } +} + +impl<'a> Hash for MultiState<'a> { + fn hash(&self, state: &mut H) { + self.hash.hash(state) + } +} + +macro_rules! set { + () => { + std::collections::HashSet::new() + }; + ( $( $x:expr ),* ) => {{ + let mut set = std::collections::HashSet::new(); + $( + set.insert($x); + )* + set + }}; +} + +impl ENFA { + fn shift(self, amt: usize) -> Vec { + let mut s = self.states; + + for state in s.iter_mut() { + state.remap(|i| i + amt); + if state.accept == Acceptance::Accept { + state.accept = Acceptance::NotYet; + } + } + + s + } + + pub fn remove_unreachable(&mut self) { + let mut used = vec![false; self.states.len()]; + used[0] = true; + for s in self.states.iter() { + for i in s.reachable_states() { + used[i] = true; + } + } + let mut remap = vec![0; self.states.len()]; + let mut shift = 0; + for i in 0..self.states.len() { + if used[i] { + remap[i] = i - shift; + } else { + shift += 1; + } + } + for i in (0..self.states.len()).rev() { + if !used[i] { + self.states.remove(i); + } + } + for s in self.states.iter_mut() { + s.remap(|i| remap[i]); + } + } +} + +impl ENFA { + fn looping(self) -> Self { + let has_submatches = self.has_submatches; + let mut states = vec![EState::start()]; + states.append(&mut self.shift(1)); + let len = states.len(); + states[0].set_epsilon_transitions([Transition::epsilon(1), Transition::epsilon(len)]); + states[len - 1].set_epsilon_transitions([Transition::epsilon(0), Transition::epsilon(len)]); + states.push(EState::terminal()); + Self { states, has_submatches } + } + + fn repeat(self, times: usize) -> Self { + let reps = vec![self; times]; + Self::concat(reps) + } + + /// between 0 and x repetitions + fn optx(self, x: usize) -> Self { + let len = self.states.len(); + let mut repped = self.repeat(x); + assert_eq!(repped.states.len(), x * len); + for i in 1..=x { + repped.states[0] + .trans + .insert(Transition::epsilon(i * len - 1)); + } + repped + } + + fn concat(nfas: Vec) -> Self { + if nfas.is_empty() { + return Self { + states: vec![EState::terminal()], + has_submatches: false, + }; + } + + let mut has_submatches = false; + let mut states: Vec = Vec::new(); + for nfa in nfas.into_iter() { + has_submatches = has_submatches || nfa.has_submatches; + let len = states.len(); + let mut ns = nfa.shift(len); + if let Some(n) = states.last_mut() { + n.trans.retain(|t| t.consumes.is_some()); + n.trans.insert(Transition::epsilon(len)); + } + states.append(&mut ns); + } + + let len = states.len(); + states[len - 1].accept = Acceptance::Accept; + + Self { states, has_submatches } + } +} + +impl ENFA { + pub fn start_multi_state<'a>(&'a self) -> MultiState<'a> { + let threads = Thread::new_simple(0).step_epsilon(self); + MultiState::new(self, threads) + } + + pub fn void_multi_state<'a>(&'a self) -> MultiState<'a> { + MultiState::new(self, vec![]) + } + + pub fn all_multi_states<'a>(&'a self) -> HashSet> { + let mut states = set![self.start_multi_state()]; + let mut q = vec![self.start_multi_state()]; + + while let Some(state) = q.pop() { + let chars = state.possible_transitions(); + + for chr in chars { + let new = state.transition(chr); + + if !states.contains(&new) { + states.insert(new.clone()); + q.push(new); + } + } + } + + states + } +} + +impl std::fmt::Debug for ENFA { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "NFA {{")?; + for (i, s) in self.states.iter().enumerate() { + write!(f, " {i}: ")?; + if let Some(a) = s.assert.as_ref() { + match a.polarity { + LookPolarity::Positive => write!(f, "+{} ", a.to)?, + LookPolarity::Negative => write!(f, "-{} ", a.to)?, + } + } + for t in s.trans.iter() { + let k = t.to; + if let Some(c) = t.consumes { + write!(f, "{c:?}=>{k} ")?; + } else { + write!(f, "~>{k} ")?; + } + } + match s.accept { + Acceptance::Accept => write!(f, "accept")?, + Acceptance::Assertion => write!(f, "assert")?, + Acceptance::NotYet => {} + } + writeln!(f)?; + } + write!(f, "}}") + } +} + +pub type StateId = usize; + +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub struct Assertion { + to: StateId, + polarity: LookPolarity, +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub struct Transition { + to: StateId, + consumes: Option, +} + +impl Transition { + fn new(consumes: ByteRange, to: StateId) -> Self { + Self { + to, + consumes: Some(consumes), + } + } + + fn epsilon(to: StateId) -> Self { + Self { to, consumes: None } + } + + fn remap(&mut self, mut f: impl FnMut(StateId) -> StateId) { + self.to = f(self.to); + } + + fn reachable_states(&self) -> impl Iterator { + [self.to].into_iter() + } +} + +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum Acceptance { + Accept, + Assertion, + NotYet, +} + +#[derive(Debug, Clone)] +pub struct EState { + pub trans: HashSet, + pub assert: Option, + pub accept: Acceptance, +} + +impl EState { + fn set_epsilon_transitions(&mut self, trans: impl IntoIterator) { + self.trans.retain(|t| t.consumes.is_some()); + for transition in trans.into_iter() { + assert!(transition.consumes.is_none()); + self.trans.insert(transition); + } + } + + fn start() -> Self { + Self { + trans: HashSet::new(), + assert: None, + accept: Acceptance::NotYet, + } + } + fn terminal() -> Self { + Self { + trans: HashSet::new(), + assert: None, + accept: Acceptance::Accept, + } + } + + fn remap(&mut self, mut f: impl FnMut(StateId) -> StateId) { + self.trans = self + .trans + .iter() + .cloned() + .map(|mut t| { + t.remap(&mut f); + t + }) + .collect(); + if let Some(a) = self.assert.as_mut() { + a.to = f(a.to); + } + } + + fn reachable_states(&self) -> impl Iterator { + self.trans + .iter() + .flat_map(|t| t.reachable_states()) + .chain(self.assert.iter().map(|a| a.to)) + } +} + +#[derive(Clone, Debug)] +pub enum EnfaTranslationError { + CharacterClassNotSupported, + AssertionsNotSupported, +} + +impl TryFrom for ENFA { + type Error = EnfaTranslationError; + + fn try_from(value: Pattern) -> Result { + Ok(match value { + Pattern::Byte(c) => Self::try_from(Pattern::Range(c, c))?, + Pattern::Range(c1, c2) => Self { + states: vec![ + EState { + trans: set![Transition::new(ByteRange::new_range(c1, c2), 1)], + assert: None, + accept: Acceptance::NotYet, + }, + EState::terminal(), + ], + has_submatches: false, + }, + Pattern::CharacterClass(_) => { + return Err(EnfaTranslationError::CharacterClassNotSupported); + } + Pattern::Alt(alts) => { + let nfas: Vec = alts + .into_iter() + .map(Self::try_from) + .collect::>()?; + let mut states = vec![EState::start()]; + let mut ends = vec![]; + let mut has_submatches = false; + for nfa in nfas.into_iter() { + has_submatches = has_submatches || nfa.has_submatches; + let len = states.len(); + states[0].trans.insert(Transition::epsilon(len)); + states.append(&mut (nfa.shift(len))); + ends.push(states.len() - 1); + } + states.push(EState::terminal()); + for end in ends.into_iter() { + let last = states.len() - 1; + states[end].trans.insert(Transition::epsilon(last)); + } + Self { states, has_submatches } + } + Pattern::Concat(seq) => { + let nfas: Vec = seq + .into_iter() + .map(Self::try_from) + .collect::>()?; + Self::concat(nfas) + } + Pattern::Rep(regex, min, None, _) => { + let nfa = ENFA::try_from(*regex)?; + let base = nfa.clone().repeat(min as usize); + let tail = nfa.looping(); + Self::concat(vec![base, tail]) + } + Pattern::Rep(regex, min, Some(max), _) => { + assert!(min < max); + let nfa = Self::try_from(*regex)?; + let base = nfa.clone().repeat(min as usize); + let tail = nfa.optx((max - min) as usize); + Self::concat(vec![base, tail]) + } + Pattern::Nothing => Self { + states: vec![EState::terminal()], + has_submatches: false, + }, + Pattern::Assertion(dir, polarity, pat) => { + if dir == LookDirection::Behind { + return Err(EnfaTranslationError::AssertionsNotSupported); + } + let mut regex = Self::try_from(*pat)?; + for s in regex.states.iter_mut() { + if s.accept == Acceptance::Accept { + s.accept = Acceptance::Assertion; + } + } + let mut regex = regex.shift(1); + let mut states = Vec::with_capacity(regex.len() + 2); + states.push(EState { + trans: set![Transition::epsilon(regex.len() + 1)], + assert: Some(Assertion { to: 1, polarity }), + accept: Acceptance::NotYet, + }); + states.append(&mut regex); + states.push(EState::terminal()); + Self { states, has_submatches: false, } + } + Pattern::Submatch(pat) => { + let mut this = Self::try_from(*pat)?; + this.has_submatches = true; + this + } + }) + } +} diff --git a/src/regex/mod.rs b/src/regex/mod.rs new file mode 100644 index 0000000..be3026f --- /dev/null +++ b/src/regex/mod.rs @@ -0,0 +1,344 @@ +use crate::parse::Parse; + +pub mod bc; +mod byte_range; +pub mod dfa; +pub mod enfa; +pub mod simple; + +#[derive(PartialEq, Eq, Debug, Clone, Copy, Hash)] +pub enum LookDirection { + Ahead, + Behind, +} + +impl LookDirection { + pub fn reverse(self) -> Self { + match self { + Self::Ahead => Self::Behind, + Self::Behind => Self::Ahead, + } + } +} + +#[derive(PartialEq, Eq, Debug, Clone, Copy, Hash)] +pub enum LookPolarity { + Positive, + Negative, +} + +#[derive(PartialEq, Eq, Debug, Clone, Copy, Hash)] +pub enum CharacterClass { + Everything, + Nothing, + Whitespace, + Alphabetic, + Alphanumeric, +} + +impl CharacterClass { + pub fn matches(self, byte: u8) -> bool { + match self { + CharacterClass::Everything => true, + CharacterClass::Nothing => false, + CharacterClass::Whitespace => byte.is_ascii_whitespace(), + CharacterClass::Alphabetic => byte.is_ascii_alphabetic(), + CharacterClass::Alphanumeric => byte.is_ascii_alphanumeric(), + } + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub enum Pattern { + Byte(u8), + Range(u8, u8), + CharacterClass(CharacterClass), + Alt(Vec), + Concat(Vec), + Rep(Box, u32, Option, GreedyBehavior), + Assertion(LookDirection, LookPolarity, Box), + Submatch(Box), + Nothing, +} + +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub enum ByteConsumption { + Bounded(usize), + Unbounded, +} + +impl ByteConsumption { + pub fn zero() -> Self { + Self::Bounded(0) + } + pub fn one() -> Self { + Self::Bounded(1) + } +} + +impl Ord for ByteConsumption { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + match (self, other) { + (Self::Bounded(a), Self::Bounded(b)) => a.cmp(b), + (Self::Bounded(_), Self::Unbounded) => std::cmp::Ordering::Less, + (Self::Unbounded, Self::Bounded(_)) => std::cmp::Ordering::Greater, + (Self::Unbounded, Self::Unbounded) => std::cmp::Ordering::Equal, + } + } +} +impl PartialOrd for ByteConsumption { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl std::ops::Add for ByteConsumption { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + if let Self::Bounded(a) = self + && let Self::Bounded(b) = rhs + { + Self::Bounded(a + b) + } else { + Self::Unbounded + } + } +} + +impl std::ops::Mul for ByteConsumption { + type Output = Self; + + fn mul(self, rhs: usize) -> Self::Output { + match self { + Self::Bounded(x) => Self::Bounded(x * rhs), + Self::Unbounded => Self::Unbounded, + } + } +} + +impl std::iter::Sum for ByteConsumption { + fn sum>(iter: I) -> Self { + iter.fold(Self::zero(), |a, b| a + b) + } +} + +impl Pattern { + pub fn max_byte_consumption(&self) -> ByteConsumption { + match self { + Pattern::Byte(_) => ByteConsumption::one(), + Pattern::Range(_, _) => ByteConsumption::one(), + Pattern::CharacterClass(_) => ByteConsumption::one(), + Pattern::Alt(patterns) => patterns + .iter() + .map(Self::max_byte_consumption) + .max() + .unwrap_or(ByteConsumption::zero()), + Pattern::Concat(patterns) => patterns.iter().map(Self::max_byte_consumption).sum(), + Pattern::Rep(pattern, _, Some(max_reps), _) => { + pattern.max_byte_consumption() * (*max_reps as usize) + } + Pattern::Rep(_, _, None, _) => ByteConsumption::Unbounded, + Pattern::Assertion(_, _, _) => ByteConsumption::zero(), + Pattern::Nothing => ByteConsumption::zero(), + Pattern::Submatch(pat) => pat.max_byte_consumption(), + } + } + + pub fn reverse(self) -> Self { + use Pattern::*; + match self { + Byte(_) | Nothing | Range(..) | CharacterClass(_) => self, + Alt(patterns) => Alt(patterns.into_iter().map(Self::reverse).collect()), + Concat(patterns) => Concat(patterns.into_iter().map(Self::reverse).rev().collect()), + Rep(pattern, min, max, greedy) => Rep(Box::new(pattern.reverse()), min, max, greedy), + Assertion(dir, pol, pat) => Assertion(dir.reverse(), pol, Box::new(pat.reverse())), + Submatch(pat) => Submatch(Box::new(pat.reverse())), + } + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum GreedyBehavior { + Greedy, + NonGreedy, +} + +pub enum CompiledPattern { + Dfa(dfa::DFA), + Bytecode(bc::BytecodeCompiledRegex), + Anything(simple::Anything), + Nothing(simple::Nothing), + Exact(simple::Exact), +} + +impl CompiledPattern { + fn is_dfa(&self) -> bool { + matches!(self, Self::Dfa(_)) + } + fn is_bytecode(&self) -> bool { + matches!(self, Self::Bytecode(_)) + } + fn is_simple(&self) -> bool { + matches!(self, Self::Anything(_) | Self::Nothing(_) | Self::Exact(_)) + } +} + +#[derive(PartialEq, Eq, Debug)] +pub struct Match { + pub submatches: Box<[Option>]>, +} + +impl Match { + pub fn new_empty() -> Self { + Self { + submatches: [].into(), + } + } +} + +pub trait RegexEngine: Sized { + type CompileError: std::fmt::Debug + Clone; + + fn compile(pat: Pattern) -> Result; + + fn run(&self, input: &[u8]) -> Option; + + fn matches(&self, input: &[u8]) -> bool { + self.run(input).is_some() + } +} + +impl RegexEngine for CompiledPattern { + type CompileError = bc::RegexCompilationError; + + fn compile(pat: Pattern) -> Result { + if let Ok(c) = simple::Anything::compile(pat.clone()) { + Ok(Self::Anything(c)) + } else if let Ok(c) = simple::Nothing::compile(pat.clone()) { + Ok(Self::Nothing(c)) + } else if let Ok(c) = simple::Exact::compile(pat.clone()) { + Ok(Self::Exact(c)) + } else if let Ok(c) = dfa::DFA::compile(pat.clone()) { + Ok(Self::Dfa(c)) + } else { + bc::BytecodeCompiledRegex::compile(pat).map(Self::Bytecode) + } + } + + fn run(&self, input: &[u8]) -> Option { + match self { + CompiledPattern::Dfa(x) => x.run(input), + CompiledPattern::Bytecode(x) => x.run(input), + CompiledPattern::Anything(x) => x.run(input), + CompiledPattern::Nothing(x) => x.run(input), + CompiledPattern::Exact(x) => x.run(input), + } + } +} + +macro_rules! all_engines { + ($ty_name:ident, $($x:ident : $ty:ty,)*) => { + pub struct $ty_name { + $($x: Option<$ty>,)* + } + impl RegexEngine for $ty_name { + type CompileError = (); + + fn compile(pat: Pattern) -> Result { + let x = Self { + $($x: RegexEngine::compile(pat.clone()).ok(),)* + }; + if $(x.$x.is_none())&&* { + Err(()) + } else { + Ok(x) + } + } + + fn run(&self, input: &[u8]) -> Option { + $(let $x = self.$x.as_ref().map(|x| x.run(input));)* + let mut result = None; + $( + if let Some(res) = $x { + if let Some(result) = result { + assert_eq!(res, result, concat!("engine ", stringify!($x), " does not agree with previously run engines.")); + } + result = Some(res) + } + )* + result.unwrap() + } + } + } +} + +all_engines!( + AllEngines, + dfa: dfa::DFA, + bc: bc::BytecodeCompiledRegex, + any: simple::Anything, + nothing: simple::Nothing, + exact: simple::Exact, +); + +impl Pattern { + pub fn try_compile(self) -> Result { + CompiledPattern::compile(self) + } +} + +#[cfg(test)] +macro_rules! regex_matches { + ($regex:literal, $match:literal, $true:literal) => { + assert_eq!( + ::parse_from_bytes($regex.as_bytes()) + .unwrap() + .try_compile() + .unwrap() + .matches($match.as_bytes()), + $true + ) + }; +} + +#[test] +fn foo_matches_foo() { + regex_matches!("foo", "foo", true); +} + +#[test] +fn dot_star_is_simple() { + let x = Pattern::parse_from_bytes(b".*") + .unwrap() + .try_compile() + .unwrap(); + assert!(x.is_simple()); +} + +#[test] +fn match_is_bytecode() { + let x = Pattern::parse_from_bytes(b".*(ele.*phant).*") + .unwrap() + .try_compile() + .unwrap(); + assert!(x.is_bytecode()); +} + +#[test] +fn simple_word_is_exact() { + let x = Pattern::parse_from_bytes(b"Gnu[ ]plus[ ]Linux") + .unwrap() + .try_compile() + .unwrap(); + assert!(x.is_simple()); +} + +#[test] +fn no_match_is_dfa() { + let x = Pattern::parse_from_bytes(b".*Gnu.*plus.*Linux.*") + .unwrap() + .try_compile() + .unwrap(); + assert!(x.is_dfa()); +} diff --git a/src/regex/simple.rs b/src/regex/simple.rs new file mode 100644 index 0000000..00bc9b4 --- /dev/null +++ b/src/regex/simple.rs @@ -0,0 +1,125 @@ +use crate::regex::CharacterClass; + +use super::{Match, Pattern, RegexEngine}; + +fn empty_match() -> Option { + Some(Match { + submatches: [].into(), + }) +} + +pub struct Anything; + +#[derive(Debug, Clone)] +pub struct NotASimpleWildcard; + +impl RegexEngine for Anything { + type CompileError = NotASimpleWildcard; + + fn compile(pat: Pattern) -> Result { + if let Pattern::Rep(pat, 0, None, _) = pat + && let Pattern::CharacterClass(CharacterClass::Everything) = *pat + { + Ok(Anything) + } else { + Err(NotASimpleWildcard) + } + } + + fn run(&self, _input: &[u8]) -> Option { + empty_match() + } +} + +pub struct Nothing; +#[derive(Debug, Clone)] +pub struct NotASimpleNothing; + +impl RegexEngine for Nothing { + type CompileError = NotASimpleNothing; + + fn compile(pat: Pattern) -> Result { + match pat { + Pattern::Range(a, b) if a > b => Ok(Nothing), + Pattern::CharacterClass(CharacterClass::Nothing) => Ok(Nothing), + Pattern::Alt(pats) => { + let all_impossible = pats.into_iter().map(Self::compile).all(|p| p.is_ok()); + if all_impossible { + Ok(Nothing) + } else { + Err(NotASimpleNothing) + } + } + Pattern::Concat(pats) => { + if let Some(pat) = pats.into_iter().next() { + Self::compile(pat) + } else { + Err(NotASimpleNothing) + } + } + Pattern::Rep(_, x, Some(y), _) if y < x => Ok(Nothing), + Pattern::Rep(_, 0, None, _) => Err(NotASimpleNothing), + Pattern::Rep(pat, _gt_0, _, _) => Self::compile(*pat), + Pattern::Submatch(pat) => Self::compile(*pat), + _ => Err(NotASimpleNothing), + } + } + + fn run(&self, _input: &[u8]) -> Option { + None + } +} + +pub struct Exact { + bytes: Vec, +} + +const MEM_LIMIT: usize = 25_000; + +#[derive(Debug, Clone)] +pub struct NotSimplyAString; + +fn ce(pat: Pattern) -> Option> { + match pat { + Pattern::Byte(b) => Some(vec![b]), + Pattern::Concat(patterns) => { + let mut pats = patterns.into_iter().map(ce).collect::>>()?; + let mut out = Vec::new(); + for p in pats.iter_mut() { + out.append(p); + } + Some(out) + } + Pattern::Rep(pat, min, Some(max), _) if min == max => { + if let Some(bytes) = ce(*pat) + && bytes.len() * (min as usize) < MEM_LIMIT + { + Some(bytes.repeat(min as usize)) + } else { + None + } + } + Pattern::Submatch(_) => None, // TODO: submatches could be stored as constant offsets + Pattern::Nothing => Some(Vec::new()), + _ => None, + } +} + +impl RegexEngine for Exact { + type CompileError = NotSimplyAString; + + fn compile(pat: Pattern) -> Result { + match ce(pat) { + Some(bytes) => Ok(Self { bytes }), + None => Err(NotSimplyAString), + } + } + + fn run(&self, input: &[u8]) -> Option { + if input == self.bytes { + empty_match() + } else { + None + } + } +} -- cgit v1.2.3