use std::collections::{HashMap, VecDeque}; use super::{ CharacterClass, GreedyBehavior, LookDirection, LookPolarity, Match, Pattern, RegexEngine, byte_range::ByteRange, }; use crate::bitset::BitSet; trait Flavor: Clone { type CustomInstr: Copy + Clone + std::fmt::Debug; type ThreadData: Clone; type StepData<'a, 'b> where 'b: 'a; fn accepts<'a, 'b>( thread: &mut Thread, instr: Self::CustomInstr, sd: &mut Self::StepData<'a, 'b>, ) -> bool; fn save(x: u32) -> Option; } #[derive(Copy, Clone, Debug)] struct MainFlavor; impl Flavor for MainFlavor { type CustomInstr = MainInstr; type ThreadData = Box<[Option]>; type StepData<'a, 'b> = (usize, &'a BitSet, &'a mut LookaheadVM<'b>) where 'b: 'a; fn accepts<'a, 'b>( thread: &mut Thread, instr: Self::CustomInstr, data: &mut Self::StepData<'a, 'b>, ) -> bool { match instr { MainInstr::Save(reg) => { thread.data[reg as usize] = Some(data.0); true } MainInstr::Join(assertion) => { let should_match = assertion.pol == LookPolarity::Positive; let state = assertion.target as usize; let is_matching = match assertion.dir { LookDirection::Ahead => data.2.get_state(data.0, state), LookDirection::Behind => data.1.get(state), }; is_matching == should_match } } } fn save(x: u32) -> Option { Some(MainInstr::Save(x)) } } #[derive(Copy, Clone, Debug)] enum Nothing {} #[derive(Copy, Clone, Debug)] struct AssertionFlavor; impl Flavor for AssertionFlavor { type CustomInstr = Nothing; type ThreadData = (); type StepData<'a, 'b> = () where 'b: 'a; fn accepts(_thread: &mut Thread, instr: Self::CustomInstr, _sd: &mut ()) -> bool { match instr {} } fn save(_: u32) -> Option { None } } type JumpTarget = u32; type Register = u32; #[derive(Copy, Clone, Debug)] struct Assertion { target: JumpTarget, dir: LookDirection, pol: LookPolarity, } #[derive(Copy, Clone, Debug)] enum Instr { Class(CharacterClass), Consume(ByteRange), Jump(JumpTarget), Fork(JumpTarget, JumpTarget), Custom(F::CustomInstr), } #[derive(Copy, Clone, Debug)] enum MainInstr { Save(Register), Join(Assertion), } #[derive(Clone)] struct Thread { pc: JumpTarget, data: F::ThreadData, } struct VM<'p, F: Flavor> { instr: &'p [Instr], passive_threads: VecDeque>, active_threads: VecDeque>, hot: BitSet, warm: BitSet, } impl<'p, F: Flavor> VM<'p, F> { fn new(instr: &'p [Instr], starting_thread: Thread) -> Self { Self { instr, passive_threads: vec![starting_thread].into(), active_threads: VecDeque::new(), hot: BitSet::new(instr.len()), warm: BitSet::new(instr.len()), } } fn step_epsilon<'a>(&mut self, sd: &mut F::StepData<'a, 'p>) { std::mem::swap(&mut self.active_threads, &mut self.passive_threads); self.hot.set_all(false); self.warm.set_all(false); macro_rules! add_thread { ($t:expr) => {{ let t = $t; let bit = t.pc as usize; if !self.warm.get(bit) { self.warm.set(bit, true); self.active_threads.push_front(t); } }}; } while let Some(mut thread) = self.active_threads.pop_front() { match self.instr[thread.pc as usize] { Instr::Class(_) | Instr::Consume(_) => { if !self.hot.get(thread.pc as usize) { self.hot.set(thread.pc as usize, true); self.passive_threads.push_back(thread); } } Instr::Jump(j) => { thread.pc = j; add_thread!(thread); } Instr::Fork(a, b) => { add_thread!(Thread { pc: b, data: thread.data.clone(), }); add_thread!(Thread { pc: a, data: thread.data.clone(), }); } Instr::Custom(instr) => { if F::accepts(&mut thread, instr, sd) { thread.pc += 1; add_thread!(thread); } } } } } fn step_consume(&mut self, byte: u8) { self.hot.set_all(false); self.passive_threads .retain_mut(|thread| match self.instr[thread.pc as usize] { Instr::Class(class) => { if class.matches(byte) { thread.pc += 1; self.hot.set(thread.pc as usize, true); true } else { false } } Instr::Consume(bytes) => { if bytes.contains(byte) { thread.pc += 1; self.hot.set(thread.pc as usize, true); true } else { false } } _ => false, }); } } struct LookaheadVM<'a> { vm: VM<'a, AssertionFlavor>, data: &'a [u8], cached: bool, cache_data: Vec, loc_offset: usize, } impl<'a> LookaheadVM<'a> { fn new(vm: VM<'a, AssertionFlavor>, data: &'a [u8]) -> Self { Self { vm, data, cached: false, cache_data: Vec::new(), loc_offset: 0, } } fn get_state(&mut self, loc: usize, state: usize) -> bool { if !self.cached { assert!(self.cache_data.is_empty()); assert_eq!(self.loc_offset, 0); self.loc_offset = loc; self.vm.step_epsilon(&mut ()); self.cache_data.push(self.vm.hot.clone()); for i in (loc..self.data.len()).rev() { self.vm.step_consume(self.data[i]); self.vm.step_epsilon(&mut ()); self.cache_data.push(self.vm.hot.clone()); } self.cache_data.reverse(); self.cached = true; } assert!( loc >= self.loc_offset, "get_state must be called with non-decreasing arguments." ); self.cache_data[loc - self.loc_offset].get(state) } } struct VirtualMachine<'a> { vm0: VM<'a, AssertionFlavor>, vm1: VM<'a, MainFlavor>, vm2: LookaheadVM<'a>, accepting: &'a BitSet, } impl<'a> VirtualMachine<'a> { fn step_epsilon_1(&mut self, loc: usize) { self.vm1 .step_epsilon(&mut (loc, &self.vm0.hot, &mut self.vm2)); } fn step_epsilon(&mut self, loc: usize) { self.vm0.step_epsilon(&mut ()); self.step_epsilon_1(loc); } fn step_consume(&mut self, byte: u8) { self.vm0.step_consume(byte); self.vm1.step_consume(byte); } fn step(&mut self, byte: u8, loc: usize) { self.step_epsilon(loc); self.step_consume(byte); } fn extract_match(&self) -> Option { self.vm1 .passive_threads .iter() .filter(|t| self.accepting.get(t.pc as usize)) .map(|t| { let submatches: Vec<_> = t.data.windows(2).map(|x| Some(x[0]?..x[1]?)).collect(); Match { submatches: submatches.into(), } }) .next() } } fn fmt_instructions( f: &mut std::fmt::Formatter<'_>, label: &str, insns: &[Instr], ) -> std::fmt::Result { writeln!(f, "# {label}")?; for (idx, ins) in insns.iter().enumerate() { writeln!(f, "{idx}: {ins:?}")?; } Ok(()) } pub struct BytecodeCompiledRegex { instrs0: Box<[Instr]>, instrs1: Box<[Instr]>, instrs2: Box<[Instr]>, no_lookbehind: bool, submatch_count: usize, accepting: BitSet, } impl std::fmt::Debug for BytecodeCompiledRegex { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fmt_instructions(f, "behind", &self.instrs0)?; fmt_instructions(f, "ahead", &self.instrs2)?; fmt_instructions(f, "main", &self.instrs1)?; writeln!(f, "accepting: {:?}", self.accepting) } } impl BytecodeCompiledRegex { pub fn re_match(&self, data: &[u8]) -> Option { let vm0 = VM::new(&self.instrs0, Thread { pc: 0, data: () }); let vm1 = VM::new( &self.instrs1, Thread { pc: 0, data: vec![None; 2 * self.submatch_count].into(), }, ); let vm2 = VM::new(&self.instrs2, Thread { pc: 0, data: () }); let vm2 = LookaheadVM::new(vm2, data); let mut vm = VirtualMachine { vm0, vm1, vm2, accepting: &self.accepting, }; if self.no_lookbehind { for (i, ch) in data.iter().cloned().enumerate() { vm.step_epsilon_1(i); vm.vm1.step_consume(ch); } vm.step_epsilon(data.len()); } else { for (i, ch) in data.iter().cloned().enumerate() { vm.step(ch, i); } vm.step_epsilon(data.len()); } vm.extract_match() } pub fn matches(&self, data: &[u8]) -> bool { self.re_match(data).is_some() } } type AssertionHandler<'a, F> = Box CompileResult>>; #[derive(Copy, Clone)] struct CompiledSnippet { begin: JumpTarget, end: JumpTarget, } struct Compiler<'a, F: Flavor> { instrs: Vec>, map: HashMap, assertion_handler: AssertionHandler<'a, F>, assertion_fork_base: usize, submatch_count: usize, } fn fork(repeat: usize, exit: usize, greedy: GreedyBehavior) -> Instr { let repeat = repeat as JumpTarget; let exit = exit as JumpTarget; match greedy { GreedyBehavior::Greedy => Instr::Fork(repeat, exit), GreedyBehavior::NonGreedy => Instr::Fork(exit, repeat), } } impl<'a, F: Flavor> Compiler<'a, F> { fn new( assertion_handler: impl 'a + FnMut(LookDirection, LookPolarity, Pattern) -> CompileResult>, ) -> Self { Self { instrs: Vec::new(), map: HashMap::new(), assertion_handler: Box::new(assertion_handler), assertion_fork_base: usize::MAX, submatch_count: 0, } } fn rep_1_or_more(&mut self, pat: Pattern, greedy: GreedyBehavior) -> CompileResult { let base = self.instrs.len(); self.compile(pat)?; let exit = self.instrs.len() + 1; self.instrs.push(fork(base, exit, greedy)); Ok(()) } fn rep_0_or_1(&mut self, pat: Pattern, greedy: GreedyBehavior) -> CompileResult { let base = self.instrs.len(); self.instrs.push(Instr::Jump(u32::MAX)); self.compile(pat)?; self.instrs[base] = fork(base + 1, self.instrs.len(), greedy); Ok(()) } fn rep_any_amt(&mut self, pat: Pattern, greedy: GreedyBehavior) -> CompileResult { let base = self.instrs.len(); self.instrs.push(Instr::Jump(u32::MAX)); self.compile(pat)?; let fork_pos = self.instrs.len(); let after = fork_pos + 1; self.instrs.push(fork(base + 1, after, greedy)); self.instrs[base] = Instr::Jump(fork_pos as JumpTarget); Ok(()) } fn compile(&mut self, pat: Pattern) -> CompileResult { match pat { Pattern::Byte(x) => self.instrs.push(Instr::Consume(ByteRange::new_single(x))), Pattern::Range(a, b) => self.instrs.push(Instr::Consume(ByteRange::new_range(a, b))), Pattern::CharacterClass(cc) => { self.instrs.push(Instr::Class(cc)); } Pattern::Alt(patterns) => { let branch_factor = patterns.len(); assert!(branch_factor > 0); let base = self.instrs.len(); // placeholders to later place in forks for _ in 0..patterns.len() - 1 { self.instrs.push(Instr::Jump(u32::MAX)); } let mut enter_pats = Vec::new(); let mut leave_pats = Vec::new(); for pat in patterns.into_iter() { enter_pats.push(self.instrs.len()); self.compile(pat)?; leave_pats.push(self.instrs.len()); // placeholder to place in join self.instrs.push(Instr::Jump(u32::MAX)); } self.instrs.pop(); // remove last jump let join_point = self.instrs.len(); // link forks for i in 0..branch_factor - 1 { let a = enter_pats[i]; let b = if i == branch_factor - 2 { enter_pats[i + 1] } else { base + i + 1 }; self.instrs[base + i] = Instr::Fork(a as JumpTarget, b as JumpTarget); } // link joins for i in 0..branch_factor - 1 { self.instrs[leave_pats[i]] = Instr::Jump(join_point as JumpTarget); } } Pattern::Concat(patterns) => { for pat in patterns.into_iter() { self.compile(pat)?; } } Pattern::Rep(pat, 0, None, greed) => { self.rep_any_amt(*pat, greed)?; } Pattern::Rep(pat, min, None, greed) => { let pat = *pat; for _ in 1..min { self.compile(pat.clone())?; } self.rep_1_or_more(pat, greed)?; } Pattern::Rep(pat, min, Some(max), greed) => { let pat = *pat; let opt = max - min; for _ in 0..min { self.compile(pat.clone())?; } for _ in 0..opt { self.rep_0_or_1(pat.clone(), greed)?; } } Pattern::Assertion(look_direction, look_polarity, pattern) => { let ins = (self.assertion_handler)(look_direction, look_polarity, *pattern)?; self.instrs.push(ins); } Pattern::Nothing => {} Pattern::Submatch(pat) => { let i = self.submatch_count as u32 * 2; self.submatch_count += 1; if let Some(ins) = F::save(i) { self.instrs.push(Instr::Custom(ins)); } self.compile(*pat)?; if let Some(ins) = F::save(i + 1) { self.instrs.push(Instr::Custom(ins)); } } } Ok(()) } fn compile_and_memoize(&mut self, pat: Pattern) -> CompileResult { if let Some(&jt) = self.map.get(&pat) { return Ok(jt); } let begin = self.instrs.len() as JumpTarget; self.compile(pat.clone())?; let end = self.instrs.len() as JumpTarget; self.instrs.push(Instr::Class(CharacterClass::Nothing)); let bounds = CompiledSnippet { begin, end }; self.map.insert(pat, bounds); Ok(bounds) } fn finalize_assertion_forks(&mut self) { let fork_targets: Vec = self.map.values().map(|v| v.begin).collect(); let fork_begin = self.instrs.len() as JumpTarget; match fork_targets.len() { 0 => { self.instrs[self.assertion_fork_base] = Instr::Class(CharacterClass::Nothing); } 1 => { self.instrs[self.assertion_fork_base] = Instr::Jump(fork_targets[0]); } 2 => { self.instrs[self.assertion_fork_base] = Instr::Fork(fork_targets[0], fork_targets[1]); } n => { self.instrs[self.assertion_fork_base] = Instr::Fork(fork_targets[0], fork_begin); for i in 1..n - 1 { let fork = if i == n - 2 { Instr::Fork(fork_targets[i], fork_targets[i + 1]) } else { Instr::Fork(fork_targets[i], self.instrs.len() as JumpTarget + 1) }; self.instrs.push(fork); } } } } } fn assertion_compiler() -> Compiler<'static, AssertionFlavor> { let mut c = Compiler::new(|_, _, _| Err(RegexCompilationError::NestedLookaroundNotSupported)); c.rep_any_amt( Pattern::CharacterClass(CharacterClass::Everything), GreedyBehavior::NonGreedy, ) .expect("characterclass should always compile"); c.assertion_fork_base = c.instrs.len(); c.instrs.push(Instr::Jump(u32::MAX)); // in the end this gets replaced by a jump to a fork-list for all the assertions c } #[derive(Clone, Debug)] pub enum RegexCompilationError { NestedLookaroundNotSupported, } pub type CompileResult = Result; impl TryFrom for BytecodeCompiledRegex { type Error = RegexCompilationError; fn try_from(value: Pattern) -> Result { let mut neg = assertion_compiler(); let mut pos = assertion_compiler(); let (final_state, instrs, submatch_count) = { let mut main: Compiler = Compiler::new(|dir, pol, pat| { let target = match dir { LookDirection::Ahead => pos.compile_and_memoize(pat.reverse()), LookDirection::Behind => neg.compile_and_memoize(pat), }? .end; Ok(Instr::Custom(MainInstr::Join(Assertion { target, dir, pol, }))) }); main.compile(value)?; let end = main.instrs.len(); main.instrs.push(Instr::Class(CharacterClass::Nothing)); (end, main.instrs, main.submatch_count) }; neg.finalize_assertion_forks(); pos.finalize_assertion_forks(); let mut accepting = BitSet::new(instrs.len()); accepting.set(final_state, true); Ok(Self { no_lookbehind: neg.map.is_empty(), instrs0: neg.instrs.into(), instrs1: instrs.into(), instrs2: pos.instrs.into(), accepting, submatch_count, }) } } impl RegexEngine for BytecodeCompiledRegex { type CompileError = RegexCompilationError; fn compile(pat: Pattern) -> Result { Self::try_from(pat) } fn run(&self, input: &[u8]) -> Option { self.re_match(input) } } #[cfg(test)] mod tests { use super::*; use crate::parse::Parse; fn regex(s: &str) -> BytecodeCompiledRegex { let pat = Pattern::parse_from_bytes(s.as_bytes()).unwrap(); let compiled = BytecodeCompiledRegex::try_from(pat).unwrap(); compiled } #[test] fn print_compiled_vm() { let compiled = regex("a?b?"); println!("{compiled:#?}"); assert_eq!(compiled.matches(b"ab"), true); assert_eq!(compiled.matches(b"a"), true); assert_eq!(compiled.matches(b"b"), true); assert_eq!(compiled.matches(b""), true); } #[test] fn nongreedy_star() { let re = regex("(ab*?)bb*"); assert_eq!( re.re_match(b"abbb").unwrap().submatches[0].clone().unwrap(), 0..1 ); assert_eq!( re.re_match(b"abbbbb").unwrap().submatches[0] .clone() .unwrap(), 0..1 ); } #[test] fn greedy_star() { let re = regex("(ab*)bb*"); assert_eq!( re.re_match(b"abbb").unwrap().submatches[0].clone().unwrap(), 0..3 ); assert_eq!( re.re_match(b"abbbbb").unwrap().submatches[0] .clone() .unwrap(), 0..5 ); } #[test] fn nongreedy_plus() { let re = regex("(ab+?)bb*"); assert_eq!( re.re_match(b"abbbb").unwrap().submatches[0] .clone() .unwrap(), 0..2 ); assert_eq!( re.re_match(b"abbbbb").unwrap().submatches[0] .clone() .unwrap(), 0..2 ); } #[test] fn greedy_plus() { let re = regex("(ab+)bb*"); assert_eq!( re.re_match(b"abbb").unwrap().submatches[0].clone().unwrap(), 0..3 ); assert_eq!( re.re_match(b"abbbbb").unwrap().submatches[0] .clone() .unwrap(), 0..5 ); } #[test] fn nongreedy_qm() { let re = regex("(ab??)bb*"); assert_eq!( re.re_match(b"abbbb").unwrap().submatches[0] .clone() .unwrap(), 0..1 ); assert_eq!( re.re_match(b"abbbbb").unwrap().submatches[0] .clone() .unwrap(), 0..1 ); } #[test] fn greedy_qm() { let re = regex("(ab?)bb*"); assert_eq!( re.re_match(b"abbb").unwrap().submatches[0].clone().unwrap(), 0..2 ); assert_eq!( re.re_match(b"abbbbb").unwrap().submatches[0] .clone() .unwrap(), 0..2 ); } }