Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/strpool #184

Merged
merged 76 commits into from
Mar 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
cc3fc26
wip: StrPool
leissa Feb 17, 2023
6dee755
fine-tuning string pool
leissa Feb 22, 2023
3b33413
clang-format
leissa Feb 22, 2023
c24f639
Some refactoring
leissa Feb 24, 2023
0e43cbd
Pos: use 0 to encode invalid
leissa Feb 24, 2023
cce4e0f
wip: StrPool
leissa Feb 17, 2023
220511c
clang-format
leissa Feb 17, 2023
87cd40e
remove git conflict
leissa Feb 24, 2023
066bb1d
compile fix
leissa Feb 24, 2023
381951d
wip: removing dbg field
leissa Feb 25, 2023
09e3c01
clang-format
leissa Feb 25, 2023
cb24b25
clang-format
leissa Feb 25, 2023
3d52608
wip: adjusting to new API
leissa Feb 25, 2023
72b397f
clang-format
leissa Feb 25, 2023
c0c5fb2
wip: porting
leissa Feb 26, 2023
93f5d68
clang-format
leissa Feb 26, 2023
4f53a5b
wip: porting to new API
leissa Feb 26, 2023
0b2d5e3
clang-format
leissa Feb 26, 2023
333d23b
compiles again bug segfaults
leissa Feb 26, 2023
d48c1fb
clang-format
leissa Feb 26, 2023
8bbcebd
fix segfault
leissa Feb 26, 2023
7dc7860
clang-format
leissa Feb 26, 2023
16229a2
all empty strings/nullptrs etc are now handled as Sym()
leissa Feb 27, 2023
c6f5769
clang-format
leissa Feb 27, 2023
a0dd8c3
more porting
leissa Feb 28, 2023
d03a8e0
clang-format
leissa Feb 28, 2023
b83b8c9
it compiles again
leissa Mar 1, 2023
da6e4d8
clang-format
leissa Mar 1, 2023
93b620f
removed unused code
leissa Mar 1, 2023
5d5f10b
refactor
leissa Mar 1, 2023
67ca27b
clang-format
leissa Mar 1, 2023
1243678
polish
leissa Mar 1, 2023
e53cf57
clang-format
leissa Mar 1, 2023
d6ade83
bug fix + support for automatic Loc setting in World
leissa Mar 3, 2023
8d1d61d
misc
leissa Mar 3, 2023
5047a59
clang-format
leissa Mar 3, 2023
541260d
local static - maybe this works
leissa Mar 3, 2023
b831b8f
fiddling around with Doxygen
leissa Mar 4, 2023
ca19e79
first lit tests are working again
leissa Mar 5, 2023
14404f0
clang-format
leissa Mar 5, 2023
262115f
bug fix
leissa Mar 5, 2023
d8bad0e
propagate meta through rebuild, passes, etc
leissa Mar 5, 2023
cc116b8
clang-format
leissa Mar 5, 2023
6871d6b
removed fields and World::sym2tuple - not used
leissa Mar 5, 2023
b5e1598
fix lit/affine/lower_for.thorin
leissa Mar 5, 2023
04fde90
all lit test cases work again
leissa Mar 5, 2023
fc67b53
compile fix in release build
leissa Mar 6, 2023
a92a5d4
fix warning
leissa Mar 6, 2023
5a69b2b
set world's name
leissa Mar 6, 2023
b8c8d75
whitespace
leissa Mar 6, 2023
f71f6c3
fixing move semantics
leissa Mar 6, 2023
9d423e6
cleanup
leissa Mar 7, 2023
0c0affc
track field names of nom sigma in meta field
leissa Mar 7, 2023
959dd95
make implicit part of the funciton type
leissa Mar 7, 2023
fe134f7
clang-format
leissa Mar 7, 2023
b0fb508
bug fix in World::iapp: callee may actually sth weird
leissa Mar 7, 2023
75a8121
fix test case
leissa Mar 7, 2023
b532dea
error in case of unhandled def in LLVM backend
leissa Mar 7, 2023
aaf8cc7
removed Def::meta and friends
leissa Mar 7, 2023
05e833b
Def::isa_lit_arity & Def::as_lit_arity
leissa Mar 7, 2023
c6c4e45
clang-format
leissa Mar 7, 2023
2997a1d
Introduced Def::external_ flag
leissa Mar 8, 2023
4d9a345
removed unused code
leissa Mar 8, 2023
3f6234c
place config.h into build/include/thorin
leissa Mar 8, 2023
0d210c4
Hopefuly, make MSVC happy
leissa Mar 8, 2023
29eeefb
does this make MSVC happy? + formmatting
leissa Mar 8, 2023
cf5a8af
cleanup
leissa Mar 8, 2023
690090f
probing msvc
leissa Mar 8, 2023
c0fc248
hopefully this makes msvc happy
leissa Mar 8, 2023
6847eaa
fix
leissa Mar 8, 2023
2b1efbc
formatting
leissa Mar 8, 2023
b1ed3df
docs
leissa Mar 8, 2023
99d0a2e
fixed incorrect renaming
leissa Mar 8, 2023
40002b8
removed "performance hack"
leissa Mar 9, 2023
59f75fd
resolving comment
leissa Mar 9, 2023
cf4439d
resolving comment
leissa Mar 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "thorin/config.h"
#include "thorin/dialects.h"
#include "thorin/driver.h"

#include "thorin/be/dot/dot.h"
#include "thorin/fe/parser.h"
Expand All @@ -26,7 +27,7 @@ int main(int argc, char** argv) {
try {
static const auto version = "thorin command-line utility version " THORIN_VER "\n";

World::State state;
Driver driver;
bool show_help = false;
bool show_version = false;
std::string input, prefix;
Expand All @@ -37,7 +38,7 @@ int main(int argc, char** argv) {
int verbose = 0;
int opt = 2;
auto inc_verbose = [&](bool) { ++verbose; };
auto& flags = state.pod.flags;
auto& flags = driver.flags;

// clang-format off
auto cli = lyra::cli()
Expand Down Expand Up @@ -78,9 +79,9 @@ int main(int argc, char** argv) {
}

#if THORIN_ENABLE_CHECKS
for (auto b : breakpoints) state.breakpoints.emplace(b);
for (auto b : breakpoints) driver.breakpoints.emplace(b);
#endif
World world(state);
World& world = driver.world;
world.log().ostream = &std::cerr;
world.log().level = (Log::Level)verbose;
// prepare output files and streams
Expand Down Expand Up @@ -124,9 +125,11 @@ int main(int argc, char** argv) {
}

for (const auto& dialect : dialects)
fe::Parser::import_module(world, dialect.name(), dialect_paths, &normalizers);
fe::Parser::import_module(world, world.sym(dialect.name()), dialect_paths, &normalizers);

fe::Parser parser(world, input, ifs, dialect_paths, &normalizers, os[Md]);
auto sym = world.sym(std::move(input));
world.set(sym);
fe::Parser parser(world, sym, ifs, dialect_paths, &normalizers, os[Md]);
parser.parse_module();

if (os[H]) {
Expand Down
4 changes: 2 additions & 2 deletions cmake/Thorin.cmake
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# clear globals
SET(THORIN_DIALECT_LIST "" CACHE INTERNAL "THORIN_DIALECT_LIST")
SET(THORIN_DIALECT_LAYOUT "" CACHE INTERNAL "THORIN_DIALECT_LAYOUT")
SET(THORIN_DIALECT_LIST "" CACHE INTERNAL "THORIN_DIALECT_LIST")
SET(THORIN_DIALECT_LAYOUT "" CACHE INTERNAL "THORIN_DIALECT_LAYOUT")

if(NOT THORIN_TARGET_NAMESPACE)
set(THORIN_TARGET_NAMESPACE "")
Expand Down
22 changes: 13 additions & 9 deletions dialects/affine/passes/lower_for.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,22 @@ const Def* LowerFor::rewrite(const Def* def) {

auto for_pi = for_ax->callee_type();
DefArray for_dom{for_pi->num_doms() - 2, [&](size_t i) { return for_pi->dom(i); }};
auto for_lam = w.nom_lam(w.cn(for_dom), w.dbg("for"));
auto for_lam = w.nom_lam(w.cn(for_dom))->set("for");

auto body = for_ax->arg(for_ax->num_args() - 2, w.dbg("body"));
auto brk = for_ax->arg(for_ax->num_args() - 1, w.dbg("break"));
auto body = for_ax->arg(for_ax->num_args() - 2)->set("body");
auto brk = for_ax->arg(for_ax->num_args() - 1)->set("break");

auto body_type = body->type()->as<Pi>();
auto yield_pi = body_type->doms().back()->as<Pi>();
auto yield_lam = w.nom_lam(yield_pi, w.dbg("yield"));
auto yield_lam = w.nom_lam(yield_pi)->set("yield");

{ // construct yield
auto [iter, end, step, acc] = for_lam->vars<4>({w.dbg("begin"), w.dbg("end"), w.dbg("step"), w.dbg("acc")});
auto yield_acc = yield_lam->var();
auto [iter, end, step, acc] = for_lam->vars<4>();
iter->set("iter");
end->set("end");
step->set("step");
acc->set("acc");
auto yield_acc = yield_lam->var();

auto add = w.call(core::wrap::add, 0_n, Defs{iter, step});
yield_lam->app(false, for_lam, {add, end, step, yield_acc});
Expand All @@ -37,20 +41,20 @@ const Def* LowerFor::rewrite(const Def* def) {

// reduce the body to remove the cn parameter
auto nom_body = body->as_nom<Lam>();
auto new_body = nom_body->stub(w, w.cn(w.sigma()), body->dbg());
auto new_body = nom_body->stub(w, w.cn(w.sigma()))->set(body->dbg());
new_body->set(nom_body->reduce(w.tuple({iter, acc, yield_lam})));

// break
auto if_else_cn = w.cn(w.sigma());
auto if_else = w.nom_lam(if_else_cn, nullptr);
auto if_else = w.nom_lam(if_else_cn);
if_else->app(false, brk, acc);

auto cmp = w.call(core::icmp::ul, Defs{iter, end});
for_lam->branch(false, cmp, new_body, if_else, w.tuple());
}

DefArray for_args{for_ax->num_args() - 2, [&](size_t i) { return for_ax->arg(i); }};
return rewritten_[def] = w.app(for_lam, for_args, for_ax->dbg());
return rewritten_[def] = w.app(for_lam, for_args);
}

return def;
Expand Down
14 changes: 8 additions & 6 deletions dialects/autodiff/auxiliary/autodiff_aux.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
#include "dialects/autodiff/autodiff.h"
#include "dialects/mem/mem.h"

using namespace std::literals;

namespace thorin::autodiff {

const Def* id_pullback(const Def* A) {
auto& world = A->world();
auto arg_pb_ty = pullback_type(A, A);
auto id_pb = world.nom_lam(arg_pb_ty, world.dbg("id_pb"));
auto id_pb_scalar = id_pb->var((nat_t)0, world.dbg("s"));
auto id_pb = world.nom_lam(arg_pb_ty)->set("id_pb");
auto id_pb_scalar = id_pb->var(0_s)->set("s");
id_pb->app(true,
id_pb->var(1), // can not use ret_var as the result might be higher order
id_pb_scalar);
Expand All @@ -24,7 +26,7 @@ const Def* zero_pullback(const Def* E, const Def* A) {
auto& world = A->world();
auto A_tangent = tangent_type_fun(A);
auto pb_ty = pullback_type(E, A);
auto pb = world.nom_lam(pb_ty, world.dbg("zero_pb"));
auto pb = world.nom_lam(pb_ty)->set("zero_pb");
world.DLOG("zero_pullback for {} resp. {} (-> {})", E, A, A_tangent);
pb->app(true, pb->var(1), op_zero(A_tangent));
return pb;
Expand Down Expand Up @@ -129,7 +131,7 @@ const Def* zero_def(const Def* T) {
return zero_arr;
} else if (Idx::size(T)) {
// TODO: real
auto zero = world.lit(T, 0, world.dbg("zero"));
auto zero = world.lit(T, 0)->set("zero");
world.DLOG("zero_def for int is {}", zero);
return zero;
} else if (auto sig = T->isa<Sigma>()) {
Expand Down Expand Up @@ -234,8 +236,8 @@ const Def* compose_continuation(const Def* f, const Def* g) {
auto H = world.cn({A, world.cn(C)});
auto Hcont = world.cn(B);

auto h = world.nom_lam(H, world.dbg("comp_" + f->name() + "_" + g->name()));
auto hcont = world.nom_lam(Hcont, world.dbg("comp_" + f->name() + "_" + g->name() + "_cont"));
auto h = world.nom_lam(H)->set("comp_"s + *f->sym() + "_"s + *g->sym());
auto hcont = world.nom_lam(Hcont)->set("comp_"s + *f->sym() + "_"s + *g->sym() + "_cont"s);

h->app(true, g, {h->var((nat_t)0), hcont});

Expand Down
28 changes: 15 additions & 13 deletions dialects/autodiff/auxiliary/autodiff_rewrite_inner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include "dialects/core/core.h"
#include "dialects/direct/direct.h"

using namespace std::literals;

namespace thorin::autodiff {

// TODO remove macro
Expand Down Expand Up @@ -44,8 +46,8 @@ const Def* AutoDiffEval::augment_lam(Lam* lam, Lam* f, Lam* f_diff) {
}
// TODO: better fix (another pass as analysis?)
// TODO: handle open functions
if (is_open_continuation(lam) || lam->name().find("ret") != std::string::npos ||
lam->name().find("_cont") != std::string::npos) {
if (is_open_continuation(lam) || lam->sym()->find("ret") != std::string::npos ||
lam->sym()->find("_cont") != std::string::npos) {
// A open continuation behaves the same as return:
// ```
// cont: Cn[X]
Expand All @@ -62,7 +64,7 @@ const Def* AutoDiffEval::augment_lam(Lam* lam, Lam* f, Lam* f_diff) {
world.DLOG("pb type is {}", pb_ty);
auto aug_ty = world.cn({aug_dom, pb_ty});
world.DLOG("augmented type is {}", aug_ty);
auto aug_lam = world.nom_lam(aug_ty, world.dbg("aug_" + lam->name()));
auto aug_lam = world.nom_lam(aug_ty)->set("aug_"s + *lam->sym());
auto aug_var = aug_lam->var((nat_t)0);
augmented[lam->var()] = aug_var;
augmented[lam] = aug_lam; // TODO: only one of these two
Expand Down Expand Up @@ -115,10 +117,10 @@ const Def* AutoDiffEval::augment_extract(const Extract* ext, Lam* f, Lam* f_diff
assert(partial_pullback.count(aug_tuple));
auto tuple_pb = partial_pullback[aug_tuple];
auto pb_ty = pullback_type(ext->type(), f_arg_ty);
auto pb_fun = world.nom_lam(pb_ty, world.dbg("extract_pb"));
auto pb_fun = world.nom_lam(pb_ty)->set("extract_pb");
world.DLOG("Pullback: {} : {}", pb_fun, pb_fun->type());
auto pb_tangent = pb_fun->var((nat_t)0, world.dbg("s"));
auto tuple_tan = world.insert(op_zero(aug_tuple->type()), aug_index, pb_tangent, world.dbg("tup_s"));
auto pb_tangent = pb_fun->var(0_s)->set("s");
auto tuple_tan = world.insert(op_zero(aug_tuple->type()), aug_index, pb_tangent)->set("tup_s");
pb_fun->app(true, tuple_pb,
{
tuple_tan,
Expand Down Expand Up @@ -152,12 +154,12 @@ const Def* AutoDiffEval::augment_tuple(const Tuple* tup, Lam* f, Lam* f_diff) {
// ((cps2ds e0*) (s#0), ..., (cps2ds em*) (s#m))
// ```
auto pb_ty = pullback_type(tup->type(), f_arg_ty);
auto pb = world.nom_lam(pb_ty, world.dbg("tup_pb"));
auto pb = world.nom_lam(pb_ty)->set("tup_pb");
world.DLOG("Augmented tuple: {} : {}", aug_tup, aug_tup->type());
world.DLOG("Tuple Pullback: {} : {}", pb, pb->type());
world.DLOG("shadow pb: {} : {}", shadow_pb, shadow_pb->type());

auto pb_tangent = pb->var((nat_t)0, world.dbg("tup_s"));
auto pb_tangent = pb->var(0_s)->set("tup_s");

DefArray tangents(pbs.size(),
[&](nat_t i) { return world.app(direct::op_cps2ds_dep(pbs[i]), world.extract(pb_tangent, i)); });
Expand Down Expand Up @@ -188,7 +190,7 @@ const Def* AutoDiffEval::augment_pack(const Pack* pack, Lam* f, Lam* f_diff) {
world.DLOG("shadow pb of pack: {} : {}", pb_pack, pb_pack->type());

auto pb_type = pullback_type(pack->type(), f_arg_ty);
auto pb = world.nom_lam(pb_type, world.dbg("pack_pb"));
auto pb = world.nom_lam(pb_type)->set("pack_pb");

world.DLOG("pb of pack: {} : {}", pb, pb_type);

Expand Down Expand Up @@ -305,7 +307,7 @@ const Def* AutoDiffEval::augment_app(const App* app, Lam* f, Lam* f_diff) {
world.DLOG("ret_g_deriv_ty: {} ", ret_g_deriv_ty);
auto c1_ty = ret_g_deriv_ty->as<Pi>();
world.DLOG("c1_ty: (cn[X, cn[X+, cn E+]]) {}", c1_ty);
auto c1 = world.nom_lam(c1_ty, world.dbg("c1"));
auto c1 = world.nom_lam(c1_ty)->set("c1");
auto res = c1->var((nat_t)0);
auto r_pb = c1->var(1);
c1->app(true, aug_cont, {res, compose_continuation(e_pb, r_pb)});
Expand Down Expand Up @@ -364,14 +366,14 @@ const Def* AutoDiffEval::augment_(const Def* def, Lam* f, Lam* f_diff) {
world.DLOG("Augment axiom: {} : {}", ax, ax->type());
world.DLOG("axiom curry: {}", ax->curry());
world.DLOG("axiom flags: {}", ax->flags());
std::string diff_name = ax->name();
std::string diff_name = ax->sym();
findAndReplaceAll(diff_name, ".", "_");
findAndReplaceAll(diff_name, "%", "");
diff_name = "internal_diff_" + diff_name;
world.DLOG("axiom name: {}", ax->name());
world.DLOG("axiom name: {}", ax->sym());
world.DLOG("axiom function name: {}", diff_name);

auto diff_fun = world.lookup(diff_name);
auto diff_fun = world.lookup(world.sym(diff_name));
if (!diff_fun) {
world.ELOG("derivation not found: {}", diff_name);
auto expected_type = autodiff_type_fun(ax->type());
Expand Down
4 changes: 2 additions & 2 deletions dialects/autodiff/auxiliary/autodiff_rewrite_toplevel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ const Def* AutoDiffEval::derive_(const Def* def) {
auto lam = def->as_nom<Lam>(); // TODO check if nominal
world.DLOG("Derive lambda: {}", def);
auto deriv_ty = autodiff_type_fun_pi(lam->type());
auto deriv = world.nom_lam(deriv_ty, world.dbg(lam->name() + "_deriv"));
auto deriv = world.nom_lam(deriv_ty)->set(*lam->sym() + "_deriv");

// We first pre-register the derivatives.
// This knowledge is needed for recursion.
Expand All @@ -20,7 +20,7 @@ const Def* AutoDiffEval::derive_(const Def* def) {

auto [arg_ty, ret_pi] = lam->type()->doms<2>();
auto deriv_all_args = deriv->var();
const Def* deriv_arg = deriv->var((nat_t)0, world.dbg("arg"));
const Def* deriv_arg = deriv->var(0_s)->set("arg");

// We generate the shadow pullbacks dynamically to save work and avoid code duplication.
// Only the toplevel pullback for arguments and return continuation is special cased.
Expand Down
22 changes: 11 additions & 11 deletions dialects/autodiff/normalizers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,31 @@ namespace thorin::autodiff {

/// Currently this normalizer does nothin.
/// TODO: Maybe we want to handle trivial lookup replacements here.
const Def* normalize_ad(const Def* type, const Def* callee, const Def* arg, const Def* dbg) {
Ref normalize_ad(Ref type, Ref callee, Ref arg) {
auto& world = type->world();
return world.raw_app(type, callee, arg, dbg);
return world.raw_app(type, callee, arg);
}

const Def* normalize_AD(const Def* type, const Def* callee, const Def* arg, const Def* dbg) {
Ref normalize_AD(Ref type, Ref callee, Ref arg) {
auto& world = type->world();
auto ad_ty = autodiff_type_fun(arg);
if (ad_ty) return ad_ty;
return world.raw_app(type, callee, arg, dbg);
return world.raw_app(type, callee, arg);
}

const Def* normalize_Tangent(const Def*, const Def*, const Def* arg, const Def*) { return tangent_type_fun(arg); }
Ref normalize_Tangent(Ref, Ref, Ref arg) { return tangent_type_fun(arg); }

/// Currently this normalizer does nothing.
/// We usually want to keep zeros as long as possible to avoid unnecessary allocations.
/// A high-level addition with zero can be shortened directly.
const Def* normalize_zero(const Def* type, const Def* callee, const Def* arg, const Def* dbg) {
Ref normalize_zero(Ref type, Ref callee, Ref arg) {
auto& world = type->world();
return world.raw_app(type, callee, arg, dbg);
return world.raw_app(type, callee, arg);
}

/// Currently resolved the full addition.
/// There is no benefit in keeping additions around longer than necessary.
const Def* normalize_add(const Def* type, const Def* callee, const Def* arg, const Def* dbg) {
Ref normalize_add(Ref type, Ref callee, Ref arg) {
auto& world = type->world();

// TODO: add tuple -> tuple of adds
Expand Down Expand Up @@ -85,10 +85,10 @@ const Def* normalize_add(const Def* type, const Def* callee, const Def* arg, con
}
// TODO: mem stays here (only resolved after direct simplification)

return world.raw_app(type, callee, arg, dbg);
return world.raw_app(type, callee, arg);
}

const Def* normalize_sum(const Def* type, const Def* callee, const Def* arg, const Def* dbg) {
Ref normalize_sum(Ref type, Ref callee, Ref arg) {
auto& world = type->world();

auto [count, T] = callee->as<App>()->args<2>();
Expand All @@ -105,7 +105,7 @@ const Def* normalize_sum(const Def* type, const Def* callee, const Def* arg, con
}
assert(0);

return world.raw_app(type, callee, arg, dbg);
return world.raw_app(type, callee, arg);
}

THORIN_autodiff_NORMALIZER_IMPL
Expand Down
2 changes: 1 addition & 1 deletion dialects/autodiff/passes/autodiff_ext_cleanup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace thorin::autodiff {

void AutoDiffExternalCleanup::enter() {
Lam* lam = curr_nom();
if (lam->name().starts_with("internal_diff_")) {
if (lam->sym()->starts_with("internal_diff_")) {
lam->make_internal();
world().DLOG("internalized {}", lam);
}
Expand Down
Loading