From 51ff18f72bce8abb9d22bfeb11712bc33bec8656 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Piotr=20Ma=C5=9Blanka?= <piotr.maslanka@henrietta.com.pl>
Date: Sat, 22 Aug 2020 20:48:50 +0200
Subject: [PATCH] overhaul predicates

---
 CHANGELOG.md                              |   4 +-
 docs/coding/predicates.rst                |  68 ++++-----
 docs/index.rst                            |   1 +
 satella/__init__.py                       |   2 +-
 satella/coding/predicates.py              |  83 +++++++++++
 satella/coding/predicates/__init__.py     |   9 --
 satella/coding/predicates/decorators.py   |  40 ------
 satella/coding/predicates/dictionaries.py |  14 --
 satella/coding/predicates/generic.py      |  53 -------
 satella/coding/predicates/number.py       |  36 -----
 satella/coding/predicates/sequences.py    |  39 ------
 tests/test_coding/test_predicates.py      | 160 +++++++++++-----------
 12 files changed, 202 insertions(+), 307 deletions(-)
 create mode 100644 satella/coding/predicates.py
 delete mode 100644 satella/coding/predicates/__init__.py
 delete mode 100644 satella/coding/predicates/decorators.py
 delete mode 100644 satella/coding/predicates/dictionaries.py
 delete mode 100644 satella/coding/predicates/generic.py
 delete mode 100644 satella/coding/predicates/number.py
 delete mode 100644 satella/coding/predicates/sequences.py

diff --git a/CHANGELOG.md b/CHANGELOG.md
index dfe71ca7..27335ed2 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,4 +1,4 @@
-# v2.10.1
+# v2.11.0
 
 * `for_argument` can now accept strings
 * changed typing for `for_argument`
@@ -7,4 +7,4 @@
 * added predicate `has_keys`, deprecated
   existing `has_keys`
 * changed ordering of arguments to attribute and item
-
+* overhauled `predicates`
diff --git a/docs/coding/predicates.rst b/docs/coding/predicates.rst
index 9d9f31e4..2d2abd3e 100644
--- a/docs/coding/predicates.rst
+++ b/docs/coding/predicates.rst
@@ -5,60 +5,66 @@ Predicates
 Predicates are functions that take something and return a boolean about truthfulness
 of given statement. Satella contains a bunch of functions to produce these predicates.
 
-These go superbly hand-in-hand with preconditions and postconditions.
+Satella lets you express predicates in a Pythonic way, eg:
 
-Predicates
-----------
+    ::
 
-.. autofunction:: satella.coding.predicates.between
+        p = x == 2
 
-.. autofunction:: satella.coding.predicates.length_is
+        assert(p(2) and not p(1))
 
-.. autofunction:: satella.coding.predicates.length_multiple_of
+        p = x > 2
 
-.. autofunction:: satella.coding.predicates.one_of
+        assert(p(2) and not p(1))
 
-.. autofunction:: satella.coding.predicates.equals
 
-.. autofunction:: satella.coding.predicates.shorter_than
+This behaviour extends to operators, item procurement and attr procurement. The only exception is the length,
+which due to Python limitations (namely __len__ being allowed to return an int only) is called
+via it's method .length(), eg:
 
-.. autofunction:: satella.coding.predicates.longer_than
 
-.. autofunction:: satella.coding.predicates.is_not_none
 
-.. autofunction:: satella.coding.predicates.not_equal
+    ::
 
-.. autofunction:: satella.coding.predicates.has_keys
+        p = x.length() == 2
 
+        assert(p([1, 2]) and not p([3])
 
-Decorators
-----------
 
-Decorators are used to extend given predicates. Eg:
+You can also piece together multiple predicates.
+Because of Python limitations please use & and | operators in place of and and or.
+Also use ^ in place of xor and ~ in place of not.
 
-    ::
-        P = namedtuple('P', ('x', 'y'))
-        p = P(2,5)
-        assert attribute(equals(5), 'y')(p)
 
     ::
 
-        p = [1, 2, 5]
-        assert item(equals(2), 1)(p)
+        p = x > 2 & x < 6
 
-    ::
-        p = [1, 2, 5]
-        assert p_all(item(equals(1), 0), item(equals(2), 1))
+        assert(p(4) and not p(8) and not p(1))
+
+Predicate class is documented here:
+
+.. autoclass:: satella.coding.predicates.Predicate
+
+To use the predicate you are to execute the following import:
 
     ::
-        p = [1, 2, 5]
-        assert p_any(item(equals(1), 0), item(equals(2), 1))
 
+        from satella.coding.predicates import x
+
+        p = x == 2
+
+        assert(p(2))
 
-.. autofunction:: satella.coding.predicates.attribute
+You can also check if a dict has provided keys
 
-.. autofunction:: satella.coding.predicates.item
+::
+    a = {'hello': 'hello', 'world': 'world'}
+    p = x.has_keys('hello', 'world')
+    assert p(a)
 
-.. autofunction:: satella.coding.predicates.p_all
+Or check whether an instance is of provided type
 
-.. autofunction:: satella.coding.predicates.p_any
+::
+    p = x.instanceof(int)
+    assert p(2)
diff --git a/docs/index.rst b/docs/index.rst
index 73a6aae6..7835d8eb 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -14,6 +14,7 @@ Visit the project's page at GitHub_!
            configuration/sources
            coding/functions
            coding/structures
+           coding/predicates
            coding/concurrent
            coding/sequences
            coding/transforms
diff --git a/satella/__init__.py b/satella/__init__.py
index 2c2d3ecf..fb7a749a 100644
--- a/satella/__init__.py
+++ b/satella/__init__.py
@@ -1 +1 @@
-__version__ = '2.10.1_a6'
+__version__ = '2.11.0'
diff --git a/satella/coding/predicates.py b/satella/coding/predicates.py
new file mode 100644
index 00000000..776b3546
--- /dev/null
+++ b/satella/coding/predicates.py
@@ -0,0 +1,83 @@
+import typing as tp
+import operator
+
+__all__ = ['x']
+
+
+def make_operation_two_args(operation_two_args: tp.Callable[[tp.Any, tp.Any], tp.Any]):
+    def operation(self, a) -> 'Predicate':
+        if isinstance(a, Predicate):
+            def op(v):
+                return operation_two_args(self(v), a(v))
+        else:
+            def op(v):
+                return operation_two_args(self(v), a)
+        return Predicate(op)
+    return operation
+
+
+def make_operation_single_arg(operation):
+    def operation_v(self):
+        def operate(v):
+            return operation(v)
+        return Predicate(operate)
+    return operation_v
+
+
+def _has_keys(a, keys):
+    for key in keys:
+        if key not in a:
+            return False
+    return True
+
+
+class Predicate:
+    __slots__ = ('operation', )
+
+    def __init__(self, operation: tp.Callable[[tp.Any], tp.Any]):
+        self.operation = operation
+
+    def __call__(self, v):
+        return self.operation(v)
+
+    def has_keys(self, *keys):
+        """
+        Return a predicate checking whether this value has provided keys
+        """
+        return make_operation_two_args(_has_keys)(self, keys)
+
+    def instanceof(self, instance):
+        """
+        Return a predicate checking whether this value is an instance of instance
+        """
+        return make_operation_two_args(isinstance)(self, instance)
+
+    length = make_operation_single_arg(len)
+
+    __getattr__ = make_operation_two_args(getattr)
+    __getitem__ = make_operation_two_args(lambda a, b: a[b])
+    __eq__ = make_operation_two_args(operator.eq)
+    __ne__ = make_operation_two_args(operator.ne)
+    __lt__ = make_operation_two_args(operator.lt)
+    __gt__ = make_operation_two_args(operator.gt)
+    __le__ = make_operation_two_args(operator.le)
+    __ge__ = make_operation_two_args(operator.ge)
+    __add__ = make_operation_two_args(operator.add)
+    __sub__ = make_operation_two_args(operator.sub)
+    __mul__ = make_operation_two_args(operator.mul)
+    __and__ = make_operation_two_args(operator.and_)
+    __or__ = make_operation_two_args(operator.or_)
+    __xor__ = make_operation_two_args(operator.xor)
+    __neg__ = make_operation_single_arg(lambda y: -y)
+    __invert__ = make_operation_single_arg(operator.invert)
+    __abs__ = make_operation_single_arg(abs)
+    __int__ = make_operation_single_arg(int)
+    __float__ = make_operation_single_arg(float)
+    __complex__ = make_operation_single_arg(complex)
+    __str__ = make_operation_single_arg(str)
+    __truediv__ = make_operation_two_args(operator.__truediv__)
+    __floordiv__ = make_operation_two_args(operator.floordiv)
+    __mod__ = make_operation_two_args(operator.mod)
+
+
+x = Predicate(lambda y: y)
diff --git a/satella/coding/predicates/__init__.py b/satella/coding/predicates/__init__.py
deleted file mode 100644
index 48619537..00000000
--- a/satella/coding/predicates/__init__.py
+++ /dev/null
@@ -1,9 +0,0 @@
-from .number import between
-from .generic import one_of, equals, is_not_none, not_equal, has_attr
-from .sequences import length_is, length_multiple_of, shorter_than, longer_than
-from .decorators import attribute, item, p_all, p_any
-from .dictionaries import has_keys
-
-__all__ = ['between', 'one_of', 'length_is', 'shorter_than', 'length_multiple_of',
-           'equals', 'attribute', 'item', 'is_not_none', 'not_equal', 'longer_than',
-           'has_attr', 'p_all', 'p_any', 'has_keys']
diff --git a/satella/coding/predicates/decorators.py b/satella/coding/predicates/decorators.py
deleted file mode 100644
index 81b0efc6..00000000
--- a/satella/coding/predicates/decorators.py
+++ /dev/null
@@ -1,40 +0,0 @@
-import typing as tp
-
-
-def p_all(*args: tp.Callable[[tp.Any], bool]) -> tp.Callable[[tp.Any], bool]:
-    """
-    Make a predicate returning True if all specified predicates return True
-    """
-    def predicate(v) -> bool:
-        return all(arg(v) for arg in args)
-    return predicate
-
-
-def p_any(*args: tp.Callable[[tp.Any], bool]) -> tp.Callable[[tp.Any], bool]:
-    """
-    Make a predicate returning True if any of specified predicates return True
-    """
-    def predicate(v) -> bool:
-        return any(arg(v) for arg in args)
-    return predicate
-
-
-def attribute(attr: str, p: tp.Callable[[tp.Any], bool]) -> tp.Callable[[tp.Any], bool]:
-    """
-    Make predicate p refer to attribute of the object passed to it.
-    """
-    def predicate(v) -> bool:
-        return p(getattr(v, attr))
-    return predicate
-
-
-def item(i, p: tp.Callable[[tp.Any], bool]) -> tp.Callable[[tp.Any], bool]:
-    """
-    Make predicate p refer to i-th item of the value passed to it
-
-    i doesn't have to be an integer, it will be passed to __getitem__
-    """
-    def predicate(v) -> bool:
-        return p(v[i])
-    return predicate
-
diff --git a/satella/coding/predicates/dictionaries.py b/satella/coding/predicates/dictionaries.py
deleted file mode 100644
index 6d31e509..00000000
--- a/satella/coding/predicates/dictionaries.py
+++ /dev/null
@@ -1,14 +0,0 @@
-import typing as tp
-
-
-def has_keys(*keys: tp.Any) -> tp.Callable[[tp.Dict], bool]:
-    """
-    Return a predicate to check if your dictionary has all of given keys
-    """
-    def predicate(v: tp.Dict) -> bool:
-        for key in keys:
-            if key not in v:
-                return False
-        return True
-    return predicate
-
diff --git a/satella/coding/predicates/generic.py b/satella/coding/predicates/generic.py
deleted file mode 100644
index 6e745d68..00000000
--- a/satella/coding/predicates/generic.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import typing as tp
-
-
-def one_of(*args) -> tp.Callable[[tp.Any], bool]:
-    """
-    Return a predicate that will return True if passed value equals to one of the arguments
-
-    :param args: a list of arguments on which the predicate will return True
-    :param attribute: if given, then it will first try to access given attribute of v
-    """
-    def predicate(v) -> bool:
-        return v in args
-    return predicate
-
-
-def _is_not_none(v) -> bool:
-    return v is not None
-
-
-def is_not_none():
-    """
-    Return a predicate that will return True if passed element is not None
-    """
-    return _is_not_none
-
-
-def has_attr(x):
-    """
-    Build a predicate that returns True if passed element has attribute x
-    """
-    def predicate(v):
-        return hasattr(v, x)
-    return predicate
-
-
-def not_equal(x):
-    """
-    Build a predicate that returns True only if value passed to it does not equal x
-    """
-    def predicate(v):
-        return v != x
-    return predicate
-
-
-def equals(x):
-    """
-    Build a predicate that returns True only if value passed to it equals x
-    """
-    def predicate(v):
-        return v == x
-    return predicate
-
-
diff --git a/satella/coding/predicates/number.py b/satella/coding/predicates/number.py
deleted file mode 100644
index fd3d26c1..00000000
--- a/satella/coding/predicates/number.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import typing as tp
-import math
-
-Number = tp.Union[float, int]
-Predicate = tp.Callable[[Number], bool]
-
-
-def between(left: Number = -math.inf, right: Number = math.inf,
-            incl_left: bool = True, incl_right: bool = True) -> Predicate:
-    """
-    Build a predicate to check whether a given number is in particular range
-
-    :param left: predicate will be true for numbers larger than this
-    :param right: predicate will be true for numbers smaller than this
-    :param incl_left: whether to include left in the range for the predicate. Set to True
-        will result in a <= operator, whereas False will result in a >
-    :param incl_right: whether to include left in the range for the predicate
-    :param attribute: if given, then it will first try to access given attribute of v
-    """
-    def predicate(x: Number) -> bool:
-        if incl_left:
-            if x < left:
-                return False
-        else:
-            if x <= left:
-                return False
-
-        if incl_right:
-            if x > right:
-                return False
-        else:
-            if x >= right:
-                return False
-
-        return True
-    return predicate
diff --git a/satella/coding/predicates/sequences.py b/satella/coding/predicates/sequences.py
deleted file mode 100644
index 661c6b29..00000000
--- a/satella/coding/predicates/sequences.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import typing as tp
-
-
-def shorter_than(x) -> tp.Callable[[tp.Sequence], bool]:
-    """
-    Return a predicate that will return True if length of sequence is less than x
-
-    :param x: value of x
-    """
-    def predicate(v):
-        return len(v) < x
-    return predicate
-
-
-def longer_than(x):
-    """
-    Return a predicate that will return True if length of sequence is greater than x
-    """
-    def predicate(v):
-        return len(v) > x
-    return predicate
-
-
-def length_is(x) -> tp.Callable[[tp.Sequence], bool]:
-    """
-    Return a predicate that will return True if length of sequence is x
-    """
-    def predicate(v):
-        return len(v) == x
-    return predicate
-
-
-def length_multiple_of(x) -> tp.Callable[[tp.Sequence], bool]:
-    """
-    Return a predicate that will return True if length of sequence is a multiple of x
-    """
-    def predicate(v):
-        return not (len(v) % x)
-    return predicate
diff --git a/tests/test_coding/test_predicates.py b/tests/test_coding/test_predicates.py
index 03cb7632..d8a6ef68 100644
--- a/tests/test_coding/test_predicates.py
+++ b/tests/test_coding/test_predicates.py
@@ -1,90 +1,86 @@
 import unittest
 
 
-from satella.coding.predicates import between, one_of, length_is, shorter_than, \
-    length_multiple_of, attribute, equals, item, longer_than, is_not_none, not_equal, \
-    has_attr, p_all, p_any
+from satella.coding.predicates import x
 
 
 class TestPredicates(unittest.TestCase):
 
-    def test_p_all(self):
-        p = [1, 2]
-        self.assertTrue(p_all(item(equals(1), 0), item(equals(2), 1))(p))
-        self.assertFalse(p_all(item(equals(1), 0), item(equals(3), 1))(p))
-
-    def test_p_any(self):
-        p = [1, 2]
-        self.assertTrue(p_any(item(equals(1), 0), item(equals(3), 1))(p))
-        self.assertFalse(p_any(item(equals(4), 0), item(equals(3), 1))(p))
-
-    def test_has_attr(self):
+    def test_instanceof(self):
+        p = x.instanceof(int)
+        self.assertTrue(p(2))
+        self.assertFalse(p('2'))
+
+    def test_has_keys(self):
+        a = {'hello': 'world', 'hello2': 'world'}
+        p = x.has_keys('hello', 'hello2')
+        self.assertTrue(p(a))
+        del a['hello']
+        self.assertFalse(p(a))
+
+    def test_joined_predicates(self):
+        p = (x > 2) & (x < 6)
+        self.assertTrue(p(4))
+        self.assertFalse(p(1))
+        self.assertFalse(p(8))
+
+        p = (x < 2) | (x > 6)
+        self.assertTrue(p(1))
+        self.assertTrue(p(8))
+        self.assertFalse(p(4))
+
+    def test_ops(self):
+        p = (x + 2) == 2
+        self.assertTrue(p(0))
+        self.assertFalse(p(1))
+        p = (x - 2) == 0
+        self.assertTrue(p(2))
+        self.assertFalse(p(1))
+        p = (x * 2) == 2
+        self.assertTrue(p(1))
+        self.assertFalse(p(2))
+        p = (x / 2) == 1
+        self.assertTrue(p(2))
+        self.assertFalse(p(1))
+        p = (x + 2) % 3 == 0
+        self.assertTrue(p(1))
+        self.assertFalse(p(2))
+
+    def test_getattr(self):
         class A:
-            def __init__(self):
-                self.b = 2
-        a = A()
-        self.assertTrue(has_attr('b')(a))
-        self.assertFalse(has_attr('c')(a))
-
-    def test_not_equal(self):
-        self.assertTrue(not_equal(5)(6))
-        self.assertFalse(not_equal(5)(5))
-
-    def test_is_not_none(self):
-        self.assertTrue(is_not_none()(6))
-        self.assertFalse(is_not_none()(None))
-
-    def test_longer_than(self):
-        a = 'ala'
-        self.assertTrue(longer_than(2)(a))
-        self.assertFalse(longer_than(3)(a))
-
-    def test_length_is_attribute(self):
-        class Attr:
-            def __init__(self, b):
-                self.a = b
-
-        a = Attr('ala')
-        self.assertTrue(attribute('a', length_is(3))(a))
-        self.assertFalse(attribute('a', length_is(4))(a))
-
-    def test_length_is_item(self):
-        a = [1, 2, 5]
-
-        self.assertTrue(item(1, equals(2))(a))
-        self.assertFalse(item(0, equals(2))(a))
-
-    def test_length_is(self):
-        a = 'ala'
-        self.assertTrue(length_is(3)(a))
-        self.assertFalse(length_is(4)(a))
-
-    def test_shorter_than(self):
-        a = 'ala'
-        self.assertTrue(shorter_than(4)(a))
-        self.assertFalse(shorter_than(3)(a))
-
-    def test_length_multiple_of(self):
-        a = 'ala '
-        self.assertTrue(length_multiple_of(4)(a))
-        self.assertFalse(length_multiple_of(3)(a))
-
-    def test_one_of(self):
-        two_or_five = one_of(2, 5)
-        self.assertTrue(two_or_five(2))
-        self.assertTrue(two_or_five(5))
-        self.assertFalse(two_or_five(1))
-
-    def test_between(self):
-        between2_and_5 = between(2, 5)
-        self.assertTrue(between2_and_5(2))
-        self.assertTrue(between2_and_5(5))
-        self.assertTrue(between2_and_5(3))
-        self.assertFalse(between2_and_5(1))
-        self.assertFalse(between2_and_5(6))
-        between2_and_5 = between(2, 5, False, False)
-        self.assertFalse(between2_and_5(2))
-        self.assertFalse(between2_and_5(5))
-        self.assertTrue(between2_and_5(3))
-        self.assertFalse(between2_and_5(1))
-        self.assertFalse(between2_and_5(6))
+            def __init__(self, a=2):
+                self.attr = a
+
+        p = x.attr == 2
+        self.assertTrue(p(A()))
+        self.assertFalse(p(A(3)))
+
+    def test_getitem(self):
+        p = x[0] == 1
+        self.assertTrue(p([1, 2]))
+        self.assertFalse(p([2, 2]))
+
+    def test_len(self):
+        p = x.length() == 2
+        self.assertTrue(p([1, 2]))
+        self.assertFalse(p([]))
+
+    def test_equals(self):
+        p = x == 2
+        self.assertTrue(p(2))
+        self.assertFalse(p(3))
+        p = x > 2
+        self.assertTrue(p(3))
+        self.assertFalse(p(2))
+        p = x < 2
+        self.assertTrue(p(1))
+        self.assertFalse(p(2))
+        p = x >= 2
+        self.assertTrue(p(2))
+        p = x <= 2
+        self.assertTrue(p(1))
+        self.assertTrue(p(2))
+        self.assertFalse(p(3))
+        p = x != 2
+        self.assertTrue(p(1))
+        self.assertFalse(p(2))
-- 
GitLab