#!/usr/bin/env python3 from lark import Lark, Transformer, ParseError, Tree parser = Lark(r''' DIGIT : /[0-9]/ DIGITS : DIGIT+ NUMBER : "-"? DIGITS ?bracketliteral : "\\" /./ | /[^\]\-]/ ?range : bracketliteral -> char | bracketliteral "-" bracketliteral brackets : "[" range+ "]" -> either char : /[^\\\[\]\{\}\(\)\|\^~\?!]/ | "\\" /\D/ | // numrange : DIGITS | DIGITS "-" DIGITS ?unit : parens | brackets | char ?concat_func : unit | concat_func "{" DIGITS "}" -> concat_repeat | concat_func "?" -> zero_or_one | concat_func "~" -> reverse | concat_func "~" NUMBER -> roll | concat_func "~{" NUMBER ["," DIGITS] "}" -> roll | concat_func "!" -> collapse | concat_func "!" DIGIT+ -> collapse | concat_func "!{" numrange ("," numrange)* "}" -> collapse_ranges | concat_func "\\" DIGIT+ -> index | concat_func "\\{" numrange ("," numrange)* "}" -> index_ranges ?concat : concat_func+ ?choice_func : concat | choice_func "^{" DIGITS "}" -> weave_repeat | choice_func "|{" DIGITS "}" -> either_repeat ?choice : choice_func | choice ("^" choice_func)+ -> weave | choice ("|" choice_func)+ -> either ?parens : "(" choice ")" ''', start='choice', ambiguity='resolve__antiscore_sum') class Expand(Transformer): def __init__(self, amp=None): self.amp = amp def char(self, args): if args: c = args[0].value else: c = '' if self.amp and c == '&': return self.amp return [c] def range(self, args): result = [] a, b = map(ord, args) while a < b: result.append(chr(a)) a += 1 while a > b: result.append(chr(a)) a -= 1 result.append(chr(a)) return result def zero_or_one(self, args): return self.either([[''], args[0]]) def either(self, args): result = [] for x in args: result.extend(x) return result def concat(self, args): result = [''] for arg in args: replace = [] for a in result: for b in arg: replace.append(a + b) result = replace return result def weave(self, args): result = [] for i in range(max(map(len, args))): for arg in args: if i < len(arg): result.append(arg[i]) return result def roll(self, args): if len(args) == 3: g = int(args[2].value) else: g = len(args[0]) r = int(args[1].value) groups = [[]] for i, elem in enumerate(args[0]): if i % g == 0: groups.append([]) groups[-1].append(elem) result = [] for group in groups: for i in range(len(group)): result.append(group[(i + r) % len(group)]) return result def reverse(self, args): return args[0][::-1] def numrange(self, args): result = [] a = int(args[0].value) if len(args) == 1: b = a else: b = int(args[1].value) while a < b: result.append(a) a += 1 while a > b: result.append(a) a -= 1 result.append(a) return result def index(self, args): result, x = [], args[0] for i in args[1:]: result.append(x[int(i.value) % len(x)]) return result def index_ranges(self, args): result, x = [], args[0] for arg in args[1:]: for i in arg: result.append(x[i % len(x)]) return result def collapse(self, args): result, x = '', args[0] if len(args) > 1: for i in args[1:]: result += x[int(i.value) % len(x)] else: result = ''.join(args[0]) return [result] def collapse_ranges(self, args): result, x = '', args[0] for arg in args[1:]: for i in arg: result += x[i % len(x)] return [result] def concat_repeat(self, args): return self.concat([args[0]] * int(args[1].value)) def either_repeat(self, args): return self.either([args[0]] * int(args[1].value)) def weave_repeat(self, args): return self.weave([args[0]] * int(args[1].value)) def lookup(choices): lookup = dict() for n, choice in enumerate(choices): curr = lookup for c in choice: if c not in curr: curr[c] = dict() curr = curr[c] curr[None] = n return lookup def findall(lookup, string): i, result = 0, [] while i < len(string): c = string[i] if c in lookup: j = i + 1 curr = lookup[c] while j < len(string) and string[j] in curr: curr = curr[string[j]] j += 1 if None in curr: result.append((curr[None], i, j)) i = j else: i += 1 elif None in lookup: result.append((lookup[None], i, i)) i += 1 else: i += 1 if None in lookup: i = len(string) result.append((lookup[None], i, i)) return result def replace(a, b, s): try: a = parser.parse(a) b = parser.parse(b) except ParseError: return '' a = Expand().transform(a) look = lookup(a) locs = findall(look, s) if not locs: return '' b = Expand(amp=a).transform(b) for n, i, j in reversed(locs): r = b[n % len(b)] s = s[:i] + r + s[j:] return s if __name__ == '__main__': from sys import argv p = parser.parse(argv[1]) print(p.pretty()) print(Expand().transform(p))