1
0
Эх сурвалжийг харах

Introduce a `_propagation_context` to keep track of visited properties (#4628)

This PR tries to solve #4626 with a new context manager
`_propagation_context` that keeps track of visited properties using a
ContextVar `propagation_visited`. The new implementation better
encapsulates the logic around the `visited` set. This way it simplifies
other parts of the code as they don't have to deal with this set
anymore.

It also fixes the incorrect check
```py
if source_obj_id in visited:
    return
```

@evnchn, @sfeltman Would you like to have a look?
Falko Schindler 1 сар өмнө
parent
commit
3df8b2b16e
1 өөрчлөгдсөн 43 нэмэгдсэн , 26 устгасан
  1. 43 26
      nicegui/binding.py

+ 43 - 26
nicegui/binding.py

@@ -6,6 +6,8 @@ import dataclasses
 import time
 import weakref
 from collections import defaultdict
+from contextlib import contextmanager
+from contextvars import ContextVar
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -32,6 +34,8 @@ if TYPE_CHECKING:
 
 MAX_PROPAGATION_TIME = 0.01
 
+propagation_visited: ContextVar[Optional[Set[Tuple[int, str]]]] = ContextVar('propagation_visited', default=None)
+
 bindings: DefaultDict[Tuple[int, str], List] = defaultdict(list)
 bindable_properties: weakref.WeakValueDictionary[Tuple[int, str], Any] = weakref.WeakValueDictionary()
 active_links: List[Tuple[Any, str, Any, str, Callable[[Any], Any]]] = []
@@ -66,41 +70,54 @@ async def refresh_loop() -> None:
         await asyncio.sleep(core.app.config.binding_refresh_interval)
 
 
+@contextmanager
+def _propagation_context():
+    visited = propagation_visited.get()
+    is_root_call = visited is None
+    if is_root_call:
+        visited = set()
+        token = propagation_visited.set(visited)
+    try:
+        yield visited
+    finally:
+        if is_root_call:
+            propagation_visited.reset(token)
+
+
 def _refresh_step() -> None:
-    visited: Set[Tuple[int, str]] = set()
     t = time.time()
-    for link in active_links:
-        (source_obj, source_name, target_obj, target_name, transform) = link
-        if _has_attribute(source_obj, source_name):
-            value = transform(_get_attribute(source_obj, source_name))
-            if not _has_attribute(target_obj, target_name) or _get_attribute(target_obj, target_name) != value:
-                _set_attribute(target_obj, target_name, value)
-                _propagate(target_obj, target_name, visited)
-        del link, source_obj, target_obj  # pylint: disable=modified-iterating-list
+    with _propagation_context():
+        for link in active_links:
+            (source_obj, source_name, target_obj, target_name, transform) = link
+            if _has_attribute(source_obj, source_name):
+                value = transform(_get_attribute(source_obj, source_name))
+                if not _has_attribute(target_obj, target_name) or _get_attribute(target_obj, target_name) != value:
+                    _set_attribute(target_obj, target_name, value)
+                    _propagate(target_obj, target_name)
+            del link, source_obj, target_obj  # pylint: disable=modified-iterating-list
     if time.time() - t > MAX_PROPAGATION_TIME:
         log.warning(f'binding propagation for {len(active_links)} active links took {time.time() - t:.3f} s')
 
 
-def _propagate(source_obj: Any, source_name: str, visited: Optional[Set[Tuple[int, str]]] = None) -> None:
-    if visited is None:
-        visited = set()
-    source_obj_id = id(source_obj)
-    if source_obj_id in visited:
-        return
-    visited.add((source_obj_id, source_name))
+def _propagate(source_obj: Any, source_name: str) -> None:
+    with _propagation_context() as visited:
+        source_obj_id = id(source_obj)
+        if (source_obj_id, source_name) in visited:
+            return
+        visited.add((source_obj_id, source_name))
 
-    if not _has_attribute(source_obj, source_name):
-        return
-    source_value = _get_attribute(source_obj, source_name)
+        if not _has_attribute(source_obj, source_name):
+            return
+        source_value = _get_attribute(source_obj, source_name)
 
-    for _, target_obj, target_name, transform in bindings.get((source_obj_id, source_name), []):
-        if (id(target_obj), target_name) in visited:
-            continue
+        for _, target_obj, target_name, transform in bindings.get((source_obj_id, source_name), []):
+            if (id(target_obj), target_name) in visited:
+                continue
 
-        target_value = transform(source_value)
-        if not _has_attribute(target_obj, target_name) or _get_attribute(target_obj, target_name) != target_value:
-            _set_attribute(target_obj, target_name, target_value)
-            _propagate(target_obj, target_name, visited)
+            target_value = transform(source_value)
+            if not _has_attribute(target_obj, target_name) or _get_attribute(target_obj, target_name) != target_value:
+                _set_attribute(target_obj, target_name, target_value)
+                _propagate(target_obj, target_name)
 
 
 def bind_to(self_obj: Any, self_name: str, other_obj: Any, other_name: str, forward: Callable[[Any], Any]) -> None: