## ยง Mutorch

#!/usr/bin/env python3

# x2, w1, w2 are Leaf variables
# x1 = f(w1, w2)
# y = g(x1, x2)
# loss = h(y)

# ==========================
# t is a hallucinated variable.
# y = f(x)
# GIVEN: dt/dy
# TO FIND: dt/dx
# dt/dx = dt/dy * dy/dx
# dt/dloss
# t = loss
# dt/dloss = dloss/dloss = 1

# y1 = f(x1, x2, x3)
# y2 = g(x1, x2, x3)

# FORWARD MODE: [Tangent space] ---- objects of the does nothing at all :\$\texttt{form (partial f/partial x)
# total gradient of x1: df/dx1 + dg/dx1
# total gradient of x2: df/dx2 + dg/dx2
# total gradient of x3: df/dx3 + dg/dx3

# l = r cos(theta)
# dl = dr cos(theta) + rsin(theta) dtheta
# dl/dtheta = dr/dtheta cos(theta) + rsin(theta) dtheta/dtheta
# dl/dtheta =   0       * .......  + rsin(theta) * 1

# dl/dr = dr/dr cos(theta) + rsin(theta) dtheta/dr
# dl/dr = cos(theta) +      .............*0

# REVERSE MODE: [CoTangent space] --- objects of the form df
# total gradient of y1: dy1 = (df/dx1)dx1 + (df/dx2)dx2  + (df/dx3)dx3
# total gradient of y2: dy2 = (dg/dx1)dx1 + (dg/dx2)dx2  + (dg/dx3)dx3
# HALLUCINATED T:
#    y1 = f(x1, x2, x3)
#    GIVEN:   dt/dy1 [output]
#    TO FIND: dt/dx1, dt/dx2, dt/dx3 [inputs]
#    SOLN:    dt/dxi = dt/dy * dy/dxi
#                    = dt/dy * df/dxi
import pudb

class Expr:
def __mul__(self, other):
return Mul(self, other)

pass

class Var(Expr):
def __init__(self, name, val):
self.name = name
self.val = val
def __str__(self):
return "(var-%s | %s)" % (self.name, self.val)
def __repr__(self):
return self.__str__()

def backprop(self, dt_doutput):

class Mul(Expr):
def __init__(self, lhs, rhs):
self.lhs = lhs
self.rhs = rhs
self.val = self.lhs.val * self.rhs.val
def __str__(self):
return "(* %s %s | %s)" % (self.lhs, self.rhs, self.val)
def __repr__(self):
return self.__str__()

#         -------- input1
#   S    /
#  ---> v
#  <--output *
#      ^
#       \_________ input2
# think in terms of sensitivity.
# - output has S sensitivity to something,
# - output = input1 + input2
# - how much sensitivity does input1 have to S?
# - the same (S), because "sensitivity" is linear [a conjecture/axiom]
# output = f(input1, input2); f(input1, input2) = input1 + input2
def backprop(self, dt_output):
# dt/dinput1 = dt/doutput * ddoutput/dinput1 =
#            = dt/doutput * d(f(input1, input2))/dinput1
#            = dt/doutput * d(input1 * input2)/dinput1
#            = dt/doutput * input2
self.lhs.backprop(dt_output * self.rhs.val)
self.rhs.backprop(dt_output * self.lhs.val)

# a = ...   ^
# b = ...   ^
# c = a + b ^
#
def __init__(self, lhs, rhs):
self.lhs = lhs
self.rhs = rhs
self.val = self.lhs.val + self.rhs.val
def __str__(self):
return "(+ %s %s | %s)" % (self.lhs, self.rhs, self.val)
def __repr__(self):
return self.__str__()

#         -------- input1
#   S    /
#  ---> v
#  <--output
#      ^
#       \_________ input2
# think in terms of sensitivity.
# - output has S sensitivity to something,
# - output = input1 + input2
# - how much sensitivity does input1 have to S?
# - the same (S), because "sensitivity" is linear [a conjecture/axiom]
# output = f(input1, input2); f(input1, input2) = input1 + input2
def backprop(self, dt_output):
# dt/dinput1 = dt/doutput * ddoutput/dinput1 =
#            = dt/doutput * d(f(input1, input2))/dinput1
#            = dt/doutput * d(input1 + input2)/dinput1
#            = dt/doutput * 1
self.lhs.backprop(dt_output * 1)
self.rhs.backprop(dt_output * 1)

class Max(Expr):
def __init__(self, lhs, rhs):
self.lhs = lhs
self.rhs = rhs
self.val = max(self.lhs.val, self.rhs.val)
def __str__(self):
return "(max %s %s | %s)" % (self.lhs, self.rhs, self.val)
def __repr__(self):
return self.__str__()

def backprop(self, dt_output):
# dt/dinput1 = dt/doutput * doutput/dinput 1
#            = dt/doutput *d max(input1, input2)/dinput1
#            = |dt/doutput *d input1/dinput1 [if input1 > input2] = 1
#            = |dt/doutput *d input2/dinput1 [if input2 > input1] = 0
if self.val == self.lhs.val:
self.lhs.backprop(dt_output * 1)
else:
self.rhs.backprop(dt_output * 1)

x = Var("x", 10)
print("x: %s" % x)
y = Var("y", 20)
p = Var("p", 30)
print("y: %s" % y)
z0 = Mul(x, x)
print("z0: %s" % z0)
print("z1: %s" % z1)

# z1 = x*x+y
# dz1/dx = 2x
# dz1/dy = 1
# dz1/dp = 0