diff --git a/src/causality.ml b/src/causality.ml index 8b389b1b083de7c1fdc1dfe6f36d9366d55b0331..e4df998c59d044dfbea936d51e080aa67d235324 100644 --- a/src/causality.ml +++ b/src/causality.ml @@ -101,14 +101,20 @@ module ExprDep = struct inputs/mems. a mem read var represents a mem at the beginning of a cycle *) let mk_read_var id = Format.sprintf "#%s" id - (* instance vars represent node instance calls, they are not part of the + (* instance vars represent node instance calls and returns, they are not part of the program/schedule, but used to simplify causality analysis *) - let mk_instance_var id = + let mk_instance_var eq f = incr instance_var_cpt; - Format.sprintf "!%s_%d" id !instance_var_cpt + Format.sprintf "%s_%d" f (fst eq.eq_loc).Lexing.pos_lnum + + let mk_call_instance_var id = + Format.sprintf "?%s" id + + let mk_return_instance_var id = + Format.sprintf "!%s" id let is_read_var v = v.[0] = '#' - let is_instance_var v = v.[0] = '!' + let is_instance_var v = v.[0] = '!' || v.[0] = '?' let is_ghost_var v = is_instance_var v || is_read_var v let undo_read_var id = @@ -132,6 +138,20 @@ module ExprDep = struct in match_mem eq.eq_lhs eq.eq_rhs mems + let eq_call_output_variables call_outputs eq = + let match_call_outputs lhs rhs call_outputs = + match rhs.expr_desc with + | Expr_appl (f, _, None) when not (Basic_library.is_expr_internal_fun rhs) -> + let f_var = mk_instance_var eq f + in List.fold_right (fun e -> IMap.add e f_var) lhs call_outputs + | Expr_appl (f, _, Some _) -> + let f_var = mk_instance_var eq f + in List.fold_right (fun e -> IMap.add e f_var) lhs call_outputs + | _ -> + call_outputs + in + match_call_outputs eq.eq_lhs eq.eq_rhs call_outputs + let node_memory_variables nd = List.fold_left eq_memory_variables ISet.empty (get_node_eqs nd) @@ -148,6 +168,9 @@ module ExprDep = struct ISet.empty (if nd.node_iscontract then [] else nd.node_outputs) + let node_call_output_variables nd = + List.fold_left eq_call_output_variables IMap.empty (get_node_eqs nd) + let node_local_variables nd = List.fold_left (fun locals (v, _) -> ISet.add v.var_id locals) @@ -207,7 +230,7 @@ module ExprDep = struct lhs g, g') | (false, false) -> (add_edges lhs [x] g, g') | (true , false) -> (add_edges lhs [x] g, g') | (true , true ) -> (g, add_edges [x] lhs g') *) - let add_eq_dependencies mems inputs node_vars eq (g, g') = + let add_eq_dependencies mems inputs node_vars call_output_vars eq (g, g') = let add_var lhs_is_mem lhs x (g, g') = if is_instance_var x || ISet.mem x node_vars then if ISet.mem x mems then @@ -215,8 +238,14 @@ module ExprDep = struct if lhs_is_mem then g, add_edges [ x ] lhs g' else add_edges [ x ] lhs g, g' else - let x = if ISet.mem x inputs then mk_read_var x else x in - add_edges lhs [ x ] g, g' + let dep_x = + if ISet.mem x inputs then [ mk_read_var x ] + else + try + let f_var = IMap.find x call_output_vars + in [ x; mk_return_instance_var f_var ] + with Not_found -> [ x ] in + add_edges lhs dep_x g, g' else add_edges lhs [ mk_read_var x ] g, g' (* x is a global constant, treated as a read var *) in @@ -238,15 +267,11 @@ module ExprDep = struct let rec add_dep lhs_is_mem lhs rhs g = (* Add mashup dependencies for a user-defined node instance [lhs] = [f]([e]) *) (* i.e every input is connected to every output, through a ghost var *) - let mashup_appl_dependencies f e g = - let f_var = - mk_instance_var - (Format.sprintf "%s_%d" f (fst eq.eq_loc).Lexing.pos_lnum) - in + let mashup_appl_dependencies call_var e g = List.fold_right - (fun rhs -> add_dep lhs_is_mem (adjust_tuple f_var rhs) rhs) + (fun rhs -> add_dep lhs_is_mem (adjust_tuple call_var rhs) rhs) (expr_list_of_expr e) - (add_var lhs_is_mem lhs f_var g) + (add_var lhs_is_mem lhs call_var g) in let g = add_clock lhs_is_mem lhs rhs.expr_clock g in match rhs.expr_desc with @@ -284,15 +309,25 @@ module ExprDep = struct add_dep lhs_is_mem lhs e2 (add_dep lhs_is_mem lhs e1 g) | Expr_when (e, c, _) -> add_dep lhs_is_mem lhs e (add_var lhs_is_mem lhs c g) - | Expr_appl (f, e, None) -> + | Expr_appl (_, e, None) -> if Basic_library.is_expr_internal_fun rhs (* tuple component-wise dependency for internal operators *) then List.fold_right (add_dep lhs_is_mem lhs) (expr_list_of_expr e) g (* mashed up dependency for user-defined operators *) - else mashup_appl_dependencies f e g - | Expr_appl (f, e, Some c) -> - mashup_appl_dependencies f e (add_dep lhs_is_mem lhs c g) + else let f_var = + try + IMap.find (List.hd lhs) call_output_vars + with Not_found -> assert false in + let call_var = mk_call_instance_var f_var + in mashup_appl_dependencies call_var e g + | Expr_appl (_, e, Some c) -> + let f_var = + try + IMap.find (List.hd lhs) call_output_vars + with Not_found -> assert false in + let call_var = mk_call_instance_var f_var + in mashup_appl_dependencies call_var e (add_dep lhs_is_mem lhs c g) in let g = List.fold_left @@ -304,14 +339,14 @@ module ExprDep = struct in add_dep false eq.eq_lhs eq.eq_rhs (g, g') - (* Returns the dependence graph for node [n] *) - let dependence_graph mems inputs node_vars n = + (* Returns the dependency graph for node [n] *) + let dependence_graph mems inputs node_vars call_output_vars n = instance_var_cpt := 0; let g = new_graph (), new_graph () in (* Basic dependencies *) let g = List.fold_right - (add_eq_dependencies mems inputs node_vars) + (add_eq_dependencies mems inputs node_vars call_output_vars) (get_node_eqs n) g in @@ -753,6 +788,9 @@ let merge_with g1 g2 = let world = "!!_world" +let add_call_return_dependency call_output_vars g = + IMap.iter (fun o i -> IdentDepGraph.add_edge g (ExprDep.mk_return_instance_var i) o) call_output_vars + let add_external_dependency outputs mems g = IdentDepGraph.add_vertex g world; ISet.iter (fun o -> IdentDepGraph.add_edge g world o) outputs; @@ -769,8 +807,9 @@ let global_dependency node = in let outputs = ExprDep.node_output_variables node in let node_vars = ExprDep.node_variables node in + let call_output_vars = ExprDep.node_call_output_variables node in let g_non_mems, g_mems = - ExprDep.dependence_graph mems inputs node_vars node + ExprDep.dependence_graph mems inputs node_vars call_output_vars node in (*Format.eprintf "g_non_mems: %a" pp_dep_graph g_non_mems; Format.eprintf "g_mems: %a" pp_dep_graph g_mems;*) @@ -779,6 +818,7 @@ let global_dependency node = let vdecls', eqs', g_mems' = CycleDetection.break_cycles node mems g_mems in (*Format.eprintf "g_mems': %a" pp_dep_graph g_mems';*) merge_with g_non_mems g_mems'; + add_call_return_dependency call_output_vars g_non_mems; add_external_dependency outputs mems g_non_mems; ( { node with diff --git a/src/causality.mli b/src/causality.mli index 4e187c7a67547cf13db9d3fc4c1fdbbe045ba291..08dab51ec0f538c88c00895b8ec4e4c4e09302bc 100644 --- a/src/causality.mli +++ b/src/causality.mli @@ -35,7 +35,9 @@ val slice_graph : IdentDepGraph.t -> ident -> IdentDepGraph.t module ExprDep : sig (* instance vars represent node instance calls, they are not part of the program/schedule, but used to simplify causality analysis *) - val mk_instance_var : ident -> ident + val mk_instance_var : eq -> ident -> ident + val mk_call_instance_var : ident -> ident + val mk_return_instance_var : ident -> ident val mk_read_var : ident -> ident val is_instance_var : ident -> bool val is_ghost_var : ident -> bool @@ -46,6 +48,7 @@ module ExprDep : sig val node_input_variables : node_desc -> ISet.t val node_local_variables : node_desc -> ISet.t val node_output_variables : node_desc -> ISet.t + val node_call_output_variables : node_desc -> ident IMap.t val node_memory_variables : node_desc -> ISet.t end diff --git a/src/scheduling.ml b/src/scheduling.ml index a8efd9350f1d8be1efb668d3f9101a78e1b9b01f..afc8206b8481f68eae71a096db89f208f5e86883 100644 --- a/src/scheduling.ml +++ b/src/scheduling.ml @@ -174,7 +174,7 @@ let remove_node_inlined_locals locals report = IMap.iter (fun v _ -> Hashtbl.remove report.fanin_table v) locals; IMap.iter (fun v _ -> - let iv = ExprDep.mk_instance_var v in + let iv = ExprDep.(mk_call_instance_var (mk_call_instance_var v)) in Liveness.replace_in_dep_graph v iv report.dep_graph) locals; { report with schedule = schedule' }