Skip to content

Commit

Permalink
Merge pull request #391 from mrjazzybread/bool
Browse files Browse the repository at this point in the history
Bool
  • Loading branch information
n-osborne authored Oct 4, 2024
2 parents 85dc67d + 4023535 commit f1f1860
Show file tree
Hide file tree
Showing 18 changed files with 1,472 additions and 1,880 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
[#406] (https://github.com/ocaml-gospel/gospel/pull/406)
- Display an error message when encoutering a Functor application
[#404] (https://github.com/ocaml-gospel/gospel/pull/404)
- Changed the gospel typechecker to use bool as the type of logical formulae
[\#391](https://github.com/ocaml-gospel/gospel/pull/391)

# 0.3

Expand Down
4 changes: 2 additions & 2 deletions src/coercion.ml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type coercion = {

let ty_of ls =
match (ls.ls_args, ls.ls_value) with
| [ { ty_node = Tyapp (ty1, _) } ], Some { ty_node = Tyapp (ty2, _) } ->
| [ { ty_node = Tyapp (ty1, _) } ], { ty_node = Tyapp (ty2, _) } ->
(ty1.ts_ident.id_str, ty2.ts_ident.id_str)
| _ -> assert false

Expand All @@ -47,7 +47,7 @@ let empty = Mts.empty

let create_crc ls =
match (ls.ls_args, ls.ls_value) with
| [ { ty_node = Tyapp (ts1, tl1) } ], Some { ty_node = Tyapp (ts2, tl2) }
| [ { ty_node = Tyapp (ts1, tl1) } ], { ty_node = Tyapp (ts2, tl2) }
when not (ts_equal ts1 ts2) ->
{
crc_kind = CRCleaf ls;
Expand Down
108 changes: 42 additions & 66 deletions src/dterm.ml
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ let specialize_ls ls =
| Tyvar tv -> find_tv tv
| Tyapp (ts, tyl) -> Tapp (ts, List.map spec tyl)
in
(List.map spec ls.ls_args, Option.map spec ls.ls_value)
(List.map spec ls.ls_args, spec ls.ls_value)

let specialize_cs ~loc cs =
if cs.ls_constr = false then
W.error ~loc (W.Not_a_constructor cs.ls_name.id_str);
let dtyl, dty = specialize_ls cs in
(dtyl, Option.get dty)
(dtyl, dty)

(* terms *)

Expand Down Expand Up @@ -142,17 +142,17 @@ let rec unify_dty_ty dty ty =
match (head dty, ty.ty_node) with
| Tvar tvar, _ -> tvar.dtv_def <- Some (Tty ty)
| Tty ty1, _ when ty_equal ty1 ty -> ()
| Tapp (ts1, dl), Tyapp (ts2, tl) when ts_equal ts1 ts2 -> (
try List.iter2 unify_dty_ty dl tl with Invalid_argument _ -> raise Exit)
| Tapp (ts1, dl), Tyapp (ts2, tl) when ts_equal ts1 ts2 ->
List.iter2 unify_dty_ty dl tl
| _ -> raise Exit

let rec unify dty1 dty2 =
match (head dty1, head dty2) with
| Tvar { dtv_id = id1; _ }, Tvar { dtv_id = id2; _ } when id1 = id2 -> ()
| Tvar tvar, dty | dty, Tvar tvar ->
if occur tvar dty then raise Exit else tvar.dtv_def <- Some dty
| Tapp (ts1, dtyl1), Tapp (ts2, dtyl2) when ts_equal ts1 ts2 -> (
try List.iter2 unify dtyl1 dtyl2 with Invalid_argument _ -> raise Exit)
| Tapp (ts1, dtyl1), Tapp (ts2, dtyl2) when ts_equal ts1 ts2 ->
List.iter2 unify dtyl1 dtyl2
| Tty ty, dty | dty, Tty ty -> unify_dty_ty dty ty
| _ -> raise Exit

Expand Down Expand Up @@ -186,19 +186,17 @@ let dty_unify ~loc dty1 dty2 =
let dterm_unify dt dty =
match dt.dt_dty with
| Some dt_dty -> dty_unify ~loc:dt.dt_loc dt_dty dty
| None -> (
try unify dty_bool dty
with Exit -> W.error ~loc:dt.dt_loc W.Term_expected)
| None -> assert false

let dfmla_unify dt =
match dt.dt_dty with
| None -> ()
| None -> assert false
| Some dt_dty -> (
try unify dt_dty dty_bool
with Exit -> W.error ~loc:dt.dt_loc W.Formula_expected)

let unify dt dty =
match dty with None -> dfmla_unify dt | Some dt_dty -> dterm_unify dt dt_dty
match dty with None -> assert false | Some dt_dty -> dterm_unify dt dt_dty

(* environment *)

Expand Down Expand Up @@ -229,7 +227,7 @@ let apply_coercion l dt =
let apply dt ls =
let dtyl, dty = specialize_ls ls in
dterm_unify dt (List.hd dtyl);
{ dt_node = DTapp (ls, [ dt ]); dt_dty = dty; dt_loc = dt.dt_loc }
{ dt_node = DTapp (ls, [ dt ]); dt_dty = Some dty; dt_loc = dt.dt_loc }
in
List.fold_left apply dt l

Expand Down Expand Up @@ -278,15 +276,6 @@ let max_dty crcmap dtl =
in
if l = [] then (List.hd dtl).dt_dty else aux l

let max_dty crcmap dtl =
match max_dty crcmap dtl with
| Some (Tty ty)
when ty_equal ty ty_bool
&& List.exists (fun { dt_dty; _ } -> dt_dty = None) dtl ->
(* favor prop over bool *)
None
| dty -> dty

let dterm_expected crcmap dt dty =
try
let ts1, ts2 = (ts_of_dty dt.dt_dty, ts_of_dty dty) in
Expand All @@ -302,7 +291,7 @@ let dterm_expected_op crcmap dt dty =
unify dt dty;
dt

let dfmla_expected crcmap dt = dterm_expected_op crcmap dt None
let dfmla_expected crcmap dt = dterm_expected_op crcmap dt (Some dty_bool)
let dterm_expected crcmap dt dty = dterm_expected_op crcmap dt (Some dty)

(** dterm to tterm *)
Expand Down Expand Up @@ -335,73 +324,61 @@ let pattern dp =
let p = pattern_node dp in
(p, !vars)

let rec term env prop dt =
let rec term env dt =
let loc = dt.dt_loc in
let t = term_node ~loc env prop dt.dt_dty dt.dt_node in
match t.t_ty with
| Some _ when prop -> (
try t_equ t (t_bool_true loc) loc
with TypeMismatch (ty1, ty2) ->
let t1 = Fmt.str "%a" print_ty ty1 in
let t2 = Fmt.str "%a" print_ty ty2 in
W.error ~loc (W.Bad_type (t1, t2)))
| None when not prop -> t_if t (t_bool_true loc) (t_bool_false loc) loc
| _ -> t

and term_node ~loc env prop dty dterm_node =
term_node ~loc env dt.dt_dty dt.dt_node

and term_node ~loc env dty dterm_node =
match dterm_node with
| DTvar pid ->
let vs = denv_find ~loc:pid.pid_loc pid.pid_str env in
(* TODO should I match vs.vs_ty with dty? *)
t_var vs loc
| DTconst c -> t_const c (ty_of_dty (Option.get dty)) loc
| DTapp (ls, []) when ls_equal ls fs_bool_true ->
if prop then t_true loc else t_bool_true loc
| DTapp (ls, []) when ls_equal ls fs_bool_false ->
if prop then t_false loc else t_bool_false loc
| DTapp (ls, [ dt1; dt2 ]) when ls_equal ls ps_equ ->
if dt1.dt_dty = None || dt2.dt_dty = None then
f_iff (term env true dt1) (term env true dt2) loc
else t_equ (term env false dt1) (term env false dt2) loc
| DTapp (ls, []) when ls_equal ls fs_bool_true -> t_true loc
| DTapp (ls, []) when ls_equal ls fs_bool_false -> t_false loc
| DTapp (ls, [ dt1 ]) when ls.ls_field ->
t_field (term env false dt1) ls (Option.map ty_of_dty dty) loc
t_field (term env dt1) ls
(Option.fold ~some:ty_of_dty ~none:ty_bool dty)
loc
| DTapp (ls, dtl) ->
t_app ls (List.map (term env false) dtl) (Option.map ty_of_dty dty) loc
t_app ls
(List.map (term env) dtl)
(Option.fold ~some:ty_of_dty ~none:ty_bool dty)
loc
| DTif (dt1, dt2, dt3) ->
let prop = prop || dty = None in
t_if (term env true dt1) (term env prop dt2) (term env prop dt3) loc
t_if (term env dt1) (term env dt2) (term env dt3) loc
| DTlet (pid, dt1, dt2) ->
let prop = prop || dty = None in
let t1 = term env false dt1 in
let vs = create_vsymbol pid (t_type t1) in
let t1 = term env dt1 in
let vs = create_vsymbol pid t1.t_ty in
let env = Mstr.add pid.pid_str vs env in
let t2 = term env prop dt2 in
let t2 = term env dt2 in
t_let vs t1 t2 loc
| DTbinop (b, dt1, dt2) ->
let t1, t2 = (term env true dt1, term env true dt2) in
let t1, t2 = (term env dt1, term env dt2) in
t_binop b t1 t2 loc
| DTnot dt -> t_not (term env true dt) loc
| DTtrue -> if prop then t_true loc else t_bool_true loc
| DTfalse -> if prop then t_false loc else t_bool_false loc
| DTnot dt -> t_not (term env dt) loc
| DTtrue -> t_true loc
| DTfalse -> t_false loc
| DTattr (dt, at) ->
let t = term env prop dt in
let t = term env dt in
t_attr_set at t
| DTold dt -> t_old (term env prop dt) loc
| DTold dt -> t_old (term env dt) loc
| DTquant (q, bl, dt) ->
let add_var (env, vsl) (pid, dty) =
let vs = create_vsymbol pid (ty_of_dty dty) in
(Mstr.add pid.pid_str vs env, vs :: vsl)
in
let env, vsl = List.fold_left add_var (env, []) bl in
let t = term env prop dt in
t_quant q (List.rev vsl) t (Option.map ty_of_dty dty) loc
let t = term env dt in
t_quant q (List.rev vsl) t (ty_of_dty (Option.get dty)) loc
| DTlambda (dpl, dt) ->
let ty = ty_of_dty_raw dty and pl = List.map pattern dpl in
let env =
let join _ _ vs = Some vs in
List.fold_left (fun env (_, vs) -> Mstr.union join env vs) env pl
in
let t = term env false dt in
let t = term env dt in
(* Are the patterns exhaustive? *)
List.iter
(fun (p, _) ->
Expand All @@ -410,16 +387,16 @@ and term_node ~loc env prop dty dterm_node =
[ (p, None, t (* [t] is really just a place holder *)) ]
~loc)
pl;
t_lambda (List.map fst pl) t (Some ty) loc
t_lambda (List.map fst pl) t ty loc
| DTcase (dt, ptl) ->
let t = term env false dt in
let t = term env dt in
let branch (dp, guard, dt) =
let p, vars = pattern dp in
let join _ _ vs = Some vs in
let env = Mstr.union join env vars in
let dt = term env false dt in
let dt = term env dt in
let guard =
match guard with None -> None | Some g -> Some (term env true g)
match guard with None -> None | Some g -> Some (term env g)
in
(p, guard, dt)
in
Expand All @@ -428,5 +405,4 @@ and term_node ~loc env prop dty dterm_node =
Patmat.checks ty pl ~loc;
t_case t pl loc

let fmla env dt = term env true dt
let term env dt = term env false dt
let term env dt = term env dt
3 changes: 1 addition & 2 deletions src/dterm.mli
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ val dty_of_dterm : dterm -> dty
val dty_of_ty : Ttypes.ty -> dty
val dty_fresh : unit -> dty
val max_dty : Coercion.t -> dterm list -> dty option
val specialize_ls : lsymbol -> dty list * dty option
val specialize_ls : lsymbol -> dty list * dty
val specialize_cs : loc:Location.t -> lsymbol -> dty list * dty
val dty_unify : loc:Location.t -> dty -> dty -> unit
val dterm_unify : dterm -> dty -> unit
Expand All @@ -81,5 +81,4 @@ val is_in_denv : denv -> string -> bool
val denv_add_var : denv -> string -> dty -> denv
val denv_add_var_quant : denv -> (Identifier.Preid.t * dty) list -> denv
val term : vsymbol Mstr.t -> dterm -> term
val fmla : vsymbol Mstr.t -> dterm -> term
val pattern : dpattern -> Tterm.pattern * vsymbol Mstr.t
4 changes: 1 addition & 3 deletions src/patmat.ml
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,7 @@ end = struct
match pmat.mat with
| [ e ] ->
let args, l = split [] 0 e in
let hd =
mk_pattern (Papp (ck, args)) (Option.get ck.ls_value) Location.none
in
let hd = mk_pattern (Papp (ck, args)) ck.ls_value Location.none in
{ rows = 1; cols = pmat.cols - ak + 1; mat = [ hd :: l ] }
| _ -> assert false
end
Expand Down
10 changes: 5 additions & 5 deletions src/symbols.ml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ module Mvs = Map.Make (Vs)
type lsymbol = {
ls_name : Ident.t;
ls_args : ty list;
ls_value : ty option;
ls_value : ty;
ls_constr : bool;
(* true if it is a construct, false otherwise*)
ls_field : bool; (* true if it is a record/model field *)
Expand All @@ -57,19 +57,19 @@ let lsymbol ?(constr = false) ~field ls_name ls_args ls_value =
{ ls_name; ls_args; ls_value; ls_constr = constr; ls_field = field }

let fsymbol ?(constr = false) ~field nm tyl ty =
lsymbol ~constr ~field nm tyl (Some ty)
lsymbol ~constr ~field nm tyl ty

let psymbol nm ty = lsymbol ~field:false nm ty None
let psymbol nm ty = lsymbol ~field:false nm ty ty_bool

let ls_subst_ts old_ts new_ts ({ ls_name; ls_constr; ls_field; _ } as ls) =
let ls_args = List.map (ty_subst_ts old_ts new_ts) ls.ls_args in
let ls_value = Option.map (ty_subst_ts old_ts new_ts) ls.ls_value in
let ls_value = ty_subst_ts old_ts new_ts ls.ls_value in
lsymbol ls_name ls_args ls_value ~constr:ls_constr ~field:ls_field

let ls_subst_ty old_ts new_ts new_ty ls =
let subst ty = ty_subst_ty old_ts new_ts new_ty ty in
let ls_args = List.map subst ls.ls_args in
let ls_value = Option.map subst ls.ls_value in
let ls_value = subst ls.ls_value in
lsymbol ls.ls_name ls_args ls_value ~constr:ls.ls_constr ~field:ls.ls_field

(** buil-in lsymbols *)
Expand Down
Loading

0 comments on commit f1f1860

Please sign in to comment.