aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorJonas Maier <jonas@x77.dev>2026-06-02 15:52:36 +0200
committerJonas Maier <jonas@x77.dev>2026-06-02 15:52:36 +0200
commit93a7ccabdf85ab2733b2b67810750a97bf3509cb (patch)
tree0e42faacb6b84446128a8ca6acd123e32d811ca1 /src
parent1ec7a9d2d3bc77b07c97a07e896be05b4099cf9f (diff)
downloadpish-93a7ccabdf85ab2733b2b67810750a97bf3509cb.tar.gz
refactor enfa for better type safety
Diffstat (limited to 'src')
-rw-r--r--src/parse/regex/dfa.rs22
-rw-r--r--src/parse/regex/enfa.rs235
2 files changed, 147 insertions, 110 deletions
diff --git a/src/parse/regex/dfa.rs b/src/parse/regex/dfa.rs
index 2fd1935..78888a2 100644
--- a/src/parse/regex/dfa.rs
+++ b/src/parse/regex/dfa.rs
@@ -3,7 +3,7 @@ use std::collections::HashMap;
use super::{
byte_range::ByteRange,
- enfa::{ENFA, MultiState},
+ enfa::{ENFA, Epsilon, MultiState, Resolved},
};
pub type StateId = usize;
@@ -60,17 +60,9 @@ impl DFA {
}
}
-impl From<ENFA> for DFA {
- fn from(mut nfa: ENFA) -> Self {
- nfa.simplify();
-
- for s in nfa.states.iter() {
- if s.trans.iter().any(|t| t.is_epsilon()) {
- panic!(
- "NFA simplification did not remove epsilon transitions - cannot proceed with powerset construction."
- );
- }
- }
+impl From<ENFA<Resolved>> for DFA {
+ fn from(mut nfa: ENFA<Resolved>) -> Self {
+ nfa.remove_unreachable();
let mut multi_states = nfa.all_multi_states();
multi_states.insert(nfa.void_multi_state());
@@ -109,3 +101,9 @@ impl From<ENFA> for DFA {
}
}
}
+
+impl From<ENFA<Epsilon>> for DFA {
+ fn from(value: ENFA<Epsilon>) -> Self {
+ Self::from(value.resolve_epsilon())
+ }
+}
diff --git a/src/parse/regex/enfa.rs b/src/parse/regex/enfa.rs
index 6b7b7bb..88f4536 100644
--- a/src/parse/regex/enfa.rs
+++ b/src/parse/regex/enfa.rs
@@ -3,19 +3,35 @@ use std::{
hash::{DefaultHasher, Hash, Hasher},
};
+pub trait Stage : Clone + std::fmt::Debug + std::hash::Hash + Eq {
+ type Consume: Clone + std::fmt::Debug + std::hash::Hash + Eq;
+}
+
+#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
+pub struct Epsilon;
+impl Stage for Epsilon {
+ type Consume = Option<ByteRange>;
+}
+
+#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
+pub struct Resolved;
+impl Stage for Resolved {
+ type Consume = ByteRange;
+}
+
use super::Pattern;
use super::byte_range::ByteRange;
/// NFA with epsilon transitions
#[derive(Clone)]
#[allow(clippy::upper_case_acronyms)]
-pub struct ENFA {
- pub states: Vec<EState>,
+pub struct ENFA<S: Stage> {
+ pub states: Vec<EState<S>>,
}
#[derive(Clone)]
pub struct MultiState<'a> {
- nfa: &'a ENFA,
+ nfa: &'a ENFA<Resolved>,
states: Vec<StateId>,
accept: bool,
hash: u64,
@@ -23,7 +39,7 @@ pub struct MultiState<'a> {
impl<'a> PartialEq for MultiState<'a> {
fn eq(&self, other: &Self) -> bool {
- (self.nfa as *const ENFA as u64) == (other.nfa as *const ENFA as u64)
+ (self.nfa as *const ENFA<Resolved> as u64) == (other.nfa as *const ENFA<Resolved> as u64)
&& self.states == other.states
&& self.accept == other.accept
&& self.hash == other.hash
@@ -32,7 +48,7 @@ impl<'a> PartialEq for MultiState<'a> {
impl<'a> Eq for MultiState<'a> {}
impl<'a> MultiState<'a> {
- pub fn new(nfa: &'a ENFA, mut states: Vec<StateId>) -> Self {
+ pub fn new(nfa: &'a ENFA<Resolved>, mut states: Vec<StateId>) -> Self {
states.sort();
states.dedup();
states.shrink_to_fit();
@@ -55,7 +71,7 @@ impl<'a> MultiState<'a> {
let mut vec: Vec<_> = self
.states
.iter()
- .flat_map(|&i| self.nfa.states[i].trans.iter().map(|x| x.consumes.unwrap()))
+ .flat_map(|&i| self.nfa.states[i].trans.iter().map(|x| x.consumes))
.collect();
vec = ByteRange::non_overlapping(vec);
vec.sort();
@@ -70,7 +86,7 @@ impl<'a> MultiState<'a> {
.iter()
.flat_map(|&s| {
self.nfa.states[s].trans.iter().filter_map(|t| {
- if t.consumes.unwrap().overlaps(ch) {
+ if t.consumes.overlaps(ch) {
Some(t.to)
} else {
None
@@ -106,8 +122,8 @@ macro_rules! set {
}};
}
-impl ENFA {
- fn shift(self, amt: usize) -> Vec<EState> {
+impl<S: Stage> ENFA<S> {
+ fn shift(self, amt: usize) -> Vec<EState<S>> {
let mut s = self.states;
for state in s.iter_mut() {
@@ -118,50 +134,7 @@ impl ENFA {
s
}
- fn epsilon_dfs(&self, i: StateId, visited: &mut [bool]) {
- if visited[i] {
- return;
- }
- visited[i] = true;
- for t in self.states[i].trans.iter() {
- if t.is_epsilon() {
- self.epsilon_dfs(t.to, visited);
- }
- }
- }
-
- fn resolve_epsilon(&mut self) {
- // state X --> { state Y, Z, W which get inlined into X }
- let includes: Vec<Vec<StateId>> = (0..self.states.len())
- .map(|i| {
- let mut reach = vec![false; self.states.len()];
- self.epsilon_dfs(i, &mut reach);
- reach
- .into_iter()
- .enumerate()
- .filter_map(|(x, r)| if r { Some(x) } else { None })
- .collect()
- })
- .collect();
-
- // clear epsilons
- for s in self.states.iter_mut() {
- s.trans.retain(|t| !t.is_epsilon());
- }
-
- // inline real transitions
- for i in 0..self.states.len() {
- for &k in includes[i].iter() {
- let new = self.states[k].trans.clone();
- self.states[i].trans.extend(new);
- if self.states[k].accept {
- self.states[i].accept = true;
- }
- }
- }
- }
-
- fn remove_unreachable(&mut self) {
+ pub fn remove_unreachable(&mut self) {
let mut used = vec![false; self.states.len()];
used[0] = true;
for s in self.states.iter() {
@@ -187,38 +160,54 @@ impl ENFA {
s.remap(|i| remap[i]);
}
}
+}
- pub fn simplify(&mut self) {
- self.resolve_epsilon();
- self.remove_unreachable();
- }
-
- pub fn start_multi_state<'a>(&'a self) -> MultiState<'a> {
- MultiState::new(self, vec![0])
- }
-
- pub fn void_multi_state<'a>(&'a self) -> MultiState<'a> {
- MultiState::new(self, vec![])
+impl ENFA<Epsilon> {
+ fn epsilon_dfs(&self, i: StateId, visited: &mut [bool]) {
+ if visited[i] {
+ return;
+ }
+ visited[i] = true;
+ for t in self.states[i].trans.iter() {
+ if t.is_epsilon() {
+ self.epsilon_dfs(t.to, visited);
+ }
+ }
}
- pub fn all_multi_states<'a>(&'a self) -> HashSet<MultiState<'a>> {
- let mut states = set![self.start_multi_state()];
- let mut q = vec![self.start_multi_state()];
-
- while let Some(state) = q.pop() {
- let chars = state.possible_transitions();
+ pub fn resolve_epsilon(self) -> ENFA<Resolved> {
+ // state X --> { state Y, Z, W which get inlined into X }
+ let includes: Vec<Vec<StateId>> = (0..self.states.len())
+ .map(|i| {
+ let mut reach = vec![false; self.states.len()];
+ self.epsilon_dfs(i, &mut reach);
+ reach
+ .into_iter()
+ .enumerate()
+ .filter_map(|(x, r)| if r { Some(x) } else { None })
+ .collect()
+ })
+ .collect();
- for chr in chars {
- let new = state.transition(chr);
+ // states without epsilon transitions
+ let mut states: Vec<EState<Resolved>> = self
+ .states
+ .into_iter()
+ .map(EState::remove_epsilon)
+ .collect();
- if !states.contains(&new) {
- states.insert(new.clone());
- q.push(new);
+ // inline real transitions
+ for i in 0..states.len() {
+ for &k in includes[i].iter() {
+ let new = states[k].trans.clone();
+ states[i].trans.extend(new);
+ if states[k].accept {
+ states[i].accept = true;
}
}
}
- states
+ ENFA { states }
}
fn looping(self) -> Self {
@@ -256,7 +245,7 @@ impl ENFA {
};
}
- let mut states: Vec<EState> = Vec::new();
+ let mut states: Vec<EState<Epsilon>> = Vec::new();
for nfa in nfas.into_iter() {
let len = states.len();
let mut ns = nfa.shift(len);
@@ -274,7 +263,37 @@ impl ENFA {
}
}
-impl std::fmt::Debug for ENFA {
+impl ENFA<Resolved> {
+ pub fn start_multi_state<'a>(&'a self) -> MultiState<'a> {
+ MultiState::new(self, vec![0])
+ }
+
+ pub fn void_multi_state<'a>(&'a self) -> MultiState<'a> {
+ MultiState::new(self, vec![])
+ }
+
+ pub fn all_multi_states<'a>(&'a self) -> HashSet<MultiState<'a>> {
+ let mut states = set![self.start_multi_state()];
+ let mut q = vec![self.start_multi_state()];
+
+ while let Some(state) = q.pop() {
+ let chars = state.possible_transitions();
+
+ for chr in chars {
+ let new = state.transition(chr);
+
+ if !states.contains(&new) {
+ states.insert(new.clone());
+ q.push(new);
+ }
+ }
+ }
+
+ states
+ }
+}
+
+impl std::fmt::Debug for ENFA<Epsilon> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "NFA {{")?;
for (i, s) in self.states.iter().enumerate() {
@@ -299,12 +318,12 @@ impl std::fmt::Debug for ENFA {
pub type StateId = usize;
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
-pub struct Transition {
+pub struct Transition<S: Stage> {
to: StateId,
- consumes: Option<ByteRange>,
+ consumes: S::Consume,
}
-impl Transition {
+impl Transition<Epsilon> {
fn new(consumes: ByteRange, to: StateId) -> Self {
let consumes = Some(consumes);
Self { to, consumes }
@@ -314,6 +333,18 @@ impl Transition {
Self { to, consumes: None }
}
+ pub fn is_epsilon(&self) -> bool {
+ self.consumes.is_none()
+ }
+
+ fn non_epsilon(self) -> Option<Transition<Resolved>> {
+ let consumes = self.consumes?;
+ let to = self.to;
+ Some(Transition { consumes, to })
+ }
+}
+
+impl<S: Stage> Transition<S> {
fn remap(&mut self, mut f: impl FnMut(StateId) -> StateId) {
self.to = f(self.to);
}
@@ -321,19 +352,35 @@ impl Transition {
fn reachable_states(&self) -> impl Iterator<Item = StateId> {
[self.to].into_iter()
}
-
- pub fn is_epsilon(&self) -> bool {
- self.consumes.is_none()
- }
}
#[derive(Debug, Clone)]
-pub struct EState {
- pub trans: HashSet<Transition>,
+pub struct EState<S: Stage> {
+ pub trans: HashSet<Transition<S>>,
pub accept: bool,
}
-impl EState {
+impl EState<Epsilon> {
+ fn remove_epsilon(self) -> EState<Resolved> {
+ let trans = self
+ .trans
+ .into_iter()
+ .filter_map(Transition::non_epsilon)
+ .collect();
+ let accept = self.accept;
+ EState { trans, accept }
+ }
+
+ fn set_epsilon_transitions(&mut self, trans: impl IntoIterator<Item = Transition<Epsilon>>) {
+ self.trans.retain(|t| t.consumes.is_some());
+ for transition in trans.into_iter() {
+ assert!(transition.consumes.is_none());
+ self.trans.insert(transition);
+ }
+ }
+}
+
+impl<S: Stage> EState<S> {
fn start() -> Self {
Self {
trans: HashSet::new(),
@@ -347,14 +394,6 @@ impl EState {
}
}
- fn set_epsilon_transitions(&mut self, trans: impl IntoIterator<Item = Transition>) {
- self.trans.retain(|t| t.consumes.is_some());
- for transition in trans.into_iter() {
- assert!(transition.consumes.is_none());
- self.trans.insert(transition);
- }
- }
-
fn remap(&mut self, mut f: impl FnMut(StateId) -> StateId) {
self.trans = self
.trans
@@ -377,7 +416,7 @@ pub enum EnfaTranslationError {
AssertionsNotSupported,
}
-impl TryFrom<Pattern> for ENFA {
+impl TryFrom<Pattern> for ENFA<Epsilon> {
type Error = EnfaTranslationError;
fn try_from(value: Pattern) -> Result<Self, Self::Error> {
@@ -393,7 +432,7 @@ impl TryFrom<Pattern> for ENFA {
],
},
Pattern::Alt(alts) => {
- let nfas: Vec<ENFA> = alts
+ let nfas: Vec<ENFA<Epsilon>> = alts
.into_iter()
.map(Self::try_from)
.collect::<Result<_, _>>()?;