Transform Pass: Constant Folding


In this pass, we remove all instructions we can evaluate at compile-time. This includes arithmetic and boolean operators.

The idea is really simple: scan basic blocks, and if an instruction can be immediately evaluated, do so.

Note that for this pass to be as easy as it is, SSA is crucial.

Consider this snippet of code:

define x;
assign x := 10;
assign x := x + 42;
assign x := x * 10
return x;

and the associated load/store based IR:

entry:  default.0
    x := alloc
    _ := store 10# in %x
    x.load := load %x
    tmp.0 := add %x.load 42#
    _.1 := store %tmp.0 in %x
    x.load.1 := load %x
    tmp.1 := mul %x.load.1 10#
    _.2 := store %tmp.1 in %x

We cannot simply replace x with 10 due to the mutation happening on x!

Now, consider the SSA form of the same computation:

entry:  default.0
    tmp.0 := add 10# 42#
    tmp.1 := mul %tmp.0 10#

Due to the immutable nature of SSA, we are guaranteed that we can replace all occurences of a variable with it's RHS, and the semantics of the program will remain the same! (AKA equational reasoning).

This is enormously powerful because it allows to replace values with wild abandon :).

Key Takeaway of this pass

{-# LANGUAGE ViewPatterns #-}

module TransformConstantFolding where
import qualified OrderedMap as M
import Control.Monad.State.Strict
import Data.Traversable
import Data.Foldable
import Control.Applicative
import qualified Data.List.NonEmpty as NE
import IR
import BaseIR
import Data.Text.Prettyprint.Doc as PP
import PrettyUtils

boolToInt :: Bool -> Int
boolToInt False = 0
boolToInt True = 1

-- | Fold all possible arithmetic / boolean ops
tryFoldInst :: Inst -> Maybe Value
tryFoldInst (InstAdd (ValueConstInt i) (ValueConstInt j)) = 
    Just $ ValueConstInt (i + j)
tryFoldInst (InstMul (ValueConstInt i) (ValueConstInt j)) = 
    Just $ ValueConstInt (i * j)
tryFoldInst (InstL (ValueConstInt i) (ValueConstInt j)) = 
    Just $ ValueConstInt $ boolToInt (i < j)

tryFoldInst (InstAnd (ValueConstInt i) (ValueConstInt j)) = 
    Just $ ValueConstInt (i * j)
tryFoldInst i = Nothing

collectFoldableInsts :: Named Inst -> [(Label Inst, Value)]
collectFoldableInsts (Named name (tryFoldInst -> Just v)) = [(name, v)]
collectFoldableInsts _ = []

runTillStable :: Eq a => (a -> a) -> a -> a
runTillStable f a = let a' = f a in
    if a' == a
    then a'
    else f a'

transformConstantFold :: IRProgram -> IRProgram
transformConstantFold =  dceProgram . (runTillStable foldProgram)  where

    -- | Collection of instruction names and values
    foldableInsts :: IRProgram -> [(Label Inst, Value)]
    foldableInsts p = foldMapProgramBBs (foldMapBB (collectFoldableInsts) (const mempty)) p

    -- | Program after constant folding
    foldProgram :: IRProgram -> IRProgram
    foldProgram program = foldl (\p (name, v) -> replaceUsesOfInst name v p) program (foldableInsts program)

    -- | program after dead code elimination
    dceProgram :: IRProgram -> IRProgram
    dceProgram program =
        foldl (\p name -> filterProgramInsts (not . hasName name) p) program (map fst (foldableInsts program))