Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small fixes to tests and congruence #223

Merged
merged 6 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Change List

## 2024-06-25
Repair two broken tests, small improvements to
SimonGuilloud marked this conversation as resolved.
Show resolved Hide resolved

## 2024-04-12
Addition of the Congruence tactic, solving sequents by congruence closure using egraphs.

Expand Down
94 changes: 35 additions & 59 deletions lisa-sets/src/main/scala/lisa/automation/Congruence.scala
Original file line number Diff line number Diff line change
Expand Up @@ -240,25 +240,16 @@ import scala.collection.mutable

class EGraphTerms() {

type ENode = Term | Formula



val termMap = mutable.Map[Term, Set[Term]]()
val termParents = mutable.Map[Term, mutable.Set[AppliedFunctional | AppliedPredicate]]()
var termWorklist = List[Term]()
val termUF = new UnionFind[Term]()




val formulaMap = mutable.Map[Formula, Set[Formula]]()
val formulaParents = mutable.Map[Formula, mutable.Set[AppliedConnector]]()
var formulaWorklist = List[Formula]()
val formulaUF = new UnionFind[Formula]()



def find(id: Term): Term = termUF.find(id)
def find(id: Formula): Formula = formulaUF.find(id)

trait TermStep
case class TermExternal(between: (Term, Term)) extends TermStep
Expand Down Expand Up @@ -316,62 +307,57 @@ class EGraphTerms() {

def makeSingletonEClass(node:Term): Term = {
termUF.add(node)
termMap(node) = Set(node)
termParents(node) = mutable.Set()
node
}
def makeSingletonEClass(node:Formula): Formula = {
formulaUF.add(node)
formulaMap(node) = Set(node)
formulaParents(node) = mutable.Set()
node
}

def classOf(id: Term): Set[Term] = termMap(id)
def classOf(id: Formula): Set[Formula] = formulaMap(id)

def idEq(id1: Term, id2: Term): Boolean = termUF.find(id1) == termUF.find(id2)
def idEq(id1: Formula, id2: Formula): Boolean = formulaUF.find(id1) == formulaUF.find(id2)
def idEq(id1: Term, id2: Term): Boolean = find(id1) == find(id2)
def idEq(id1: Formula, id2: Formula): Boolean = find(id1) == find(id2)

def canonicalize(node: Term): Term = node match
case AppliedFunctional(label, args) =>
AppliedFunctional(label, args.map(termUF.find.asInstanceOf))
AppliedFunctional(label, args.map(t => find(t)))
case _ => node


def canonicalize(node: Formula): Formula = {
node match
case AppliedPredicate(label, args) => AppliedPredicate(label, args.map(termUF.find))
case AppliedConnector(label, args) => AppliedConnector(label, args.map(formulaUF.find))
case AppliedPredicate(label, args) => AppliedPredicate(label, args.map(find))
case AppliedConnector(label, args) => AppliedConnector(label, args.map(find))
case node => node
}

def add(node: Term): Term =
if termMap.contains(node) then return node
if termUF.parent.contains(node) then return node
makeSingletonEClass(node)
node match
case node @ AppliedFunctional(_, args) =>
args.foreach(child =>
add(child)
termParents(child).add(node)
termParents(find(child)).add(node)
)
node
case _ => node

def add(node: Formula): Formula =
if formulaMap.contains(node) then return node
if formulaUF.parent.contains(node) then return node
makeSingletonEClass(node)
node match
case node @ AppliedPredicate(_, args) =>
args.foreach(child =>
add(child)
termParents(child).add(node)
termParents(find(child)).add(node)
)
node
case node @ AppliedConnector(_, args) =>
args.foreach(child =>
add(child)
formulaParents(child).add(node)
formulaParents(find(child)).add(node)
)
node
case _ => node
Expand All @@ -393,74 +379,64 @@ class EGraphTerms() {
}

protected def mergeWithStep(id1: Term, id2: Term, step: TermStep): Unit = {
if termUF.find(id1) == termUF.find(id2) then ()
if find(id1) == find(id2) then ()
else
termProofMap((id1, id2)) = step
val newSet = termMap(termUF.find(id1)) ++ termMap(termUF.find(id2))
val newparents = termParents(termUF.find(id1)) ++ termParents(termUF.find(id2))
val newparents = termParents(find(id1)) ++ termParents(find(id2))
termUF.union(id1, id2)
val newId1 = termUF.find(id1)
val newId2 = termUF.find(id2)
termMap(newId1) = newSet
termMap(newId2) = newSet
termParents(newId1) = newparents
termParents(newId2) = newparents
val newId = find(id1)

val id = termUF.find(id2)
termWorklist = id :: termWorklist
val cause = (id1, id2)
val termSeen = mutable.Map[Term, AppliedFunctional]()
val formulaSeen = mutable.Map[Formula, AppliedPredicate]()
var formWorklist = List[(Formula, Formula, FormulaStep)]()
var termWorklist = List[(Term, Term, TermStep)]()

newparents.foreach {
case pTerm: AppliedFunctional =>
val canonicalPTerm = canonicalize(pTerm)
if termSeen.contains(canonicalPTerm) then
val qTerm = termSeen(canonicalPTerm)
Some((pTerm, qTerm, cause))
mergeWithStep(pTerm, qTerm, TermCongruence((pTerm, qTerm)))
termWorklist = (pTerm, qTerm, TermCongruence((pTerm, qTerm))) :: termWorklist
//mergeWithStep(pTerm, qTerm, TermCongruence((pTerm, qTerm)))
else
termSeen(canonicalPTerm) = pTerm
case pFormula: AppliedPredicate =>
val canonicalPFormula = canonicalize(pFormula)
if formulaSeen.contains(canonicalPFormula) then
val qFormula = formulaSeen(canonicalPFormula)

Some((pFormula, qFormula, cause))
mergeWithStep(pFormula, qFormula, FormulaCongruence((pFormula, qFormula)))
formWorklist = (pFormula, qFormula, FormulaCongruence((pFormula, qFormula))) :: formWorklist
//mergeWithStep(pFormula, qFormula, FormulaCongruence((pFormula, qFormula)))
else
formulaSeen(canonicalPFormula) = pFormula
}
termParents(id) = (termSeen.values.to(mutable.Set): mutable.Set[AppliedFunctional | AppliedPredicate]) ++ formulaSeen.values.to(mutable.Set)
termParents(newId) = (termSeen.values.to(mutable.Set): mutable.Set[AppliedFunctional | AppliedPredicate]) ++ formulaSeen.values.to(mutable.Set)
formWorklist.foreach { case (l, r, step) => mergeWithStep(l, r, step) }
termWorklist.foreach { case (l, r, step) => mergeWithStep(l, r, step) }
}

protected def mergeWithStep(id1: Formula, id2: Formula, step: FormulaStep): Unit =
if formulaUF.find(id1) == formulaUF.find(id2) then ()
if find(id1) == find(id2) then ()
else
formulaProofMap((id1, id2)) = step
val newSet = formulaMap(formulaUF.find(id1)) ++ formulaMap(formulaUF.find(id2))
val newparents = formulaParents(formulaUF.find(id1)) ++ formulaParents(formulaUF.find(id2))
val newparents = formulaParents(find(id1)) ++ formulaParents(find(id2))
formulaUF.union(id1, id2)
val newId1 = formulaUF.find(id1)
val newId2 = formulaUF.find(id2)
formulaMap(newId1) = newSet
formulaMap(newId2) = newSet
formulaParents(newId1) = newparents
formulaParents(newId2) = newparents
val id = formulaUF.find(id2)
formulaWorklist = id :: formulaWorklist
val cause = (id1, id2)
val newId = find(id1)

val formulaSeen = mutable.Map[Formula, AppliedConnector]()
var formWorklist = List[(Formula, Formula, FormulaStep)]()

newparents.foreach {
case pFormula: AppliedConnector =>
val canonicalPFormula = canonicalize(pFormula)
if formulaSeen.contains(canonicalPFormula) then
val qFormula = formulaSeen(canonicalPFormula)
Some((pFormula, qFormula, cause))
mergeWithStep(pFormula, qFormula, FormulaCongruence((pFormula, qFormula)))
formWorklist = (pFormula, qFormula, FormulaCongruence((pFormula, qFormula))) :: formWorklist
//mergeWithStep(pFormula, qFormula, FormulaCongruence((pFormula, qFormula)))
else
formulaSeen(canonicalPFormula) = pFormula
}
formulaParents(id) = formulaSeen.values.to(mutable.Set)
formulaParents(newId) = formulaSeen.values.to(mutable.Set)
formWorklist.foreach { case (l, r, step) => mergeWithStep(l, r, step) }


def proveTerm(using lib: Library, proof: lib.Proof)(id1: Term, id2:Term, base: Sequent): proof.ProofTacticJudgement =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ object Comprehensions {
}

def filter(using _proof: Proof, name: sourcecode.Name)(filter: (Term ** 1) |-> Formula): Comprehension { val proof: _proof.type } = {
if (_proof.lockedSymbols ++ _proof.possibleGoal.toSet.flatMap(_.allSchematicLabels)).map(_.id.name).contains(name.value) then throw new Exception(s"Name $name is already used in the proof")
if (_proof.lockedSymbols ++ _proof.possibleGoal.toSet.flatMap(_.freeSchematicLabels)).map(_.id.name).contains(name.value) then throw new Exception(s"Name $name is already used in the proof")
val id = name.value
inline def _filter = filter
inline def _t = t
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object SetTheory2 extends lisa.Main {
thenHave(in(x, A) |- ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z))) by Weakening
thenHave(in(x, A) |- ∀(z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z)))) by RightForall
thenHave(in(x, A) |- ∀(y, ∀(z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z))))) by RightForall
//thenHave(in(x, A) ==> ∀(y, ∀(z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z))))) by Restate
thenHave(in(x, A) ==> ∀(y, ∀(z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z))))) by Restate
thenHave(∀(x, in(x, A) ==> ∀(y, ∀(z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z)))))) by RightForall
thenHave(thesis) by Restate

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import org.scalatest.funsuite.AnyFunSuite

class CongruenceTest extends AnyFunSuite with lisa.TestMain {


given lib: lisa.SetTheoryLibrary.type = lisa.SetTheoryLibrary

val a = variable
Expand Down Expand Up @@ -254,8 +255,6 @@ class CongruenceTest extends AnyFunSuite with lisa.TestMain {
assert(egraph.idEq(fx, x))
assert(egraph.idEq(x, fx))

assert(egraph.explain(fx, x) == Some(List(egraph.TermCongruence(fx, fffx), egraph.TermCongruence(fffx, ffffffffx), egraph.TermExternal(ffffffffx, x))))

}


Expand Down
Loading