From 807309459580c7aed2ab5d7b3d528a116d2a3f2e Mon Sep 17 00:00:00 2001 From: Ricardo Costa Date: Tue, 16 Jun 2026 15:04:48 +0100 Subject: [PATCH] Add VC Function Substitution --- .../opt/VCFunctionSubstitution.java | 156 ++++++++++++++++++ .../rj_language/opt/VCSimplification.java | 5 +- .../opt/VCFunctionSubstitutionTest.java | 75 +++++++++ .../opt/VCImplicationGenerator.java | 61 ++++++- .../VCSimplificationPropertyBasedTest.java | 7 + .../rj_language/opt/VCSimplificationTest.java | 12 ++ 6 files changed, 313 insertions(+), 3 deletions(-) create mode 100644 liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VCFunctionSubstitution.java create mode 100644 liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCFunctionSubstitutionTest.java diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VCFunctionSubstitution.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VCFunctionSubstitution.java new file mode 100644 index 00000000..b36be0f1 --- /dev/null +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VCFunctionSubstitution.java @@ -0,0 +1,156 @@ +package liquidjava.rj_language.opt; + +import java.util.Optional; + +import liquidjava.processor.SimplifiedVCImplication; +import liquidjava.processor.VCImplication; +import liquidjava.rj_language.Predicate; +import liquidjava.rj_language.ast.BinaryExpression; +import liquidjava.rj_language.ast.Expression; +import liquidjava.rj_language.ast.FunctionInvocation; +import liquidjava.rj_language.ast.GroupExpression; + +/** + * Simplifies VCImplication chains by propagating exact function invocation equalities + */ +public class VCFunctionSubstitution implements VCSimplificationPass { + + /** + * A substitution discovered from a function invocation equality + */ + private record Substitution(VCImplication node, FunctionInvocation invocation, Expression replacement) { + } + + /** + * Applies one function invocation substitution in a VC chain + */ + @Override + public VCImplication apply(VCImplication implication) { + VCImplication result = implication.clone(); + Optional substitutionOpt = findSubstitution(result); + + if (substitutionOpt.isPresent()) { + Substitution substitution = substitutionOpt.get(); + result = substitute(result, substitution.node(), substitution.invocation(), substitution.replacement()); + } + return result; + } + + /** + * Preserves nodes before the source equality and starts rewriting at the source suffix + */ + private VCImplication substitute(VCImplication implication, VCImplication node, FunctionInvocation invocation, + Expression replacement) { + if (implication == null) + return null; + + // skip the source node to remove it from the chain and start substitution from the next node + if (implication == node) { + VCImplication result = implication.copyWithRefinement(implication.getRefinement().clone()); + result.setNext(substituteSuffix(implication.getNext(), node, invocation, replacement)); + return result; + } + + // preserve the current node and continue rewriting the suffix + VCImplication result = implication.copyWithRefinement(implication.getRefinement().clone()); + result.setNext(substitute(implication.getNext(), node, invocation, replacement)); + return result; + } + + /** + * Rewrites every node after the source equality with one function substitution + */ + private VCImplication substituteSuffix(VCImplication implication, VCImplication source, + FunctionInvocation invocation, Expression replacement) { + if (implication == null) + return null; + + VCImplication result = substituteNode(implication, source, invocation, replacement); + result.setNext(substituteSuffix(implication.getNext(), source, invocation, replacement)); + return result; + } + + /** + * Substitutes one exact function invocation inside one VC node while preserving simplification metadata + */ + private VCImplication substituteNode(VCImplication implication, VCImplication source, FunctionInvocation invocation, + Expression replacement) { + Expression expression = implication.getRefinement().getExpression().clone(); + if (!containsExpression(expression, invocation)) + return implication.copyWithRefinement(new Predicate(expression)); + + Expression substituted = expression.substitute(invocation, replacement.clone()); + return new SimplifiedVCImplication(implication, new Predicate(substituted), source); + } + + /** + * Finds the first function substitution candidate that is used in the remaining suffix + */ + private Optional findSubstitution(VCImplication implication) { + if (implication == null) + return Optional.empty(); + + Optional current = getSubstitution(implication); + if (current.isPresent() && containsExpression(implication.getNext(), current.get().invocation())) + return current; + + return findSubstitution(implication.getNext()); + } + + /** + * Extracts a substitution from one VC node refinement + */ + private Optional getSubstitution(VCImplication implication) { + return getSubstitution(implication, implication.getRefinement().getExpression().clone()); + } + + /** + * Extracts a substitution from a top-level equality or conjunction + */ + private Optional getSubstitution(VCImplication implication, Expression expression) { + if (expression instanceof GroupExpression group) + return getSubstitution(implication, group.getExpression()); + + if (expression instanceof BinaryExpression binary && "&&".equals(binary.getOperator())) { + Optional left = getSubstitution(implication, binary.getFirstOperand()); + if (left.isPresent()) + return left; + return getSubstitution(implication, binary.getSecondOperand()); + } + + if (!(expression instanceof BinaryExpression binary) || !"==".equals(binary.getOperator())) + return Optional.empty(); + + Expression left = binary.getFirstOperand(); + Expression right = binary.getSecondOperand(); + if (left instanceof FunctionInvocation invocation && !containsExpression(right, left)) + return Optional.of(new Substitution(implication, (FunctionInvocation) invocation.clone(), right.clone())); + if (right instanceof FunctionInvocation invocation && !containsExpression(left, right)) + return Optional.of(new Substitution(implication, (FunctionInvocation) invocation.clone(), left.clone())); + + return Optional.empty(); + } + + /** + * Checks whether an expression contains another expression + */ + private boolean containsExpression(Expression expression, Expression target) { + if (expression.equals(target)) + return true; + + for (Expression child : expression.getChildren()) + if (containsExpression(child, target)) + return true; + return false; + } + + /** + * Checks whether a VC suffix contains an expression + */ + private boolean containsExpression(VCImplication implication, Expression target) { + for (VCImplication current = implication; current != null; current = current.getNext()) + if (containsExpression(current.getRefinement().getExpression(), target)) + return true; + return false; + } +} diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VCSimplification.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VCSimplification.java index b26af0e6..4d47ff17 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VCSimplification.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VCSimplification.java @@ -9,8 +9,9 @@ */ public class VCSimplification { - private static final List PASSES = List.of(new VCSubstitution(), new VCBinderSimplification(), - new VCFolding(), new VCArithmeticSimplification(), new VCLogicalSimplification()); + private static final List PASSES = List.of(new VCSubstitution(), new VCFunctionSubstitution(), + new VCBinderSimplification(), new VCFolding(), new VCArithmeticSimplification(), + new VCLogicalSimplification()); /** * Applies all available simplification steps to a VC chain until a fixed point is reached diff --git a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCFunctionSubstitutionTest.java b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCFunctionSubstitutionTest.java new file mode 100644 index 00000000..5b71075b --- /dev/null +++ b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCFunctionSubstitutionTest.java @@ -0,0 +1,75 @@ +package liquidjava.rj_language.opt; + +import static liquidjava.utils.VCTestUtils.*; + +import liquidjava.processor.VCImplication; +import org.junit.jupiter.api.Test; + +class VCFunctionSubstitutionTest { + + private final VCFunctionSubstitution substitution = new VCFunctionSubstitution(); + + @Test + void substitutesExactFunctionInvocationIntoSuffix() { + VCImplication implication = vc("f(x) == 0", "f(y) == f(x) + 1"); + + assertSimplificationSteps(substitution::apply, implication, + chain(expect("f(x) == 0"), expect("f(y) == 0 + 1", "f(x) == 0"))); + } + + @Test + void substitutesReverseFunctionEquality() { + VCImplication implication = vc("0 == f(x)", "f(y) == f(x) + 1"); + + assertSimplificationSteps(substitution::apply, implication, + chain(expect("0 == f(x)"), expect("f(y) == 0 + 1", "0 == f(x)"))); + } + + @Test + void preservesSourceNode() { + VCImplication implication = vc("f(x) == 0", "f(x) > -1"); + + assertSimplificationSteps(substitution::apply, implication, + chain(expect("f(x) == 0"), expect("0 > -1", "f(x) == 0"))); + } + + @Test + void doesNotRewriteEarlierNodesFromLaterEquality() { + VCImplication implication = vc("f(x) > 0", "f(x) == 1"); + + assertSimplificationSteps(substitution::apply, implication, chain(expect("f(x) > 0"), expect("f(x) == 1"))); + } + + @Test + void skipsUsedUpEqualityAndUsesNextAvailableEquality() { + VCImplication implication = vc("f(x) == 0", "f(y) == f(x) + 1", "f(y) == 1"); + + assertSimplificationSteps(substitution::apply, implication, + chain(expect("f(x) == 0"), expect("f(y) == 0 + 1", "f(x) == 0"), expect("f(y) == 1")), + chain(expect("f(x) == 0"), expect("f(y) == 0 + 1", "f(x) == 0"), + expect("0 + 1 == 1", "f(y) == 0 + 1"))); + } + + @Test + void doesNotGeneralizeAcrossDifferentArguments() { + VCImplication implication = vc("f(x) == 0", "f(y) == 0"); + + assertSimplificationSteps(substitution::apply, implication, chain(expect("f(x) == 0"), expect("f(y) == 0"))); + } + + @Test + void ignoresRecursiveFunctionEquality() { + VCImplication implication = vc("f(x) == f(x) + 1", "f(x) > 0"); + + assertSimplificationSteps(substitution::apply, implication, + chain(expect("f(x) == f(x) + 1"), expect("f(x) > 0"))); + } + + @Test + void extractsEqualityFromTopLevelConjunction() { + VCImplication implication = vc("ok && f(x) == 0", "f(y) == f(x) + 1"); + + assertSimplificationSteps(substitution::apply, implication, + chain(expect("ok && f(x) == 0"), expect("f(y) == 0 + 1", "ok && f(x) == 0"))); + } +} diff --git a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCImplicationGenerator.java b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCImplicationGenerator.java index b7517c49..7ab68894 100644 --- a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCImplicationGenerator.java +++ b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCImplicationGenerator.java @@ -11,6 +11,7 @@ public class VCImplicationGenerator extends Generator { public static final String[] BINDERS = { "x", "y", "z", "w" }; public static final String[] FREE_VARS = { "a", "b", "c", "d" }; + public static final String[] FUNCTIONS = { "f", "g" }; private static final String[] COMPARISON_OPS = { "==", "!=", ">=", ">", "<=", "<" }; private static final String[] BOOLEAN_OPS = { "&&", "||", "-->", "==", "!=" }; private static final String[] ARITHMETIC_OPS = { "+", "-", "*" }; @@ -21,7 +22,7 @@ public VCImplicationGenerator() { @Override public VCImplication generate(SourceOfRandomness random, GenerationStatus status) { - return switch (random.nextInt(0, 14)) { + return switch (random.nextInt(0, 18)) { case 0 -> vc(substitution(random, "x"), comparison(random, "x")); case 1 -> vc(reverseSubstitution(random, "x"), comparison(random, "x")); case 2 -> vc(nonSubstitution(random, "x"), substitution(random, "y"), comparison(random, "y")); @@ -36,6 +37,11 @@ public VCImplication generate(SourceOfRandomness random, GenerationStatus status case 11 -> vc(logicalIdentity(random)); case 12 -> vc(unusedTrueBinder(random)); case 13 -> vc(falseBinder(random)); + case 14 -> exactFunctionSubstitution(random); + case 15 -> reverseFunctionSubstitution(random); + case 16 -> chainedFunctionSubstitution(random); + case 17 -> differentArgumentFunctionSubstitution(random); + case 18 -> recursiveFunctionSubstitution(random); default -> vc(substitution(random, "x"), substitution(random, "y"), foldableComparison(random)); }; } @@ -62,6 +68,59 @@ private static String nonSubstitution(SourceOfRandomness random, String binder) return "∀" + binder + ":int. " + binder + " == " + binder + " " + signed(random.nextInt(1, 5)); } + private static VCImplication exactFunctionSubstitution(SourceOfRandomness random) { + String function = functionName(random); + return vc(functionSubstitution(random, function, "a"), functionUse(random, function, "a")); + } + + private static VCImplication reverseFunctionSubstitution(SourceOfRandomness random) { + String function = functionName(random); + return vc(reverseFunctionSubstitution(random, function, "a"), functionUse(random, function, "a")); + } + + private static VCImplication chainedFunctionSubstitution(SourceOfRandomness random) { + String function = functionName(random); + return vc(functionSubstitution(random, function, "a"), dependentFunctionSubstitution(random, function), + functionUse(random, function, "b")); + } + + private static VCImplication differentArgumentFunctionSubstitution(SourceOfRandomness random) { + String function = functionName(random); + return vc(functionSubstitution(random, function, "a"), functionUse(random, function, "b")); + } + + private static VCImplication recursiveFunctionSubstitution(SourceOfRandomness random) { + String function = functionName(random); + String invocation = functionInvocation(function, "a"); + return vc(invocation + " == " + invocation + " " + signed(random.nextInt(1, 5)), + functionUse(random, function, "a")); + } + + private static String functionSubstitution(SourceOfRandomness random, String function, String argument) { + return functionInvocation(function, argument) + " == " + value(random); + } + + private static String reverseFunctionSubstitution(SourceOfRandomness random, String function, String argument) { + return value(random) + " == " + functionInvocation(function, argument); + } + + private static String dependentFunctionSubstitution(SourceOfRandomness random, String function) { + return functionInvocation(function, "b") + " == " + functionInvocation(function, "a") + " " + + signed(random.nextInt(-3, 3)); + } + + private static String functionUse(SourceOfRandomness random, String function, String argument) { + return functionInvocation(function, argument) + " " + comparisonOperator(random) + " " + intLiteral(random); + } + + private static String functionInvocation(String function, String argument) { + return function + "(" + argument + ")"; + } + + private static String functionName(SourceOfRandomness random) { + return FUNCTIONS[random.nextInt(0, FUNCTIONS.length - 1)]; + } + private static String[] unusedTrueBinder(SourceOfRandomness random) { return new String[] { "∀x:int. true", comparison(random, "a") }; } diff --git a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCSimplificationPropertyBasedTest.java b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCSimplificationPropertyBasedTest.java index 8a638ddc..52d495b6 100644 --- a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCSimplificationPropertyBasedTest.java +++ b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCSimplificationPropertyBasedTest.java @@ -11,6 +11,7 @@ import com.pholser.junit.quickcheck.runner.JUnitQuickcheck; import liquidjava.processor.VCImplication; import liquidjava.processor.context.Context; +import liquidjava.processor.context.GhostFunction; import liquidjava.rj_language.Predicate; import liquidjava.rj_language.ast.BinaryExpression; import liquidjava.rj_language.ast.Expression; @@ -19,12 +20,15 @@ import liquidjava.smt.SMTResult; import liquidjava.utils.TestUtils; import org.junit.runner.RunWith; +import spoon.Launcher; +import spoon.reflect.factory.Factory; @RunWith(JUnitQuickcheck.class) public class VCSimplificationPropertyBasedTest { private static final int TRIALS = 50; // number of random VCs to test private static final int MAX_STEPS = 20; // to prevent infinite loops in case of non-termination + private static final Factory FACTORY = new Launcher().getFactory(); @Property(trials = TRIALS) public void eachSimplificationStepPreservesVcSemantics(@From(VCImplicationGenerator.class) VCImplication vc) { @@ -47,6 +51,9 @@ private static void setUpContext() { TestUtils.addIntVariableToContext(variable); for (String variable : VCImplicationGenerator.FREE_VARS) TestUtils.addIntVariableToContext(variable); + for (String function : VCImplicationGenerator.FUNCTIONS) + Context.getInstance().addGhostFunction( + new GhostFunction(function, List.of("int"), FACTORY.Type().INTEGER_PRIMITIVE, FACTORY, "")); } private static void assertEquivalent(VCImplication unsimplified, VCImplication simplified, int step) { diff --git a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCSimplificationTest.java b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCSimplificationTest.java index 62514f4e..a07dbac2 100644 --- a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCSimplificationTest.java +++ b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VCSimplificationTest.java @@ -144,6 +144,18 @@ void simplifyAppliesLongSubstitutionChainBeforeReachingFixedPoint() { chain(expect("3 == 3", "2 + 1 == 3")), chain(expect("true", "3 == 3"))); } + @Test + void simplifyPropagatesFunctionInvocationEqualitiesBeforeReachingFixedPoint() { + VCImplication implication = vc("f(x) == 0", "f(y) == f(x) + 1", "f(y) == 1"); + + assertSimplificationSteps(VCSimplification::simplifyOnce, implication, + chain(expect("f(x) == 0"), expect("f(y) == 0 + 1", "f(x) == 0"), expect("f(y) == 1")), + chain(expect("f(x) == 0"), expect("f(y) == 0 + 1", "f(x) == 0"), expect("0 + 1 == 1", "f(y) == 0 + 1")), + chain(expect("f(x) == 0"), expect("f(y) == 1", "f(y) == 0 + 1"), expect("0 + 1 == 1", "f(y) == 0 + 1")), + chain(expect("f(x) == 0"), expect("f(y) == 1", "f(y) == 0 + 1"), expect("1 == 1", "0 + 1 == 1")), + chain(expect("f(x) == 0"), expect("f(y) == 1", "f(y) == 0 + 1"), expect("true", "1 == 1"))); + } + @Test void simplifyCombinesSubstitutionAndNestedFoldingAcrossFixedPoint() { VCImplication implication = vc("∀x:int. x == 1", "∀y:int. y == x + 2", "y - 1 == 2");