(* OptLambda - lambda code optimiser. *)

(*
$File: Compiler/OptLambda.sml $
$Date: 1992/09/17 14:18:01 $
$Version$
$Locker:  $
*)

(*$OptLambda: LVARS LAMBDA_EXP BASIC_IO FLAGS FINMAP OPT_LAMBDA*)

functor OptLambda(structure Lvars: LVARS

		  structure LambdaExp: LAMBDA_EXP
		    sharing type LambdaExp.lvar = Lvars.lvar

		  structure BasicIO: BASIC_IO
		  structure Flags: FLAGS

		  structure FinMap: FINMAP
		    sharing type LambdaExp.map = FinMap.map
		 ): OPT_LAMBDA =
  struct
    val DISABLE = false

    open LambdaExp

    fun pr x = if Flags.DEBUG_OPTIMISER
      	       then BasicIO.print(" " ^ x)
	       else ()

   (* Each of the individual optimiser functions works on the topmost
      node only; a generic tree walker does the rest. An optimiser
      calls `tick' if it does something. *)

    local
      val flag = ref false
    in
      fun reset() = (flag := false)
      fun tick msg = (pr msg; flag := true)
      fun test() = !flag
    end

    (* pass - apply an optimising function (in fact, any function of type
	      LambdaExp -> LambdaExp) over an entire lambda expression.
	      Note the clever definition of `passSwitch' to get polymorphism
	      over the switches. *)

    fun pass f lamb =
      let
	fun passSwitch f (SWITCH{arg, selections, wildcard}) =
	  SWITCH{arg=f arg,
		 selections=FinMap.composemap f selections,
		 wildcard=case wildcard
		       of Some lamb => Some(f lamb)
			| None => None
		}
      in
	case f lamb
	  of FN(lv, lamb) => FN(lv, pass f lamb)
	   | FIX(lvs, binds, scope) =>
	       FIX(lvs, map (pass f) binds, pass f scope)

	   | APP(lamb1, lamb2) => APP(pass f lamb1, pass f lamb2)
	   | PRIM_APP(n, lamb) => PRIM_APP(n, pass f lamb)
	   | VECTOR lambs => VECTOR(map (pass f) lambs)
	   | SELECT(i, lamb) => SELECT(i, pass f lamb)
	   | SWITCH_I switch => SWITCH_I(passSwitch (pass f) switch)
	   | SWITCH_S switch => SWITCH_S(passSwitch (pass f) switch)
	   | SWITCH_R switch => SWITCH_R(passSwitch (pass f) switch)
	   | RAISE lamb => RAISE(pass f lamb)
	   | HANDLE(lamb1, lamb2) => HANDLE(pass f lamb1, pass f lamb2)
	   | REF lamb => REF(pass f lamb)
	   | x => x
      end

   (* countLvar - count the number of (non-binding) occurrences of an lvar in
		  a LambdaExp. We can use `pass' for this, but only if we pass
		  a side-effecting counting function. *)

    fun countLvar lvar lamb =
      let
	val n = ref 0

	fun f lamb =
	  case lamb
	    of VAR lv => (if lv = lvar then n := !n + 1 else (); lamb)
	     | _ => lamb
      in
	pass f lamb; !n
      end

   (* replaceLvar - replace an lvar with a LambdaExp. *)

    fun replaceLvar (lvar, replacement) lamb =
      let
	fun f lamb =
	  case lamb
	    of VAR lv => if lv = lvar then replacement else lamb
	     | _ => lamb
      in
	pass f lamb
      end

   (* isSafe - determines whether a lambda expression is safe enough to
	       be optimised out (if unused), or beta-expanded. Function
	       applications are never safe - they may cause side-effects
	       and hence evaluation order is important. RAISE expressions
	       are similarly restricted. *)

    fun isSafe lamb =
      let
	fun isSafeSwitch(SWITCH{arg, selections, wildcard}) =
	  isSafe arg
	  andalso FinMap.Fold (fn ((_, lamb), x) => isSafe lamb andalso x)
			      true selections
	  andalso (case wildcard
		     of Some lamb => isSafe lamb
		      | None => true
		  )
      in
	case lamb
	  of VAR _	     => true
	   | INTEGER _	     => true
	   | STRING _	     => true
	   | REAL _	     => true
	   | FN _	     => true
	   | FIX _	     => true
	   | APP _	     => false
	   | PRIM_APP _	     => false
	   | VECTOR lambs    =>
	       List.foldL (fn a => fn b => isSafe a andalso b) true lambs

	   | SELECT(_, lamb) => isSafe lamb
	   | SWITCH_I sw     => isSafeSwitch sw
	   | SWITCH_S sw     => isSafeSwitch sw
	   | SWITCH_R sw     => isSafeSwitch sw
	   | RAISE _   	     => false
	   | HANDLE(lamb, _) => isSafe lamb
		(* if `lamb' is safe, then the actual handler can never be
		   activated. If `lamb' is unsafe, then the entire expression
		   is unsafe anyway. *)
	   | REF lamb	     => isSafe lamb
	   | VOID	     => true
      end

   (* isSmall - expressions which are small and simple can be expanded in
		place (beta-reduction) regardless of the number of occurrences
		of the lvar. It is not a requirement that small expressions
		be safe (BETA1 is careful about that) but it's likely. *)

    fun isSmall lamb =
      case lamb
	of VAR _	=> true		(* alpha-reduction here...! *)
	 | INTEGER _	=> true
	 | STRING _	=> true
	 | REAL _	=> true
	 | VOID		=> true
	 | _		=> false


   (* Here are the optimisers. *)

   (* optimiseUNUSED - remove a binding of an lvar to a simple expression
		       if the lvar isn't referenced. *)

    fun optimiseUNUSED lamb =
      case lamb
	of APP(FN(lv, scope), bind) =>
	     if countLvar lv scope = 0 andalso isSafe bind
	     then (tick "UNUSED"; scope)
	     else lamb

	 | _ => lamb

   (* optimiseBETA1 - for bindings of the form `lv1 = bind', with a trivial
    		      (and safe!) bind, replace lv1 with the bind in the
		      scope. This gives us alpha-reduction (when bind is an
		      lvar). *)

    fun optimiseBETA1 lamb =
      case lamb
	of APP(FN(lv1, scope), bind) =>
	     if isSmall bind andalso isSafe bind then
	       (tick "BETA1"; replaceLvar (lv1, bind) scope)
	     else
	       lamb

	 | _ => lamb

   (* optimiseBETA2 - for bindings `lv = bind in scope' with bind safe and
		      lv occurring once in the scope, replace the occurrence
		      of lv with bind. *)

    fun optimiseBETA2 lamb =
      case lamb
	of APP(FN(lv, scope), bind) =>
	     if countLvar lv scope = 1
		andalso isSafe bind
	     then (tick "BETA2"; replaceLvar (lv, bind) scope)
	     else lamb

	 | _ => lamb

   (* optimiseHOISTFIX - in any FIX declaration, any function whose body
    		         doesn't refer to any of the FIX identifiers can
			 be hoisted and made non-recursive (and hence
			 open to further optimisations). *)

    fun optimiseHOISTFIX lamb =
      case lamb
	of FIX(lvars, bodies, scope) =>
	     let
	       fun canHoistBody body =
		 List.foldL (fn lv => fn result =>
			       result andalso (countLvar lv body = 0)
			    ) true lvars

	       fun findHoistable(b :: rest) =
		     if canHoistBody b then Some 0
		     else (case findHoistable rest
		             of Some i => Some(i + 1)
			      | None => None
			  )

		 | findHoistable nil = None

	       fun hoistNth n =
		 let
		   val (lvN, lvars') = List.removeNth n lvars
		   val (bodyN, bodies') = List.removeNth n bodies
		 in
		   Let((lvN, bodyN),
		       FIX(lvars', bodies', scope)
		      )
		 end
	     in
	       case findHoistable bodies
		 of Some n => (tick "HOISTFIX"; hoistNth n)
		  | None => lamb
	     end

         | _ => lamb

   (* optimiseFIX0 - since we can lift things out of FIX declarations, this
    		     means we might end up with empty ones. We remove them. *)

    fun optimiseFIX0 lamb =
      case lamb
	of FIX(nil, nil, scope) => (tick "FIX0"; scope)
	 | _ => lamb

   (* optimiseSWITCH - removes switches with no discriminants. It can only
		       do so if there's a wildcard present, though; it's
		       not clever enough to raise Match/Bind or whatever.
		       Oh: it can't do anything if the selector is unsafe
		       (like a function application). *)

    fun optimiseSWITCH lamb =
      let
	fun optSwitch switch =
	  case switch
	    of SWITCH{arg, selections, wildcard=Some wildcard} =>
	         if isSafe arg andalso EqSet.isEmpty(FinMap.dom selections)
		 then (tick "SWITCH"; wildcard)
		 else lamb

	     | _ => lamb
      in
	case lamb
	  of SWITCH_I switch => optSwitch switch
	   | SWITCH_S switch => optSwitch switch
	   | SWITCH_R switch => optSwitch switch
	   | _ => lamb
      end

    val allOptimisers =
      List.foldL (General.curry op o) (fn x => x)
	[optimiseUNUSED, optimiseBETA1, optimiseBETA2, optimiseSWITCH,
	 optimiseHOISTFIX, optimiseFIX0
	]

    fun optimise lamb =
      let
	val _ = (pr "(Reset)"; reset())
	val lamb' = pass allOptimisers lamb
      in
	if test() then optimise lamb'
	else (if Flags.DEBUG_OPTIMISER then BasicIO.println "" else ();
	      lamb
	     )
      end

    val optimise =
      if DISABLE then fn x => x else optimise
  end;
