Skip to content
Snippets Groups Projects
Commit 1837ce98 authored by THIRIOUX Xavier's avatar THIRIOUX Xavier
Browse files

added some infrastructure to ease optimization (reusing vars)

parent 920c31de
No related branches found
No related tags found
No related merge requests found
...@@ -141,6 +141,9 @@ let node_local_variables nd = ...@@ -141,6 +141,9 @@ let node_local_variables nd =
let node_output_variables nd = let node_output_variables nd =
List.fold_left (fun outputs v -> ISet.add v.var_id outputs) ISet.empty nd.node_outputs List.fold_left (fun outputs v -> ISet.add v.var_id outputs) ISet.empty nd.node_outputs
let node_auxiliary_variables nd =
ISet.diff (node_local_variables nd) (node_memory_variables nd)
let node_variables nd = let node_variables nd =
let inputs = node_input_variables nd in let inputs = node_input_variables nd in
let inoutputs = List.fold_left (fun inoutputs v -> ISet.add v.var_id inoutputs) inputs nd.node_outputs in let inoutputs = List.fold_left (fun inoutputs v -> ISet.add v.var_id inoutputs) inputs nd.node_outputs in
...@@ -452,7 +455,7 @@ struct ...@@ -452,7 +455,7 @@ struct
(* map: var |-> list of disjoint vars, sorted in increasing branch length order, (* map: var |-> list of disjoint vars, sorted in increasing branch length order,
maybe removing shorter branches *) maybe removing shorter branches *)
type clock_map = (ident, var_decl list) Hashtbl.t type clock_map = (ident, ident list) Hashtbl.t
let clock_disjoint_map vdecls = let clock_disjoint_map vdecls =
let map = Hashtbl.create 23 in let map = Hashtbl.create 23 in
...@@ -460,8 +463,8 @@ struct ...@@ -460,8 +463,8 @@ struct
List.iter List.iter
(fun v1 -> let disj_v1 = (fun v1 -> let disj_v1 =
List.fold_left List.fold_left
(fun res v2 -> if Clocks.disjoint v1.var_clock v2.var_clock then CISet.add v2 res else res) (fun res v2 -> if Clocks.disjoint v1.var_clock v2.var_clock then ISet.add v2.var_id res else res)
CISet.empty ISet.empty
vdecls in vdecls in
(* disjoint vdecls are stored in increasing branch length order *) (* disjoint vdecls are stored in increasing branch length order *)
Hashtbl.add map v1.var_id disj_v1) Hashtbl.add map v1.var_id disj_v1)
...@@ -470,21 +473,22 @@ struct ...@@ -470,21 +473,22 @@ struct
end end
(* replace variable [v] by [v'] in disjunction [map]. Then: (* replace variable [v] by [v'] in disjunction [map]. Then:
- the mapping v |-> ... disappears - the mapping v' becomes v' |-> (map v) inter (map v')
- the mapping v' becomes v' |-> (map v) inter (map v') - the mapping v |-> ... then disappears
- other mappings become x |-> (map x) \ (if v in x then v else v') - other mappings become x |-> (map x) \ (if v in x then v else v')
*) *)
let replace_in_disjoint_map map v v' = let replace_in_disjoint_map map v v' =
begin begin
Hashtbl.remove map v.var_id; Hashtbl.replace map v' (ISet.inter (Hashtbl.find map v) (Hashtbl.find map v'));
Hashtbl.replace map v'.var_id (CISet.inter (Hashtbl.find map v.var_id) (Hashtbl.find map v'.var_id)); Hashtbl.remove map v;
Hashtbl.iter (fun x map_x -> Hashtbl.replace map x (CISet.remove (if CISet.mem v map_x then v else v') map_x)) map; Hashtbl.iter (fun x map_x -> Hashtbl.replace map x (ISet.remove (if ISet.mem v map_x then v else v') map_x)) map;
end end
let pp_disjoint_map fmt map = let pp_disjoint_map fmt map =
begin begin
Format.fprintf fmt "{ /* disjoint map */@."; Format.fprintf fmt "{ /* disjoint map */@.";
Hashtbl.iter (fun k v -> Format.fprintf fmt "%s # { %a }@." k (Utils.fprintf_list ~sep:", " Printers.pp_var_name) (CISet.elements v)) map; Hashtbl.iter (fun k v -> Format.fprintf fmt "%s # { %a }@." k (Utils.fprintf_list ~sep:", " Format.pp_print_string) (ISet.elements v)) map;
Format.fprintf fmt "}@." Format.fprintf fmt "}@."
end end
end end
......
...@@ -29,9 +29,7 @@ exception Error of Location.t * error ...@@ -29,9 +29,7 @@ exception Error of Location.t * error
module VDeclModule = module VDeclModule =
struct (* Node module *) struct (* Node module *)
type t = var_decl type t = var_decl
let compare v1 v2 = compare v1 v2 let compare v1 v2 = compare v1.var_id v2.var_id
let hash n = Hashtbl.hash n
let equal n1 n2 = n1 = n2
end end
module VMap = Map.Make(VDeclModule) module VMap = Map.Make(VDeclModule)
...@@ -301,9 +299,8 @@ let is_tuple_expr expr = ...@@ -301,9 +299,8 @@ let is_tuple_expr expr =
let expr_list_of_expr expr = let expr_list_of_expr expr =
match expr.expr_desc with match expr.expr_desc with
| Expr_tuple elist -> | Expr_tuple elist -> elist
elist | _ -> [expr]
| _ -> [expr]
let expr_of_expr_list loc elist = let expr_of_expr_list loc elist =
match elist with match elist with
......
...@@ -184,28 +184,52 @@ let replace_in_set s v v' = ...@@ -184,28 +184,52 @@ let replace_in_set s v v' =
let replace_in_death_table death v v' = let replace_in_death_table death v v' =
Hashtbl.iter (fun k dead -> Hashtbl.replace death k (replace_in_set dead v v')) death Hashtbl.iter (fun k dead -> Hashtbl.replace death k (replace_in_set dead v v')) death
let find_compatible_local node var dead = let find_compatible_local node var dead death disjoint policy =
(*Format.eprintf "find_compatible_local %s %s %a@." node.node_id var pp_iset dead;*) (*Format.eprintf "find_compatible_local %s %s %a@." node.node_id var pp_iset dead;*)
let typ = (get_node_var var node).var_type in let typ = (get_node_var var node).var_type in
let eq_var = get_node_eq var node in let eq_var = get_node_eq var node in
let locals = node.node_locals in
let aliasable_inputs = let aliasable_inputs =
match NodeDep.get_callee eq_var.eq_rhs with match NodeDep.get_callee eq_var.eq_rhs with
| None -> [] | None -> []
| Some (_, args) -> List.fold_right (fun e r -> match e.expr_desc with Expr_ident id -> id::r | _ -> r) args [] in | Some (_, args) -> List.fold_right (fun e r -> match e.expr_desc with Expr_ident id -> id::r | _ -> r) args [] in
let filter v = let filter base (v : var_decl) =
let res = let res =
ISet.mem v.var_id dead base v
&& Typing.eq_ground typ v.var_type && Typing.eq_ground typ v.var_type
&& not (Types.is_address_type v.var_type && List.mem v.var_id aliasable_inputs) in && not (Types.is_address_type v.var_type && List.mem v.var_id aliasable_inputs) in
begin begin
(*Format.eprintf "filter %a = %s@." Printers.pp_var_name v (if res then "true" else "false");*) (*Format.eprintf "filter %a = %s@." Printers.pp_var_name v (if res then "true" else "false");*)
res res
end in end in
(*Format.eprintf "reuse %s@." var;*)
try try
Some ((List.find filter node.node_locals).var_id) let disj = Hashtbl.find disjoint var in
with Not_found -> None let reuse = List.find (filter (fun v -> ISet.mem v.var_id disj && not (ISet.mem v.var_id dead))) locals in
(*Format.eprintf "reuse %s by %s@." var reuse.var_id;*)
Disjunction.replace_in_disjoint_map disjoint var reuse.var_id;
(*Format.eprintf "new disjoint:%a@." Disjunction.pp_disjoint_map disjoint;*)
Hashtbl.add policy var reuse.var_id
with Not_found ->
try
let reuse = List.find (filter (fun v -> ISet.mem v.var_id dead)) locals in
(*Format.eprintf "reuse %s by %s@." var reuse.var_id;*)
replace_in_death_table death var reuse.var_id;
(*Format.eprintf "new death:%a@." pp_death_table death;*)
Hashtbl.add policy var reuse.var_id
with Not_found -> ()
let reuse_policy node sort death = (* the reuse policy seeks to use less local variables
by replacing local variables, applying the rules
in the following order:
1) use another clock disjoint still live variable,
with the greatest possible disjoint clock
2) reuse a dead variable
For the sake of safety, we replace variables by others:
- with the same type
- not aliasable (i.e. address type)
*)
let reuse_policy node sort death disjoint =
let dead = ref ISet.empty in let dead = ref ISet.empty in
let policy = Hashtbl.create 23 in let policy = Hashtbl.create 23 in
let sort = ref sort in let sort = ref sort in
...@@ -216,9 +240,7 @@ let reuse_policy node sort death = ...@@ -216,9 +240,7 @@ let reuse_policy node sort death =
begin begin
dead := ISet.union (Hashtbl.find death head) !dead; dead := ISet.union (Hashtbl.find death head) !dead;
end; end;
(match find_compatible_local node head !dead with find_compatible_local node head !dead death disjoint policy;
| None -> ()
| Some l -> replace_in_death_table death head l; Hashtbl.add policy head l);
sort := List.tl !sort; sort := List.tl !sort;
done; done;
policy policy
......
...@@ -377,10 +377,7 @@ let translate_eq node ((m, si, j, d, s) as args) eq = ...@@ -377,10 +377,7 @@ let translate_eq node ((m, si, j, d, s) as args) eq =
| p , Expr_appl (f, arg, r) when not (Basic_library.is_internal_fun f) -> | p , Expr_appl (f, arg, r) when not (Basic_library.is_internal_fun f) ->
let var_p = List.map (fun v -> get_node_var v node) p in let var_p = List.map (fun v -> get_node_var v node) p in
let el = let el = expr_list_of_expr arg in
match arg.expr_desc with
| Expr_tuple el -> el
| _ -> [arg] in
let vl = List.map (translate_expr node args) el in let vl = List.map (translate_expr node args) el in
let node_f = node_from_name f in let node_f = node_from_name f in
let call_f = let call_f =
...@@ -504,6 +501,77 @@ let get_machine_opt name machines = ...@@ -504,6 +501,77 @@ let get_machine_opt name machines =
| None -> if m.mname.node_id = name then Some m else None) | None -> if m.mname.node_id = name then Some m else None)
None machines None machines
(* variable substitution for optimizing purposes *)
(* checks whether an [instr] is skip and can be removed from program *)
let rec instr_is_skip instr =
match instr with
| MLocalAssign (i, LocalVar v) when i = v -> true
| MStateAssign (i, StateVar v) when i = v -> true
| MBranch (g, hl) -> List.for_all (fun (_, il) -> instrs_are_skip il) hl
| _ -> false
and instrs_are_skip instrs =
List.for_all instr_is_skip instrs
let rec instr_remove_skip instr cont =
match instr with
| MLocalAssign (i, LocalVar v) when i = v -> cont
| MStateAssign (i, StateVar v) when i = v -> cont
| MBranch (g, hl) -> MBranch (g, List.map (fun (h, il) -> (h, instrs_remove_skip il [])) hl) :: cont
| _ -> instr::cont
and instrs_remove_skip instrs cont =
List.fold_right instr_remove_skip instrs cont
let rec value_replace_var fvar value =
match value with
| Cst c -> value
| LocalVar v -> LocalVar (fvar v)
| StateVar v -> value
| Fun (id, args) -> Fun (id, List.map (value_replace_var fvar) args)
| Array vl -> Array (List.map (value_replace_var fvar) vl)
| Access (t, i) -> Access(value_replace_var fvar t, i)
| Power (v, n) -> Power(value_replace_var fvar v, n)
let rec instr_replace_var fvar instr =
match instr with
| MLocalAssign (i, v) -> MLocalAssign (fvar i, value_replace_var fvar v)
| MStateAssign (i, v) -> MStateAssign (i, value_replace_var fvar v)
| MReset i -> instr
| MStep (il, i, vl) -> MStep (List.map fvar il, i, List.map (value_replace_var fvar) vl)
| MBranch (g, hl) -> MBranch (value_replace_var fvar g, List.map (fun (h, il) -> (h, instrs_replace_var fvar il)) hl)
and instrs_replace_var fvar instrs =
List.map (instr_replace_var fvar) instrs
let step_replace_var fvar step =
{ step with
step_checks = List.map (fun (l, v) -> (l, value_replace_var fvar v)) step.step_checks;
step_locals = Utils.remove_duplicates (List.map fvar step.step_locals);
step_instrs = instrs_replace_var fvar step.step_instrs;
}
let rec machine_replace_var fvar m =
{ m with
mstep = step_replace_var fvar m.mstep
}
let machine_reuse_var m reuse =
let reuse_vdecl = Hashtbl.create 23 in
begin
Hashtbl.iter (fun v v' -> Hashtbl.add reuse_vdecl (get_node_var v m.mname) (get_node_var v' m.mname)) reuse;
let fvar v =
try
Hashtbl.find reuse_vdecl v
with Not_found -> v in
machine_replace_var fvar m
end
let prog_reuse_var prog node_schs =
List.map
(fun m ->
machine_reuse_var m (Utils.IMap.find m.mname.node_id node_schs).Scheduling.reuse_table
) prog
(* Local Variables: *) (* Local Variables: *)
(* compile-command:"make -C .." *) (* compile-command:"make -C .." *)
......
...@@ -296,6 +296,9 @@ let rec compile basename extension = ...@@ -296,6 +296,9 @@ let rec compile basename extension =
(Utils.fprintf_list ~sep:"@ " Machine_code.pp_machine) (Utils.fprintf_list ~sep:"@ " Machine_code.pp_machine)
machine_code); machine_code);
(* experimental
let machine_code = Machine_code.prog_reuse_var machine_code node_schs in
*)
(* Optimize machine code *) (* Optimize machine code *)
let machine_code = let machine_code =
if !Options.optimization >= 2 then if !Options.optimization >= 2 then
......
...@@ -45,7 +45,7 @@ let update_elim outputs elim instr = ...@@ -45,7 +45,7 @@ let update_elim outputs elim instr =
(* When optimization >= 3, we also inline any basic operator call. (* When optimization >= 3, we also inline any basic operator call.
All those are returning a single ouput *) All those are returning a single ouput *)
| MStep([v], id, vl) when | MStep([v], id, vl) when
List.mem id Basic_library.internal_funs Basic_library.is_internal_fun id
&& !Options.optimization >= 3 && !Options.optimization >= 3
-> assert false -> assert false
(* true, apply elim v (Fun(id, vl))*) (* true, apply elim v (Fun(id, vl))*)
...@@ -53,7 +53,7 @@ let update_elim outputs elim instr = ...@@ -53,7 +53,7 @@ let update_elim outputs elim instr =
| MLocalAssign (v, ((Fun (id, il)) as e)) when | MLocalAssign (v, ((Fun (id, il)) as e)) when
not (List.mem v outputs) not (List.mem v outputs)
&& List.mem id Basic_library.internal_funs (* this will avoid inlining ite *) && Basic_library.is_internal_fun id (* this will avoid inlining ite *)
&& !Options.optimization >= 3 && !Options.optimization >= 3
-> ( -> (
(* Format.eprintf "WE STORE THE EXPRESSION DEFINING %s TO ELIMINATE IT@." v.var_id; *) (* Format.eprintf "WE STORE THE EXPRESSION DEFINING %s TO ELIMINATE IT@." v.var_id; *)
......
...@@ -37,8 +37,8 @@ type schedule_report = ...@@ -37,8 +37,8 @@ type schedule_report =
unused_vars : ISet.t; unused_vars : ISet.t;
(* the table mapping each local var to its in-degree *) (* the table mapping each local var to its in-degree *)
fanin_table : (ident, int) Hashtbl.t; fanin_table : (ident, int) Hashtbl.t;
(* the table mapping each assignment to a set of dead/reusable variables *) (* the table mapping each assignment to a reusable variable *)
death_table : (ident, ISet.t) Hashtbl.t reuse_table : (ident, ident) Hashtbl.t
} }
(* Topological sort with a priority for variables belonging in the same equation lhs. (* Topological sort with a priority for variables belonging in the same equation lhs.
...@@ -162,7 +162,7 @@ let schedule_node n = ...@@ -162,7 +162,7 @@ let schedule_node n =
Disjunction.pp_disjoint_map disjoint Disjunction.pp_disjoint_map disjoint
); );
let reuse = Liveness.reuse_policy n sort death in let reuse = Liveness.reuse_policy n sort death disjoint in
Log.report ~level:5 Log.report ~level:5
(fun fmt -> (fun fmt ->
Format.eprintf Format.eprintf
...@@ -171,7 +171,7 @@ let schedule_node n = ...@@ -171,7 +171,7 @@ let schedule_node n =
Liveness.pp_reuse_policy reuse Liveness.pp_reuse_policy reuse
); );
n', { schedule = sort; unused_vars = unused; fanin_table = fanin; death_table = death } n', { schedule = sort; unused_vars = unused; fanin_table = fanin; reuse_table = reuse }
with (Causality.Cycle v) as exc -> with (Causality.Cycle v) as exc ->
pp_error Format.err_formatter v; pp_error Format.err_formatter v;
raise exc raise exc
......
...@@ -53,6 +53,11 @@ let option_map f o = ...@@ -53,6 +53,11 @@ let option_map f o =
| None -> None | None -> None
| Some e -> Some (f e) | Some e -> Some (f e)
let rec remove_duplicates l =
match l with
| [] -> []
| t::q -> if List.mem t q then remove_duplicates q else t :: remove_duplicates q
let position pred l = let position pred l =
let rec pos p l = let rec pos p l =
match l with match l with
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment