@@ -7,6 +7,8 @@ From MetaCoq.Template Require Import Ast.
77Require Import ssreflect ssrbool.
88Require Import ZArith.
99
10+ #[local] Set Universe Polymorphism.
11+
1012(** Raw term printing *)
1113
1214Module string_of_term_tree.
@@ -464,3 +466,254 @@ Ltac nth_leb_simpl :=
464466 replace n' with n in H' by lia; rewrite -> H in H'; injection H'; intros; subst
465467 | _ => lia || congruence || solve [repeat (f_equal; try lia)]
466468 end .
469+
470+ (** * Traversal functions. *)
471+
472+ Section TraverseWithBinders.
473+ Context {Acc : Type} {A : Type} (a : A) (lift : aname -> A -> A).
474+
475+ Definition lift_names : list aname -> A -> A :=
476+ fun names a => List.fold_right lift a names.
477+
478+ Definition map_predicate_with_binders (f : A -> term -> term) (p : predicate term) : predicate term :=
479+ let a_return := lift_names p.(pcontext) a in
480+ {| puinst := p.(puinst)
481+ ; pparams := List.map (f a) p.(pparams)
482+ ; pcontext := p.(pcontext)
483+ ; preturn := f a_return p.(preturn) |}.
484+
485+ Definition map_branch_with_binders (f : A -> term -> term) (b : branch term) : branch term :=
486+ let a_body := lift_names b.(bcontext) a in
487+ {| bcontext := b.(bcontext) ; bbody := f a_body b.(bbody) |}.
488+
489+ (** [map_term_with_binders a lift f t] maps [f] on the immediate subterms of [t].
490+ It carries an extra data [a] (typically a lift index) which is processed by [lift]
491+ (which typically add 1 to [a]) at each binder traversal.
492+ It is not recursive and the order in which subterms are processed is not specified. *)
493+ Definition map_term_with_binders (f : A -> term -> term) (t : term) : term :=
494+ match t with
495+ | tRel _ | tVar _ | tSort _ | tConst _ _ | tInd _ _ | tConstruct _ _ _ | tInt _ | tFloat _ | tString _ => t
496+ | tEvar n ts => tEvar n (List.map (f a) ts)
497+ | tCast b k t => tCast (f a b) k (f a t)
498+ | tProd name ty body => tProd name (f a ty) (f (lift name a) body)
499+ | tLambda name ty body => tLambda name (f a ty) (f (lift name a) body)
500+ | tLetIn name def ty body => tLetIn name (f a def) (f a ty) (f (lift name a) body)
501+ | tApp func args => tApp (f a func) (List.map (f a) args)
502+ | tProj proj t => tProj proj (f a t)
503+ (* For [tFix] and [tCoFix] we have to take care to lift [a]
504+ only when processing the body of the (co)fixpoint. *)
505+ | tFix defs i =>
506+ let a_body := lift_names (List.map dname defs) a in
507+ let on_def := map_def (f a) (f a_body) in
508+ tFix (List.map on_def defs) i
509+ | tCoFix defs i =>
510+ let a_body := lift_names (List.map dname defs) a in
511+ let on_def := map_def (f a) (f a_body) in
512+ tCoFix (List.map on_def defs) i
513+ | tCase ci pred x branches =>
514+ tCase ci (map_predicate_with_binders f pred) (f a x) (List.map (map_branch_with_binders f) branches)
515+ | tArray l t def ty => tArray l (List.map (f a) t) (f a def) (f a ty)
516+ end .
517+
518+ Definition fold_predicate_with_binders (f : A -> Acc -> term -> Acc) (acc : Acc) (p : predicate term) : Acc :=
519+ let a_return := lift_names p.(pcontext) a in
520+ let acc := List.fold_left (f a) p.(pparams) acc in
521+ f a_return acc p.(preturn).
522+
523+ Definition fold_branch_with_binders (f : A -> Acc -> term -> Acc) (acc : Acc) (b : branch term) : Acc :=
524+ let a_body := lift_names b.(bcontext) a in
525+ f a_body acc b.(bbody).
526+
527+ (** Fold version of [map_term_with_binders]. *)
528+ Definition fold_term_with_binders (f : A -> Acc -> term -> Acc) (acc : Acc) (t : term) : Acc :=
529+ match t with
530+ | tRel _ | tVar _ | tSort _ | tConst _ _ | tInd _ _
531+ | tConstruct _ _ _ | tInt _ | tFloat _ | tString _ => acc
532+ | tEvar n ts => List.fold_left (f a) ts acc
533+ | tCast b k t => let acc := f a acc b in f a acc t
534+ | tProd name ty body => let acc := f a acc ty in f (lift name a) acc body
535+ | tLambda name ty body => let acc := f a acc ty in f (lift name a) acc body
536+ | tLetIn name def ty body =>
537+ let acc := f a acc def in
538+ let acc := f a acc ty in
539+ f (lift name a) acc body
540+ | tApp func args => List.fold_left (f a) (func :: args) acc
541+ | tProj proj t => f a acc t
542+ | tFix defs i =>
543+ let a_body := lift_names (List.map dname defs) a in
544+ let acc := List.fold_left (f a) (List.map dtype defs) acc in
545+ List.fold_left (f a_body) (List.map dbody defs) acc
546+ | tCoFix defs i =>
547+ let a_body := lift_names (List.map dname defs) a in
548+ let acc := List.fold_left (f a) (List.map dtype defs) acc in
549+ List.fold_left (f a_body) (List.map dbody defs) acc
550+ | tCase ci pred x branches =>
551+ let acc := fold_predicate_with_binders f acc pred in
552+ let acc := f a acc x in
553+ List.fold_left (fold_branch_with_binders f) branches acc
554+ | tArray l t def ty =>
555+ let acc := List.fold_left (f a) t acc in
556+ let acc := f a acc def in
557+ f a acc ty
558+ end .
559+
560+ End TraverseWithBinders.
561+
562+ Section TraverseWithBindersM.
563+ Import MCMonadNotation.
564+
565+ Context {M : Type -> Type} `{Monad M} {Acc : Type} {A : Type} {a : A} {liftM : aname -> A -> M A}.
566+
567+ Definition lift_namesM (names : list aname) (a : A) : M A :=
568+ let fix loop names a :=
569+ match names with
570+ | [] => ret a
571+ | n :: names => loop names =<< liftM n a
572+ end
573+ in
574+ loop (List.rev names) a.
575+
576+ Definition map_defM {A B} (tyf bodyf : A -> M B) (d : def A) : M (def B) :=
577+ mlet dtype <- tyf d.(dtype) ;;
578+ mlet dbody <- bodyf d.(dbody) ;;
579+ ret (mkdef _ d.(dname) dtype dbody d.(rarg)).
580+
581+ Definition map_predicate_with_bindersM (f : A -> term -> M term) (p : predicate term) : M (predicate term) :=
582+ mlet a_return <- lift_namesM p.(pcontext) a ;;
583+ mlet pparams <- monad_map (f a) p.(pparams) ;;
584+ mlet preturn <- f a_return p.(preturn) ;;
585+ ret (mk_predicate p.(puinst) pparams p.(pcontext) preturn).
586+
587+ Definition map_branch_with_bindersM (f : A -> term -> M term) (b : branch term) : M (branch term) :=
588+ mlet a_body <- lift_namesM b.(bcontext) a ;;
589+ mlet bbody <- f a_body b.(bbody) ;;
590+ ret (mk_branch b.(bcontext) bbody).
591+
592+ (** Monadic variant of [map_term_with_binders]. *)
593+ Definition map_term_with_bindersM (f : A -> term -> M term) (t : term) : M term :=
594+ match t with
595+ | tRel _ | tVar _ | tSort _ | tConst _ _ | tInd _ _
596+ | tConstruct _ _ _ | tInt _ | tFloat _ | tString _ => ret t
597+ | tEvar n ts =>
598+ mlet ts <- monad_map (f a) ts ;;
599+ ret (tEvar n ts)
600+ | tCast b k t =>
601+ mlet b <- f a b ;;
602+ mlet t <- f a t ;;
603+ ret (tCast b k t)
604+ | tProd name ty body =>
605+ mlet ty <- f a ty ;;
606+ mlet a_body <- liftM name a ;;
607+ mlet body <- f a_body body ;;
608+ ret (tProd name ty body)
609+ | tLambda name ty body =>
610+ mlet ty <- f a ty ;;
611+ mlet a_body <- liftM name a ;;
612+ mlet body <- f a_body body ;;
613+ ret (tLambda name ty body)
614+ | tLetIn name def ty body =>
615+ mlet def <- f a def ;;
616+ mlet ty <- f a ty ;;
617+ mlet a_body <- liftM name a ;;
618+ mlet body <- f a_body body ;;
619+ ret (tLetIn name def ty body)
620+ | tApp func args =>
621+ mlet func <- f a func ;;
622+ mlet args <- monad_map (f a) args ;;
623+ ret (tApp func args)
624+ | tProj proj t =>
625+ mlet t <- f a t ;;
626+ ret (tProj proj t)
627+ (* For [tFix] and [tCoFix] we have to take care to lift [a]
628+ only when processing the body of the (co)fixpoint. *)
629+ | tFix defs i =>
630+ mlet a_body <- lift_namesM (List.map dname defs) a ;;
631+ let on_def := map_defM (f a) (f a_body) in
632+ mlet defs <- monad_map on_def defs ;;
633+ ret (tFix defs i)
634+ | tCoFix defs i =>
635+ mlet a_body <- lift_namesM (List.map dname defs) a ;;
636+ let on_def := map_defM (f a) (f a_body) in
637+ mlet defs <- monad_map on_def defs ;;
638+ ret (tCoFix defs i)
639+ | tCase ci pred x branches =>
640+ mlet pred <- map_predicate_with_bindersM f pred ;;
641+ mlet x <- f a x ;;
642+ mlet branches <- monad_map (map_branch_with_bindersM f) branches ;;
643+ ret (tCase ci pred x branches)
644+ | tArray l t def ty =>
645+ mlet t <- monad_map (f a) t ;;
646+ mlet def <- f a def ;;
647+ mlet ty <- f a ty ;;
648+ ret (tArray l t def ty)
649+ end .
650+
651+ Definition fold_predicate_with_bindersM (f : A -> Acc -> term -> M Acc) (acc : Acc) (p : predicate term) : M Acc :=
652+ mlet a_return <- lift_namesM p.(pcontext) a ;;
653+ mlet acc <- monad_fold_left (f a) p.(pparams) acc ;;
654+ f a_return acc p.(preturn).
655+
656+ Definition fold_branch_with_bindersM (f : A -> Acc -> term -> M Acc) (acc : Acc) (b : branch term) : M Acc :=
657+ mlet a_body <- lift_namesM b.(bcontext) a ;;
658+ f a_body acc b.(bbody).
659+
660+ (** Monadic variant of [fold_term_with_binders]. *)
661+ Definition fold_term_with_bindersM (f : A -> Acc -> term -> M Acc) (acc : Acc) (t : term) : M Acc :=
662+ match t with
663+ | tRel _ | tVar _ | tSort _ | tConst _ _ | tInd _ _
664+ | tConstruct _ _ _ | tInt _ | tFloat _ | tString _ => ret acc
665+ | tEvar n ts => monad_fold_left (f a) ts acc
666+ | tCast b k t => mlet acc <- f a acc b ;; f a acc t
667+ | tProd name ty body =>
668+ mlet a_body <- liftM name a ;;
669+ mlet acc <- f a acc ty ;;
670+ f a_body acc body
671+ | tLambda name ty body =>
672+ mlet a_body <- liftM name a ;;
673+ mlet acc <- f a acc ty ;;
674+ f a_body acc body
675+ | tLetIn name def ty body =>
676+ mlet a_body <- liftM name a ;;
677+ mlet def <- f a acc def ;;
678+ mlet acc <- f a acc ty ;;
679+ f a_body acc body
680+ | tApp func args => monad_fold_left (f a) (func :: args) acc
681+ | tProj proj t => f a acc t
682+ | tFix defs i =>
683+ mlet a_body <- lift_namesM (List.map dname defs) a ;;
684+ mlet acc <- monad_fold_left (f a) (List.map dtype defs) acc ;;
685+ monad_fold_left (f a_body) (List.map dbody defs) acc
686+ | tCoFix defs i =>
687+ mlet a_body <- lift_namesM (List.map dname defs) a ;;
688+ mlet acc <- monad_fold_left (f a) (List.map dtype defs) acc ;;
689+ monad_fold_left (f a_body) (List.map dbody defs) acc
690+ | tCase ci pred x branches =>
691+ mlet acc <- fold_predicate_with_bindersM f acc pred ;;
692+ mlet acc <- f a acc x ;;
693+ monad_fold_left (fold_branch_with_bindersM f) branches acc
694+ | tArray l t def ty =>
695+ mlet acc <- monad_fold_left (f a) t acc ;;
696+ mlet acc <- f a acc def ;;
697+ f a acc ty
698+ end .
699+
700+ End TraverseWithBindersM.
701+
702+
703+ (** [map_term f t] maps [f] on the immediate subterms of [t].
704+ It is not recursive and the order in which subterms are processed is not specified. *)
705+ Definition map_term (f : term -> term) (t : term) : term :=
706+ @map_term_with_binders unit tt (fun _ _ => tt) (fun _ => f) t.
707+
708+ (** Monadic variant of [map_term]. *)
709+ Definition map_termM {M} `{Monad M} (f : term -> M term) (t : term) : M term :=
710+ @map_term_with_bindersM M _ unit tt (fun _ _ => ret tt) (fun _ => f) t.
711+
712+ (** [fold_term f acc t] folds [f] on the immediate subterms of [t].
713+ It is not recursive and the order in which subterms are processed is not specified. *)
714+ Definition fold_term {Acc} (f : Acc -> term -> Acc) (acc : Acc) (t : term) : Acc :=
715+ @fold_term_with_binders Acc unit tt (fun _ _ => tt) (fun _ => f) acc t.
716+
717+ (** Monadic variant of [fold_term]. *)
718+ Definition fold_termM {M} `{Monad M} {Acc} (f : Acc -> term -> M Acc) (acc : Acc) (t : term) : M Acc :=
719+ @fold_term_with_bindersM M _ Acc unit tt (fun _ _ => ret tt) (fun _ => f) acc t.
0 commit comments