Skip to content

Commit

Permalink
No longer translate literal patterns to == applications
Browse files Browse the repository at this point in the history
  • Loading branch information
chengluyu committed Mar 15, 2023
1 parent aab8d27 commit 2263d9f
Show file tree
Hide file tree
Showing 12 changed files with 245 additions and 182 deletions.
45 changes: 27 additions & 18 deletions shared/src/main/scala/mlscript/ucs/Clause.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,46 @@ abstract class Clause {
* @return
*/
val locations: Ls[Loc]

protected final def bindingsToString: String =
(if (bindings.isEmpty) "" else " with " + Clause.showBindings(bindings))
}

object Clause {
final case class MatchLiteral(
scrutinee: Scrutinee,
literal: SimpleTerm
)(override val locations: Ls[Loc]) extends Clause {
override def toString(): String = s"«$scrutinee is $literal" + bindingsToString
}

final case class MatchClass(
scrutinee: Scrutinee,
className: Var,
fields: Ls[Str -> Var]
)(override val locations: Ls[Loc]) extends Clause
)(override val locations: Ls[Loc]) extends Clause {
override def toString(): String = s"«$scrutinee is $className»" + bindingsToString
}

final case class MatchTuple(
scrutinee: Scrutinee,
arity: Int,
fields: Ls[Str -> Var]
)(override val locations: Ls[Loc]) extends Clause
)(override val locations: Ls[Loc]) extends Clause {
override def toString(): String = s"«$scrutinee is Tuple#$arity»" + bindingsToString
}

final case class BooleanTest(test: Term)(override val locations: Ls[Loc]) extends Clause
final case class BooleanTest(test: Term)(
override val locations: Ls[Loc]
) extends Clause {
override def toString(): String = s"«$test»" + bindingsToString
}

final case class Binding(name: Var, term: Term)(override val locations: Ls[Loc]) extends Clause
final case class Binding(name: Var, term: Term)(
override val locations: Ls[Loc]
) extends Clause {
override def toString(): String = s"«$name = $term»" + bindingsToString
}

def showBindings(bindings: Ls[(Bool, Var, Term)]): Str =
bindings match {
Expand All @@ -48,20 +70,7 @@ object Clause {
}.mkString("(", ", ", ")")
}


def showClauses(clauses: Iterable[Clause]): Str = {
clauses.iterator.map { clause =>
(clause match {
case Clause.BooleanTest(test) => s"«$test»"
case Clause.MatchClass(scrutinee, Var(className), fields) =>
s"«$scrutinee is $className»"
case Clause.MatchTuple(scrutinee, arity, fields) =>
s"«$scrutinee is Tuple#$arity»"
case Clause.Binding(Var(name), term) =>
s"«$name = $term»"
}) + (if (clause.bindings.isEmpty) "" else " with " + showBindings(clause.bindings))
}.mkString("", " and ", "")
}
def showClauses(clauses: Iterable[Clause]): Str = clauses.mkString("", " and ", "")

def print(println: (=> Any) => Unit, conjunctions: Iterable[Conjunction -> Term]): Unit = {
println("Flattened conjunctions")
Expand Down
14 changes: 11 additions & 3 deletions shared/src/main/scala/mlscript/ucs/Conjunction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mlscript.ucs
import mlscript._, utils._, shorthands._
import Clause._, helpers._
import scala.collection.mutable.Buffer
import scala.annotation.tailrec

/**
* A `Conjunction` represents a list of `Clause`s.
Expand Down Expand Up @@ -53,13 +54,20 @@ final case class Conjunction(clauses: Ls[Clause], trailingBindings: Ls[(Bool, Va
def +(lastBinding: (Bool, Var, Term)): Conjunction =
Conjunction(clauses, trailingBindings :+ lastBinding)

def separate(expectedScrutinee: Scrutinee): Opt[(MatchClass, Conjunction)] = {
def rec(past: Ls[Clause], upcoming: Ls[Clause]): Opt[(Ls[Clause], MatchClass, Ls[Clause])] = {
def separate(expectedScrutinee: Scrutinee): Opt[(MatchClass \/ MatchLiteral, Conjunction)] = {
@tailrec
def rec(past: Ls[Clause], upcoming: Ls[Clause]): Opt[(Ls[Clause], MatchClass \/ MatchLiteral, Ls[Clause])] = {
upcoming match {
case Nil => N
case (head @ MatchLiteral(scrutinee, _)) :: tail =>
if (scrutinee === expectedScrutinee) {
S((past, R(head), tail))
} else {
rec(past :+ head, tail)
}
case (head @ MatchClass(scrutinee, _, _)) :: tail =>
if (scrutinee === expectedScrutinee) {
S((past, head, tail))
S((past, L(head), tail))
} else {
rec(past :+ head, tail)
}
Expand Down
145 changes: 93 additions & 52 deletions shared/src/main/scala/mlscript/ucs/Desugarer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,13 @@ class Desugarer extends TypeDefs { self: Typer =>
def makeScrutinee(term: Term, matchRootLoc: Opt[Loc])(implicit ctx: Ctx): Scrutinee =
traceUCS(s"Making a scrutinee for `$term`") {
term match {
case _: SimpleTerm => Scrutinee(N, term)(matchRootLoc)
case _ => Scrutinee(S(makeLocalizedName(term)), term)(matchRootLoc)
case _: Var =>
printlnUCS(s"The scrutinee does not need an alias.")
Scrutinee(N, term)(matchRootLoc)
case _ =>
val localizedName = makeLocalizedName(term)
printlnUCS(s"The scrutinee needs an alias: $localizedName")
Scrutinee(S(localizedName), term)(matchRootLoc)
}
}()

Expand Down Expand Up @@ -160,9 +165,13 @@ class Desugarer extends TypeDefs { self: Typer =>
case Var("_") => Nil
// This case handles literals.
// x is true | x is false | x is 0 | x is "text" | ...
case literal @ (Var("true") | Var("false") | _: Lit) =>
val test = mkBinOp(scrutinee.reference, Var("=="), literal)
val clause = Clause.BooleanTest(test)(scrutinee.term.toLoc.toList ::: literal.toLoc.toList)
case literal: Var if literal.name === "true" || literal.name === "false" =>
val clause = Clause.MatchLiteral(scrutinee, literal)(scrutinee.term.toLoc.toList ::: literal.toLoc.toList)
clause.bindings = scrutinee.asBinding.toList
printlnUCS(s"Add bindings to the clause: ${scrutinee.asBinding}")
clause :: Nil
case literal: Lit =>
val clause = Clause.MatchLiteral(scrutinee, literal)(scrutinee.term.toLoc.toList ::: literal.toLoc.toList)
clause.bindings = scrutinee.asBinding.toList
printlnUCS(s"Add bindings to the clause: ${scrutinee.asBinding}")
clause :: Nil
Expand Down Expand Up @@ -515,22 +524,34 @@ class Desugarer extends TypeDefs { self: Typer =>
*/
type ExhaustivenessMap = Map[Str \/ Int, Map[Var, MutCase]]

def getScurtineeKey(scrutinee: Scrutinee)(implicit ctx: Ctx, raise: Raise): Str \/ Int = {
scrutinee.term match {
// The original scrutinee is an reference.
case v @ Var(name) =>
ctx.env.get(name) match {
case S(VarSymbol(_, defVar)) => defVar.uid.fold[Str \/ Int](L(v.name))(R(_))
case S(_) | N => L(v.name)
}
// Otherwise, the scrutinee has a temporary name.
case _ =>
scrutinee.local match {
case N => throw new Error("check your `makeScrutinee`")
case S(localNameVar) => L(localNameVar.name)
}
}
}
/**
* This method obtains a proper key of the given scrutinee
* for memorizing patterns belongs to the scrutinee.
*
* @param scrutinee the scrutinee
* @param ctx the context
* @param raise we need this to raise errors.
* @return the variable name or the variable ID
*/
def getScurtineeKey(scrutinee: Scrutinee)(implicit ctx: Ctx, raise: Raise): Str \/ Int =
traceUCS(s"[getScrutineeKey] $scrutinee") {
scrutinee.term match {
// The original scrutinee is an reference.
case v @ Var(name) =>
printlnUCS("The original scrutinee is an reference.")
ctx.env.get(name) match {
case S(VarSymbol(_, defVar)) => defVar.uid.fold[Str \/ Int](L(v.name))(R(_))
case S(_) | N => L(v.name)
}
// Otherwise, the scrutinee was localized because it might be effectful.
case _ =>
printlnUCS("The scrutinee was localized because it might be effectful.")
scrutinee.local match {
case N => throw new Error("check your `makeScrutinee`")
case S(localNameVar) => L(localNameVar.name)
}
}
}()

/**
* Check the exhaustiveness of the given `MutCaseOf`.
Expand All @@ -542,10 +563,8 @@ class Desugarer extends TypeDefs { self: Typer =>
def checkExhaustive
(t: MutCaseOf, parentOpt: Opt[MutCaseOf])
(implicit scrutineePatternMap: ExhaustivenessMap, ctx: Ctx, raise: Raise)
: Unit = {
printlnUCS(s"Check exhaustiveness of ${t.describe}")
indent += 1
try t match {
: Unit = traceUCS(s"[checkExhaustive] ${t.describe}") {
t match {
case _: Consequent => ()
case MissingCase =>
parentOpt match {
Expand All @@ -567,18 +586,26 @@ class Desugarer extends TypeDefs { self: Typer =>
case S(_) if default.isDefined =>
printlnUCS("The match has a default branch. So, it is always safe.")
case S(patternMap) =>
printlnUCS(s"The exhaustiveness map is ${scrutineePatternMap}")
printlnUCS(s"The exhaustiveness map is")
scrutineePatternMap.foreach { case (key, matches) =>
printlnUCS(s"- $key -> ${matches.keysIterator.mkString(", ")}")
}
printlnUCS(s"The scrutinee key is ${getScurtineeKey(scrutinee)}")
printlnUCS("Pattern map of the scrutinee:")
if (patternMap.isEmpty)
printlnUCS("<Empty>")
else
patternMap.foreach { case (key, mutCase) => printlnUCS(s"- $key => $mutCase")}
// Filter out missing cases in `branches`.
val missingCases = patternMap.removedAll(branches.iterator.map {
case MutCase(classNameVar -> _, _) => classNameVar
val missingCases = patternMap.removedAll(branches.iterator.flatMap {
case MutCase.Literal(tof @ Var(n), _) if n === "true" || n === "false" => Some(tof)
case MutCase.Literal(_, _) => None
case MutCase.Constructor(classNameVar -> _, _) => Some(classNameVar)
})
printlnUCS(s"Number of missing cases: ${missingCases.size}")
printlnUCS("Missing cases")
missingCases.foreach { case (key, m) =>
printlnUCS(s"- $key -> ${m}")
}
if (!missingCases.isEmpty) {
throw new DesugaringException({
val numMissingCases = missingCases.size
Expand All @@ -597,53 +624,67 @@ class Desugarer extends TypeDefs { self: Typer =>
}
}
default.foreach(checkExhaustive(_, S(t)))
branches.foreach { case MutCase(_, consequent) =>
checkExhaustive(consequent, S(t))
branches.foreach { branch =>
checkExhaustive(branch.consequent, S(t))
}
} finally indent -= 1
}
}
}()

def summarizePatterns(t: MutCaseOf)(implicit ctx: Ctx, raise: Raise): ExhaustivenessMap = {
def summarizePatterns(t: MutCaseOf)(implicit ctx: Ctx, raise: Raise): ExhaustivenessMap = traceUCS("[summarizePatterns]") {
val m = MutMap.empty[Str \/ Int, MutMap[Var, MutCase]]
def rec(t: MutCaseOf): Unit = {
printlnUCS(s"Summarize pattern of ${t.describe}")
indent += 1
try t match {
def rec(t: MutCaseOf): Unit = traceUCS(s"[rec] ${t.describe}") {
t match {
case Consequent(term) => ()
case MissingCase => ()
case IfThenElse(_, whenTrue, whenFalse) =>
rec(whenTrue)
rec(whenFalse)
case Match(scrutinee, branches, default) =>
val key = getScurtineeKey(scrutinee)
branches.foreach { mutCase =>
val patternMap = m.getOrElseUpdate( key, MutMap.empty)
if (!patternMap.contains(mutCase.patternFields._1)) {
patternMap += ((mutCase.patternFields._1, mutCase))
}
rec(mutCase.consequent)
val patternMap = m.getOrElseUpdate(key, MutMap.empty)
branches.foreach {
case mutCase @ MutCase.Literal(literal, consequent) =>
literal match {
case tof @ Var(n) if n === "true" || n === "false" =>
if (!patternMap.contains(tof)) {
patternMap += ((tof, mutCase))
}
case _ => () // TODO: Summarize literals.
}
rec(consequent)
case mutCase @ MutCase.Constructor((className, _), consequent) =>
if (!patternMap.contains(className)) {
patternMap += ((className, mutCase))
}
rec(consequent)
}
default.foreach(rec)
} finally indent -= 1
}
}
}()
rec(t)
printlnUCS("Exhaustiveness map")
m.foreach { case (scrutinee, patterns) =>
printlnUCS(s"- $scrutinee => " + patterns.keys.mkString(", "))
}
printlnUCS("Summarized patterns")
if (m.isEmpty)
printlnUCS("<Empty>")
else
m.foreach { case (scrutinee, patterns) =>
printlnUCS(s"- $scrutinee => " + patterns.keysIterator.mkString(", "))
}
Map.from(m.iterator.map { case (key, patternMap) => key -> Map.from(patternMap) })
}
}()

protected def constructTerm(m: MutCaseOf)(implicit ctx: Ctx): Term = {
def rec(m: MutCaseOf)(implicit defs: Set[Var]): Term = m match {
case Consequent(term) => term
case Match(scrutinee, branches, wildcard) =>
def rec2(xs: Ls[MutCase]): CaseBranches =
xs match {
case MutCase(className -> fields, cases) :: next =>
case MutCase.Constructor(className -> fields, cases) :: next =>
// TODO: expand bindings here
val consequent = rec(cases)(defs ++ fields.iterator.map(_._2))
Case(className, mkLetFromFields(scrutinee, fields.toList, consequent), rec2(next))
case MutCase.Literal(literal, cases) :: next =>
val consequent = rec(cases)
Case(literal, consequent, rec2(next))
case Nil =>
wildcard.fold[CaseBranches](NoCases) { rec(_) |> Wildcard }
}
Expand Down
Loading

0 comments on commit 2263d9f

Please sign in to comment.