Skip to content

Commit

Permalink
API: configurable calc
Browse files Browse the repository at this point in the history
  • Loading branch information
gares committed Nov 30, 2023
1 parent 88288be commit 2c9b588
Show file tree
Hide file tree
Showing 8 changed files with 278 additions and 193 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ API:
(experimental)
- New `BuiltInPredicate.FullHO` for higher order external predicates
- New `BuiltInPredicate.HOAdaptors` for `map` and `filter` like HO predicates
- New `Calc.register` to register operators for `calc` (aka infix `is`)

Library:
- New `std.fold-right`
Expand Down
211 changes: 204 additions & 7 deletions src/API.ml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ module Setup = struct
type state_descriptor = Data.State.descriptor
type quotations_descriptor = Data.QuotationHooks.descriptor ref
type hoas_descriptor = Data.HoasHooks.descriptor ref
type calc_descriptor = Data.CalcHooks.descriptor ref

let default_state_descriptor = Data.State.new_descriptor ()
let default_quotations_descriptor = Data.QuotationHooks.new_descriptor ()
let default_hoas_descriptor = Data.HoasHooks.new_descriptor ()
let default_calc_descriptor = Data.CalcHooks.new_descriptor ()

type builtins = Compiler.builtins
type elpi = {
Expand All @@ -40,7 +42,7 @@ type elpi = {
}
type flags = Compiler.flags

let init ?(flags=Compiler.default_flags) ?(state=default_state_descriptor) ?(quotations=default_quotations_descriptor) ?(hoas=default_hoas_descriptor) ~builtins ?file_resolver ?(legacy_parser=false) () : elpi =
let init ?(flags=Compiler.default_flags) ?(state=default_state_descriptor) ?(quotations=default_quotations_descriptor) ?(hoas=default_hoas_descriptor) ?(calc=default_calc_descriptor) ~builtins ?file_resolver ?(legacy_parser=false) () : elpi =
(* At the moment we can only init the parser once *)
let file_resolver =
match file_resolver with
Expand All @@ -60,7 +62,7 @@ let init ?(flags=Compiler.default_flags) ?(state=default_state_descriptor) ?(quo
(* This is a bit ugly, since we print and then parse... *)
let b = Buffer.create 1024 in
let fmt = Format.formatter_of_buffer b in
Data.BuiltInPredicate.document fmt decls;
Data.BuiltInPredicate.document fmt decls (List.rev !calc);
Format.pp_print_flush fmt ();
let text = Buffer.contents b in
let lexbuf = Lexing.from_string text in
Expand All @@ -76,7 +78,7 @@ let init ?(flags=Compiler.default_flags) ?(state=default_state_descriptor) ?(quo
Util.Loc.(loc.source_stop - loc.line_starts_at));
Util.anomaly ~loc msg) in
let header =
try Compiler.header_of_ast ~flags ~parser state !quotations !hoas builtins (List.concat header_src)
try Compiler.header_of_ast ~flags ~parser state !quotations !hoas !calc builtins (List.concat header_src)
with Compiler.CompileError(loc,msg) -> Util.anomaly ?loc msg in
{ parser; header; resolver = file_resolver }

Expand Down Expand Up @@ -976,13 +978,13 @@ end
module BuiltIn = struct
include ED.BuiltInPredicate
let declare ~file_name l = file_name, l
let document_fmt fmt (_,l) =
ED.BuiltInPredicate.document fmt l
let document_file ?(header="") (name,l) =
let document_fmt fmt ?(calc=Setup.default_calc_descriptor) (_,l) =
ED.BuiltInPredicate.document fmt l (List.rev !calc)
let document_file ?(header="") ?(calc=Setup.default_calc_descriptor) (name,l) =
let oc = open_out name in
let fmt = Format.formatter_of_out_channel oc in
Format.fprintf fmt "%s%!" header;
ED.BuiltInPredicate.document fmt l;
ED.BuiltInPredicate.document fmt l (List.rev !calc);
Format.pp_print_flush fmt ();
close_out oc
end
Expand Down Expand Up @@ -1068,6 +1070,201 @@ module Quotation = struct

end

module Calc = struct

let new_calc_descriptor = ED.CalcHooks.new_descriptor

type operation_declaration = {
symbol : string;
infix : bool;
args : string list list;
code : ED.term list -> ED.term;
}

let compile_operation_declaration { symbol; infix; args; code } =
let c = ED.Global_symbols.declare_global_symbol symbol in
let ty_decl args =
if infix then
Printf.sprintf "type (%s) %s." symbol (String.concat " -> " args)
else
Printf.sprintf "type %s %s." symbol (String.concat " -> " args) in
c, { ED.CalcHooks.ty_decl = List.map ty_decl args |> String.concat "\n"; code }

let register ~descriptor d =
let e = compile_operation_declaration d in
descriptor := e :: !descriptor

let register_eval n (symbol,tys) code =
let infix, n = n < 0, abs n in
let args = tys |> List.map (fun ty -> List.init (n+1) (fun _ -> ty)) in
[{ symbol; infix; args; code }]

let register_eval_ty symbol ty code =
let infix = false in
let args = [ty] in
[{ symbol; infix; args; code }]


let register_evals n l f = List.map (fun i -> register_eval n i f) l |> List.flatten

let default_calc =
let open Util in
let open RawOpaqueData in
List.flatten [
register_evals ~-2 [ "-",["A"] ; "i-",["int"] ; "r-",["float"] ] (function
| [ CData x; CData y ] when ty2 int x y -> (morph2 int (-) x y)
| [ CData x; CData y ] when ty2 float x y -> (morph2 float (-.) x y)
| _ -> type_error "Wrong arguments to -/i-/r-") ;
register_evals ~-2 [ "+",["int";"float"] ; "i+",["int"] ; "r+",["float"] ] (function
| [ CData x; CData y ] when ty2 int x y -> (morph2 int (+) x y)
| [ CData x; CData y ] when ty2 float x y -> (morph2 float (+.) x y)
| _ -> type_error "Wrong arguments to +/i+/r+") ;
register_eval ~-2 ("*",["int";"float"]) (function
| [ CData x; CData y ] when ty2 int x y -> (morph2 int ( * ) x y)
| [ CData x; CData y] when ty2 float x y -> (morph2 float ( *.) x y)
| _ -> type_error "Wrong arguments to *") ;
register_eval ~-2 ("/",["float"]) (function
| [ CData x; CData y] when ty2 float x y -> (morph2 float ( /.) x y)
| _ -> type_error "Wrong arguments to /") ;
register_eval ~-2 ("mod",["int"]) (function
| [ CData x; CData y ] when ty2 int x y -> (morph2 int (mod) x y)
| _ -> type_error "Wrong arguments to mod") ;
register_eval ~-2 ("div",["int"]) (function
| [ CData x; CData y ] when ty2 int x y -> (morph2 int (/) x y)
| _ -> type_error "Wrong arguments to div") ;
register_eval ~-2 ("^",["string"]) (function
| [ CData x; CData y ] when ty2 string x y ->
of_string (to_string x ^ to_string y)
| _ -> type_error "Wrong arguments to ^") ;
register_evals ~-1 [ "~",["int";"float"] ; "i~",["int"] ; "r~",["float"] ] (function
| [ CData x ] when is_int x -> (morph1 int (~-) x)
| [ CData x ] when is_float x -> (morph1 float (~-.) x)
| _ -> type_error "Wrong arguments to ~/i~/r~") ;
register_evals 1 [ "abs",["int";"float"] ; "iabs",["int"] ; "rabs",["float"] ] (function
| [ CData x ] when is_int x -> (map int int abs x)
| [ CData x ] when is_float x -> (map float float abs_float x)
| _ -> type_error "Wrong arguments to abs/iabs/rabs") ;
register_evals 2 [ "max",["int";"float"]] (function
| [ CData x; CData y ] when ty2 int x y -> (morph2 int max x y)
| [ CData x; CData y ] when ty2 float x y -> (morph2 float max x y)
| _ -> type_error "Wrong arguments to abs/iabs/rabs") ;
register_evals 2 [ "min",["int";"float"]] (function
| [ CData x; CData y ] when ty2 int x y -> (morph2 int min x y)
| [ CData x; CData y ] when ty2 float x y -> (morph2 float min x y)
| _ -> type_error "Wrong arguments to abs/iabs/rabs") ;
register_eval 1 ("sqrt",["float"]) (function
| [ CData x ] when is_float x -> (map float float sqrt x)
| _ -> type_error "Wrong arguments to sqrt") ;
register_eval 1 ("sin",["float"]) (function
| [ CData x ] when is_float x -> (map float float sqrt x)
| _ -> type_error "Wrong arguments to sin") ;
register_eval 1 ("cos",["float"]) (function
| [ CData x ] when is_float x -> (map float float cos x)
| _ -> type_error "Wrong arguments to cosin") ;
register_eval 1 ("arctan",["float"]) (function
| [ CData x ] when is_float x -> (map float float atan x)
| _ -> type_error "Wrong arguments to arctan") ;
register_eval 1 ("ln",["float"]) (function
| [ CData x ] when is_float x -> (map float float log x)
| _ -> type_error "Wrong arguments to ln") ;
register_eval_ty "int_to_real" ["int";"float"] (function
| [ CData x ] when is_int x -> (map int float float_of_int x)
| _ -> type_error "Wrong arguments to int_to_real") ;
register_eval_ty "floor" ["float";"int"] (function
| [ CData x ] when is_float x ->
(map float int (fun x -> int_of_float (floor x)) x)
| _ -> type_error "Wrong arguments to floor") ;
register_eval_ty "ceil" ["float";"int"] (function
| [ CData x ] when is_float x ->
(map float int (fun x -> int_of_float (ceil x)) x)
| _ -> type_error "Wrong arguments to ceil") ;
register_eval_ty "truncate" ["float";"int"] (function
| [ CData x ] when is_float x -> (map float int truncate x)
| _ -> type_error "Wrong arguments to truncate") ;
register_eval_ty "size" ["string";"int"] (function
| [ CData x ] when is_string x ->
of_int (String.length (to_string x))
| _ -> type_error "Wrong arguments to size") ;
register_eval_ty "chr" ["int";"string"] (function
| [ CData x ] when is_int x ->
of_string (String.make 1 (char_of_int (to_int x)))
| _ -> type_error "Wrong arguments to chr") ;
register_eval_ty "rhc" ["string";"int"] (function
| [ CData x ] when is_string x && String.length (to_string x) = 1 ->
of_int (int_of_char (to_string x).[0])
| _ -> type_error "Wrong arguments to rhc") ;
register_eval_ty "string_to_int" ["string";"int"] (function
| [ CData x ] when is_string x -> of_int (int_of_string (to_string x))
| _ -> type_error "Wrong arguments to string_to_int") ;
register_eval_ty "int_to_string" ["int";"string"] (function
| [ CData x ] when is_int x ->
of_string (string_of_int (to_int x))
| _ -> type_error "Wrong arguments to int_to_string") ;
register_eval_ty "substring" ["string";"int";"int";"string"] (function
| [ CData x ; CData i ; CData j ] when is_string x && ty2 int i j ->
let x = to_string x and i = to_int i and j = to_int j in
if i >= 0 && j >= 0 && String.length x >= i+j then
of_string (String.sub x i j)
else type_error "Wrong arguments to substring"
| _ -> type_error "Wrong argument type to substring") ;
register_eval_ty "real_to_string" ["float";"string"] (function
| [ CData x ] when is_float x ->
of_string (string_of_float (to_float x))
| _ -> type_error "Wrong arguments to real_to_string")
]

let () = List.iter (register ~descriptor:Setup.default_calc_descriptor) default_calc

let eval ~depth state x =
let table = ED.State.get ED.CalcHooks.eval state in
let lookup_eval c = ED.Constants.Map.find c table in
let module R = (val !r) in let open R in
let rec eval depth t =
match deref_head ~depth t with
| Lam _ -> Util.type_error "Evaluation of a lambda abstraction"
| Builtin _ -> Util.type_error "Evaluation of built-in predicate"
| App (hd,arg,args) ->
let f =
try lookup_eval hd
with Not_found ->
function
| [] -> assert false
| x::xs -> ED.mkApp hd x xs in
let args = List.map (fun x -> eval depth x) (arg::args) in
f args
| AppUVar _ | UVar _ | Discard -> Util.error "Evaluation of a non closed term. Maybe delay this predicate call and declare a constraint."
| Arg _ | AppArg _ -> Util.anomaly "Evaluation of a stack term"
| Const hd as x ->
let f =
try lookup_eval hd
with Not_found -> fun _ -> x in
f []
| (Nil | Cons _ as x) -> Util.type_error ("Lists cannot be evaluated: " ^ ED.show_term x)
| CData _ as x -> x
in
eval depth x

let calc =
let open BuiltIn in
let open ContextualConversion in
let open BuiltInPredicate.Notation in
[
LPDoc " -- Evaluation --";

LPCode "pred (is) o:A, i:A.";
LPCode "X is Y :- calc Y X.";

MLCode(Pred("calc",
In(BuiltInData.poly "A", "Expr",
Out(BuiltInData.poly "A", "Out",
Read(unit_ctx, "unifies Out with the value of Expr. It can be used in tandem with spilling, eg [f {calc (N + 1)}]"))),
(fun t _ ~depth _ _ state -> !: (eval ~depth state t))),
DocAbove);
]

end


module Utils = struct
let lp_list_to_list ~depth t =
let module R = (val !r) in let open R in
Expand Down
34 changes: 32 additions & 2 deletions src/API.mli
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ module Setup : sig
(* Built-in predicates, see {!module:BuiltIn} *)
type builtins

(* Operations avaiable via calc (infix is), see {!module:Calc} *)
type calc_descriptor

(* Compilation flags, see {!module:Compile} *)
type flags

Expand All @@ -72,6 +75,7 @@ module Setup : sig
?state:state_descriptor ->
?quotations:quotations_descriptor ->
?hoas:hoas_descriptor ->
?calc:calc_descriptor ->
builtins:builtins list ->
?file_resolver:(?cwd:string -> unit:string -> unit -> string) ->
?legacy_parser:bool ->
Expand Down Expand Up @@ -728,8 +732,8 @@ module BuiltIn : sig
(** Prints in LP syntax the "external" declarations.
* The file builtin.elpi is generated by calling this API on the
* declaration list from elpi_builtin.ml *)
val document_fmt : Format.formatter -> Setup.builtins -> unit
val document_file : ?header:string -> Setup.builtins -> unit
val document_fmt : Format.formatter -> ?calc:Setup.calc_descriptor -> Setup.builtins -> unit
val document_file : ?header:string -> ?calc:Setup.calc_descriptor -> Setup.builtins -> unit

end

Expand Down Expand Up @@ -987,6 +991,32 @@ module RawOpaqueData : sig
val to_loc : t -> Ast.Loc.t
val of_loc : Ast.Loc.t -> Data.term

end

module Calc : sig

type operation_declaration = {
symbol : string;
infix : bool; (* used for the doc *)
args : string list list; (* multiple types for the same symbol *)
code : Data.term list -> Data.term;
}

(** Registering an operation *)
val register : descriptor:Setup.calc_descriptor -> operation_declaration -> unit

(** An empty descriptor for registering operations *)
val new_calc_descriptor : unit -> Setup.calc_descriptor

(** Standard operations *)
val default_calc : operation_declaration list

(** The [calc] and [is] declarations *)
val calc : BuiltIn.declaration list

(** for use in other builtins *)
val eval : depth:int -> State.t -> Data.term -> Data.term

end

(** This module exposes the low level representation of terms.
Expand Down
14 changes: 5 additions & 9 deletions src/builtin.elpi
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,15 @@ stop :- halt.

% -- Evaluation --

pred (is) o:A, i:A.

X is Y :- calc Y X.

% [calc Expr Out] unifies Out with the value of Expr. It can be used in
% tandem with spilling, eg [f {calc (N + 1)}]
external pred calc i:A, o:A.

pred (is) o:A, i:A.

X is Y :- calc Y X.
% --- Operators ---

type (-) A -> A -> A.

Expand All @@ -83,15 +85,13 @@ type (i-) int -> int -> int.
type (r-) float -> float -> float.

type (+) int -> int -> int.

type (+) float -> float -> float.

type (i+) int -> int -> int.

type (r+) float -> float -> float.

type (*) int -> int -> int.

type (*) float -> float -> float.

type (/) float -> float -> float.
Expand All @@ -103,27 +103,23 @@ type (div) int -> int -> int.
type (^) string -> string -> string.

type (~) int -> int.

type (~) float -> float.

type (i~) int -> int.

type (r~) float -> float.

type abs int -> int.

type abs float -> float.

type iabs int -> int.

type rabs float -> float.

type max int -> int -> int.

type max float -> float -> float.

type min int -> int -> int.

type min float -> float -> float.

type sqrt float -> float.
Expand Down
Loading

0 comments on commit 2c9b588

Please sign in to comment.