aboutsummaryrefslogtreecommitdiffstats
path: root/src/regex/mod.rs
diff options
context:
space:
mode:
authorJonas Maier <jonas@x77.dev>2026-06-06 12:15:52 +0200
committerJonas Maier <jonas@x77.dev>2026-06-06 12:15:52 +0200
commit53980774c327675e886179c0a2c140744dcf9b95 (patch)
treeca1fdcc9938fce2c10c51e0a51659c6ba38ac5ba /src/regex/mod.rs
parent75e0c29cf91ddc6299c14a94a038c3e3df3d2805 (diff)
downloadpish-53980774c327675e886179c0a2c140744dcf9b95.tar.gz
special cased regex for performance
Diffstat (limited to 'src/regex/mod.rs')
-rw-r--r--src/regex/mod.rs344
1 files changed, 344 insertions, 0 deletions
diff --git a/src/regex/mod.rs b/src/regex/mod.rs
new file mode 100644
index 0000000..be3026f
--- /dev/null
+++ b/src/regex/mod.rs
@@ -0,0 +1,344 @@
+use crate::parse::Parse;
+
+pub mod bc;
+mod byte_range;
+pub mod dfa;
+pub mod enfa;
+pub mod simple;
+
+#[derive(PartialEq, Eq, Debug, Clone, Copy, Hash)]
+pub enum LookDirection {
+ Ahead,
+ Behind,
+}
+
+impl LookDirection {
+ pub fn reverse(self) -> Self {
+ match self {
+ Self::Ahead => Self::Behind,
+ Self::Behind => Self::Ahead,
+ }
+ }
+}
+
+#[derive(PartialEq, Eq, Debug, Clone, Copy, Hash)]
+pub enum LookPolarity {
+ Positive,
+ Negative,
+}
+
+#[derive(PartialEq, Eq, Debug, Clone, Copy, Hash)]
+pub enum CharacterClass {
+ Everything,
+ Nothing,
+ Whitespace,
+ Alphabetic,
+ Alphanumeric,
+}
+
+impl CharacterClass {
+ pub fn matches(self, byte: u8) -> bool {
+ match self {
+ CharacterClass::Everything => true,
+ CharacterClass::Nothing => false,
+ CharacterClass::Whitespace => byte.is_ascii_whitespace(),
+ CharacterClass::Alphabetic => byte.is_ascii_alphabetic(),
+ CharacterClass::Alphanumeric => byte.is_ascii_alphanumeric(),
+ }
+ }
+}
+
+#[derive(PartialEq, Eq, Hash, Debug, Clone)]
+pub enum Pattern {
+ Byte(u8),
+ Range(u8, u8),
+ CharacterClass(CharacterClass),
+ Alt(Vec<Pattern>),
+ Concat(Vec<Pattern>),
+ Rep(Box<Pattern>, u32, Option<u32>, GreedyBehavior),
+ Assertion(LookDirection, LookPolarity, Box<Pattern>),
+ Submatch(Box<Pattern>),
+ Nothing,
+}
+
+#[derive(Debug, PartialEq, Eq, Copy, Clone)]
+pub enum ByteConsumption {
+ Bounded(usize),
+ Unbounded,
+}
+
+impl ByteConsumption {
+ pub fn zero() -> Self {
+ Self::Bounded(0)
+ }
+ pub fn one() -> Self {
+ Self::Bounded(1)
+ }
+}
+
+impl Ord for ByteConsumption {
+ fn cmp(&self, other: &Self) -> std::cmp::Ordering {
+ match (self, other) {
+ (Self::Bounded(a), Self::Bounded(b)) => a.cmp(b),
+ (Self::Bounded(_), Self::Unbounded) => std::cmp::Ordering::Less,
+ (Self::Unbounded, Self::Bounded(_)) => std::cmp::Ordering::Greater,
+ (Self::Unbounded, Self::Unbounded) => std::cmp::Ordering::Equal,
+ }
+ }
+}
+impl PartialOrd for ByteConsumption {
+ fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
+ Some(self.cmp(other))
+ }
+}
+
+impl std::ops::Add for ByteConsumption {
+ type Output = Self;
+
+ fn add(self, rhs: Self) -> Self::Output {
+ if let Self::Bounded(a) = self
+ && let Self::Bounded(b) = rhs
+ {
+ Self::Bounded(a + b)
+ } else {
+ Self::Unbounded
+ }
+ }
+}
+
+impl std::ops::Mul<usize> for ByteConsumption {
+ type Output = Self;
+
+ fn mul(self, rhs: usize) -> Self::Output {
+ match self {
+ Self::Bounded(x) => Self::Bounded(x * rhs),
+ Self::Unbounded => Self::Unbounded,
+ }
+ }
+}
+
+impl std::iter::Sum for ByteConsumption {
+ fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
+ iter.fold(Self::zero(), |a, b| a + b)
+ }
+}
+
+impl Pattern {
+ pub fn max_byte_consumption(&self) -> ByteConsumption {
+ match self {
+ Pattern::Byte(_) => ByteConsumption::one(),
+ Pattern::Range(_, _) => ByteConsumption::one(),
+ Pattern::CharacterClass(_) => ByteConsumption::one(),
+ Pattern::Alt(patterns) => patterns
+ .iter()
+ .map(Self::max_byte_consumption)
+ .max()
+ .unwrap_or(ByteConsumption::zero()),
+ Pattern::Concat(patterns) => patterns.iter().map(Self::max_byte_consumption).sum(),
+ Pattern::Rep(pattern, _, Some(max_reps), _) => {
+ pattern.max_byte_consumption() * (*max_reps as usize)
+ }
+ Pattern::Rep(_, _, None, _) => ByteConsumption::Unbounded,
+ Pattern::Assertion(_, _, _) => ByteConsumption::zero(),
+ Pattern::Nothing => ByteConsumption::zero(),
+ Pattern::Submatch(pat) => pat.max_byte_consumption(),
+ }
+ }
+
+ pub fn reverse(self) -> Self {
+ use Pattern::*;
+ match self {
+ Byte(_) | Nothing | Range(..) | CharacterClass(_) => self,
+ Alt(patterns) => Alt(patterns.into_iter().map(Self::reverse).collect()),
+ Concat(patterns) => Concat(patterns.into_iter().map(Self::reverse).rev().collect()),
+ Rep(pattern, min, max, greedy) => Rep(Box::new(pattern.reverse()), min, max, greedy),
+ Assertion(dir, pol, pat) => Assertion(dir.reverse(), pol, Box::new(pat.reverse())),
+ Submatch(pat) => Submatch(Box::new(pat.reverse())),
+ }
+ }
+}
+
+#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
+pub enum GreedyBehavior {
+ Greedy,
+ NonGreedy,
+}
+
+pub enum CompiledPattern {
+ Dfa(dfa::DFA),
+ Bytecode(bc::BytecodeCompiledRegex),
+ Anything(simple::Anything),
+ Nothing(simple::Nothing),
+ Exact(simple::Exact),
+}
+
+impl CompiledPattern {
+ fn is_dfa(&self) -> bool {
+ matches!(self, Self::Dfa(_))
+ }
+ fn is_bytecode(&self) -> bool {
+ matches!(self, Self::Bytecode(_))
+ }
+ fn is_simple(&self) -> bool {
+ matches!(self, Self::Anything(_) | Self::Nothing(_) | Self::Exact(_))
+ }
+}
+
+#[derive(PartialEq, Eq, Debug)]
+pub struct Match {
+ pub submatches: Box<[Option<core::ops::Range<usize>>]>,
+}
+
+impl Match {
+ pub fn new_empty() -> Self {
+ Self {
+ submatches: [].into(),
+ }
+ }
+}
+
+pub trait RegexEngine: Sized {
+ type CompileError: std::fmt::Debug + Clone;
+
+ fn compile(pat: Pattern) -> Result<Self, Self::CompileError>;
+
+ fn run(&self, input: &[u8]) -> Option<Match>;
+
+ fn matches(&self, input: &[u8]) -> bool {
+ self.run(input).is_some()
+ }
+}
+
+impl RegexEngine for CompiledPattern {
+ type CompileError = bc::RegexCompilationError;
+
+ fn compile(pat: Pattern) -> Result<Self, Self::CompileError> {
+ if let Ok(c) = simple::Anything::compile(pat.clone()) {
+ Ok(Self::Anything(c))
+ } else if let Ok(c) = simple::Nothing::compile(pat.clone()) {
+ Ok(Self::Nothing(c))
+ } else if let Ok(c) = simple::Exact::compile(pat.clone()) {
+ Ok(Self::Exact(c))
+ } else if let Ok(c) = dfa::DFA::compile(pat.clone()) {
+ Ok(Self::Dfa(c))
+ } else {
+ bc::BytecodeCompiledRegex::compile(pat).map(Self::Bytecode)
+ }
+ }
+
+ fn run(&self, input: &[u8]) -> Option<Match> {
+ match self {
+ CompiledPattern::Dfa(x) => x.run(input),
+ CompiledPattern::Bytecode(x) => x.run(input),
+ CompiledPattern::Anything(x) => x.run(input),
+ CompiledPattern::Nothing(x) => x.run(input),
+ CompiledPattern::Exact(x) => x.run(input),
+ }
+ }
+}
+
+macro_rules! all_engines {
+ ($ty_name:ident, $($x:ident : $ty:ty,)*) => {
+ pub struct $ty_name {
+ $($x: Option<$ty>,)*
+ }
+ impl RegexEngine for $ty_name {
+ type CompileError = ();
+
+ fn compile(pat: Pattern) -> Result<Self, Self::CompileError> {
+ let x = Self {
+ $($x: RegexEngine::compile(pat.clone()).ok(),)*
+ };
+ if $(x.$x.is_none())&&* {
+ Err(())
+ } else {
+ Ok(x)
+ }
+ }
+
+ fn run(&self, input: &[u8]) -> Option<Match> {
+ $(let $x = self.$x.as_ref().map(|x| x.run(input));)*
+ let mut result = None;
+ $(
+ if let Some(res) = $x {
+ if let Some(result) = result {
+ assert_eq!(res, result, concat!("engine ", stringify!($x), " does not agree with previously run engines."));
+ }
+ result = Some(res)
+ }
+ )*
+ result.unwrap()
+ }
+ }
+ }
+}
+
+all_engines!(
+ AllEngines,
+ dfa: dfa::DFA,
+ bc: bc::BytecodeCompiledRegex,
+ any: simple::Anything,
+ nothing: simple::Nothing,
+ exact: simple::Exact,
+);
+
+impl Pattern {
+ pub fn try_compile(self) -> Result<CompiledPattern, bc::RegexCompilationError> {
+ CompiledPattern::compile(self)
+ }
+}
+
+#[cfg(test)]
+macro_rules! regex_matches {
+ ($regex:literal, $match:literal, $true:literal) => {
+ assert_eq!(
+ <Pattern as crate::parse::Parse>::parse_from_bytes($regex.as_bytes())
+ .unwrap()
+ .try_compile()
+ .unwrap()
+ .matches($match.as_bytes()),
+ $true
+ )
+ };
+}
+
+#[test]
+fn foo_matches_foo() {
+ regex_matches!("foo", "foo", true);
+}
+
+#[test]
+fn dot_star_is_simple() {
+ let x = Pattern::parse_from_bytes(b".*")
+ .unwrap()
+ .try_compile()
+ .unwrap();
+ assert!(x.is_simple());
+}
+
+#[test]
+fn match_is_bytecode() {
+ let x = Pattern::parse_from_bytes(b".*(ele.*phant).*")
+ .unwrap()
+ .try_compile()
+ .unwrap();
+ assert!(x.is_bytecode());
+}
+
+#[test]
+fn simple_word_is_exact() {
+ let x = Pattern::parse_from_bytes(b"Gnu[ ]plus[ ]Linux")
+ .unwrap()
+ .try_compile()
+ .unwrap();
+ assert!(x.is_simple());
+}
+
+#[test]
+fn no_match_is_dfa() {
+ let x = Pattern::parse_from_bytes(b".*Gnu.*plus.*Linux.*")
+ .unwrap()
+ .try_compile()
+ .unwrap();
+ assert!(x.is_dfa());
+}