(* Title: HOL/Tools/BNF/bnf_lfp_compat.ML
Author: Jasmin Blanchette, TU Muenchen
Copyright 2013, 2014
Compatibility layer with the old datatype package. Partly based on
Title: HOL/Tools/Old_Datatype/old_datatype_data.ML
Author: Stefan Berghofer, TU Muenchen
signature BNF_LFP_COMPAT =
datatype preference = Keep_Nesting | Include_GFPs | Kill_Type_Args
val get_all: theory -> preference list -> Old_Datatype_Aux.info Symtab.table
val get_info: theory -> preference list -> string -> Old_Datatype_Aux.info option
val the_info: theory -> preference list -> string -> Old_Datatype_Aux.info
val the_spec: theory -> string -> (string * sort) list * (string * typ list) list
val the_descr: theory -> preference list -> string list ->
Old_Datatype_Aux.descr * (string * sort) list * string list * string
* (string list * string list) * (typ list * typ list)
val get_constrs: theory -> string -> (string * typ) list option
val interpretation: string -> preference list ->
(Old_Datatype_Aux.config -> string list -> theory -> theory) -> theory -> theory
val datatype_compat: string list -> local_theory -> local_theory
val datatype_compat_global: string list -> theory -> theory
val datatype_compat_cmd: string list -> local_theory -> local_theory
val add_datatype: preference list -> Old_Datatype_Aux.spec list -> theory -> string list * theory
val primrec: (binding * typ option * mixfix) list -> Specification.multi_specs ->
local_theory -> (term list * thm list) * local_theory
val primrec_global: (binding * typ option * mixfix) list ->
Specification.multi_specs -> theory -> (term list * thm list) * theory
val primrec_overloaded: (string * (string * typ) * bool) list ->
(binding * typ option * mixfix) list -> Specification.multi_specs -> theory ->
(term list * thm list) * theory
val primrec_simple: ((binding * typ) * mixfix) list -> term list ->
local_theory -> (string list * (term list * thm list)) * local_theory
structure BNF_LFP_Compat : BNF_LFP_COMPAT =
open Ctr_Sugar
open BNF_Util
open BNF_Tactics
open BNF_FP_Util
open BNF_FP_Def_Sugar
open BNF_FP_N2M_Sugar
open BNF_LFP
val compat_N = "compat_";
val rec_split_N = "rec_split_";
datatype preference = Keep_Nesting | Include_GFPs | Kill_Type_Args;
fun mk_split_rec_rhs ctxt fpTs Cs (recs as rec1 :: _) =
fun repair_rec_arg_args [] [] = []
| repair_rec_arg_args ((g_T as Type (\<^type_name>\<open>fun\<close>, _)) :: g_Ts) (g :: gs) =
val (x_Ts, body_T) = strip_type g_T;
(case try HOLogic.dest_prodT body_T of
NONE => [g]
| SOME (fst_T, _) =>
if member (op =) fpTs fst_T then
let val (xs, _) = mk_Frees "x" x_Ts ctxt in
map (fn mk_proj => fold_rev Term.lambda xs (mk_proj (Term.list_comb (g, xs))))
[HOLogic.mk_fst, HOLogic.mk_snd]
:: repair_rec_arg_args g_Ts gs
| repair_rec_arg_args (g_T :: g_Ts) (g :: gs) =
if member (op =) fpTs g_T then
val j = find_index (member (op =) Cs) g_Ts;
val h = nth gs j;
val g_Ts' = nth_drop j g_Ts;
val gs' = nth_drop j gs;
[g, h] :: repair_rec_arg_args g_Ts' gs'
[g] :: repair_rec_arg_args g_Ts gs;
fun repair_back_rec_arg f_T f' =
val g_Ts = Term.binder_types f_T;
val (gs, _) = mk_Frees "g" g_Ts ctxt;
fold_rev Term.lambda gs (Term.list_comb (f',
flat_rec_arg_args (repair_rec_arg_args g_Ts gs)))
val f_Ts = binder_fun_types (fastype_of rec1);
val (fs', _) = mk_Frees "f" (replicate (length f_Ts) Term.dummyT) ctxt;
fun mk_rec' recx =
fold_rev Term.lambda fs' (Term.list_comb (recx, map2 repair_back_rec_arg f_Ts fs'))
|> Syntax.check_term ctxt;
map mk_rec' recs
fun define_split_recs fpTs Cs recs lthy =
val b_names = Name.variant_list [] (map base_name_of_typ fpTs);
fun mk_binding b_name =
Binding.qualify true (compat_N ^ b_name)
(Binding.prefix_name rec_split_N (Binding.name b_name));
val bs = map mk_binding b_names;
val rhss = mk_split_rec_rhs lthy fpTs Cs recs;
@{fold_map 3} (define_co_rec_as Least_FP Cs) fpTs bs rhss lthy
fun mk_split_rec_thmss ctxt Xs ctrXs_Tsss ctrss rec0_thmss (recs as rec1 :: _) rec_defs =
val f_Ts = binder_fun_types (fastype_of rec1);
val (fs, _) = mk_Frees "f" f_Ts ctxt;
val frecs = map (fn recx => Term.list_comb (recx, fs)) recs;
val Xs_frecs = Xs ~~ frecs;
val fss = unflat ctrss fs;
fun mk_rec_call g n (Type (\<^type_name>\<open>fun\<close>, [_, ran_T])) =
Abs (Name.uu, Term.dummyT, mk_rec_call g (n + 1) ran_T)
| mk_rec_call g n X =
val frec = the (AList.lookup (op =) Xs_frecs X);
val xg = Term.list_comb (g, map Bound (n - 1 downto 0));
in frec $ xg end;
fun mk_rec_arg_arg ctrXs_T g =
g :: (if member (op =) Xs (body_type ctrXs_T) then [mk_rec_call g 0 ctrXs_T] else []);
fun mk_goal frec ctrXs_Ts ctr f =
val ctr_Ts = binder_types (fastype_of ctr);
val (gs, _) = mk_Frees "g" ctr_Ts ctxt;
val gctr = Term.list_comb (ctr, gs);
val fgs = flat_rec_arg_args (map2 mk_rec_arg_arg ctrXs_Ts gs);
fold_rev (fold_rev Logic.all) [fs, gs]
(mk_Trueprop_eq (frec $ gctr, Term.list_comb (f, fgs)))
|> Syntax.check_term ctxt
val goalss = @{map 4} (@{map 3} o mk_goal) frecs ctrXs_Tsss ctrss fss;
fun tac ctxt =
unfold_thms_tac ctxt (@{thms o_apply fst_conv snd_conv} @ rec_defs @ flat rec0_thmss) THEN
HEADGOAL (rtac ctxt refl);
fun prove goal =
Goal.prove_sorry ctxt [] [] goal (tac o #context)
|> Thm.close_derivation \<^here>;
map (map prove) goalss
fun define_split_rec_derive_induct_rec_thms Xs fpTs ctrXs_Tsss ctrss inducts induct recs0 rec_thmss
lthy =
val thy = Proof_Context.theory_of lthy;
(* imperfect: will not yield the expected theorem for functions taking a large number of
arguments *)
val repair_induct = unfold_thms lthy @{thms all_mem_range};
val inducts' = map repair_induct inducts;
val induct' = repair_induct induct;
val Cs = map ((fn TVar ((s, _), S) => TFree (s, S)) o body_type o fastype_of) recs0;
val recs = map2 (mk_co_rec thy Least_FP Cs) fpTs recs0;
val ((recs', rec'_defs), lthy') = define_split_recs fpTs Cs recs lthy |>> split_list;
val rec'_thmss = mk_split_rec_thmss lthy' Xs ctrXs_Tsss ctrss rec_thmss recs' rec'_defs;
((inducts', induct', recs', rec'_thmss), lthy')
fun body_rec_indices (Old_Datatype_Aux.DtRec kk) = [kk]
| body_rec_indices (Old_Datatype_Aux.DtType (\<^type_name>\<open>fun\<close>, [_, D])) = body_rec_indices D
| body_rec_indices _ = [];
fun reindex_desc desc =
val kks = map fst desc;
val perm_kks = sort int_ord kks;
fun perm_dtyp (Old_Datatype_Aux.DtType (s, Ds)) = Old_Datatype_Aux.DtType (s, map perm_dtyp Ds)
| perm_dtyp (Old_Datatype_Aux.DtRec kk) =
Old_Datatype_Aux.DtRec (find_index (curry (op =) kk) kks)
| perm_dtyp D = D;
if perm_kks = kks then
perm_kks ~~
map (fn (_, (s, Ds, sDss)) => (s, map perm_dtyp Ds, map (apsnd (map perm_dtyp)) sDss)) desc
fun mk_infos_of_mutually_recursive_new_datatypes prefs check_names fpT_names0 lthy =
val thy = Proof_Context.theory_of lthy;
fun not_datatype_name s =
error (quote s ^ " is not a datatype");
fun not_mutually_recursive ss =
error ("{" ^ commas ss ^ "} is not a complete set of mutually recursive datatypes");
fun checked_fp_sugar_of s =
(case fp_sugar_of lthy s of
SOME (fp_sugar as {fp, fp_co_induct_sugar = SOME _, ...}) =>
if member (op =) prefs Include_GFPs orelse fp = Least_FP then fp_sugar
else not_datatype_name s
| _ => not_datatype_name s);
val fpTs0 as Type (_, var_As) :: _ =
map (#T o checked_fp_sugar_of o fst o dest_Type)
(#Ts (#fp_res (checked_fp_sugar_of (hd fpT_names0))));
val fpT_names as fpT_name1 :: _ = map (fst o dest_Type) fpTs0;
val _ = check_names (op =) (fpT_names0, fpT_names) orelse not_mutually_recursive fpT_names0;
val (As_names, _) = Variable.variant_fixes (map (fn TVar ((s, _), _) => s) var_As) lthy;
val As = map2 (fn s => fn TVar (_, S) => TFree (s, S)) As_names var_As;
val fpTs = map (fn s => Type (s, As)) fpT_names;
val nn_fp = length fpTs;
val mk_dtyp = Old_Datatype_Aux.dtyp_of_typ (map (apsnd (map Term.dest_TFree) o dest_Type) fpTs);
fun mk_ctr_descr Ts = mk_ctr Ts #> dest_Const ##> (binder_types #> map mk_dtyp);
fun mk_typ_descr index (Type (T_name, Ts)) ({ctrs, ...} : ctr_sugar) =
(index, (T_name, map mk_dtyp Ts, map (mk_ctr_descr Ts) ctrs));
val fp_sugars as {fp, ...} :: _ = map checked_fp_sugar_of fpT_names;
val fp_ctr_sugars = map (#ctr_sugar o #fp_ctr_sugar) fp_sugars;
val orig_descr = @{map 3} mk_typ_descr (0 upto nn_fp - 1) fpTs fp_ctr_sugars;
val all_infos = Old_Datatype_Data.get_all thy;
val (orig_descr' :: nested_descrs) =
if member (op =) prefs Keep_Nesting then [orig_descr]
else fst (Old_Datatype_Aux.unfold_datatypes lthy orig_descr all_infos orig_descr nn_fp);
fun cliquify_descr [] = []
| cliquify_descr [entry] = [[entry]]
| cliquify_descr (full_descr as (_, (T_name1, _, _)) :: _) =
val nn =
if member (op =) fpT_names T_name1 then
(case Symtab.lookup all_infos T_name1 of
SOME {descr, ...} =>
length (filter_out (exists Old_Datatype_Aux.is_rec_type o #2 o snd) descr)
| NONE => raise Fail "unknown old-style datatype");
chop nn full_descr ||> cliquify_descr |> op ::
(* put nested types before the types that nest them, as needed for N2M *)
val descrs = burrow reindex_desc (orig_descr' :: rev nested_descrs);
val (mutual_cliques, descr) =
split_list (flat (map_index (fn (i, descr) => map (pair i) descr)
(maps cliquify_descr descrs)));
val fpTs' = Old_Datatype_Aux.get_rec_types descr;
val nn = length fpTs';
val fp_sugars = map (checked_fp_sugar_of o fst o dest_Type) fpTs';
val ctr_Tsss = map (map (map (Old_Datatype_Aux.typ_of_dtyp descr) o snd) o #3 o snd) descr;
val kkssss = map (map (map body_rec_indices o snd) o #3 o snd) descr;
val callers = map (fn kk => Var ((Name.uu, kk), \<^typ>\<open>unit => unit\<close>)) (0 upto nn - 1);
fun apply_comps n kk =
mk_partial_compN n (replicate n HOLogic.unitT ---> HOLogic.unitT) (nth callers kk);
val callssss = map2 (map2 (map2 (map o apply_comps o num_binder_types))) ctr_Tsss kkssss;
val b_names = Name.variant_list [] (map base_name_of_typ fpTs');
val compat_b_names = map (prefix compat_N) b_names;
val compat_bs = map Binding.name compat_b_names;
val ((fp_sugars', (lfp_sugar_thms', _)), lthy') =
if nn > nn_fp then
mutualize_fp_sugars (K true) Least_FP mutual_cliques compat_bs fpTs' callers callssss
fp_sugars lthy
((fp_sugars, (NONE, NONE)), lthy);
fun mk_ctr_of ({fp_ctr_sugar = {ctr_sugar = {ctrs, ...}, ...}, ...} : fp_sugar) (Type (_, Ts)) =
map (mk_ctr Ts) ctrs;
val substAT = Term.typ_subst_atomic (var_As ~~ As);
val Xs' = map #X fp_sugars';
val ctrXs_Tsss' = map (map (map substAT) o #ctrXs_Tss o #fp_ctr_sugar) fp_sugars';
val ctrss' = map2 mk_ctr_of fp_sugars' fpTs';
val {fp_co_induct_sugar = SOME {common_co_inducts = induct :: _, ...}, ...} :: _ = fp_sugars';
val inducts = map (hd o #co_inducts o the o #fp_co_induct_sugar) fp_sugars';
val recs = map (#co_rec o the o #fp_co_induct_sugar) fp_sugars';
val rec_thmss = map (#co_rec_thms o the o #fp_co_induct_sugar) fp_sugars';
fun is_nested_rec_type (Type (\<^type_name>\<open>fun\<close>, [_, T])) = member (op =) Xs' (body_type T)
| is_nested_rec_type _ = false;
val ((lfp_sugar_thms'', (inducts', induct', recs', rec'_thmss)), lthy'') =
if member (op =) prefs Keep_Nesting orelse
not (exists (exists (exists is_nested_rec_type)) ctrXs_Tsss') then
((lfp_sugar_thms', (inducts, induct, recs, rec_thmss)), lthy')
else if fp = Least_FP then
define_split_rec_derive_induct_rec_thms Xs' fpTs' ctrXs_Tsss' ctrss' inducts induct recs
rec_thmss lthy'
|>> `(fn (inducts', induct', _, rec'_thmss) =>
SOME ((inducts', induct', mk_induct_attrs ctrss'), (rec'_thmss, [])))
not_datatype_name fpT_name1;
val rec'_names = map (fst o dest_Const) recs';
val rec'_thms = flat rec'_thmss;
fun mk_info (kk, {T = Type (T_name0, _), fp_ctr_sugar = {ctr_sugar = {casex, exhaust, nchotomy,
injects, distincts, case_thms, case_cong, case_cong_weak, split,
split_asm, ...}, ...}, ...} : fp_sugar) =
{index = kk, descr = descr, inject = injects, distinct = distincts, induct = induct',
inducts = inducts', exhaust = exhaust, nchotomy = nchotomy, rec_names = rec'_names,
rec_rewrites = rec'_thms, case_name = fst (dest_Const casex), case_rewrites = case_thms,
case_cong = case_cong, case_cong_weak = case_cong_weak, split = split,
split_asm = split_asm});
val infos = map_index mk_info (take nn_fp fp_sugars');
(nn, b_names, compat_b_names, lfp_sugar_thms'', infos, lthy'')
fun infos_of_new_datatype_mutual_cluster lthy prefs fpT_name =
fun get prefs =
#5 (mk_infos_of_mutually_recursive_new_datatypes prefs subset [fpT_name] lthy)
handle ERROR _ => [];
(case get prefs of
[] => if member (op =) prefs Keep_Nesting then [] else get (Keep_Nesting :: prefs)
| infos => infos)
fun get_all thy prefs =
val ctxt = Proof_Context.init_global thy;
val old_info_tab = Old_Datatype_Data.get_all thy;
val new_T_names = BNF_FP_Def_Sugar.fp_sugars_of_global thy
|> map_filter (try (fn {T = Type (s, _), fp_res_index = 0, ...} => s));
val new_infos =
maps (infos_of_new_datatype_mutual_cluster ctxt (insert (op =) Keep_Nesting prefs))
fold (if member (op =) prefs Keep_Nesting then Symtab.update else Symtab.default) new_infos
fun get_one get_old get_new thy prefs x =
val (get_fst, get_snd) = (get_old thy, get_new thy) |> member (op =) prefs Keep_Nesting ? swap;
(case get_fst x of NONE => get_snd x | res => res)
fun get_info_of_new_datatype prefs thy T_name =
let val ctxt = Proof_Context.init_global thy in
AList.lookup (op =) (infos_of_new_datatype_mutual_cluster ctxt prefs T_name) T_name
fun get_info thy prefs =
get_one Old_Datatype_Data.get_info (get_info_of_new_datatype prefs) thy prefs;
fun the_info thy prefs T_name =
(case get_info thy prefs T_name of
SOME info => info
| NONE => error ("Unknown datatype " ^ quote T_name));
fun the_spec thy T_name =
val {descr, index, ...} = the_info thy [Keep_Nesting, Include_GFPs] T_name;
val (_, Ds, ctrs0) = the (AList.lookup (op =) descr index);
val tfrees = map Old_Datatype_Aux.dest_DtTFree Ds;
val ctrs = map (apsnd (map (Old_Datatype_Aux.typ_of_dtyp descr))) ctrs0;
in (tfrees, ctrs) end;
fun the_descr thy prefs (T_names0 as T_name01 :: _) =
fun not_mutually_recursive ss =
error ("{" ^ commas ss ^ "} is not a complete set of mutually recursive datatypes");
val info = the_info thy prefs T_name01;
val descr = #descr info;
val (_, Ds, _) = the (AList.lookup (op =) descr (#index info));
val vs = map Old_Datatype_Aux.dest_DtTFree Ds;
fun is_DtTFree (Old_Datatype_Aux.DtTFree _) = true
| is_DtTFree _ = false;
val k = find_index (fn (_, (_, dTs, _)) => not (forall is_DtTFree dTs)) descr;
val protoTs as (dataTs, _) = chop k descr
|> (apply2 o map)
(fn (_, (T_name, Ds, _)) => (T_name, map (Old_Datatype_Aux.typ_of_dtyp descr) Ds));
val T_names = map fst dataTs;
val _ = eq_set (op =) (T_names, T_names0) orelse not_mutually_recursive T_names0
val (Ts, Us) = apply2 (map Type) protoTs;
val names = map Long_Name.base_name T_names;
val (auxnames, _) = Name.make_context names
|> fold_map (Name.variant o Old_Datatype_Aux.name_of_typ) Us;
val prefix = space_implode "_" names;
(descr, vs, T_names, prefix, (names, auxnames), (Ts, Us))
fun get_constrs thy T_name =
try (the_spec thy) T_name
|> Option.map (fn (tfrees, ctrs) =>
fun varify_tfree (s, S) = TVar ((s, 0), S);
fun varify_typ (TFree x) = varify_tfree x
| varify_typ T = T;
val dataT = Type (T_name, map varify_tfree tfrees);
fun mk_ctr_typ Ts = map (Term.map_atyps varify_typ) Ts ---> dataT;
map (apsnd mk_ctr_typ) ctrs
fun old_interpretation_of prefs f config T_names thy =
if not (member (op =) prefs Keep_Nesting) orelse
exists (is_none o fp_sugar_of_global thy) T_names then
f config T_names thy
fun new_interpretation_of prefs f (fp_sugars : fp_sugar list) thy =
let val T_names = map (fst o dest_Type o #T) fp_sugars in
if (member (op =) prefs Include_GFPs orelse forall (curry (op =) Least_FP o #fp) fp_sugars)
andalso (member (op =) prefs Keep_Nesting orelse
exists (is_none o Old_Datatype_Data.get_info thy) T_names) then
f Old_Datatype_Aux.default_config T_names thy
fun interpretation name prefs f =
Old_Datatype_Data.interpretation (old_interpretation_of prefs f)
#> fp_sugars_interpretation name (Local_Theory.background_theory o new_interpretation_of prefs f);
val nitpicksimp_simp_attrs = @{attributes [nitpick_simp, simp]};
fun datatype_compat fpT_names lthy =
val (nn, b_names, compat_b_names, lfp_sugar_thms, infos, lthy') =
mk_infos_of_mutually_recursive_new_datatypes [] eq_set fpT_names lthy;
val (all_notes, rec_thmss) =
(case lfp_sugar_thms of
NONE => ([], [])
| SOME ((inducts, induct, induct_attrs), (rec_thmss, _)) =>
val common_name = compat_N ^ mk_common_name b_names;
val common_notes =
(if nn > 1 then [(inductN, [induct], induct_attrs)] else [])
|> filter_out (null o #2)
|> map (fn (thmN, thms, attrs) =>
((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
val notes =
[(inductN, map single inducts, induct_attrs),
(recN, rec_thmss, nitpicksimp_simp_attrs)]
|> filter_out (null o #2)
|> maps (fn (thmN, thmss, attrs) =>
if forall null thmss then
map2 (fn b_name => fn thms =>
((Binding.qualify true b_name (Binding.name thmN), attrs), [(thms, [])]))
compat_b_names thmss);
(common_notes @ notes, rec_thmss)
val register_interpret =
Old_Datatype_Data.register infos
#> Old_Datatype_Data.interpretation_data (Old_Datatype_Aux.default_config, map fst infos);
|> Local_Theory.raw_theory register_interpret
|> Local_Theory.notes all_notes
|> snd
|> Code.declare_default_eqns (map (rpair true) (flat rec_thmss))
val datatype_compat_global = Named_Target.theory_map o datatype_compat;
fun datatype_compat_cmd raw_fpT_names lthy =
val fpT_names =
map (fst o dest_Type o Proof_Context.read_type_name {proper = true, strict = false} lthy)
datatype_compat fpT_names lthy
fun add_datatype prefs old_specs thy =
val fpT_names = map (Sign.full_name thy o #1 o fst) old_specs;
fun new_type_args_of (s, S) =
(if member (op =) prefs Kill_Type_Args then NONE else SOME Binding.empty,
(TFree (s, \<^sort>\<open>type\<close>), S));
fun new_ctr_spec_of (b, Ts, mx) = (((Binding.empty, b), map (pair Binding.empty) Ts), mx);
fun new_spec_of ((b, old_tyargs, mx), old_ctr_specs) =
(((((map new_type_args_of old_tyargs, b), mx), map new_ctr_spec_of old_ctr_specs),
(Binding.empty, Binding.empty, Binding.empty)), []);
val new_specs = map new_spec_of old_specs;
|> Named_Target.theory_map
(co_datatypes Least_FP construct_lfp (default_ctr_options, new_specs))
|> not (member (op =) prefs Keep_Nesting) ? perhaps (try (datatype_compat_global fpT_names)))
fun old_of_new f (ts, _, simpss) = (ts, f simpss);
val primrec = apfst (old_of_new flat) ooo BNF_LFP_Rec_Sugar.primrec false [];
val primrec_global = apfst (old_of_new flat) ooo BNF_LFP_Rec_Sugar.primrec_global false [];
val primrec_overloaded = apfst (old_of_new flat) oooo BNF_LFP_Rec_Sugar.primrec_overloaded false [];
val primrec_simple = apfst (apfst fst o apsnd (old_of_new (flat o snd))) ooo
BNF_LFP_Rec_Sugar.primrec_simple false;
val _ =
Outer_Syntax.local_theory \<^command_keyword>\<open>datatype_compat\<close>
"register datatypes as old-style datatypes and derive old-style properties"
(Scan.repeat1 Parse.type_const >> datatype_compat_cmd);
