Skip to content

Commit

Permalink
fix bug in mint inference (#4562)
Browse files Browse the repository at this point in the history
This fixes an issue where AddSortInjections would not pick the right
sort parameters for parametric productions of sort MInt.

We do this by:

1. Changing the z3 inferencer to annotate parametric productions
involving parametric sorts with casts.
2. Fixing a bug in the sort graph provided to the disambiguation
pipeline
3. Fixing a bug in AddSortInjections.lub

---------

Co-authored-by: rv-jenkins <[email protected]>
  • Loading branch information
Dwight Guth and rv-jenkins authored Aug 8, 2024
1 parent 4da0743 commit 7e2efdd
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 14 deletions.
1 change: 1 addition & 0 deletions k-distribution/include/kframework/builtin/domains.md
Original file line number Diff line number Diff line change
Expand Up @@ -3015,6 +3015,7 @@ than the input.

```k
syntax {Width1, Width2} MInt{Width1} ::= roundMInt(MInt{Width2}) [function, total, hook(MINT.round)]
syntax {Width1, Width2} MInt{Width1} ::= signExtendMInt(MInt{Width2}) [function, total, hook(MINT.sext)]
```

```k
Expand Down
7 changes: 7 additions & 0 deletions k-distribution/tests/regression-new/mint-llvm-3/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
DEF=test
EXT=test
TESTDIR=.
KOMPILE_BACKEND=llvm
KOMPILE_FLAGS=--syntax-module TEST

include ../../../include/kframework/ktest.mak
19 changes: 19 additions & 0 deletions k-distribution/tests/regression-new/mint-llvm-3/test.k
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module TEST
imports BOOL
imports MINT

syntax MInt{64}
syntax MInt{32}

syntax KItem ::= foo(MInt{64})

syntax MInt{64} ::= m64() [function]
rule m64() => 0p64
syntax MInt{32} ::= m32() [function]
rule m32() => 0p32

rule foo(X) => .K
requires (X +MInt m64()) <=uMInt (roundMInt(m32()) <<MInt 0p64)

rule true => foo(0p64)
endmodule
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,9 @@ private static Sort lub(
s ->
mod.subsorts().lessThanEq(s, Sorts.KBott())
|| mod.subsorts().greaterThan(s, Sorts.K()));
if (expectedSort != null && !expectedSort.name().equals(SORTPARAM_NAME)) {
if (expectedSort != null
&& expectedSort.head().params() == 0
&& !expectedSort.name().equals(SORTPARAM_NAME)) {
bounds.removeIf(s -> !mod.subsorts().lessThanEq(s, expectedSort));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -429,15 +429,6 @@ public static Tuple3<Module, Module, Module> getCombinedGrammarImpl(
stream(mod.allSorts())
.filter(s -> (!isParserSort(s) || s.equals(Sorts.KItem()) || s.equals(Sorts.K())))
.toList();
for (SortHead sh : mutable(mod.definedInstantiations()).keySet()) {
for (Sort s : mutable(mod.definedInstantiations().apply(sh))) {
// syntax MInt{K} ::= MInt{6}
Production p1 =
Production(
Option.empty(), Seq(), Sort(s.name(), Sorts.K()), Seq(NonTerminal(s)), Att.empty());
prods.add(p1);
}
}
for (Production p : iterable(mod.productions())) {
if (p.params().nonEmpty()) {
if (p.params().contains(p.sort())) { // case 1
Expand Down Expand Up @@ -634,6 +625,17 @@ public static Tuple3<Module, Module, Module> getCombinedGrammarImpl(
}

disambProds = new HashSet<>(parseProds);

for (SortHead sh : mutable(mod.definedInstantiations()).keySet()) {
for (Sort s : mutable(mod.definedInstantiations().apply(sh))) {
// syntax MInt{K} ::= MInt{6}
Production p1 =
Production(
Option.empty(), Seq(), Sort(s.name(), Sorts.K()), Seq(NonTerminal(s)), Att.empty());
parseProds.add(p1);
}
}

if (mod.importedModuleNames().contains(PROGRAM_LISTS)) {
Set<Sentence> prods3 = new HashSet<>();
// if no start symbol has been defined in the configuration, then use K
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ public Either<Set<KEMException>, Term> apply(Term term) {
return typeError(pr, expectedSort, inferred);
}
// well typed, so add a cast and return
return wrapTermWithCast((Constant) pr, inferred);
return wrapTermWithCast(pr, inferred);
}
// compute the instantiated production with its sort parameters
Production substituted = pr.production().substitute(inferencer.getArgs(pr));
Expand Down Expand Up @@ -275,12 +275,26 @@ public Either<Set<KEMException>, Term> apply(Term term) {
j++;
}
}
if (pr.production().params().nonEmpty() && hasParametricSort(pr.production())) {
return wrapTermWithCast(tc, substituted.sort());
}
return Right.apply(tc);
}
return Right.apply(pr);
}

private Either<Set<KEMException>, Term> wrapTermWithCast(Constant c, Sort declared) {
private boolean hasParametricSort(Production prod) {
if (prod.sort().head().params() != 0) {
return true;
}
if (stream(prod.nonterminals()).anyMatch(nt -> nt.sort().head().params() != 0)) {
return true;
}
return false;
}

private Either<Set<KEMException>, Term> wrapTermWithCast(
ProductionReference pr, Sort declared) {
if (castContext != CastContext.SEMANTIC) {
// There isn't an existing :Sort, so add one
Production cast =
Expand All @@ -289,9 +303,10 @@ private Either<Set<KEMException>, Term> wrapTermWithCast(Constant c, Sort declar
.productionsFor()
.apply(KLabel("#SemanticCastTo" + declared.toString()))
.head();
return Right.apply(TermCons.apply(ConsPStack.singleton(c), cast, c.location(), c.source()));
return Right.apply(
TermCons.apply(ConsPStack.singleton(pr), cast, pr.location(), pr.source()));
}
return Right.apply(c);
return Right.apply(pr);
}
}
}

0 comments on commit 7e2efdd

Please sign in to comment.