(* Title: HOL/Tools/SMT/verit_replay.ML
Author: Mathias Fleury, MPII
VeriT proof parsing and replay.
signature VERIT_REPLAY =
val replay: Proof.context -> SMT_Translate.replay_data -> string list -> thm
structure Verit_Replay: VERIT_REPLAY =
fun subst_only_free pairs =
fun substf u =
(case Termtab.lookup pairs u of
SOME u' => u'
| NONE =>
(case u of
(Abs(a,T,t)) => Abs(a, T, substf t)
| (t$u') => substf t $ substf u'
| u => u))
in substf end;
fun under_fixes f unchanged_prems (prems, nthms) names args insts decls (concl, ctxt) =
val thms1 = unchanged_prems @ map (SMT_Replay.varify ctxt) prems
val _ = SMT_Config.verit_msg ctxt (fn () => \<^print> ("names =", names))
val thms2 = map snd nthms
val _ = SMT_Config.verit_msg ctxt (fn () => \<^print> ("prems=", prems))
val _ = SMT_Config.verit_msg ctxt (fn () => \<^print> ("nthms=", nthms))
val _ = SMT_Config.verit_msg ctxt (fn () => \<^print> ("thms1=", thms1))
val _ = SMT_Config.verit_msg ctxt (fn () => \<^print> ("thms2=", thms2))
in (f ctxt (thms1 @ thms2) args insts decls concl) end
(** Replaying **)
fun replay_thm method_for rewrite_rules ll_defs ctxt assumed unchanged_prems prems nthms
concl_transformation global_transformation args insts
(Verit_Proof.VeriT_Replay_Node {id, rule, concl, bounds, declarations = decls, ...}) =
val _ = SMT_Config.verit_msg ctxt (fn () => \<^print> id)
val rewrite = let val thy = Proof_Context.theory_of (empty_simpset ctxt) in
Raw_Simplifier.rewrite_term thy rewrite_rules []
#> not (null ll_defs andalso Verit_Proof.keep_raw_lifting rule) ? SMTLIB_Isar.unlift_term ll_defs
val rewrite_concl = if Verit_Proof.keep_app_symbols rule then
filter (curry Term.could_unify (Thm.concl_of @{thm SMT.fun_app_def}) o Thm.concl_of) rewrite_rules
else rewrite_rules
val post = let val thy = Proof_Context.theory_of (empty_simpset ctxt) in
Raw_Simplifier.rewrite_term thy rewrite_concl []
#> Object_Logic.atomize_term ctxt
#> not (null ll_defs) ? SMTLIB_Isar.unlift_term ll_defs
#> SMTLIB_Isar.unskolemize_names ctxt
#> HOLogic.mk_Trueprop
val concl = concl
|> Term.subst_free concl_transformation
|> subst_only_free global_transformation
|> post
if rule = Verit_Proof.input_rule then
(case Symtab.lookup assumed id of
SOME (_, thm) => thm
| _ => raise Fail ("assumption " ^ @{make_string} id ^ " not found"))
under_fixes (method_for rule) unchanged_prems
(prems, nthms) (map fst bounds)
(map rewrite args)
(Symtab.map (K rewrite) insts)
(concl, ctxt)
|> Simplifier.simplify (empty_simpset ctxt addsimps rewrite_rules)
fun add_used_asserts_in_step (Verit_Proof.VeriT_Replay_Node {prems,
subproof = (_, _, _, subproof), ...}) =
union (op =) (map_filter (try SMTLIB_Interface.assert_index_of_name) prems @
flat (map (fn x => add_used_asserts_in_step x []) subproof))
fun remove_rewrite_rules_from_rules n =
(fn (step as Verit_Proof.VeriT_Replay_Node {id, ...}) =>
(case try SMTLIB_Interface.assert_index_of_name id of
NONE => SOME step
| SOME a => if a < n then NONE else SOME step))
fun replay_theorem_step rewrite_rules ll_defs assumed inputs proof_prems
(step as Verit_Proof.VeriT_Replay_Node {id, rule, prems, bounds, args, insts,
subproof = (fixes, assms, input, subproof), concl, ...}) state =
val (proofs, stats, ctxt, concl_tranformation, global_transformation) = state
val (_, ctxt) = Variable.variant_fixes (map fst bounds) ctxt
|> (fn (names, ctxt) => (names,
fold Variable.declare_term [SMTLIB_Isar.unskolemize_names ctxt concl] ctxt))
val (names, sub_ctxt) = Variable.variant_fixes (map fst fixes) ctxt
||> fold Variable.declare_term (map Free fixes)
val export_vars = concl_tranformation @
(ListPair.zip (map Free fixes, map Free (ListPair.zip (names, map snd fixes))))
val post = let val thy = Proof_Context.theory_of (empty_simpset ctxt) in
Raw_Simplifier.rewrite_term thy ((if Verit_Proof.keep_raw_lifting rule then tl rewrite_rules else rewrite_rules)) []
#> Object_Logic.atomize_term ctxt
#> not (null ll_defs andalso Verit_Proof.keep_raw_lifting rule) ? SMTLIB_Isar.unlift_term ll_defs
#> SMTLIB_Isar.unskolemize_names ctxt
#> HOLogic.mk_Trueprop
val assms = map (subst_only_free global_transformation o Term.subst_free (export_vars) o post) assms
val input = map (subst_only_free global_transformation o Term.subst_free (export_vars) o post) input
val (all_proof_prems', sub_ctxt2) = Assumption.add_assumes (map (Thm.cterm_of sub_ctxt) (assms @ input))
fun is_refl thm = Thm.concl_of thm |> (fn (_ $ t) => t) |> HOLogic.dest_eq |> (op =) handle TERM _=> false
val all_proof_prems' =
|> filter_out is_refl
val proof_prems' = take (length assms) all_proof_prems'
val input = drop (length assms) all_proof_prems'
val all_proof_prems = proof_prems @ proof_prems'
val replay = replay_theorem_step rewrite_rules ll_defs assumed (input @ inputs) all_proof_prems
val (proofs', stats, _, _, sub_global_rew) =
fold replay subproof (proofs, stats, sub_ctxt2, export_vars, global_transformation)
val export_thm = singleton (Proof_Context.export sub_ctxt2 ctxt)
(*for sko_ex and sko_forall, assumptions are in proofs', but the definition of the skolem
function is in proofs *)
val nthms = prems
|> map (apsnd export_thm) o map_filter (Symtab.lookup (if (null subproof) then proofs else proofs'))
val nthms' = (if Verit_Proof.is_skolemization rule
then prems else [])
|> map_filter (Symtab.lookup proofs)
val args = map (Term.subst_free concl_tranformation o subst_only_free global_transformation) args
val insts = Symtab.map (K (Term.subst_free concl_tranformation o subst_only_free global_transformation)) insts
val proof_prems =
if Verit_Replay_Methods.requires_subproof_assms prems rule then proof_prems else []
val local_inputs =
if Verit_Replay_Methods.requires_local_input prems rule then input @ inputs else []
val replay = Timing.timing (replay_thm Verit_Replay_Methods.method_for rewrite_rules ll_defs
ctxt assumed [] (proof_prems @ local_inputs) (nthms @ nthms') concl_tranformation
global_transformation args insts)
val ({elapsed, ...}, thm) =
SMT_Config.with_time_limit ctxt SMT_Config.reconstruction_step_timeout replay step
handle Timeout.TIMEOUT _ => raise SMT_Failure.SMT SMT_Failure.Time_Out
val stats' = Symtab.cons_list (rule, Time.toNanoseconds elapsed) stats
(* val _ = ((Time.toMilliseconds elapsed > 10 andalso (rule = "cong")) ? @{print})
("WARNING slow " ^ id ^ @{make_string} rule ^ ": " ^ string_of_int (Time.toMilliseconds elapsed) ^ " "
^ @{make_string} (proof_prems @ local_inputs))
val _ = ((Time.toMilliseconds elapsed > 10 andalso (rule = "cong")) ? @{print})
( (proof_prems @ local_inputs))
val _ = ((Time.toMilliseconds elapsed > 10 andalso (rule = "cong")) ? @{print})
val _ = ((Time.toMilliseconds elapsed > 40) ? @{print})
("WARNING slow " ^ id ^ @{make_string} rule ^ ": " ^ string_of_int (Time.toMilliseconds elapsed)) *)
val proofs = Symtab.update (id, (map fst bounds, thm)) proofs
in (proofs, stats', ctxt,
concl_tranformation, sub_global_rew) end
fun replay_definition_step rewrite_rules ll_defs _ _ _
(Verit_Proof.VeriT_Replay_Node {id, declarations = raw_declarations, subproof = (_, _, _, subproof), ...}) state =
val _ = if null subproof then ()
else raise (Fail ("unrecognized veriT proof, definition has a subproof"))
val (proofs, stats, ctxt, concl_tranformation, global_transformation) = state
val global_transformer = subst_only_free global_transformation
val rewrite = let val thy = Proof_Context.theory_of ctxt in
Raw_Simplifier.rewrite_term thy (rewrite_rules) []
#> not (null ll_defs) ? SMTLIB_Isar.unlift_term ll_defs
val start0 = Timing.start ()
val declaration = map (apsnd (rewrite o global_transformer)) raw_declarations
val (names, ctxt) = Variable.variant_fixes (map fst declaration) ctxt
||> fold Variable.declare_term (map Free (map (apsnd fastype_of) declaration))
val old_names = map Free (map (fn (a, b) => (a, fastype_of b)) declaration)
val new_names = map Free (ListPair.zip (names, map (fastype_of o snd) declaration))
fun update_mapping (a, b) tab =
if a <> b andalso Termtab.lookup tab a = NONE
then Termtab.update_new (a, b) tab else tab
val global_transformation = global_transformation
|> fold update_mapping (ListPair.zip (old_names, new_names))
val global_transformer = subst_only_free global_transformation
val generate_definition =
(fn (name, term) => (HOLogic.mk_Trueprop
(Const(\<^const_name>\<open>HOL.eq\<close>, fastype_of term --> fastype_of term --> @{typ bool}) $
Free (name, fastype_of term) $ term)))
#> global_transformer
#> Thm.cterm_of ctxt
val decls = map generate_definition declaration
val (defs, ctxt) = Assumption.add_assumes decls ctxt
val thms_with_old_name = ListPair.zip (map fst declaration, defs)
val proofs = fold
(fn (name, thm) => Symtab.update (id, ([name], @{thm sym} OF [thm])))
thms_with_old_name proofs
val total = Time.toNanoseconds (#elapsed (Timing.result start0))
val stats = Symtab.cons_list ("choice", total) stats
in (proofs, stats, ctxt, concl_tranformation, global_transformation) end
fun replay_assumed assms ll_defs rewrite_rules stats ctxt term =
val rewrite = let val thy = Proof_Context.theory_of (empty_simpset ctxt) in
Raw_Simplifier.rewrite_term thy rewrite_rules []
#> not (null ll_defs) ? SMTLIB_Isar.unlift_term ll_defs
val replay = Timing.timing (SMT_Replay_Methods.prove ctxt (rewrite term))
val ({elapsed, ...}, thm) =
SMT_Config.with_time_limit ctxt SMT_Config.reconstruction_step_timeout replay
(fn _ => Method.insert_tac ctxt (map snd assms) THEN' Classical.fast_tac ctxt)
handle Timeout.TIMEOUT _ => raise SMT_Failure.SMT SMT_Failure.Time_Out
val stats' = Symtab.cons_list (Verit_Proof.input_rule, Time.toNanoseconds elapsed) stats
(thm, stats')
fun replay_step rewrite_rules ll_defs assumed inputs proof_prems
(step as Verit_Proof.VeriT_Replay_Node {rule, ...}) state =
if rule = Verit_Proof.veriT_def
then replay_definition_step rewrite_rules ll_defs assumed inputs proof_prems step state
else replay_theorem_step rewrite_rules ll_defs assumed inputs proof_prems step state
fun replay outer_ctxt
({context = ctxt, typs, terms, rewrite_rules, assms, ll_defs, ...} : SMT_Translate.replay_data)
output =
val rewrite_rules =
filter_out (fn thm => Term.could_unify (Thm.prop_of @{thm verit_eq_true_simplify},
Thm.prop_of thm))
val num_ll_defs = length ll_defs
val index_of_id = Integer.add (~ num_ll_defs)
val id_of_index = Integer.add num_ll_defs
val start0 = Timing.start ()
val (actual_steps, ctxt2) =
Verit_Proof.parse_replay typs terms output ctxt
val parsing_time = Time.toNanoseconds (#elapsed (Timing.result start0))
fun step_of_assume (j, (_, th)) =
Verit_Proof.VeriT_Replay_Node {
id = SMTLIB_Interface.assert_name_of_index (id_of_index j),
rule = Verit_Proof.input_rule,
args = [],
prems = [],
proof_ctxt = [],
concl = Thm.prop_of th
|> Raw_Simplifier.rewrite_term (Proof_Context.theory_of
(empty_simpset ctxt addsimps rewrite_rules)) rewrite_rules [],
bounds = [],
insts = Symtab.empty,
declarations = [],
subproof = ([], [], [], [])}
val used_assert_ids = fold add_used_asserts_in_step actual_steps []
fun normalize_tac ctxt = let val thy = Proof_Context.theory_of (empty_simpset ctxt) in
Raw_Simplifier.rewrite_term thy rewrite_rules [] end
val used_assm_js =
map_filter (fn id => let val i = index_of_id id in if i >= 0 then SOME (i, nth assms i)
else NONE end)
val assm_steps = map step_of_assume used_assm_js
fun extract (Verit_Proof.VeriT_Replay_Node {id, rule, concl, bounds, ...}) =
(id, rule, concl, map fst bounds)
fun cond rule = rule = Verit_Proof.input_rule
val add_asssert = SMT_Replay.add_asserted Symtab.update Symtab.empty extract cond
val ((_, _), (ctxt3, assumed)) =
add_asssert outer_ctxt rewrite_rules assms
(map_filter (remove_rewrite_rules_from_rules num_ll_defs) assm_steps) ctxt2
val used_rew_js =
map_filter (fn id => let val i = index_of_id id in if i < 0
then SOME (id, normalize_tac ctxt (nth ll_defs id)) else NONE end)
val (assumed, stats) = fold (fn ((id, thm)) => fn (assumed, stats) =>
let val (thm, stats) = replay_assumed assms ll_defs rewrite_rules stats ctxt thm
in (Symtab.update (SMTLIB_Interface.assert_name_of_index id, ([], thm)) assumed, stats)
used_rew_js (assumed, Symtab.cons_list ("parsing", parsing_time) Symtab.empty)
val ctxt4 =
|> put_simpset (SMT_Replay.make_simpset ctxt3 [])
|> Config.put SAT.solver (Config.get ctxt3 SMT_Config.sat_solver)
val len = Verit_Proof.number_of_steps actual_steps
fun steps_with_depth _ [] = []
| steps_with_depth i (p :: ps) = (i + Verit_Proof.number_of_steps [p], p) ::
steps_with_depth (i + Verit_Proof.number_of_steps [p]) ps
val actual_steps = steps_with_depth 0 actual_steps
val start = Timing.start ()
val print_runtime_statistics = SMT_Replay.intermediate_statistics ctxt4 start len
fun blockwise f (i, x) (next, y) =
(if i > next then print_runtime_statistics i else ();
(if i > next then i + 10 else next, f x y))
val global_transformation : term Termtab.table = Termtab.empty
val (_, (proofs, stats, ctxt5, _, _)) =
fold (blockwise (replay_step rewrite_rules ll_defs assumed [] [])) actual_steps
(1, (assumed, stats, ctxt4, [], global_transformation))
val total = Time.toMilliseconds (#elapsed (Timing.result start))
val (_, (_, Verit_Proof.VeriT_Replay_Node {id, ...})) = split_last actual_steps
val _ = print_runtime_statistics len
val thm_with_defs = Symtab.lookup proofs id |> the |> snd
|> singleton (Proof_Context.export ctxt5 outer_ctxt)
val _ = SMT_Config.statistics_msg ctxt5
(Pretty.string_of o SMT_Replay.pretty_statistics "veriT" total) stats
val _ = SMT_Replay.spying (SMT_Config.spy_verit ctxt) ctxt
(fn () => SMT_Replay.print_stats (Symtab.dest stats)) "spy_verit"
Verit_Replay_Methods.discharge ctxt [thm_with_defs] @{term False}
