diff --git a/check-docs.py b/check-docs.py index 39e9627..3ea9c67 100755 --- a/check-docs.py +++ b/check-docs.py @@ -21,7 +21,8 @@ def parse_dl(path): parse_dl(os.path.join(dldir, 'base.sam')) declared.add("ASSIGN(String Type, Object value, Type result)") -declared.add("MATCH(Object a, Object b, Object result)") +declared.add("MATCH(Object a, Object b)") +declared.add("MATCH_TO(Object a, Object b, Object result)") documented.add("NOT_EQUAL(Object a, Object b)") undocumented = declared - documented diff --git a/doc/base.rst b/doc/base.rst index d47e5a7..324f9d3 100644 --- a/doc/base.rst +++ b/doc/base.rst @@ -150,14 +150,15 @@ These are not relations, so you can't enumerate all their values, but you can us ASSIGN("String", any(Value)) -> any(String) ASSIGN("String", any(int)) -> nothing -.. function:: MATCH(Object a, Object b, Object result) +.. function:: MATCH(Object a, Object b) +.. function:: MATCH_TO(Object a, Object b, Object result) - Tests whether `a = b`, taking account of the fact that either may be an `any` value. The `result` is the intersection + Tests whether `a = b`, taking account of the fact that either may be an `any` value. The `result` (in the three-argument form) is the intersection of the possible values. e.g.:: - MATCH("foo", "foo") -> "foo" - MATCH("foo", "bar") -> nothing - MATCH(any(String), any(Value)) -> any(String) + MATCH_TO("foo", "foo") -> "foo" + MATCH_TO("foo", "bar") -> nothing + MATCH_TO(any(String), any(Value)) -> any(String) .. function:: MAKE_OBJECT(String nameHint, String invocation, Ref child) diff --git a/doc/examples/includes/rbacTabs.sam b/doc/examples/includes/rbacTabs.sam index 16e49f7..f82b125 100644 --- a/doc/examples/includes/rbacTabs.sam +++ b/doc/examples/includes/rbacTabs.sam @@ -30,7 +30,7 @@ declare hasRole(Ref object, Ref target, String role). hasRole(?Object, ?Target, ?Role) :- hasIdentity(?Object, ?Identity1), grantsRole(?Target, ?Role, ?Identity2), - MATCH(?Identity1, ?Identity2, ?Identity). // (may grant to any(String)) + MATCH(?Identity1, ?Identity2). // (may grant to any(String)) guiObjectTab(50, "Has roles", "hasRole/3", "object"). guiObjectTab(60, "Grants roles", "grantsRole/3", "target"). diff --git a/src/main/java/eu/serscis/sam/Constants.java b/src/main/java/eu/serscis/sam/Constants.java index 5d0e48d..0419aff 100644 --- a/src/main/java/eu/serscis/sam/Constants.java +++ b/src/main/java/eu/serscis/sam/Constants.java @@ -37,7 +37,8 @@ public class Constants { static public IPredicate ASSIGNP = BASIC.createPredicate("ASSIGN", 3); - static public IPredicate MATCHP = BASIC.createPredicate("MATCH", 3); + static public IPredicate MATCH2P = BASIC.createPredicate("MATCH", 2); + static public IPredicate MATCH_TOP = BASIC.createPredicate("MATCH_TO", 3); static public IPredicate EQUALP = BASIC.createPredicate("EQUAL", 2); static public IPredicate accessControlOnP = BASIC.createPredicate("accessControlOn", 0); static public IPredicate expectFailureP = BASIC.createPredicate("expectFailure", 0); diff --git a/src/main/java/eu/serscis/sam/Eval.java b/src/main/java/eu/serscis/sam/Eval.java index 2b9f280..c2588d5 100644 --- a/src/main/java/eu/serscis/sam/Eval.java +++ b/src/main/java/eu/serscis/sam/Eval.java @@ -62,7 +62,7 @@ public IRule process(IRule rule) throws RuleUnsafeException { for (ILiteral lit : rule.getBody()) { String p = lit.getAtom().getPredicate().getPredicateSymbol(); if (p.equals("TO_STRING") || p.equals("STRING_CONCAT") || - p.equals("MAKE_OBJECT") || p.equals("MATCH") || p.equals("ASSIGN")) { + p.equals("MAKE_OBJECT") || p.equals("MATCH_TO") || p.equals("ASSIGN")) { return rule; } } diff --git a/src/main/java/eu/serscis/sam/MatchBuiltin.java b/src/main/java/eu/serscis/sam/MatchBuiltin.java index 4ebd094..90571a8 100644 --- a/src/main/java/eu/serscis/sam/MatchBuiltin.java +++ b/src/main/java/eu/serscis/sam/MatchBuiltin.java @@ -28,6 +28,7 @@ package eu.serscis.sam; +import org.deri.iris.builtins.BooleanBuiltin; import static org.deri.iris.factory.Factory.BASIC; import org.deri.iris.EvaluationException; @@ -40,21 +41,25 @@ public class MatchBuiltin extends FunctionalBuiltin { - private static final String PREDICATE_STRING = "MATCH"; - private static final IPredicate PREDICATE = BASIC.createPredicate(PREDICATE_STRING, -1); + private static final String PREDICATE_STRING = "MATCH_TO"; + private static final IPredicate PREDICATE = Constants.MATCH_TOP; public MatchBuiltin(ITerm... terms) { super(BASIC.createPredicate(PREDICATE_STRING, terms.length), terms); - if (terms.length < 2 || terms.length > 3) { - throw new IllegalArgumentException("The amount of terms <" + terms.length + "> must be 2 or 3"); + if (terms.length != 3) { + throw new IllegalArgumentException("The amount of terms <" + terms.length + "> must be 3"); } } protected ITerm computeResult(ITerm[] terms) throws EvaluationException { - ITerm first = terms[0]; - ITerm second = terms[1]; + if (terms.length != 3) { + throw new IllegalArgumentException("The amount of terms <" + terms.length + "> must be 3"); + } + return MatchBuiltin.computeResult(terms[0], terms[1]); + } + private static ITerm computeResult(ITerm first, ITerm second) throws EvaluationException { ITerm result; /* If both terms are not any(), they match only if they are equal. @@ -75,7 +80,7 @@ protected ITerm computeResult(ITerm[] terms) throws EvaluationException { return result; } - private ITerm matchAny(AnyTerm any, ITerm other) { + private static ITerm matchAny(AnyTerm any, ITerm other) { if (other instanceof AnyTerm) { Type intersection = any.type.intersect(((AnyTerm) other).type); if (intersection == null) { @@ -92,4 +97,21 @@ private ITerm matchAny(AnyTerm any, ITerm other) { return null; // e.g. !MATCH(any(int), "hi", ?Result) } + + /* Two-argument form that just tests whether a match exists. */ + public static class MatchBuiltinBoolean extends BooleanBuiltin { + private static final IPredicate PREDICATE = Constants.MATCH2P; + + public MatchBuiltinBoolean(final ITerm... t) { + super(PREDICATE, t); + } + + public boolean computeResult(ITerm[] terms) { + try { + return MatchBuiltin.computeResult(terms[0], terms[1]) != null; + } catch (EvaluationException ex) { + throw new RuntimeException(ex); + } + } + } } diff --git a/src/main/java/eu/serscis/sam/Model.java b/src/main/java/eu/serscis/sam/Model.java index c1b8176..2428668 100644 --- a/src/main/java/eu/serscis/sam/Model.java +++ b/src/main/java/eu/serscis/sam/Model.java @@ -102,6 +102,7 @@ public class Model { builtinRegister.registerBuiltin(new AssignBuiltin(t1, t2, t3)); builtinRegister.registerBuiltin(new MatchBuiltin(t1, t2, t3)); + builtinRegister.registerBuiltin(new MatchBuiltin.MatchBuiltinBoolean(t1, t2)); } public Model(Configuration configuration) { diff --git a/src/main/java/eu/serscis/sam/TypeChecker.java b/src/main/java/eu/serscis/sam/TypeChecker.java index 2866d93..bb3e90c 100644 --- a/src/main/java/eu/serscis/sam/TypeChecker.java +++ b/src/main/java/eu/serscis/sam/TypeChecker.java @@ -190,7 +190,9 @@ public void check(ILiteral lit, boolean head, List tokens) throws ParserE new TermDefinition(Type.ObjectT) }; } - } else if ((predicate.equals(Constants.EQUALP) || predicate.equals(Constants.MATCHP)) && !head) { + } else if ((predicate.equals(Constants.EQUALP) || + predicate.equals(Constants.MATCH2P) || + predicate.equals(Constants.MATCH_TOP)) && !head) { updateTypesEqual(tuple, tokens); return; } else if ((predicate.getPredicateSymbol().equals("error")) && tuple.size() > 0) { diff --git a/src/main/sam/groupBy.sam b/src/main/sam/groupBy.sam index 54fce87..18f9179 100644 --- a/src/main/sam/groupBy.sam +++ b/src/main/sam/groupBy.sam @@ -49,10 +49,10 @@ didCall(?Caller, ?CallerInvocation, ?CallSite, ?Target, ?TargetInvocation, ?Meth maySend(?Caller, ?CallerInvocation, ?CallSite, ?SendPos, ?SentValue), hasParam(?Method, ?ParamType, ?AnyParam, ?ParamPos), ASSIGN(?ParamType, ?SentValue, ?ReceivedValue), - MATCH(?SendPos, ?ParamPos, ?DontCarePos), + MATCH(?SendPos, ?ParamPos), GroupByArgAt(?Method, ?ParamPos), GroupByArgCase(?Method, ?CaseValue, ?TargetInvocation), - MATCH(?ReceivedValue, ?CaseValue, ?DontCareValue). + MATCH(?ReceivedValue, ?CaseValue). didCall(?Caller, ?CallerInvocation, ?CallSite, ?Target, ?TargetContext, ?Method) :- didCall(?Caller, ?CallerInvocation, ?CallSite, ?Target, ?Method), @@ -70,10 +70,10 @@ didReceive(?Target, ?TargetInvocation, ?Method, ?ParamPos, ?ReceivedMatchedValue maySend(?Caller, ?CallerInvocation, ?CallSite, ?SendPos, ?SentValue), hasParam(?Method, ?ParamType, ?AnyParam, ?ParamPos), ASSIGN(?ParamType, ?SentValue, ?ReceivedValue), - MATCH(?SendPos, ?ParamPos, ?DontCarePos), + MATCH(?SendPos, ?ParamPos), GroupByArgAt(?Method, ?ParamPos), GroupByArgCase(?Method, ?CaseValue, ?TargetInvocation), - MATCH(?ReceivedValue, ?CaseValue, ?ReceivedMatchedValue). + MATCH_TO(?ReceivedValue, ?CaseValue, ?ReceivedMatchedValue). // the other args @@ -82,14 +82,14 @@ didReceive(?Target, ?TargetInvocation, ?Method, ?ParamPos, ?ReceivedValue) :- maySend(?Caller, ?CallerInvocation, ?CallSite, ?SendPos, ?SentValue), hasParam(?Method, ?ParamType, ?AnyParam, ?ParamPos), ASSIGN(?ParamType, ?SentValue, ?ReceivedValue), - MATCH(?SendPos, ?ParamPos, ?DontCarePos), + MATCH(?SendPos, ?ParamPos), GroupByArgAt(?Method, ?GroupByArgAt), ?GroupByArgAt != ?ParamPos. error("missing GroupByArgCase", ?Method, ?ReceivedValue) :- GroupByArgAt(?Method, ?ParamPos), didCall(?Caller, ?CallerInvocation, ?CallSite, ?Target, ?Method), maySend(?Caller, ?CallerInvocation, ?CallSite, ?SendPos, ?SentValue), - MATCH(?SendPos, ?ParamPos, ?DontCarePos), + MATCH(?SendPos, ?ParamPos), hasParam(?Method, ?ParamType, ?AnyParam, ?ParamPos), ASSIGN(?ParamType, ?SentValue, ?ReceivedValue), !didReceive(?Target, ?TargetInvocation, ?Method, ?ParamPos, ?ReceivedValue). @@ -128,7 +128,7 @@ didCall(?Caller, ?CallerInvocation, ?CallSite, ?Target, ?TargetContext, ?Method) didReceive(?Target, ?TargetInvocation, ?Method, ?ParamPos, ?ReceivedValue) :- didCall(?Caller, ?CallerInvocation, ?CallSite, ?Target, ?TargetInvocation, ?Method), maySend(?Caller, ?CallerInvocation, ?CallSite, ?SendPos, ?SentValue), - MATCH(?SendPos, ?ParamPos, ?DontCarePos), + MATCH(?SendPos, ?ParamPos), hasParam(?Method, ?ParamType, ?Param, ?ParamPos), ASSIGN(?ParamType, ?SentValue, ?ReceivedValue), GroupAs(?Method, ?TargetInvocation). diff --git a/src/main/sam/system.sam b/src/main/sam/system.sam index 5241403..bced9a6 100644 --- a/src/main/sam/system.sam +++ b/src/main/sam/system.sam @@ -46,7 +46,7 @@ didCall(?Caller, ?CallerInvocation, ?CallSite, ?Target, ?Method) :- hasMethod(?TargetType, ?Method), callsMethod(?CallSite, ?MethodName1), methodName(?Method, ?MethodName2), - MATCH(?MethodName1, ?MethodName2, ?MethodName), + MATCH(?MethodName1, ?MethodName2), accessAllowed(?Caller, ?Target, ?Method). // For dynamic receivers @@ -71,7 +71,7 @@ didCall(?Caller, ?CallerInvocation, ?CallSite, ?Target, ?Method) :- callsMethodInLocal(?CallSite, ?LocalVarName), local(?Caller, ?CallerInvocation, ?LocalVarName, ?MethodName1), methodName(?Method, ?MethodName2), - MATCH(?MethodName1, ?MethodName2, ?MethodName), + MATCH(?MethodName1, ?MethodName2), accessAllowed(?Caller, ?Target, ?Method). didCall(?Caller, ?CallerInvocation, ?CallSite, ?Target, ?CallerInvocation, ?Method) :- @@ -146,7 +146,7 @@ maySend(?Caller, ?CallerInvocation, ?CallSite, ?Pos, ?ArgValue) :- didReceive(?Target, ?TargetInvocation, ?Method, ?ParamPos, ?ReceivedValue) :- didCall(?Caller, ?CallerInvocation, ?CallSite, ?Target, ?TargetInvocation, ?Method), maySend(?Caller, ?CallerInvocation, ?CallSite, ?SendPos, ?SentValue), - MATCH(?SendPos, ?ParamPos, ?DontCarePos), + MATCH(?SendPos, ?ParamPos), hasParam(?Method, ?ParamType, ?Param, ?ParamPos), ASSIGN(?ParamType, ?SentValue, ?ReceivedValue), !methodDoesContextMapping(?Method). @@ -165,7 +165,7 @@ hasIdentity(?Child, ?Identity) :- accessAllowed(?Caller, ?Target, ?Method) :- hasIdentity(?Caller, ?CallerIdentity), // Caller's identity is Identity grantsRole(?Target, ?Role, ?GrantedIdentity), // Target grants Role to Identity - MATCH(?CallerIdentity, ?GrantedIdentity, ?UnusedResult), // Might be any() type + MATCH(?CallerIdentity, ?GrantedIdentity), // Might be any() type PermittedRole(?Method, ?Role). // Role is allowed to invoke the method /* Objects may call others with the same identity. */ diff --git a/src/test/testMatch.sam b/src/test/testMatch.sam index 5fa3c4e..751a411 100644 --- a/src/test/testMatch.sam +++ b/src/test/testMatch.sam @@ -39,7 +39,7 @@ declare xfer(Ref a, Ref b, int pos). xfer(?A, ?B, ?Pos) :- sends(?A, ?PosA), accepts(?B, ?PosB), - MATCH(?PosA, ?PosB, ?Pos). + MATCH_TO(?PosA, ?PosB, ?Pos). declare foo(String c, Object x). declare bar(String c, Object y). @@ -65,7 +65,7 @@ bar("e", ). foo("f", any(Ref)). bar("f", any(Object)). -match(?C, ?Z) :- foo(?C, ?X), bar(?C, ?Y), MATCH(?X, ?Y, ?Z). +match(?C, ?Z) :- foo(?C, ?X), bar(?C, ?Y), MATCH_TO(?X, ?Y, ?Z). assert match("a", any(boolean)). assert match("b", any(boolean)).