{-# LANGUAGE GeneralizedNewtypeDeriving #-}
import qualified Data.Map.Strict as M

-- | This file can be copy-pasted and will run!

-- | Symbols
type Sym = String
-- | Environments
type E a = M.Map Sym a
-- | Newtype to represent deriative values
type F = Float
newtype Der = Der { under :: F } deriving(Show, Num)

infixl 7 !#
-- | We are indexing the map at a "hash" (Sym)
(!#) :: E a -> Sym -> a
(!#) = (M.!)

-- | A node in the computation graph
data Node =
Node { name :: Sym -- ^ Name of the node
, ins :: [Node] -- ^ inputs to the node
, out :: E F -> F -- ^ output of the node
, der :: (E F, E (Sym -> Der))
-> Sym -> Der -- ^ derivative wrt to a name
}

-- | @ looks like a "circle", which is a node. So we are indexing the map
-- at a node.
(!@) :: E a -> Node -> a
(!@) e node = e M.! (name node)

-- | Given the current environments of values and derivatives, compute
-- | The new value and derivative for a node.
run_ :: (E F, E (Sym -> Der)) -> Node -> (E F, E (Sym -> Der))
run_ ein (Node name ins out der) =
let (e', ed') = foldl run_ ein ins -- run all the inputs
v = out e' -- compute the output
dv = der (e', ed') -- and the derivative
in (M.insert name v e', M.insert name dv ed')  -- and insert them

-- | Run the program given a node
run :: E F -> Node -> (E F, E (Sym -> Der))
run e n = run_ (e, mempty) n

-- | Let's build nodes
nconst :: Sym -> F -> Node
nconst n f = Node n [] (\_ -> f) (\_ _ -> 0)

-- | Variable
nvar :: Sym -> Node
nvar n = Node n [] (!# n) (\_ n' -> if n == n' then 1 else 0)

-- | binary operation
nbinop :: (F -> F -> F)  -- ^ output computation from inputs
-> (F -> Der -> F -> Der -> Der) -- ^ derivative computation from outputs
-> Sym -- ^ Name
-> (Node, Node) -- ^ input nodes
-> Node
nbinop f df n (in1, in2) =
Node { name = n
, ins = [in1, in2]
, out = \e -> f (e !# name in1) (e !# name in2)
, der = \(e, ed) n' ->
let (name1, name2) = (name in1, name in2)
(v1, v2) = (e !# name1, e !# name2)
(dv1, dv2) = (ed !# name1 $n', ed !# name2$ n')
in df v1 dv1 v2 dv2
}

nadd :: Sym -> (Node, Node) -> Node
nadd = nbinop (+) (\v dv v' dv' -> dv + dv')

nmul :: Sym -> (Node, Node) -> Node
nmul = nbinop (*) (\v (Der dv) v' (Der dv') -> Der $(v*dv') + (v'*dv)) main :: IO () main = do let x = nvar "x" :: Node let y = nvar "y" let xsq = nmul "xsq" (x, x) let ten = nconst "10" 10 let xsq_plus_10 = nadd "xsq_plus_10" (xsq, ten) let xsq_plus_10_plus_y = nadd "xsq_plus_10_plus_y" (xsq_plus_10, y) let (e, de) = run (M.fromList$ [("x", 2.0), ("y", 3.0)]) xsq_plus_10_plus_y
putStrLn $show e putStrLn$ show $de !@ xsq_plus_10_plus_y$ "x"
putStrLn $show$ de !@ xsq_plus_10_plus_y \$ "y"

Yeah, in ~80 lines of code, you can basically build an autograd engine. Isn't haskell so rad?