remove curried functions; multiple arguments use tuples

This commit is contained in:
darkf 2013-10-22 15:59:05 -07:00
parent 8b41c05b94
commit d9e1a7bdc1
3 changed files with 42 additions and 59 deletions

5
ast.hs
View File

@ -9,12 +9,12 @@ data AST = Add AST AST
| Mul AST AST | Mul AST AST
| Div AST AST | Div AST AST
| Block [AST] | Block [AST]
| FunDef String ([Pattern], AST) | FunDef String (Pattern, AST)
| Defun String AST | Defun String AST
| Def String AST | Def String AST
| Var String | Var String
| Lambda [(Pattern, AST)] | Lambda [(Pattern, AST)]
| Call String [AST] | Call String AST
| UnitConst | UnitConst
| Cons AST AST | Cons AST AST
| TupleConst [AST] | TupleConst [AST]
@ -27,5 +27,6 @@ data Pattern = VarP String
| IntP Integer | IntP Integer
| UnitP | UnitP
| ConsP Pattern Pattern | ConsP Pattern Pattern
| TupleP [Pattern]
| ListP [Pattern] | ListP [Pattern]
deriving (Show, Eq) deriving (Show, Eq)

View File

@ -126,35 +126,14 @@ eval (Var var) = get >>= \(_,env) ->
eval (Defun name fn) = do eval (Defun name fn) = do
(s,env) <- get (s,env) <- get
case lookup env name of case lookup env name of
Nothing -> -- bind new fn Nothing -> -- bind new fn
eval fn >>= \fn' -> eval fn >>= \fn' ->
put (s, bind env name fn') >> return fn' put (s, bind env name fn') >> return fn'
Just oldfn -> -- add pattern to old fn Just oldfn -> -- add pattern to old fn
let FnV oldpatterns = oldfn let FnV oldpats = oldfn
newfn = merge fn (Lambda oldpatterns) in Lambda [(pat, body)] = fn
put (s, bind env name newfn) >> return newfn newfn = FnV (oldpats ++ [(pat, body)]) in
-- newfn = FnV (oldpats ++ [(pat, body)]) in put (s, bind env name newfn) >> return newfn
where
-- takes a lambda and a list of patterns and merges their
------- patterns recursively, forming a new function
mergePatterns :: AST -> AST -> Value
mergePatterns (Lambda [newpat]) (Lambda oldpatterns@(oldpat:oldpats)) =
if fst newpat /= fst oldpat then
-- we've diverged, so let's add it here
FnV (oldpatterns ++ [newpat])
else
-- we're still equal, keep going
mergePatterns (snd newpat) (snd oldpat)
mergePatterns _ (Lambda b) = FnV b
mergePatterns a@(Lambda _) _ = error "k"
merge = mergePatterns
{-
mergePatterns(a, b):
if any pats(b) == pat(a),
just \(pats(b) ++ (pat(a) -> bod(a)))
else, nothing -}
eval (Def name v') = do eval (Def name v') = do
v <- eval v' v <- eval v'
@ -170,15 +149,14 @@ eval (Sub l r) = do { l <- eval l; r <- eval r; return $ l -$ r }
eval (Mul l r) = do { l <- eval l; r <- eval r; return $ l *$ r } eval (Mul l r) = do { l <- eval l; r <- eval r; return $ l *$ r }
eval (Div l r) = do { l <- eval l; r <- eval r; return $ l /$ r } eval (Div l r) = do { l <- eval l; r <- eval r; return $ l /$ r }
eval (Call name args) = get >>= \(_,env) -> eval (Call name arg) = get >>= \(_,env) ->
case lookup env name of case lookup env name of
Just fn@(FnV _) -> Just fn@(FnV _) -> eval arg >>= apply fn
do Just fn@(Builtin _) -> eval arg >>= apply fn
xargs <- mapM eval args
applyMany fn xargs
Just fn@(Builtin _) -> mapM eval args >>= applyMany fn
Nothing -> error $ "call: name " ++ name ++ " doesn't exist or is not a function" Nothing -> error $ "call: name " ++ name ++ " doesn't exist or is not a function"
eval x = error $ "eval: unhandled: " ++ show x
patternBindings :: Pattern -> Value -> Maybe Env patternBindings :: Pattern -> Value -> Maybe Env
patternBindings (VarP n) v = Just $ M.fromList [(n, v)] patternBindings (VarP n) v = Just $ M.fromList [(n, v)]
@ -198,6 +176,7 @@ patternBindings (ConsP xp xsp) (ListV (x:xs)) =
Just $ M.union xe xse Just $ M.union xe xse
patternBindings (ConsP _ _) _ = Nothing patternBindings (ConsP _ _) _ = Nothing
-- lists
patternBindings (ListP []) (ListV (x:xs)) = Nothing -- not enough patterns patternBindings (ListP []) (ListV (x:xs)) = Nothing -- not enough patterns
patternBindings (ListP (_:_)) (ListV []) = Nothing -- not enough values patternBindings (ListP (_:_)) (ListV []) = Nothing -- not enough values
patternBindings (ListP []) (ListV []) = Just M.empty -- base case patternBindings (ListP []) (ListV []) = Just M.empty -- base case
@ -208,16 +187,16 @@ patternBindings (ListP (x:xs)) (ListV (y:ys)) =
Just $ M.union env' env Just $ M.union env' env
patternBindings (ListP _) _ = Nothing -- not a list patternBindings (ListP _) _ = Nothing -- not a list
-- applies many arguments to a function -- tuples
applyMany :: Value -> [Value] -> InterpState Value patternBindings (TupleP []) (TupleV (x:_)) = Nothing -- not enough patterns
applyMany fn@(FnV _) (arg:xs) = patternBindings (TupleP (_:_)) (TupleV []) = Nothing -- not enough values
apply fn arg >>= \value -> patternBindings (TupleP []) (TupleV []) = Just M.empty -- base case
applyMany value xs patternBindings (TupleP (x:xs)) (TupleV (y:ys)) =
applyMany (Builtin (BIF fn)) (arg:xs) = do
fn arg >>= \value -> env <- patternBindings x y
applyMany value xs env' <- patternBindings (TupleP xs) (TupleV ys)
applyMany value [] = return value Just $ M.union env' env
applyMany _ xs = error "couldn't apply all arguments" patternBindings (TupleP _) _ = Nothing -- not a tuple
-- applies a function -- applies a function
apply :: Value -> Value -> InterpState Value apply :: Value -> Value -> InterpState Value
@ -235,6 +214,8 @@ apply (FnV pats) arg =
Nothing -> -- doesn't satisfy this pattern Nothing -> -- doesn't satisfy this pattern
apply' xs apply' xs
apply (Builtin (BIF fn)) arg = fn arg
evalProgram :: [AST] -> Value -- fold the state from each node and return the result evalProgram :: [AST] -> Value -- fold the state from each node and return the result
evalProgram nodes = evalState (foldr1 (>>) $ map eval nodes) initialState evalProgram nodes = evalState (foldr1 (>>) $ map eval nodes) initialState

View File

@ -89,6 +89,8 @@ consPattern = do
return $ ConsP x y return $ ConsP x y
pattern = try consPattern pattern = try consPattern
<|> try (emptyTuple TupleP)
<|> try (tupleSeq pattern TupleP)
<|> listPattern <|> listPattern
<|> varPattern <|> varPattern
<|> intPattern <|> intPattern
@ -99,27 +101,26 @@ funDef = do
name <- identifier name <- identifier
symbol "(" symbol "("
pats <- patterns pats <- patterns
let pats' = if pats == [] then [UnitP] else pats -- at least Unit let pat = (case pats of
[] -> UnitP
[a] -> a
otherwise -> TupleP pats)
symbol ")" symbol ")"
symbol "->" symbol "->"
lst <- exprparser body <- exprparser
return $ rewriteFun (FunDef name (pats', lst)) return $ Defun name $ Lambda [(pat, body)]
-- curry FunDef to a definition of lambdas
rewriteFun (FunDef name (patterns, body)) =
Defun name lam
where
-- curry it
lam = foldr (\pat lam -> Lambda [(pat, lam)]) body patterns
call = do call = do
name <- identifier name <- identifier
whiteSpace whiteSpace
symbol "(" symbol "("
args <- sepBy exprparser (symbol ",") args <- sepBy exprparser (symbol ",")
let args' = if args == [] then [UnitConst] else args -- at least Unit let arg = (case args of
[] -> UnitConst
[a] -> a
otherwise -> TupleConst args)
symbol ")" symbol ")"
return $ Call name args' return $ Call name arg
consExpr = do consExpr = do
x <- expr' x <- expr'