From 47c9aff2314e2dcac8f293b45e1ca155af64be33 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Piotr=20Ma=C5=9Blanka?= <piotr.maslanka@henrietta.com.pl>
Date: Tue, 5 Mar 2024 08:17:15 +0100
Subject: [PATCH] added a way to register an object to be cleaned up via
 MemoryPressureManager

---
 CHANGELOG.md                                |  1 +
 docs/instrumentation/memory.rst             |  4 +++
 satella/__init__.py                         |  2 +-
 satella/instrumentation/memory/memthread.py | 27 ++++++++++++++++++++-
 tests/test_instrumentation/test_memory.py   | 16 ++++++++++++
 5 files changed, 48 insertions(+), 2 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 81002101..c3a0f665 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,4 +2,5 @@
 
 * added DeferredValue
 * satella.coding.Context are considered experimental
+* added a way to register an object to be cleaned up via MemoryPressureManager
 
diff --git a/docs/instrumentation/memory.rst b/docs/instrumentation/memory.rst
index 2674a8ce..589ab711 100644
--- a/docs/instrumentation/memory.rst
+++ b/docs/instrumentation/memory.rst
@@ -54,6 +54,10 @@ that level 1 is still in effect. You can register your handlers here:
 .. autoclass:: satella.instrumentation.memory.MemoryPressureManager
     :members:
 
+Note that you can also register objects to have their methods called on entering a memory
+severity level, if these objects have a way to to for example drop some data onto disk and
+decrease memory usage via :meth:`~satella.instrumentation.memory.MemoryPressureManager.cleanup_on_entered`.
+
 install_force_gc_collect
 ------------------------
 
diff --git a/satella/__init__.py b/satella/__init__.py
index 6aaf86b1..d7226f3a 100644
--- a/satella/__init__.py
+++ b/satella/__init__.py
@@ -1 +1 @@
-__version__ = '2.24.2a3'
+__version__ = '2.24.2a4'
diff --git a/satella/instrumentation/memory/memthread.py b/satella/instrumentation/memory/memthread.py
index 247ae15f..10bfb4ef 100644
--- a/satella/instrumentation/memory/memthread.py
+++ b/satella/instrumentation/memory/memthread.py
@@ -1,6 +1,8 @@
+from __future__ import annotations
 import logging
 import os
 import typing as tp
+import weakref
 
 import psutil
 
@@ -38,7 +40,8 @@ class MemoryPressureManager(IntervalTerminableThread):
 
     :param maximum_available: maximum amount of memory that this program can use
     :param severity_levels: this defines the levels of severity. A level is reached when program's
-        consumption is other this many percent of it's maximum_available amount of memory.
+        consumption is other this many percent of it's maximum_available amount of memory. Note that you need to
+        specify only the abnormal memory levels, the default level of 0 will be added automatically.
     :param check_interval: amount of seconds of pause between consecutive checks, or
         a time string
     :param log_transitions: whether to log to logger when a transition takes place
@@ -67,11 +70,26 @@ class MemoryPressureManager(IntervalTerminableThread):
         self.callbacks_on_left = [CallableGroup(gather=False) for _ in
                                   range(len(
                                       self.severity_levels))]  # type: tp.List[CallableGroup]
+        self.objects_to_cleanup_on_entered = [[] for _ in range(len(self.severity_levels))]
         self.callbacks_on_memory_normal = CallableGroup(gather=False)
         self.severity_level = 0  # type: int
         self.stopped = False  # type: bool
         self.start()
 
+    @staticmethod
+    def cleanup_on_entered(target_level: int, obj: tp.Any,
+                           collector: tp.Callable[[tp.Any], None] = lambda y: y.cleanup()):
+        """
+        Attempt to recover memory by calling a particular method on an object.
+
+        A weak reference will be stored to this object
+
+        :param target_level: cleanup will be attempted on entering this severity level
+        :param obj: object to call this on
+        :param collector: a lambda to call a routine on this object
+        """
+        MemoryPressureManager().objects_to_cleanup_on_entered[target_level].append((weakref.ref(obj), collector))
+
     def advance_to_severity_level(self, target_level: int):
         while self.severity_level != target_level:
             delta = target_level - self.severity_level
@@ -81,6 +99,13 @@ class MemoryPressureManager(IntervalTerminableThread):
                 # Means we are ENTERING a severity level
                 self.severity_level += delta
                 self.callbacks_on_entered[self.severity_level]()
+                new_list = []
+                for ref, collector in self.objects_to_cleanup_on_entered[self.severity_level]:
+                    obj = ref()
+                    if obj is not None:
+                        collector(obj)
+                        new_list.append((ref, collector))
+                self.objects_to_cleanup_on_entered[self.severity_level] = new_list
                 if self.log_transitions:
                     logger.warning('Entered severity level %s' % (self.severity_level,))
             elif delta < 0:
diff --git a/tests/test_instrumentation/test_memory.py b/tests/test_instrumentation/test_memory.py
index 7fa6daeb..f07b8a99 100644
--- a/tests/test_instrumentation/test_memory.py
+++ b/tests/test_instrumentation/test_memory.py
@@ -53,9 +53,21 @@ class TestMemory(unittest.TestCase):
              'cancelled': 0,
              'mem_normal': 0}
 
+        class ObjectToCleanup:
+            def __init__(self):
+                self.cleaned_up = False
+
+            def cleanup(self):
+                self.cleaned_up = True
+
+        obj1 = ObjectToCleanup()
+        obj2 = ObjectToCleanup()
+
         cc = CustomCondition(lambda: a['level_2_engaged'])
 
         MemoryPressureManager(None, [odc, All(cc, Any(cc, cc))], 2)
+        MemoryPressureManager.cleanup_on_entered(1, obj1)
+        MemoryPressureManager.cleanup_on_entered(2, obj2)
 
         def memory_normal():
             nonlocal a
@@ -96,7 +108,9 @@ class TestMemory(unittest.TestCase):
         self.assertTrue(a['memory'])
         self.assertFalse(a['improved'])
         self.assertGreater(a['calls'], 0)
+        self.assertTrue(obj1.cleaned_up)
         self.assertEqual(a['times_entered_1'], 1)
+        del obj1
         odc.value = False
         time.sleep(3)
         self.assertTrue(a['improved'])
@@ -105,8 +119,10 @@ class TestMemory(unittest.TestCase):
         self.assertEqual(a['mem_normal'], 1)
         a['level_2_engaged'] = True
         time.sleep(3)
+        self.assertEqual(MemoryPressureManager().objects_to_cleanup_on_entered[1], [])
         self.assertEqual(MemoryPressureManager().severity_level, 2)
         self.assertEqual(a['cancelled'], 1)
         self.assertEqual(a['times_entered_1'], 2)
         self.assertTrue(a['level_2_confirmed'])
         self.assertEqual(a['mem_normal'], 1)
+        self.assertTrue(obj2.cleaned_up)
-- 
GitLab