(****************************************************************************)
(*                 The Calculus of Inductive Constructions                  *)
(*                                                                          *)
(*                                Projet Coq                                *)
(*                                                                          *)
(*                     INRIA        LRI-CNRS        ENS-CNRS                *)
(*              Rocquencourt         Orsay          Lyon                    *)
(*                                                                          *)
(*                                 Coq V6.3                                 *)
(*                               July 1st 1999                              *)
(*                                                                          *)
(****************************************************************************)
(*                               optimise.ml                                *)
(****************************************************************************)

open Std
open Names
open Mlterm
open Genpp

(* free_var : identifier list -> MLast -> identifier list
 * [free_var env t] returns the list of free variables of t in
 * de Bruijn environment env. *)

let free_var = 
 let rec free n idl = function
    MLrel i ->
      if i>n then [List.nth idl (i-1)] else []
  | MLapp(t,args) ->
      List.fold_left union [] (List.map (free n idl) (t::args))
  | MLlam(id,t) ->
      free (n+1) (id::idl) t
  | MLcons(_,_,args) ->
      List.fold_left union [] (List.map (free n idl) args)
  | MLcase(t,pv) ->
      union
	(free n idl t)
      	(it_vect union []
	   (Array.map (fun (_,ids,t) ->
			 let k = List.length ids in
			   free (n+k) (ids@idl) t) pv))
  | MLfix(_,_,ids,cl) ->
      let k = List.length ids in
      	List.fold_left union [] (List.map (free (n+k) (ids@idl)) cl)
  | _ -> []
 in free 0
  


let free_var_list idl l = List.fold_left union [] (List.map (free_var idl) l)


(* ml_subst_glob : identifier -> MLast -> MLast -> MLast
 * [ml_subst_glob id M t] substitutes M for (MLglob id) in t.
 * Renamings of bound variables of t are made out of avoid. *)

let ml_subst_glob avoid id m = 
 let rec substrec av = function
    MLglob id' as t -> if id=id' then rename_bindings av m else t
  | MLapp(t,argl)   -> MLapp(substrec av t, List.map (substrec av) argl)
  | MLlam(id,t)     -> MLlam(id, substrec (id::av) t)
  | MLcons(i,id,l)  -> MLcons(i,id,List.map (substrec av) l)
  | MLcase(t,v)     -> MLcase(substrec av t,
      	       	       	      Array.map 
				(fun (n,l,t) -> (n,l,substrec (l@av) t)) v)
  | MLfix(i,b,l,f)  -> MLfix(i,b,l, List.map (substrec (l@av)) f)
  | x		    -> x
 in substrec avoid



(* norm : identifier list -> MLast -> MLast
 * [norm env M] normalises the term M in the de Bruijn environment env. *)

let rec norm idl = function
    MLapp(t1,argl) ->
      (match norm idl t1 with
	   MLlam(id,t4) ->
	     let m = List.hd argl in 
	       (match List.tl argl with
		    [] -> norm idl (ml_subst1 idl m t4)
		  | l  -> norm idl (MLapp(ml_subst1 idl m t4, l)))
         | MLcase(t4,pv) ->
	     let argl' = List.map (norm idl) argl in
	     let pv' = Array.map 
      	       	       	 (fun (n,l,t) -> 
			    let k = List.length l in
			    let argl'' = List.map (ml_lift k) argl' in
      	       	       	      (n,l,norm (l@idl) (MLapp(t,argl'')))) pv in
      	       norm idl (MLcase(t4,pv'))
	 | t1' -> MLapp(t1', List.map (norm idl) argl))
	
  | MLlam(id,t) -> MLlam(id, norm (id::idl) t)
	
  | MLcase(t,v) as x -> 
      (match norm idl t with
	   MLcons(j,_,argl) -> 
	     (* redex: Case (cons_j a1 ... an) of
	      *        ... | (cons_j x1 ... xn) => c
	      *
	      * on transforme en ([x1]...[xn]c a1 ... an)
              * simple mais efficace: evite la salade de Bruijnesque
              *)
	     let (_,ids,c) = v.(j-1) in
	     let c' = List.fold_right (fun id t -> MLlam (id,t)) ids c in
	     let c' = match argl with [] -> c' | _ -> MLapp (c',argl) in
	       norm idl c'

      	 | MLcase(t',v') ->
	     let c = MLcase(t',
			    Array.map 
			      (fun (id',ids',t0') ->
				 let k = List.length ids' in
				 let v0 = Array.map (ml_liftn_branch k) v in
	 			   (id',ids',MLcase(t0',v0))) v') in
	       norm idl c
	   
      	 | t2 -> MLcase(t2,Array.map (fun (id,ids,c) -> 
      	       	       	       	       	(id,ids,norm (ids@idl) c)) v) )

  | MLfix(i,b,ids,l) -> MLfix(i,b,ids,List.map (norm (ids@idl)) l)

  | MLcons(i,id,argl) -> MLcons(i,id,List.map (norm idl) argl)

  | t -> t



(* is_strict : MLast -> bool
 * [is_strict M] tells us if the term M must NOT be expanded. *)

let rec ml_size = function
    MLapp(t,l)     -> (List.length l) + (ml_size t) + (ml_size_list l)
  | MLlam(_,t)     -> 1 + (ml_size t)
  | MLcons(_,_,l)  -> ml_size_list l
  | MLcase(t,pv)   -> 1 + (ml_size t) + (List.fold_left ( (+)) 0
      	       	              (map_vect_list (fun (_,_,t) -> ml_size t) pv))
  | MLfix(_,_,_,f) -> ml_size_list f
  | _              -> 0
and ml_size_list l =
  List.fold_left ( (+)) 0 (List.map ml_size l)


let rec hd_var = function
    MLapp(t,_) -> hd_var t
  | MLrel _    -> true
  | _	       -> false


let rec ml_abs_var = function
    MLlam(id,t)  -> add_set id (ml_abs_var t)
  | MLcase(_,pv) -> List.fold_right union (map_vect_list (fun (_,l,t) ->
		       subtract (ml_abs_var t) l) pv) []
  | _            -> []


let strict_var = 
 let rec strict idl = function
    MLapp(MLglob _,l) -> strict_list idl l
  | MLapp(MLrel n, _) -> [List.nth idl (n-1)]
  | MLapp(t1,l)       -> if hd_var t1 then strict idl t1
      	       	       	 else union (strict idl t1) (strict_list idl l)
  | MLlam(id,t)       -> strict (id::idl) t
  | MLcons(_,_,l)     -> strict_list idl l
  | MLcase(t,pv)      -> 
      union (strict idl t)
      	(List.fold_right intersect 
	   (map_vect_list (fun (_,l,c) ->
			     subtract (strict (l@idl) c) l) pv) idl)
  | MLfix(_,_,ids,f)  -> strict_list (ids@idl) f
  | MLrel n           -> [List.nth idl (n-1)]
  | _                 -> []
  and strict_list idl l =
    List.fold_left union [] (List.map (strict idl) l)
 in strict []

let is_strict t =
    (ml_size t>4)
  & ([] = subtract (ml_abs_var t) (strict_var t))

let is_fix = function MLfix _ -> true | _ -> false

let rec is_constr = function
    MLcons _   -> true
  | MLlam(_,t) -> is_constr t
  | _          -> false

(* optimise : MLdecl list -> MLdecl list
 * [optimise mlenv] performs partial evaluation on the ML environment
 * mlenv (ie) some global terms are expanded in the following terms. *)

let expand (id,t) = function
    DECLglob(id',t') -> DECLglob(id', ml_subst_glob [] id t t')
  | d		     -> d

(* when we must keep a constant *)
let keep prm id t t' = 
  let notex = (not (is_constr t)) & (is_strict t') in
  let ex = List.mem id prm.expand in
     (is_fix t')
  or (List.mem id prm.needed)
  or ((not prm.expansion) & (not ex))
  or (prm.expansion & notex & (not ex))

let optimise prm = 
  let rec opt_rec = function
      [] -> []
	  
    | (DECLtype _ |DECLabbrev _) as d :: l ->
	d::(opt_rec l)
	  
    | DECLglob(id,(MLexn _ as t)) as d :: l ->
	let l' = List.map (expand (id,t)) l in opt_rec l'
	    
    | [ DECLglob(id,t) ] ->
	let t' = norm [] t in [ DECLglob(id,t') ]

    | DECLglob(id,t) as d :: l ->
	let t' = norm [] t in
	  if keep prm id t t' then
	    (DECLglob(id,t')) :: (opt_rec l)
      	  else begin
	    Pp.warning ("The constant "^(string_of_id id)^" is expanded.");
	    let l' = List.map (expand (id,t')) l in
              opt_rec l'
	  end
  in
    opt_rec

let haskell_optimise =
  let rec opt_rec = function
      [] -> []
	  
    | (DECLtype _ |DECLabbrev _) as d::l ->
      	d::(opt_rec l)
	  
    | DECLglob(id,(MLexn _ as t)) as d::l ->
        d::(opt_rec l)
	  
    | DECLglob(id,t) as d::l ->
      	let t' = norm [] t in
	  if (string_of_id id)="eq_rec_r" 
          then let l' = List.map (expand (id,t')) l in opt_rec l'
          else (DECLglob(id,t'))::(opt_rec l)
  in 
    opt_rec
  
(* $Id: optimise.ml,v 1.16 1999/06/29 07:48:06 loiseleu Exp $ *)
