type basic_lfp_sugar =
{T: typ,
fp_res_index: int,
C: typ,
fun_arg_Tsss : typ listlistlist,
ctr_sugar: Ctr_Sugar.ctr_sugar,
recx: term,
rec_thms: thm list};
type lfp_rec_extension =
{nested_simps: thm list,
special_endgame_tac: Proof.context -> thm list -> thm list -> thm list -> tactic,
is_new_datatype: Proof.context -> string -> bool,
basic_lfp_sugars_of: binding list -> typ list -> term list ->
(term * term listlist) listlist -> local_theory ->
typ list * int list * basic_lfp_sugar list * thm list * thm list * thm list * thm
* Token.src list * bool * local_theory,
rewrite_nested_rec_call: (Proof.context -> (term -> bool) -> (string -> int) -> typ list ->
term -> term -> term -> term) option};
val register_lfp_rec_extension: lfp_rec_extension -> theory -> theory val default_basic_lfp_sugars_of: binding list -> typ list -> term list ->
(term * term listlist) listlist -> local_theory ->
typ list * int list * basic_lfp_sugar list * thm list * thm list * thm list * thm
* Token.src list * bool * local_theory val rec_specs_of: binding list -> typ list -> typ list -> term list ->
(term * term listlist) listlist -> local_theory ->
(bool * rec_spec list * typ list * thm * thm list * Token.src list * typ list) * local_theory
val lfp_rec_sugar_interpretation: string ->
(BNF_FP_Rec_Sugar_Util.fp_rec_sugar -> local_theory -> local_theory) -> theory -> theory
val primrec: bool -> rec_option list -> (binding * typ option * mixfix) list ->
Specification.multi_specs -> local_theory ->
(term list * thm list * thm listlist) * local_theory val primrec_cmd: bool -> rec_option list -> (binding * stringoption * mixfix) list ->
Specification.multi_specs_cmd -> local_theory ->
(term list * thm list * thm listlist) * local_theory val primrec_global: bool -> rec_option list -> (binding * typ option * mixfix) list ->
Specification.multi_specs -> theory -> (term list * thm list * thm listlist) * theory val primrec_overloaded: bool -> rec_option list -> (string * (string * typ) * bool) list ->
(binding * typ option * mixfix) list ->
Specification.multi_specs -> theory -> (term list * thm list * thm listlist) * theory val primrec_simple: bool -> ((binding * typ) * mixfix) list -> term list -> local_theory ->
((stringlist * (binding -> binding) list)
* (term list * thm list * (int listlist * thm listlist))) * local_theory end;
type basic_lfp_sugar =
{T: typ,
fp_res_index: int,
C: typ,
fun_arg_Tsss : typ listlistlist,
ctr_sugar: ctr_sugar,
recx: term,
rec_thms: thm list};
type lfp_rec_extension =
{nested_simps: thm list,
special_endgame_tac: Proof.context -> thm list -> thm list -> thm list -> tactic,
is_new_datatype: Proof.context -> string -> bool,
basic_lfp_sugars_of: binding list -> typ list -> term list ->
(term * term listlist) listlist -> local_theory ->
typ list * int list * basic_lfp_sugar list * thm list * thm list * thm list * thm
* Token.src list * bool * local_theory,
rewrite_nested_rec_call: (Proof.context -> (term -> bool) -> (string -> int) -> typ list ->
term -> term -> term -> term) option};
structure Data = Theory_Data
( type T = lfp_rec_extension option; val empty = NONE; val merge = merge_options;
);
val register_lfp_rec_extension = Data.put o SOME;
fun nested_simps ctxt =
(case Data.get (Proof_Context.theory_of ctxt) of
SOME {nested_simps, ...} => nested_simps
| NONE => []);
fun special_endgame_tac ctxt =
(case Data.get (Proof_Context.theory_of ctxt) of
SOME {special_endgame_tac, ...} => special_endgame_tac ctxt
| NONE => K (K (K no_tac)));
fun is_new_datatype ctxt =
(case Data.get (Proof_Context.theory_of ctxt) of
SOME {is_new_datatype, ...} => is_new_datatype ctxt
| NONE => K true);
fun default_basic_lfp_sugars_of _ [Type (arg_T_name, _)] _ _ ctxt = let val ctr_sugar as {T, ctrs, casex, case_thms, ...} =
(case ctr_sugar_of ctxt arg_T_name of
SOME ctr_sugar => ctr_sugar
| NONE => error ("Unsupported type " ^ quote arg_T_name ^ " at this stage"));
val C = body_type (fastype_of casex); val fun_arg_Tsss = map (map single o binder_types o fastype_of) ctrs;
val basic_lfp_sugar =
{T = T, fp_res_index = 0, C = C, fun_arg_Tsss = fun_arg_Tsss, ctr_sugar = ctr_sugar,
recx = casex, rec_thms = case_thms}; in
([], [0], [basic_lfp_sugar], [], [], [], TrueI (*dummy*), [], false, ctxt) end
| default_basic_lfp_sugars_of _ [T] _ _ ctxt =
error ("Cannot recurse through type " ^ quote (Syntax.string_of_typ ctxt T))
| default_basic_lfp_sugars_of _ _ _ _ _ = error "Unsupported mutual recursion at this stage";
fun basic_lfp_sugars_of bs arg_Ts callers callssss lthy =
(case Data.get (Proof_Context.theory_of lthy) of
SOME {basic_lfp_sugars_of, ...} => basic_lfp_sugars_of
| NONE => default_basic_lfp_sugars_of) bs arg_Ts callers callssss lthy;
fun rewrite_nested_rec_call ctxt =
(case Data.get (Proof_Context.theory_of ctxt) of
SOME {rewrite_nested_rec_call = SOME f, ...} => f ctxt
| _ => error "Unsupported nested recursion");
structure LFP_Rec_Sugar_Plugin = Plugin(type T = fp_rec_sugar);
fun lfp_rec_sugar_interpretation name f =
LFP_Rec_Sugar_Plugin.interpretation name (fn fp_rec_sugar => fn lthy =>
f (transfer_fp_rec_sugar (Proof_Context.theory_of lthy) fp_rec_sugar) lthy);
val interpret_lfp_rec_sugar = LFP_Rec_Sugar_Plugin.data;
fun rec_specs_of bs arg_Ts res_Ts callers callssss0 lthy0 = let val thy = Proof_Context.theory_of lthy0;
val perm_basic_lfp_sugars = sort (int_ord o apply2 #fp_res_index) basic_lfp_sugars;
val indices = map #fp_res_index basic_lfp_sugars; val perm_indices = map #fp_res_index perm_basic_lfp_sugars;
val perm_ctrss = map (#ctrs o #ctr_sugar) perm_basic_lfp_sugars;
val nn0 = length arg_Ts; val nn = length perm_ctrss; val kks = 0 upto nn - 1;
val perm_ctr_offsets = map (fn kk => Integer.sum (map length (take kk perm_ctrss))) kks;
val perm_fpTs = map #T perm_basic_lfp_sugars; val perm_Cs = map #C perm_basic_lfp_sugars; val perm_fun_arg_Tssss = map #fun_arg_Tsss perm_basic_lfp_sugars;
fun unpermute0 perm0_xs = permute_like_unique (op =) perm0_kks kks perm0_xs; fun unpermute perm_xs = permute_like_unique (op =) perm_indices indices perm_xs;
val inducts = unpermute0 (conj_dests nn common_induct);
val fpTs = unpermute perm_fpTs; val Cs = unpermute perm_Cs; val ctr_offsets = unpermute perm_ctr_offsets;
val As_rho = tvar_subst thy (take nn0 fpTs) arg_Ts; val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn res_Ts;
val substA = Term.subst_TVars As_rho; val substAT = Term.typ_subst_TVars As_rho; val substCT = Term.typ_subst_TVars Cs_rho; val substACT = substAT o substCT;
val perm_Cs' = map substCT perm_Cs;
fun call_of [i] [T] = (if exists_subtype_in Cs T then Nested_Rec else No_Rec) (i, substACT T)
| call_of [i, i'] [T, T'] = Mutual_Rec ((i, substACT T), (i', substACT T'));
fun mk_ctr_spec ctr offset fun_arg_Tss rec_thm = let val (fun_arg_hss, _) = indexedd fun_arg_Tss 0; val fun_arg_hs = flat_rec_arg_args fun_arg_hss; val fun_arg_iss = map (map (find_index_eq fun_arg_hs)) fun_arg_hss; in
{ctr = substA ctr, offset = offset, calls = map2 call_of fun_arg_iss fun_arg_Tss,
rec_thm = rec_thm} end;
fun mk_ctr_specs fp_res_index k ctrs rec_thms =
@{map 4} mk_ctr_spec ctrs (k upto k + length ctrs - 1) (nth perm_fun_arg_Tssss fp_res_index)
rec_thms;
val undef_const = Const (\<^const_name>\<open>undefined\<close>, dummyT);
type eqn_data = {
fun_name: string,
rec_type: typ,
ctr: term,
ctr_args: term list,
left_args: term list,
right_args: term list,
res_type: typ,
rhs_term: term,
user_eqn: term
};
fun dissect_eqn ctxt fun_names eqn0 = let val eqn = drop_all eqn0 |> HOLogic.dest_Trueprop handle TERM _ => ill_formed_equation_lhs_rhs ctxt [eqn0]; val (lhs, rhs) = HOLogic.dest_eq eqn handle TERM _ => ill_formed_equation_lhs_rhs ctxt [eqn]; val (fun_name, args) = strip_comb lhs
|>> (fn x => if is_Free x then fst (dest_Free x) else ill_formed_equation_head ctxt [eqn]); val (left_args, rest) = chop_prefix is_Free args; val (nonfrees, right_args) = chop_suffix is_Free rest; val num_nonfrees = length nonfrees; val _ = num_nonfrees = 1 orelse
(if num_nonfrees = 0 then missing_pattern ctxt [eqn] else more_than_one_nonvar_in_lhs ctxt [eqn]); val _ = member (op =) fun_names fun_name orelse raise ill_formed_equation_head ctxt [eqn];
val (ctr, ctr_args) = strip_comb (the_single nonfrees); val _ = try (num_binder_types o fastype_of) ctr = SOME (length ctr_args) orelse
partially_applied_ctr_in_pattern ctxt [eqn];
val _ = check_duplicate_variables_in_lhs ctxt [eqn] (left_args @ ctr_args @ right_args) val _ = forall is_Free ctr_args orelse nonprimitive_pattern_in_lhs ctxt [eqn]; val _ = let val bads =
fold_aterms (fn x as Free (v, _) => if (not (member (op =) (left_args @ ctr_args @ right_args) x) andalso not (member (op =) fun_names v) andalso not (Variable.is_fixed ctxt v)) then
cons x else
I
| _ => I) rhs []; in
null bads orelse extra_variable_in_rhs ctxt [eqn] (hd bads) end; in
{fun_name = fun_name,
rec_type = body_type (type_of ctr),
ctr = ctr,
ctr_args = ctr_args,
left_args = left_args,
right_args = right_args,
res_type = map fastype_of (left_args @ right_args) ---> fastype_of rhs,
rhs_term = rhs,
user_eqn = eqn0} end;
fun subst_rec_calls ctxt get_ctr_pos has_call ctr_args mutual_calls nested_calls = let fun try_nested_rec bound_Ts y t =
AList.lookup (op =) nested_calls y
|> Option.map (fn y' => rewrite_nested_rec_call ctxt has_call get_ctr_pos bound_Ts y y' t);
fun subst bound_Ts (t as g' $ y) = let fun subst_comb (h $ z) = subst bound_Ts h $ subst bound_Ts z
| subst_comb t = t;
val y_head = head_of y; in ifnot (member (op =) ctr_args y_head) then
subst_comb t else
(case try_nested_rec bound_Ts y_head t of
SOME t' => subst_comb t'
| NONE => letval (g, g_args) = strip_comb g' in
(casetry (get_ctr_pos o fst o dest_Free) g of
SOME ~1 => subst_comb t
| SOME ctr_pos =>
(length g_args >= ctr_pos orelse too_few_args_in_rec_call ctxt [] t;
(case AList.lookup (op =) mutual_calls y of
SOME y' => list_comb (y', map (subst bound_Ts) g_args)
| NONE => subst_comb t))
| NONE => subst_comb t) end) end
| subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
| subst bound_Ts t = try_nested_rec bound_Ts (head_of t) t |> the_default t;
fun subst' t = if has_call t then rec_call_not_apply_to_ctr_arg ctxt [] t else try_nested_rec [] (head_of t) t |> the_default t; in
subst' o subst [] end;
fun build_rec_arg ctxt (funs_data : eqn_data listlist) has_call (ctr_spec : rec_ctr_spec)
(eqn_data_opt : eqn_data option) =
(case eqn_data_opt of
NONE => undef_const
| SOME {ctr_args, left_args, right_args, rhs_term = t, ...} => let val calls = #calls ctr_spec; val n_args = fold (Integer.add o (fn Mutual_Rec _ => 2 | _ => 1)) calls 0;
val no_calls' = tag_list 0 calls
|> map_filter (try (apsnd (fn No_Rec p => p | Mutual_Rec (p, _) => p))); val mutual_calls' = tag_list 0 calls
|> map_filter (try (apsnd (fn Mutual_Rec (_, p) => p))); val nested_calls' = tag_list 0 calls
|> map_filter (try (apsnd (fn Nested_Rec p => p)));
fun ensure_unique frees t = if member (op =) frees t then Free (the_single (Term.variant_bounds t [dest_Free t])) else t;
val recs = take n_funs rec_specs |> map #recx; val rec_args = ctr_spec_eqn_data_list
|> sort (op < o apply2 (#offset o fst) |> make_ord)
|> map (uncurry (build_rec_arg ctxt funs_data has_call) o apsnd (try the_single)); val ctr_poss = map (fn x => if length (distinct (op = o apply2 (length o #left_args)) x) <> 1 then
inconstant_pattern_pos_for_fun ctxt [] (#fun_name (hd x)) else
hd x |> #left_args |> length) funs_data; in
(recs, ctr_poss)
|-> map2 (fn recx => fn ctr_pos => list_comb (recx, rec_args) |> permute_args ctr_pos)
|> Syntax.check_terms ctxt
|> @{map 3} (fn b => fn mx => fn t =>
((b, mx), ((Binding.concealed (Thm.def_binding b), []), t)))
bs mxs end;
fun find_rec_calls has_call ({ctr, ctr_args, rhs_term, ...} : eqn_data) = let funfind bound_Ts (Abs (_, T, b)) ctr_arg = find (T :: bound_Ts) b ctr_arg
| find bound_Ts (t as _ $ _) ctr_arg = let val typof = curry fastype_of1 bound_Ts; val (f', args') = strip_comb t; val n = find_index (equal ctr_arg o head_of) args'; in if n < 0 then find bound_Ts f' ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args' else let val (f, args as arg :: _) = chop n args' |>> curry list_comb f' val (arg_head, arg_args) = Term.strip_comb arg; in if has_call f then
mk_partial_compN (length arg_args) (typof arg_head) f ::
maps (fn x => find bound_Ts x ctr_arg) args else find bound_Ts f ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args end end
| find _ _ _ = []; in map (find [] rhs_term) ctr_args
|> (fn [] => NONE | callss => SOME (ctr, callss)) end;