regexen_nfae/regex_to_nfa.py

113 lines
2.5 KiB
Python

import sys
from regex import Literal, Concatenation, Alternation, Star, lit
from nfa import NFA, prettyprint
from parse_regex import parse, ParseError
def to_nfa(regex):
def new_state():
nonlocal state_name_counter, transitions
state_name = state_name_counter
state_name_counter += 1
transitions[state_name] = {}
return state_name
def worker(node):
nonlocal transitions
if type(node) == Literal:
# text
# (start) ---------> (end)
start_state = new_state()
end_state = new_state()
transitions[start_state][end_state] = lit(node.text)
return (start_state, end_state)
elif type(node) == Concatenation:
# (start) → […] → […] → […]
start_state = new_state()
prev_state = start_state
for element in node.elements:
inner_start, inner_end = worker(element)
# (prev) → (inner_start) → […]
transitions[prev_state][inner_start] = lit('')
# Link next element straight to the inner end
# state
prev_state = inner_end
return (start_state, prev_state)
elif type(node) == Alternation:
# +-> […] --+
# | |
# (start) --+-> […] --+-> (end)
# | |
# +-> […] --+
start_state = new_state()
end_state = new_state()
for element in node.elements:
inner_start, inner_end = worker(element)
# (start) → (inner_start) → […]
transitions[start_state][inner_start] = lit('')
# […] → (inner_end) → (end)
transitions[inner_end][end_state] = lit('')
return (start_state, end_state)
elif type(node) == Star:
# +- […] <-+
# | |
# v |
# (start) --+--> (end)
start_state = new_state()
end_state = new_state()
inner_start, inner_end = worker(node.element)
# (start) → (inner_start) → […]
transitions[start_state][inner_start] = lit('')
# […] → (inner_end) → (start)
transitions[inner_end][start_state] = lit('')
# (start) → (end)
transitions[start_state][end_state] = lit('')
return (start_state, end_state)
else:
raise ValueError('node has to be Literal, Concatenation, Alternation, or Star')
state_name_counter = 0
transitions = {}
start_state, end_state = worker(regex)
return NFA(start_state, [end_state], transitions)
def main():
try:
regex = parse(input('regex> '))
except ParseError as err:
print('%s: Error: %s' % (sys.argv[0], str(err)), file=sys.stderr)
else:
nfa = to_nfa(regex)
prettyprint(nfa)
if __name__ == '__main__':
main()