diff options
Diffstat (limited to 'src/regex/bc.rs')
| -rw-r--r-- | src/regex/bc.rs | 752 |
1 files changed, 752 insertions, 0 deletions
diff --git a/src/regex/bc.rs b/src/regex/bc.rs new file mode 100644 index 0000000..4a79485 --- /dev/null +++ b/src/regex/bc.rs @@ -0,0 +1,752 @@ +use std::collections::{HashMap, VecDeque}; + +use super::{ + CharacterClass, GreedyBehavior, LookDirection, LookPolarity, Match, Pattern, RegexEngine, + byte_range::ByteRange, +}; +use crate::bitset::BitSet; + +trait Flavor: Clone { + type CustomInstr: Copy + Clone + std::fmt::Debug; + type ThreadData: Clone; + type StepData<'a, 'b> + where + 'b: 'a; + + fn accepts<'a, 'b>( + thread: &mut Thread<Self>, + instr: Self::CustomInstr, + sd: &mut Self::StepData<'a, 'b>, + ) -> bool; + + fn save(x: u32) -> Option<Self::CustomInstr>; +} + +#[derive(Copy, Clone, Debug)] +struct MainFlavor; +impl Flavor for MainFlavor { + type CustomInstr = MainInstr; + type ThreadData = Box<[Option<usize>]>; + type StepData<'a, 'b> + = (usize, &'a BitSet, &'a mut LookaheadVM<'b>) + where + 'b: 'a; + + fn accepts<'a, 'b>( + thread: &mut Thread<Self>, + instr: Self::CustomInstr, + data: &mut Self::StepData<'a, 'b>, + ) -> bool { + match instr { + MainInstr::Save(reg) => { + thread.data[reg as usize] = Some(data.0); + true + } + MainInstr::Join(assertion) => { + let should_match = assertion.pol == LookPolarity::Positive; + let state = assertion.target as usize; + let is_matching = match assertion.dir { + LookDirection::Ahead => data.2.get_state(data.0, state), + LookDirection::Behind => data.1.get(state), + }; + is_matching == should_match + } + } + } + + fn save(x: u32) -> Option<Self::CustomInstr> { + Some(MainInstr::Save(x)) + } +} + +#[derive(Copy, Clone, Debug)] +enum Nothing {} + +#[derive(Copy, Clone, Debug)] +struct AssertionFlavor; +impl Flavor for AssertionFlavor { + type CustomInstr = Nothing; + type ThreadData = (); + type StepData<'a, 'b> + = () + where + 'b: 'a; + + fn accepts(_thread: &mut Thread<Self>, instr: Self::CustomInstr, _sd: &mut ()) -> bool { + match instr {} + } + + fn save(_: u32) -> Option<Self::CustomInstr> { + None + } +} + +type JumpTarget = u32; +type Register = u32; + +#[derive(Copy, Clone, Debug)] +struct Assertion { + target: JumpTarget, + dir: LookDirection, + pol: LookPolarity, +} + +#[derive(Copy, Clone, Debug)] +enum Instr<F: Flavor> { + Class(CharacterClass), + Consume(ByteRange), + Jump(JumpTarget), + Fork(JumpTarget, JumpTarget), + Custom(F::CustomInstr), +} + +#[derive(Copy, Clone, Debug)] +enum MainInstr { + Save(Register), + Join(Assertion), +} + +#[derive(Clone)] +struct Thread<F: Flavor> { + pc: JumpTarget, + data: F::ThreadData, +} + +struct VM<'p, F: Flavor> { + instr: &'p [Instr<F>], + passive_threads: VecDeque<Thread<F>>, + active_threads: VecDeque<Thread<F>>, + hot: BitSet, + warm: BitSet, +} + +impl<'p, F: Flavor> VM<'p, F> { + fn new(instr: &'p [Instr<F>], starting_thread: Thread<F>) -> Self { + Self { + instr, + passive_threads: vec![starting_thread].into(), + active_threads: VecDeque::new(), + hot: BitSet::new(instr.len()), + warm: BitSet::new(instr.len()), + } + } + + fn step_epsilon<'a>(&mut self, sd: &mut F::StepData<'a, 'p>) { + std::mem::swap(&mut self.active_threads, &mut self.passive_threads); + self.hot.set_all(false); + self.warm.set_all(false); + + macro_rules! add_thread { + ($t:expr) => {{ + let t = $t; + let bit = t.pc as usize; + if !self.warm.get(bit) { + self.warm.set(bit, true); + self.active_threads.push_front(t); + } + }}; + } + + while let Some(mut thread) = self.active_threads.pop_front() { + match self.instr[thread.pc as usize] { + Instr::Class(_) | Instr::Consume(_) => { + if !self.hot.get(thread.pc as usize) { + self.hot.set(thread.pc as usize, true); + self.passive_threads.push_back(thread); + } + } + Instr::Jump(j) => { + thread.pc = j; + add_thread!(thread); + } + Instr::Fork(a, b) => { + add_thread!(Thread { + pc: b, + data: thread.data.clone(), + }); + add_thread!(Thread { + pc: a, + data: thread.data.clone(), + }); + } + Instr::Custom(instr) => { + if F::accepts(&mut thread, instr, sd) { + thread.pc += 1; + add_thread!(thread); + } + } + } + } + } + + fn step_consume(&mut self, byte: u8) { + self.hot.set_all(false); + self.passive_threads + .retain_mut(|thread| match self.instr[thread.pc as usize] { + Instr::Class(class) => { + if class.matches(byte) { + thread.pc += 1; + self.hot.set(thread.pc as usize, true); + true + } else { + false + } + } + Instr::Consume(bytes) => { + if bytes.contains(byte) { + thread.pc += 1; + self.hot.set(thread.pc as usize, true); + true + } else { + false + } + } + _ => false, + }); + } +} + +struct LookaheadVM<'a> { + vm: VM<'a, AssertionFlavor>, + data: &'a [u8], + cached: bool, + cache_data: Vec<BitSet>, + loc_offset: usize, +} + +impl<'a> LookaheadVM<'a> { + fn new(vm: VM<'a, AssertionFlavor>, data: &'a [u8]) -> Self { + Self { + vm, + data, + cached: false, + cache_data: Vec::new(), + loc_offset: 0, + } + } + + fn get_state(&mut self, loc: usize, state: usize) -> bool { + if !self.cached { + assert!(self.cache_data.is_empty()); + assert_eq!(self.loc_offset, 0); + self.loc_offset = loc; + self.vm.step_epsilon(&mut ()); + self.cache_data.push(self.vm.hot.clone()); + for i in (loc..self.data.len()).rev() { + self.vm.step_consume(self.data[i]); + self.vm.step_epsilon(&mut ()); + self.cache_data.push(self.vm.hot.clone()); + } + self.cache_data.reverse(); + self.cached = true; + } + + assert!( + loc >= self.loc_offset, + "get_state must be called with non-decreasing arguments." + ); + self.cache_data[loc - self.loc_offset].get(state) + } +} + +struct VirtualMachine<'a> { + vm0: VM<'a, AssertionFlavor>, + vm1: VM<'a, MainFlavor>, + vm2: LookaheadVM<'a>, + accepting: &'a BitSet, +} + +impl<'a> VirtualMachine<'a> { + fn step_epsilon_1(&mut self, loc: usize) { + self.vm1 + .step_epsilon(&mut (loc, &self.vm0.hot, &mut self.vm2)); + } + + fn step_epsilon(&mut self, loc: usize) { + self.vm0.step_epsilon(&mut ()); + self.step_epsilon_1(loc); + } + + fn step_consume(&mut self, byte: u8) { + self.vm0.step_consume(byte); + self.vm1.step_consume(byte); + } + + fn step(&mut self, byte: u8, loc: usize) { + self.step_epsilon(loc); + self.step_consume(byte); + } + + fn extract_match(&self) -> Option<Match> { + self.vm1 + .passive_threads + .iter() + .filter(|t| self.accepting.get(t.pc as usize)) + .map(|t| { + let submatches: Vec<_> = t.data.windows(2).map(|x| Some(x[0]?..x[1]?)).collect(); + + Match { + submatches: submatches.into(), + } + }) + .next() + } +} + +fn fmt_instructions<F: std::fmt::Debug + Flavor>( + f: &mut std::fmt::Formatter<'_>, + label: &str, + insns: &[Instr<F>], +) -> std::fmt::Result { + writeln!(f, "# {label}")?; + for (idx, ins) in insns.iter().enumerate() { + writeln!(f, "{idx}: {ins:?}")?; + } + Ok(()) +} + +pub struct BytecodeCompiledRegex { + instrs0: Box<[Instr<AssertionFlavor>]>, + instrs1: Box<[Instr<MainFlavor>]>, + instrs2: Box<[Instr<AssertionFlavor>]>, + no_lookbehind: bool, + submatch_count: usize, + accepting: BitSet, +} + +impl std::fmt::Debug for BytecodeCompiledRegex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fmt_instructions(f, "behind", &self.instrs0)?; + fmt_instructions(f, "ahead", &self.instrs2)?; + fmt_instructions(f, "main", &self.instrs1)?; + writeln!(f, "accepting: {:?}", self.accepting) + } +} + +impl BytecodeCompiledRegex { + pub fn re_match(&self, data: &[u8]) -> Option<Match> { + let vm0 = VM::new(&self.instrs0, Thread { pc: 0, data: () }); + let vm1 = VM::new( + &self.instrs1, + Thread { + pc: 0, + data: vec![None; 2 * self.submatch_count].into(), + }, + ); + let vm2 = VM::new(&self.instrs2, Thread { pc: 0, data: () }); + let vm2 = LookaheadVM::new(vm2, data); + let mut vm = VirtualMachine { + vm0, + vm1, + vm2, + accepting: &self.accepting, + }; + if self.no_lookbehind { + for (i, ch) in data.iter().cloned().enumerate() { + vm.step_epsilon_1(i); + vm.vm1.step_consume(ch); + } + vm.step_epsilon(data.len()); + } else { + for (i, ch) in data.iter().cloned().enumerate() { + vm.step(ch, i); + } + vm.step_epsilon(data.len()); + } + vm.extract_match() + } + + pub fn matches(&self, data: &[u8]) -> bool { + self.re_match(data).is_some() + } +} + +type AssertionHandler<'a, F> = + Box<dyn 'a + FnMut(LookDirection, LookPolarity, Pattern) -> CompileResult<Instr<F>>>; + +#[derive(Copy, Clone)] +struct CompiledSnippet { + begin: JumpTarget, + end: JumpTarget, +} + +struct Compiler<'a, F: Flavor> { + instrs: Vec<Instr<F>>, + map: HashMap<Pattern, CompiledSnippet>, + assertion_handler: AssertionHandler<'a, F>, + assertion_fork_base: usize, + submatch_count: usize, +} + +fn fork<F: Flavor>(repeat: usize, exit: usize, greedy: GreedyBehavior) -> Instr<F> { + let repeat = repeat as JumpTarget; + let exit = exit as JumpTarget; + match greedy { + GreedyBehavior::Greedy => Instr::Fork(repeat, exit), + GreedyBehavior::NonGreedy => Instr::Fork(exit, repeat), + } +} + +impl<'a, F: Flavor> Compiler<'a, F> { + fn new( + assertion_handler: impl 'a + + FnMut(LookDirection, LookPolarity, Pattern) -> CompileResult<Instr<F>>, + ) -> Self { + Self { + instrs: Vec::new(), + map: HashMap::new(), + assertion_handler: Box::new(assertion_handler), + assertion_fork_base: usize::MAX, + submatch_count: 0, + } + } + + fn rep_1_or_more(&mut self, pat: Pattern, greedy: GreedyBehavior) -> CompileResult { + let base = self.instrs.len(); + self.compile(pat)?; + let exit = self.instrs.len() + 1; + self.instrs.push(fork(base, exit, greedy)); + Ok(()) + } + + fn rep_0_or_1(&mut self, pat: Pattern, greedy: GreedyBehavior) -> CompileResult { + let base = self.instrs.len(); + self.instrs.push(Instr::Jump(u32::MAX)); + self.compile(pat)?; + self.instrs[base] = fork(base + 1, self.instrs.len(), greedy); + Ok(()) + } + + fn rep_any_amt(&mut self, pat: Pattern, greedy: GreedyBehavior) -> CompileResult { + let base = self.instrs.len(); + self.instrs.push(Instr::Jump(u32::MAX)); + self.compile(pat)?; + let fork_pos = self.instrs.len(); + let after = fork_pos + 1; + self.instrs.push(fork(base + 1, after, greedy)); + self.instrs[base] = Instr::Jump(fork_pos as JumpTarget); + Ok(()) + } + + fn compile(&mut self, pat: Pattern) -> CompileResult { + match pat { + Pattern::Byte(x) => self.instrs.push(Instr::Consume(ByteRange::new_single(x))), + Pattern::Range(a, b) => self.instrs.push(Instr::Consume(ByteRange::new_range(a, b))), + Pattern::CharacterClass(cc) => { + self.instrs.push(Instr::Class(cc)); + } + Pattern::Alt(patterns) => { + let branch_factor = patterns.len(); + assert!(branch_factor > 0); + + let base = self.instrs.len(); + + // placeholders to later place in forks + for _ in 0..patterns.len() - 1 { + self.instrs.push(Instr::Jump(u32::MAX)); + } + + let mut enter_pats = Vec::new(); + let mut leave_pats = Vec::new(); + for pat in patterns.into_iter() { + enter_pats.push(self.instrs.len()); + self.compile(pat)?; + leave_pats.push(self.instrs.len()); + + // placeholder to place in join + self.instrs.push(Instr::Jump(u32::MAX)); + } + + self.instrs.pop(); // remove last jump + let join_point = self.instrs.len(); + + // link forks + for i in 0..branch_factor - 1 { + let a = enter_pats[i]; + let b = if i == branch_factor - 2 { + enter_pats[i + 1] + } else { + base + i + 1 + }; + self.instrs[base + i] = Instr::Fork(a as JumpTarget, b as JumpTarget); + } + + // link joins + for i in 0..branch_factor - 1 { + self.instrs[leave_pats[i]] = Instr::Jump(join_point as JumpTarget); + } + } + Pattern::Concat(patterns) => { + for pat in patterns.into_iter() { + self.compile(pat)?; + } + } + Pattern::Rep(pat, 0, None, greed) => { + self.rep_any_amt(*pat, greed)?; + } + Pattern::Rep(pat, min, None, greed) => { + let pat = *pat; + for _ in 1..min { + self.compile(pat.clone())?; + } + self.rep_1_or_more(pat, greed)?; + } + Pattern::Rep(pat, min, Some(max), greed) => { + let pat = *pat; + let opt = max - min; + for _ in 0..min { + self.compile(pat.clone())?; + } + for _ in 0..opt { + self.rep_0_or_1(pat.clone(), greed)?; + } + } + Pattern::Assertion(look_direction, look_polarity, pattern) => { + let ins = (self.assertion_handler)(look_direction, look_polarity, *pattern)?; + self.instrs.push(ins); + } + Pattern::Nothing => {} + Pattern::Submatch(pat) => { + let i = self.submatch_count as u32 * 2; + self.submatch_count += 1; + if let Some(ins) = F::save(i) { + self.instrs.push(Instr::Custom(ins)); + } + self.compile(*pat)?; + if let Some(ins) = F::save(i + 1) { + self.instrs.push(Instr::Custom(ins)); + } + } + } + Ok(()) + } + + fn compile_and_memoize(&mut self, pat: Pattern) -> CompileResult<CompiledSnippet> { + if let Some(&jt) = self.map.get(&pat) { + return Ok(jt); + } + let begin = self.instrs.len() as JumpTarget; + self.compile(pat.clone())?; + let end = self.instrs.len() as JumpTarget; + self.instrs.push(Instr::Class(CharacterClass::Nothing)); + let bounds = CompiledSnippet { begin, end }; + self.map.insert(pat, bounds); + Ok(bounds) + } + + fn finalize_assertion_forks(&mut self) { + let fork_targets: Vec<JumpTarget> = self.map.values().map(|v| v.begin).collect(); + let fork_begin = self.instrs.len() as JumpTarget; + match fork_targets.len() { + 0 => { + self.instrs[self.assertion_fork_base] = Instr::Class(CharacterClass::Nothing); + } + 1 => { + self.instrs[self.assertion_fork_base] = Instr::Jump(fork_targets[0]); + } + 2 => { + self.instrs[self.assertion_fork_base] = + Instr::Fork(fork_targets[0], fork_targets[1]); + } + n => { + self.instrs[self.assertion_fork_base] = Instr::Fork(fork_targets[0], fork_begin); + for i in 1..n - 1 { + let fork = if i == n - 2 { + Instr::Fork(fork_targets[i], fork_targets[i + 1]) + } else { + Instr::Fork(fork_targets[i], self.instrs.len() as JumpTarget + 1) + }; + self.instrs.push(fork); + } + } + } + } +} + +fn assertion_compiler() -> Compiler<'static, AssertionFlavor> { + let mut c = Compiler::new(|_, _, _| Err(RegexCompilationError::NestedLookaroundNotSupported)); + c.rep_any_amt( + Pattern::CharacterClass(CharacterClass::Everything), + GreedyBehavior::NonGreedy, + ) + .expect("characterclass should always compile"); + c.assertion_fork_base = c.instrs.len(); + c.instrs.push(Instr::Jump(u32::MAX)); // in the end this gets replaced by a jump to a fork-list for all the assertions + c +} + +#[derive(Clone, Debug)] +pub enum RegexCompilationError { + NestedLookaroundNotSupported, +} + +pub type CompileResult<T = ()> = Result<T, RegexCompilationError>; + +impl TryFrom<Pattern> for BytecodeCompiledRegex { + type Error = RegexCompilationError; + + fn try_from(value: Pattern) -> Result<Self, Self::Error> { + let mut neg = assertion_compiler(); + let mut pos = assertion_compiler(); + let (final_state, instrs, submatch_count) = { + let mut main: Compiler<MainFlavor> = Compiler::new(|dir, pol, pat| { + let target = match dir { + LookDirection::Ahead => pos.compile_and_memoize(pat.reverse()), + LookDirection::Behind => neg.compile_and_memoize(pat), + }? + .end; + + Ok(Instr::Custom(MainInstr::Join(Assertion { + target, + dir, + pol, + }))) + }); + main.compile(value)?; + let end = main.instrs.len(); + main.instrs.push(Instr::Class(CharacterClass::Nothing)); + (end, main.instrs, main.submatch_count) + }; + neg.finalize_assertion_forks(); + pos.finalize_assertion_forks(); + + let mut accepting = BitSet::new(instrs.len()); + accepting.set(final_state, true); + + Ok(Self { + no_lookbehind: neg.map.is_empty(), + instrs0: neg.instrs.into(), + instrs1: instrs.into(), + instrs2: pos.instrs.into(), + accepting, + submatch_count, + }) + } +} + +impl RegexEngine for BytecodeCompiledRegex { + type CompileError = RegexCompilationError; + + fn compile(pat: Pattern) -> Result<Self, Self::CompileError> { + Self::try_from(pat) + } + + fn run(&self, input: &[u8]) -> Option<Match> { + self.re_match(input) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::parse::Parse; + fn regex(s: &str) -> BytecodeCompiledRegex { + let pat = Pattern::parse_from_bytes(s.as_bytes()).unwrap(); + let compiled = BytecodeCompiledRegex::try_from(pat).unwrap(); + compiled + } + + #[test] + fn print_compiled_vm() { + let compiled = regex("a?b?"); + println!("{compiled:#?}"); + assert_eq!(compiled.matches(b"ab"), true); + assert_eq!(compiled.matches(b"a"), true); + assert_eq!(compiled.matches(b"b"), true); + assert_eq!(compiled.matches(b""), true); + } + + #[test] + fn nongreedy_star() { + let re = regex("(ab*?)bb*"); + assert_eq!( + re.re_match(b"abbb").unwrap().submatches[0].clone().unwrap(), + 0..1 + ); + assert_eq!( + re.re_match(b"abbbbb").unwrap().submatches[0] + .clone() + .unwrap(), + 0..1 + ); + } + + #[test] + fn greedy_star() { + let re = regex("(ab*)bb*"); + assert_eq!( + re.re_match(b"abbb").unwrap().submatches[0].clone().unwrap(), + 0..3 + ); + assert_eq!( + re.re_match(b"abbbbb").unwrap().submatches[0] + .clone() + .unwrap(), + 0..5 + ); + } + + #[test] + fn nongreedy_plus() { + let re = regex("(ab+?)bb*"); + assert_eq!( + re.re_match(b"abbbb").unwrap().submatches[0] + .clone() + .unwrap(), + 0..2 + ); + assert_eq!( + re.re_match(b"abbbbb").unwrap().submatches[0] + .clone() + .unwrap(), + 0..2 + ); + } + + #[test] + fn greedy_plus() { + let re = regex("(ab+)bb*"); + assert_eq!( + re.re_match(b"abbb").unwrap().submatches[0].clone().unwrap(), + 0..3 + ); + assert_eq!( + re.re_match(b"abbbbb").unwrap().submatches[0] + .clone() + .unwrap(), + 0..5 + ); + } + + #[test] + fn nongreedy_qm() { + let re = regex("(ab??)bb*"); + assert_eq!( + re.re_match(b"abbbb").unwrap().submatches[0] + .clone() + .unwrap(), + 0..1 + ); + assert_eq!( + re.re_match(b"abbbbb").unwrap().submatches[0] + .clone() + .unwrap(), + 0..1 + ); + } + + #[test] + fn greedy_qm() { + let re = regex("(ab?)bb*"); + assert_eq!( + re.re_match(b"abbb").unwrap().submatches[0].clone().unwrap(), + 0..2 + ); + assert_eq!( + re.re_match(b"abbbbb").unwrap().submatches[0] + .clone() + .unwrap(), + 0..2 + ); + } +} |
