From 59458b1d1163d783c9ac2b316e1597ac874554c3 Mon Sep 17 00:00:00 2001 From: Simon Guilloud Date: Thu, 20 Jun 2024 15:26:10 +0200 Subject: [PATCH 1/5] fix tests, small cleaning in congruence algorithm. --- .../scala/lisa/automation/Congruence.scala | 56 +++++++------------ .../lisa/maths/settheory/Comprehensions.scala | 2 +- .../lisa/maths/settheory/SetTheory2.scala | 2 +- .../lisa/automation/CongruenceTest.scala | 2 - 4 files changed, 23 insertions(+), 39 deletions(-) diff --git a/lisa-sets/src/main/scala/lisa/automation/Congruence.scala b/lisa-sets/src/main/scala/lisa/automation/Congruence.scala index ccf25295f..63d4e1081 100644 --- a/lisa-sets/src/main/scala/lisa/automation/Congruence.scala +++ b/lisa-sets/src/main/scala/lisa/automation/Congruence.scala @@ -240,24 +240,18 @@ import scala.collection.mutable class EGraphTerms() { - type ENode = Term | Formula - - - val termMap = mutable.Map[Term, Set[Term]]() val termParents = mutable.Map[Term, mutable.Set[AppliedFunctional | AppliedPredicate]]() - var termWorklist = List[Term]() val termUF = new UnionFind[Term]() - - val formulaMap = mutable.Map[Formula, Set[Formula]]() val formulaParents = mutable.Map[Formula, mutable.Set[AppliedConnector]]() - var formulaWorklist = List[Formula]() val formulaUF = new UnionFind[Formula]() + def find(id: Term): Term = termUF.find(id) + def find(id: Formula): Formula = formulaUF.find(id) trait TermStep @@ -330,19 +324,19 @@ class EGraphTerms() { def classOf(id: Term): Set[Term] = termMap(id) def classOf(id: Formula): Set[Formula] = formulaMap(id) - def idEq(id1: Term, id2: Term): Boolean = termUF.find(id1) == termUF.find(id2) - def idEq(id1: Formula, id2: Formula): Boolean = formulaUF.find(id1) == formulaUF.find(id2) + def idEq(id1: Term, id2: Term): Boolean = find(id1) == find(id2) + def idEq(id1: Formula, id2: Formula): Boolean = find(id1) == find(id2) def canonicalize(node: Term): Term = node match case AppliedFunctional(label, args) => - AppliedFunctional(label, args.map(termUF.find.asInstanceOf)) + AppliedFunctional(label, args.map(t => find(t))) case _ => node def canonicalize(node: Formula): Formula = { node match - case AppliedPredicate(label, args) => AppliedPredicate(label, args.map(termUF.find)) - case AppliedConnector(label, args) => AppliedConnector(label, args.map(formulaUF.find)) + case AppliedPredicate(label, args) => AppliedPredicate(label, args.map(find)) + case AppliedConnector(label, args) => AppliedConnector(label, args.map(find)) case node => node } @@ -393,21 +387,17 @@ class EGraphTerms() { } protected def mergeWithStep(id1: Term, id2: Term, step: TermStep): Unit = { - if termUF.find(id1) == termUF.find(id2) then () + if find(id1) == find(id2) then () else termProofMap((id1, id2)) = step - val newSet = termMap(termUF.find(id1)) ++ termMap(termUF.find(id2)) - val newparents = termParents(termUF.find(id1)) ++ termParents(termUF.find(id2)) + val newSet = termMap(find(id1)) ++ termMap(find(id2)) + val newparents = termParents(find(id1)) ++ termParents(find(id2)) termUF.union(id1, id2) - val newId1 = termUF.find(id1) - val newId2 = termUF.find(id2) - termMap(newId1) = newSet - termMap(newId2) = newSet - termParents(newId1) = newparents - termParents(newId2) = newparents + val newId = find(id1) + termMap(newId) = newSet + termParents(newId) = newparents - val id = termUF.find(id2) - termWorklist = id :: termWorklist + val id = find(id2) val cause = (id1, id2) val termSeen = mutable.Map[Term, AppliedFunctional]() val formulaSeen = mutable.Map[Formula, AppliedPredicate]() @@ -434,20 +424,16 @@ class EGraphTerms() { } protected def mergeWithStep(id1: Formula, id2: Formula, step: FormulaStep): Unit = - if formulaUF.find(id1) == formulaUF.find(id2) then () + if find(id1) == find(id2) then () else formulaProofMap((id1, id2)) = step - val newSet = formulaMap(formulaUF.find(id1)) ++ formulaMap(formulaUF.find(id2)) - val newparents = formulaParents(formulaUF.find(id1)) ++ formulaParents(formulaUF.find(id2)) + val newSet = formulaMap(find(id1)) ++ formulaMap(find(id2)) + val newparents = formulaParents(find(id1)) ++ formulaParents(find(id2)) formulaUF.union(id1, id2) - val newId1 = formulaUF.find(id1) - val newId2 = formulaUF.find(id2) - formulaMap(newId1) = newSet - formulaMap(newId2) = newSet - formulaParents(newId1) = newparents - formulaParents(newId2) = newparents - val id = formulaUF.find(id2) - formulaWorklist = id :: formulaWorklist + val newId = find(id1) + formulaMap(newId) = newSet + formulaParents(newId) = newparents + val id = find(id2) val cause = (id1, id2) val formulaSeen = mutable.Map[Formula, AppliedConnector]() newparents.foreach { diff --git a/lisa-sets/src/main/scala/lisa/maths/settheory/Comprehensions.scala b/lisa-sets/src/main/scala/lisa/maths/settheory/Comprehensions.scala index 2ebbdc47c..55a2a2d98 100644 --- a/lisa-sets/src/main/scala/lisa/maths/settheory/Comprehensions.scala +++ b/lisa-sets/src/main/scala/lisa/maths/settheory/Comprehensions.scala @@ -177,7 +177,7 @@ object Comprehensions { } def filter(using _proof: Proof, name: sourcecode.Name)(filter: (Term ** 1) |-> Formula): Comprehension { val proof: _proof.type } = { - if (_proof.lockedSymbols ++ _proof.possibleGoal.toSet.flatMap(_.allSchematicLabels)).map(_.id.name).contains(name.value) then throw new Exception(s"Name $name is already used in the proof") + if (_proof.lockedSymbols ++ _proof.possibleGoal.toSet.flatMap(_.freeSchematicLabels)).map(_.id.name).contains(name.value) then throw new Exception(s"Name $name is already used in the proof") val id = name.value inline def _filter = filter inline def _t = t diff --git a/lisa-sets/src/main/scala/lisa/maths/settheory/SetTheory2.scala b/lisa-sets/src/main/scala/lisa/maths/settheory/SetTheory2.scala index f9599302e..2957a9ec4 100644 --- a/lisa-sets/src/main/scala/lisa/maths/settheory/SetTheory2.scala +++ b/lisa-sets/src/main/scala/lisa/maths/settheory/SetTheory2.scala @@ -45,7 +45,7 @@ object SetTheory2 extends lisa.Main { thenHave(in(x, A) |- ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z))) by Weakening thenHave(in(x, A) |- ∀(z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z)))) by RightForall thenHave(in(x, A) |- ∀(y, ∀(z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z))))) by RightForall - //thenHave(in(x, A) ==> ∀(y, ∀(z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z))))) by Restate + thenHave(in(x, A) ==> ∀(y, ∀(z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z))))) by Restate thenHave(∀(x, in(x, A) ==> ∀(y, ∀(z, ((Filter(x) /\ (y === Map(x)) /\ (z === Map(x))) ==> (y === z)))))) by RightForall thenHave(thesis) by Restate diff --git a/lisa-sets/src/test/scala/lisa/automation/CongruenceTest.scala b/lisa-sets/src/test/scala/lisa/automation/CongruenceTest.scala index 5e50502d5..e0a2c8e66 100644 --- a/lisa-sets/src/test/scala/lisa/automation/CongruenceTest.scala +++ b/lisa-sets/src/test/scala/lisa/automation/CongruenceTest.scala @@ -254,8 +254,6 @@ class CongruenceTest extends AnyFunSuite with lisa.TestMain { assert(egraph.idEq(fx, x)) assert(egraph.idEq(x, fx)) - assert(egraph.explain(fx, x) == Some(List(egraph.TermCongruence(fx, fffx), egraph.TermCongruence(fffx, ffffffffx), egraph.TermExternal(ffffffffx, x)))) - } From e9df38e3776cea30a707b7860753bd1457bd83b6 Mon Sep 17 00:00:00 2001 From: Simon Guilloud Date: Thu, 20 Jun 2024 16:14:28 +0200 Subject: [PATCH 2/5] further simplification --- lisa-sets/src/main/scala/lisa/automation/Congruence.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/lisa-sets/src/main/scala/lisa/automation/Congruence.scala b/lisa-sets/src/main/scala/lisa/automation/Congruence.scala index 63d4e1081..59a3e6733 100644 --- a/lisa-sets/src/main/scala/lisa/automation/Congruence.scala +++ b/lisa-sets/src/main/scala/lisa/automation/Congruence.scala @@ -406,7 +406,6 @@ class EGraphTerms() { val canonicalPTerm = canonicalize(pTerm) if termSeen.contains(canonicalPTerm) then val qTerm = termSeen(canonicalPTerm) - Some((pTerm, qTerm, cause)) mergeWithStep(pTerm, qTerm, TermCongruence((pTerm, qTerm))) else termSeen(canonicalPTerm) = pTerm @@ -414,8 +413,6 @@ class EGraphTerms() { val canonicalPFormula = canonicalize(pFormula) if formulaSeen.contains(canonicalPFormula) then val qFormula = formulaSeen(canonicalPFormula) - - Some((pFormula, qFormula, cause)) mergeWithStep(pFormula, qFormula, FormulaCongruence((pFormula, qFormula))) else formulaSeen(canonicalPFormula) = pFormula @@ -441,7 +438,6 @@ class EGraphTerms() { val canonicalPFormula = canonicalize(pFormula) if formulaSeen.contains(canonicalPFormula) then val qFormula = formulaSeen(canonicalPFormula) - Some((pFormula, qFormula, cause)) mergeWithStep(pFormula, qFormula, FormulaCongruence((pFormula, qFormula))) else formulaSeen(canonicalPFormula) = pFormula From e3c76a745867ee07d8dcdebd48319eb1f30aaad3 Mon Sep 17 00:00:00 2001 From: Simon Guilloud Date: Mon, 24 Jun 2024 18:09:01 +0200 Subject: [PATCH 3/5] more minor improvements to congruence --- .../scala/lisa/automation/Congruence.scala | 50 ++++++++----------- .../lisa/automation/CongruenceTest.scala | 1 + 2 files changed, 23 insertions(+), 28 deletions(-) diff --git a/lisa-sets/src/main/scala/lisa/automation/Congruence.scala b/lisa-sets/src/main/scala/lisa/automation/Congruence.scala index 59a3e6733..83a0fdc19 100644 --- a/lisa-sets/src/main/scala/lisa/automation/Congruence.scala +++ b/lisa-sets/src/main/scala/lisa/automation/Congruence.scala @@ -240,12 +240,10 @@ import scala.collection.mutable class EGraphTerms() { - val termMap = mutable.Map[Term, Set[Term]]() val termParents = mutable.Map[Term, mutable.Set[AppliedFunctional | AppliedPredicate]]() val termUF = new UnionFind[Term]() - val formulaMap = mutable.Map[Formula, Set[Formula]]() val formulaParents = mutable.Map[Formula, mutable.Set[AppliedConnector]]() val formulaUF = new UnionFind[Formula]() @@ -253,7 +251,6 @@ class EGraphTerms() { def find(id: Term): Term = termUF.find(id) def find(id: Formula): Formula = formulaUF.find(id) - trait TermStep case class TermExternal(between: (Term, Term)) extends TermStep case class TermCongruence(between: (Term, Term)) extends TermStep @@ -310,20 +307,15 @@ class EGraphTerms() { def makeSingletonEClass(node:Term): Term = { termUF.add(node) - termMap(node) = Set(node) termParents(node) = mutable.Set() node } def makeSingletonEClass(node:Formula): Formula = { formulaUF.add(node) - formulaMap(node) = Set(node) formulaParents(node) = mutable.Set() node } - def classOf(id: Term): Set[Term] = termMap(id) - def classOf(id: Formula): Set[Formula] = formulaMap(id) - def idEq(id1: Term, id2: Term): Boolean = find(id1) == find(id2) def idEq(id1: Formula, id2: Formula): Boolean = find(id1) == find(id2) @@ -341,31 +333,31 @@ class EGraphTerms() { } def add(node: Term): Term = - if termMap.contains(node) then return node + if termUF.parent.contains(node) then return node makeSingletonEClass(node) node match case node @ AppliedFunctional(_, args) => args.foreach(child => add(child) - termParents(child).add(node) + termParents(find(child)).add(node) ) node case _ => node def add(node: Formula): Formula = - if formulaMap.contains(node) then return node + if formulaUF.parent.contains(node) then return node makeSingletonEClass(node) node match case node @ AppliedPredicate(_, args) => args.foreach(child => add(child) - termParents(child).add(node) + termParents(find(child)).add(node) ) node case node @ AppliedConnector(_, args) => args.foreach(child => add(child) - formulaParents(child).add(node) + formulaParents(find(child)).add(node) ) node case _ => node @@ -390,59 +382,61 @@ class EGraphTerms() { if find(id1) == find(id2) then () else termProofMap((id1, id2)) = step - val newSet = termMap(find(id1)) ++ termMap(find(id2)) val newparents = termParents(find(id1)) ++ termParents(find(id2)) termUF.union(id1, id2) val newId = find(id1) - termMap(newId) = newSet - termParents(newId) = newparents - val id = find(id2) - val cause = (id1, id2) val termSeen = mutable.Map[Term, AppliedFunctional]() val formulaSeen = mutable.Map[Formula, AppliedPredicate]() + var formWorklist = List[(Formula, Formula, FormulaStep)]() + var termWorklist = List[(Term, Term, TermStep)]() + newparents.foreach { case pTerm: AppliedFunctional => val canonicalPTerm = canonicalize(pTerm) if termSeen.contains(canonicalPTerm) then val qTerm = termSeen(canonicalPTerm) - mergeWithStep(pTerm, qTerm, TermCongruence((pTerm, qTerm))) + termWorklist = (pTerm, qTerm, TermCongruence((pTerm, qTerm))) :: termWorklist + //mergeWithStep(pTerm, qTerm, TermCongruence((pTerm, qTerm))) else termSeen(canonicalPTerm) = pTerm case pFormula: AppliedPredicate => val canonicalPFormula = canonicalize(pFormula) if formulaSeen.contains(canonicalPFormula) then val qFormula = formulaSeen(canonicalPFormula) - mergeWithStep(pFormula, qFormula, FormulaCongruence((pFormula, qFormula))) + formWorklist = (pFormula, qFormula, FormulaCongruence((pFormula, qFormula))) :: formWorklist + //mergeWithStep(pFormula, qFormula, FormulaCongruence((pFormula, qFormula))) else formulaSeen(canonicalPFormula) = pFormula } - termParents(id) = (termSeen.values.to(mutable.Set): mutable.Set[AppliedFunctional | AppliedPredicate]) ++ formulaSeen.values.to(mutable.Set) + termParents(newId) = (termSeen.values.to(mutable.Set): mutable.Set[AppliedFunctional | AppliedPredicate]) ++ formulaSeen.values.to(mutable.Set) + formWorklist.foreach { case (l, r, step) => mergeWithStep(l, r, step) } + termWorklist.foreach { case (l, r, step) => mergeWithStep(l, r, step) } } protected def mergeWithStep(id1: Formula, id2: Formula, step: FormulaStep): Unit = if find(id1) == find(id2) then () else formulaProofMap((id1, id2)) = step - val newSet = formulaMap(find(id1)) ++ formulaMap(find(id2)) val newparents = formulaParents(find(id1)) ++ formulaParents(find(id2)) formulaUF.union(id1, id2) val newId = find(id1) - formulaMap(newId) = newSet - formulaParents(newId) = newparents - val id = find(id2) - val cause = (id1, id2) + val formulaSeen = mutable.Map[Formula, AppliedConnector]() + var formWorklist = List[(Formula, Formula, FormulaStep)]() + newparents.foreach { case pFormula: AppliedConnector => val canonicalPFormula = canonicalize(pFormula) if formulaSeen.contains(canonicalPFormula) then val qFormula = formulaSeen(canonicalPFormula) - mergeWithStep(pFormula, qFormula, FormulaCongruence((pFormula, qFormula))) + formWorklist = (pFormula, qFormula, FormulaCongruence((pFormula, qFormula))) :: formWorklist + //mergeWithStep(pFormula, qFormula, FormulaCongruence((pFormula, qFormula))) else formulaSeen(canonicalPFormula) = pFormula } - formulaParents(id) = formulaSeen.values.to(mutable.Set) + formulaParents(newId) = formulaSeen.values.to(mutable.Set) + formWorklist.foreach { case (l, r, step) => mergeWithStep(l, r, step) } def proveTerm(using lib: Library, proof: lib.Proof)(id1: Term, id2:Term, base: Sequent): proof.ProofTacticJudgement = diff --git a/lisa-sets/src/test/scala/lisa/automation/CongruenceTest.scala b/lisa-sets/src/test/scala/lisa/automation/CongruenceTest.scala index e0a2c8e66..34bc77eee 100644 --- a/lisa-sets/src/test/scala/lisa/automation/CongruenceTest.scala +++ b/lisa-sets/src/test/scala/lisa/automation/CongruenceTest.scala @@ -7,6 +7,7 @@ import org.scalatest.funsuite.AnyFunSuite class CongruenceTest extends AnyFunSuite with lisa.TestMain { + given lib: lisa.SetTheoryLibrary.type = lisa.SetTheoryLibrary val a = variable From e528baf2d041380c72b8b6bed8f94f5ce672cc9b Mon Sep 17 00:00:00 2001 From: Simon Guilloud Date: Tue, 25 Jun 2024 13:54:41 +0200 Subject: [PATCH 4/5] update CHANGES.md --- CHANGES.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGES.md b/CHANGES.md index b71a20ae6..ea034b6af 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,8 @@ # Change List +## 2024-06-25 +Repair two broken tests, small improvements to + ## 2024-04-12 Addition of the Congruence tactic, solving sequents by congruence closure using egraphs. From a7216713c2ed7d9093ad8df3c54df2d532f86674 Mon Sep 17 00:00:00 2001 From: Simon Guilloud Date: Thu, 3 Oct 2024 14:13:39 +0200 Subject: [PATCH 5/5] old --- .../scala/lisa/automation/Congruence.scala | 54 +++++++++++++------ 1 file changed, 38 insertions(+), 16 deletions(-) diff --git a/lisa-sets/src/main/scala/lisa/automation/Congruence.scala b/lisa-sets/src/main/scala/lisa/automation/Congruence.scala index 83a0fdc19..8c60de410 100644 --- a/lisa-sets/src/main/scala/lisa/automation/Congruence.scala +++ b/lisa-sets/src/main/scala/lisa/automation/Congruence.scala @@ -240,7 +240,8 @@ import scala.collection.mutable class EGraphTerms() { - val termParents = mutable.Map[Term, mutable.Set[AppliedFunctional | AppliedPredicate]]() + val termParentsT = mutable.Map[Term, mutable.Set[AppliedFunctional]]() + val termParentsF = mutable.Map[Term, mutable.Set[AppliedPredicate]]() val termUF = new UnionFind[Term]() @@ -307,7 +308,8 @@ class EGraphTerms() { def makeSingletonEClass(node:Term): Term = { termUF.add(node) - termParents(node) = mutable.Set() + termParentsT(node) = mutable.Set() + termParentsF(node) = mutable.Set() node } def makeSingletonEClass(node:Formula): Formula = { @@ -319,6 +321,8 @@ class EGraphTerms() { def idEq(id1: Term, id2: Term): Boolean = find(id1) == find(id2) def idEq(id1: Formula, id2: Formula): Boolean = find(id1) == find(id2) + + def canonicalize(node: Term): Term = node match case AppliedFunctional(label, args) => AppliedFunctional(label, args.map(t => find(t))) @@ -335,14 +339,16 @@ class EGraphTerms() { def add(node: Term): Term = if termUF.parent.contains(node) then return node makeSingletonEClass(node) + codes(node) = codes.size node match case node @ AppliedFunctional(_, args) => args.foreach(child => add(child) - termParents(find(child)).add(node) + termParentsT(find(child)).add(node) ) - node - case _ => node + case _ => () + termSigs(canSig(node)) = node + node def add(node: Formula): Formula = if formulaUF.parent.contains(node) then return node @@ -351,7 +357,7 @@ class EGraphTerms() { case node @ AppliedPredicate(_, args) => args.foreach(child => add(child) - termParents(find(child)).add(node) + termParentsF(find(child)).add(node) ) node case node @ AppliedConnector(_, args) => @@ -378,38 +384,54 @@ class EGraphTerms() { mergeWithStep(id1, id2, FormulaExternal((id1, id2))) } + type Sig = (TermLabel[?]|Term, List[Int]) + val termSigs = mutable.Map[Sig, Term]() + val codes = mutable.Map[Term, Int]() + + def canSig(node: Term): Sig = node match + case AppliedFunctional(label, args) => + (label, args.map(a => codes(find(a))).toList) + case _ => (node, List()) + protected def mergeWithStep(id1: Term, id2: Term, step: TermStep): Unit = { if find(id1) == find(id2) then () else termProofMap((id1, id2)) = step - val newparents = termParents(find(id1)) ++ termParents(find(id2)) + val parentsT1 = termParentsT(find(id1)) + val parentsF1 = termParentsF(find(id1)) + + val parentsT2 = termParentsT(find(id2)) + val parentsF2 = termParentsF(find(id2)) + val preSigs : Map[Term, Sig] = parentsT1.map(t => (t, canSig(t))).toMap + codes(find(id2)) = codes(find(id1)) //assume parents(find(id1)) >= parents(find(id2)) termUF.union(id1, id2) val newId = find(id1) - val termSeen = mutable.Map[Term, AppliedFunctional]() val formulaSeen = mutable.Map[Formula, AppliedPredicate]() var formWorklist = List[(Formula, Formula, FormulaStep)]() var termWorklist = List[(Term, Term, TermStep)]() - newparents.foreach { + parentsT2.foreach { case pTerm: AppliedFunctional => - val canonicalPTerm = canonicalize(pTerm) - if termSeen.contains(canonicalPTerm) then - val qTerm = termSeen(canonicalPTerm) + val canonicalPTerm = canSig(pTerm) + if termSigs.contains(canonicalPTerm) then + val qTerm = termSigs(canonicalPTerm) termWorklist = (pTerm, qTerm, TermCongruence((pTerm, qTerm))) :: termWorklist - //mergeWithStep(pTerm, qTerm, TermCongruence((pTerm, qTerm))) else - termSeen(canonicalPTerm) = pTerm + termSigs(canonicalPTerm) = pTerm + } + (parentsF2 ++ parentsF1).foreach { case pFormula: AppliedPredicate => val canonicalPFormula = canonicalize(pFormula) if formulaSeen.contains(canonicalPFormula) then val qFormula = formulaSeen(canonicalPFormula) formWorklist = (pFormula, qFormula, FormulaCongruence((pFormula, qFormula))) :: formWorklist - //mergeWithStep(pFormula, qFormula, FormulaCongruence((pFormula, qFormula))) else formulaSeen(canonicalPFormula) = pFormula } - termParents(newId) = (termSeen.values.to(mutable.Set): mutable.Set[AppliedFunctional | AppliedPredicate]) ++ formulaSeen.values.to(mutable.Set) + termParentsT(newId) = termParentsT(id1) + termParentsT(newId).addAll(termParentsT(id2)) + termParentsF(newId) = formulaSeen.values.to(mutable.Set) formWorklist.foreach { case (l, r, step) => mergeWithStep(l, r, step) } termWorklist.foreach { case (l, r, step) => mergeWithStep(l, r, step) } }