(* Title: HOL/Tools/BNF/bnf_fp_n2m.ML Author: Dmitriy Traytel, TU Muenchen Copyright 2013
Flattening of nested to mutual (co)recursion.
*)
signature BNF_FP_N2M = sig val construct_mutualized_fp: BNF_Util.fp_kind -> int list -> typ list ->
(int * BNF_FP_Util.fp_result) list -> binding list -> (string * sort) list ->
typ list * typ listlist -> BNF_Def.bnf list -> BNF_Comp.absT_info list -> local_theory ->
BNF_FP_Util.fp_result * local_theory end;
structure BNF_FP_N2M : BNF_FP_N2M = struct
open BNF_Def open BNF_Util open BNF_Comp open BNF_FP_Util open BNF_FP_Def_Sugar open BNF_Tactics open BNF_FP_N2M_Tactics
fun mk_arg_cong ctxt n t = let val Us = fastype_of t |> strip_typeN n |> fst; val ((xs, ys), _) = ctxt
|> mk_Frees "x" Us
||>> mk_Frees "y" Us; val goal = Logic.list_implies (@{map 2} (curry mk_Trueprop_eq) xs ys,
mk_Trueprop_eq (list_comb (t, xs), list_comb (t, ys))); val vars = Variable.add_free_names ctxt goal []; in
Goal.prove_sorry ctxt vars [] goal (fn {context = ctxt, prems = _} =>
HEADGOAL (hyp_subst_tac ctxt THEN' rtac ctxt refl))
|> Thm.close_derivation \<^here> end;
val cacheN = "cache" fun mk_cacheN i = cacheN ^ string_of_int i ^ "_"; val cache_threshold = Attrib.setup_config_int \<^binding>\<open>bnf_n2m_cache_threshold\<close> (K 200); type cache = int * (term * thm) Typtab.table val empty_cache = (0, Typtab.empty) fun update_cache b0 TU t (cache as (i, tab), lthy) = if size_of_term t < Config.get lthy cache_threshold then (t, (cache, lthy)) else let val b = Binding.prefix_name (mk_cacheN i) b0; val ((c, thm), lthy') =
Local_Theory.define ((b, NoSyn), ((Binding.concealed (Thm.def_binding b), []), t)) lthy
|>> apsnd snd; in
(c, ((i + 1, Typtab.update (TU, (c, thm)) tab), lthy')) end;
fun lookup_cache (SOME _) _ _ = NONE
| lookup_cache NONE TU ((_, tab), _) = Typtab.lookup tab TU |> Option.map fst;
fun construct_mutualized_fp fp mutual_cliques fpTs indexed_fp_ress bs resBs (resDs, Dss) bnfs
(absT_infos : absT_info list) lthy = let val time = time lthy; val timer = time (Timer.startRealTimer ());
val b_names = map Binding.name_of bs; val b_name = mk_common_name b_names; val b = Binding.name b_name;
fun of_fp_res get = map (uncurry nth o swap o apsnd get) indexed_fp_ress; fun mk_co_algT T U = case_fp fp (T --> U) (U --> T); fun co_swap pair = case_fp fp I swap pair; val mk_co_comp = curry (HOLogic.mk_comp o co_swap);
val dest_co_algT = co_swap o dest_funT; val co_alg_argT = case_fp fp range_type domain_type; val co_alg_funT = case_fp fp domain_type range_type; val rewrite_comp_comp = case_fp fp @{thm rewriteL_comp_comp} @{thm rewriteR_comp_comp};
val fp_absT_infos = of_fp_res #absT_infos; val fp_bnfs = of_fp_res #bnfs; val fp_pre_bnfs = of_fp_res #pre_bnfs;
val fp_absTs = map #absT fp_absT_infos; val fp_repTs = map #repT fp_absT_infos; val fp_abss = map #abs fp_absT_infos; val fp_reps = map #rep fp_absT_infos; val fp_type_definitions = map #type_definition fp_absT_infos;
val absTs = map #absT absT_infos; val repTs = map #repT absT_infos; val absTs' = map (Logic.type_map (singleton (Variable.polymorphic lthy))) absTs; val repTs' = map (Logic.type_map (singleton (Variable.polymorphic lthy))) repTs; val abss = map #abs absT_infos; val reps = map #rep absT_infos; val abs_inverses = map #abs_inverse absT_infos; val type_definitions = map #type_definition absT_infos;
val n = length bnfs; val deads = fold (union (op =)) Dss resDs; val As = subtract (op =) deads (map TFree resBs); val names_lthy = fold Variable.declare_typ (As @ deads) lthy; val m = length As; val live = m + n;
val (((Xs, Ys), Bs), names_lthy) = names_lthy
|> mk_TFrees n
||>> mk_TFrees n
||>> mk_TFrees m;
val allAs = As @ Xs; val allBs = Bs @ Xs; val phiTs = map2 mk_pred2T As Bs; val thetaBs = As ~~ Bs; val fpTs' = map (Term.typ_subst_atomic thetaBs) fpTs; val fold_thetaAs = Xs ~~ fpTs; val fold_thetaBs = Xs ~~ fpTs'; val pre_phiTs = map2 mk_pred2T fpTs fpTs';
val ((ctors, dtors), (xtor's, xtors)) = let val ctors = map2 (force_typ names_lthy o (fn T => dummyT --> T)) fpTs (of_fp_res #ctors); val dtors = map2 (force_typ names_lthy o (fn T => T --> dummyT)) fpTs (of_fp_res #dtors); in
((ctors, dtors), `(map (Term.subst_atomic_types thetaBs)) (case_fp fp ctors dtors)) end;
val absATs = map (domain_type o fastype_of) ctors; val absBTs = map (Term.typ_subst_atomic thetaBs) absATs; val xTs = map (domain_type o fastype_of) xtors; val yTs = map (domain_type o fastype_of) xtor's;
val absAs = @{map 3} (fn Ds => mk_abs o mk_T_of_bnf Ds allAs) Dss bnfs abss; val absBs = @{map 3} (fn Ds => mk_abs o mk_T_of_bnf Ds allBs) Dss bnfs abss; val fp_repAs = map2 mk_rep absATs fp_reps; val fp_repBs = map2 mk_rep absBTs fp_reps;
val typ_subst_nonatomic_sorted = fold_rev (typ_subst_nonatomic o single); val sorted_theta = sort (int_ord o apply2 (Term.size_of_typ o fst)) (fpTs ~~ Xs) val sorted_fpTs = map fst sorted_theta;
val nesting_bnfs = nesting_bnfs lthy
[[map (typ_subst_nonatomic_sorted (rev sorted_theta) o range_type o fastype_of) fp_repAs]]
allAs; val fp_or_nesting_bnfs = distinct (op = o apply2 T_of_bnf) (fp_bnfs @ nesting_bnfs);
val rel_unfolds = maps (no_refl o single o rel_def_of_bnf) fp_pre_bnfs; val rel_xtor_co_inducts = of_fp_res (split_conj_thm o #xtor_rel_co_induct)
|> map (zero_var_indexes o unfold_thms lthy (id_apply :: rel_unfolds));
val rel_defs = map rel_def_of_bnf bnfs; val rel_monos = map rel_mono_of_bnf bnfs;
fun cast castA castB pre_rel = let val castAB = mk_vimage2p (Term.subst_atomic_types fold_thetaAs castA)
(Term.subst_atomic_types fold_thetaBs castB); in
fold_rev (fold_rev Term.absdummy) [phiTs, pre_phiTs]
(castAB $ Term.list_comb (pre_rel, map Bound (live - 1 downto 0))) end;
val castAs = map2 (curry HOLogic.mk_comp) absAs fp_repAs; val castBs = map2 (curry HOLogic.mk_comp) absBs fp_repBs;
val fp_or_nesting_rel_eqs = no_refl (map rel_eq_of_bnf fp_or_nesting_bnfs); val fp_or_nesting_rel_monos = map rel_mono_of_bnf fp_or_nesting_bnfs;
fun mutual_instantiate ctxt inst = let val thetas = AList.group (op =) (mutual_cliques ~~ inst); in
map2 (infer_instantiate ctxt o the o AList.lookup (op =) thetas) mutual_cliques end;
val rel_xtor_co_inducts_inst = let val extract =
case_fp fp (snd o Term.dest_comb) (snd o Term.dest_comb o fst o Term.dest_comb); val raw_phis = map (extract o HOLogic.dest_Trueprop o Thm.concl_of) rel_xtor_co_inducts; val inst = map (fn (t, u) => (#1 (dest_Var t), Thm.cterm_of lthy u)) (raw_phis ~~ pre_phis); in
mutual_instantiate lthy inst rel_xtor_co_inducts end
val thy = Proof_Context.theory_of lthy; fun mk_absT_fp_repT repT absT = mk_absT thy repT absT ooo mk_repT;
fun mk_un_fold b_opt ss un_folds cache_lthy TU =
(case lookup_cache b_opt TU cache_lthy of
SOME t => ((t, Drule.dummy_thm), cache_lthy)
| NONE => let val x = co_alg_argT TU; val i = find_index (fn T => x = T) Xs; val TUfold =
(case find_first (fn f => body_fun_type (fastype_of f) = TU) un_folds of
NONE => force_fold i TU (nth fp_un_folds i)
| SOME f => f);
val TUs = binder_fun_types (fastype_of TUfold);
fun mk_s TU' cache_lthy = let val i = find_index (fn T => co_alg_argT TU' = T) Xs; val fp_abs = nth fp_abss i; val fp_rep = nth fp_reps i; val abs = nth abss i; val rep = nth reps i; val sF = co_alg_funT TU'; val sF' =
mk_absT_fp_repT (nth repTs' i) (nth absTs' i) (nth fp_absTs i) (nth fp_repTs i) sF handle Term.TYPE _ => sF; val F = nth fold_preTs i; val s = nth ss i; in if sF = F then (s, cache_lthy) elseif sF' = F then (mk_co_comp_abs_rep sF sF' fp_abs fp_rep abs rep s, cache_lthy) else let val smapT = replicate live dummyT ---> mk_co_algT sF' F; fun hidden_to_unit t =
Term.subst_TVars (map (rpair HOLogic.unitT) (Term.add_tvar_names t [])) t; val smap = map_of_bnf (nth bnfs i)
|> force_typ names_lthy smapT
|> hidden_to_unit; val smap_argTs = strip_typeN live (fastype_of smap) |> fst; fun mk_smap_arg T_to_U cache_lthy =
(if domain_type T_to_U = range_type T_to_U then
(HOLogic.id_const (domain_type T_to_U), cache_lthy) else
mk_un_fold NONE ss un_folds cache_lthy T_to_U |>> fst); val (smap_args, cache_lthy') = fold_map mk_smap_arg smap_argTs cache_lthy; in
(mk_co_comp_abs_rep sF sF' fp_abs fp_rep abs rep
(mk_co_comp s (Term.list_comb (smap, smap_args))), cache_lthy') end end; val (args, cache_lthy) = fold_map mk_s TUs cache_lthy; val t = Term.list_comb (TUfold, args); in
(case b_opt of
NONE => update_cache b TU t cache_lthy |>> rpair Drule.dummy_thm
| SOME b => cache_lthy
|-> (fn cache => let val S = HOLogic.mk_tupleT fold_strTs; val s = HOLogic.mk_tuple ss; val u = Const (\<^const_name>\<open>Let\<close>, S --> (S --> TU) --> TU) $ s $ absdummy S t; in
Local_Theory.define ((b, NoSyn), ((Binding.concealed (Thm.def_binding b), []), u))
#>> apsnd snd ##> pair cache end)) end);
val un_foldN = case_fp fp ctor_foldN dtor_unfoldN; fun mk_un_folds (ss_names, lthy) = letval ss = map2 (curry Free) ss_names fold_strTs; in
fold2 (fn TU => fn b => fn ((un_folds, defs), cache_lthy) =>
mk_un_fold (SOME b) (map2 (curry Free) ss_names fold_strTs) un_folds cache_lthy TU
|>> (fn (f, d) => (f :: un_folds, d :: defs)))
resTs (map (Binding.suffix_name ("_" ^ un_foldN)) bs) (([], []), (empty_cache, lthy))
|>> map_prod rev rev
|>> pair ss end; val ((ss, (un_folds, un_fold_defs0)), (cache, (lthy, raw_lthy))) = lthy
|> (snd o Local_Theory.begin_nested)
|> Variable.add_fixes (mk_names n "s")
|> mk_un_folds
||> apsnd (`(Local_Theory.end_nested));
val un_fold_defs = map (unfold_thms raw_lthy @{thms Let_const}) un_fold_defs0;
val cache_defs = snd cache |> Typtab.dest |> map (snd o snd);
val phi = Proof_Context.export_morphism raw_lthy lthy;
val xtor_un_folds = map (head_of o Morphism.term phi) un_folds; val xtor_un_fold_defs = map (Drule.abs_def o Morphism.thm phi) un_fold_defs; val xtor_cache_defs = map (Drule.abs_def o Morphism.thm phi) cache_defs; val xtor_un_folds' = map2 (fn raw => fn t => Const (dest_Const_name t, fold_strTs ---> fastype_of raw))
un_folds xtor_un_folds;
val xtor_un_fold_thms = let val pre_fold_maps = mk_pre_fold_maps un_folds; fun mk_goals f xtor s smap fp_abs fp_rep abs rep = let val lhs = mk_co_comp f xtor; val rhs = mk_co_comp s smap; in
HOLogic.mk_eq (lhs,
mk_co_comp_abs_rep (co_alg_funT (fastype_of lhs)) (co_alg_funT (fastype_of rhs))
fp_abs fp_rep abs rep rhs) end;
val goals =
@{map 8} mk_goals un_folds xtors ss pre_fold_maps fp_abss fp_reps abss reps;
val fp_un_folds = map (mk_pointfree2 lthy) (of_fp_res #xtor_un_fold_thms);
Die Informationen auf dieser Webseite wurden
nach bestem Wissen sorgfältig zusammengestellt. Es wird jedoch weder Vollständigkeit, noch Richtigkeit,
noch Qualität der bereit gestellten Informationen zugesichert.
Bemerkung:
Die farbliche Syntaxdarstellung ist noch experimentell.