import sys from regex import Literal, Concatenation, Alternation, Star, lit from nfa import NFA, prettyprint from parse_regex import parse, ParseError def to_nfa(regex): def new_state(): nonlocal state_name_counter, transitions state_name = state_name_counter state_name_counter += 1 transitions[state_name] = {} return state_name def worker(node): nonlocal transitions if type(node) == Literal: # text # (start) ---------> (end) start_state = new_state() end_state = new_state() transitions[start_state][end_state] = lit(node.text) return (start_state, end_state) elif type(node) == Concatenation: # (start) → […] → […] → […] start_state = new_state() prev_state = start_state for element in node.elements: inner_start, inner_end = worker(element) # (prev) → (inner_start) → […] transitions[prev_state][inner_start] = lit('') # Link next element straight to the inner end # state prev_state = inner_end return (start_state, prev_state) elif type(node) == Alternation: # +-> […] --+ # | | # (start) --+-> […] --+-> (end) # | | # +-> […] --+ start_state = new_state() end_state = new_state() for element in node.elements: inner_start, inner_end = worker(element) # (start) → (inner_start) → […] transitions[start_state][inner_start] = lit('') # […] → (inner_end) → (end) transitions[inner_end][end_state] = lit('') return (start_state, end_state) elif type(node) == Star: # +- […] <-+ # | | # v | # (start) --+--> (end) start_state = new_state() end_state = new_state() inner_start, inner_end = worker(node.element) # (start) → (inner_start) → […] transitions[start_state][inner_start] = lit('') # […] → (inner_end) → (start) transitions[inner_end][start_state] = lit('') # (start) → (end) transitions[start_state][end_state] = lit('') return (start_state, end_state) else: raise ValueError('node has to be Literal, Concatenation, Alternation, or Star') state_name_counter = 0 transitions = {} start_state, end_state = worker(regex) return NFA(start_state, [end_state], transitions) def main(): try: regex = parse(input('regex> ')) except ParseError as err: print('%s: Error: %s' % (sys.argv[0], str(err)), file=sys.stderr) else: nfa = to_nfa(regex) prettyprint(nfa) if __name__ == '__main__': main()