Quellcode durchsuchen

Implement __copy__ for observable collections (fixes #3023) (#3046)

* implement __copy__ for observable collections (fixes #3023)

* distinguish between deep and shallow copies

* fix deepcopy of observable dictionaries
Falko Schindler vor 1 Jahr
Ursprung
Commit
28253947f5
3 geänderte Dateien mit 52 neuen und 0 gelöschten Zeilen
  1. 24 0
      nicegui/observables.py
  2. 13 0
      tests/test_observables.py
  3. 15 0
      tests/test_storage.py

+ 24 - 0
nicegui/observables.py

@@ -2,8 +2,11 @@ from __future__ import annotations
 
 import abc
 import time
+from copy import deepcopy
 from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Set, SupportsIndex, Union
 
+from typing_extensions import Self
+
 from . import events
 
 
@@ -38,6 +41,9 @@ class ObservableCollection(abc.ABC):  # noqa: B024
         self._change_handlers.append(handler)
 
     def _observe(self, data: Any) -> Any:
+        if isinstance(data, ObservableCollection):
+            data.on_change(self._handle_change)
+            return data
         if isinstance(data, dict):
             return ObservableDict(data, _parent=self)
         if isinstance(data, list):
@@ -46,6 +52,24 @@ class ObservableCollection(abc.ABC):  # noqa: B024
             return ObservableSet(data, _parent=self)
         return data
 
+    def __copy__(self) -> Self:
+        if isinstance(self, dict):
+            return ObservableDict(self, _parent=self._parent)
+        if isinstance(self, list):
+            return ObservableList(self, _parent=self._parent)
+        if isinstance(self, set):
+            return ObservableSet(self, _parent=self._parent)
+        raise NotImplementedError(f'ObservableCollection.__copy__ not implemented for {type(self)}')
+
+    def __deepcopy__(self, memo: Dict) -> Self:
+        if isinstance(self, dict):
+            return ObservableDict({key: deepcopy(self[key]) for key in self}, _parent=self._parent)
+        if isinstance(self, list):
+            return ObservableList([deepcopy(item) for item in self], _parent=self._parent)
+        if isinstance(self, set):
+            return ObservableSet({deepcopy(item) for item in self}, _parent=self._parent)
+        raise NotImplementedError(f'ObservableCollection.__deepcopy__ not implemented for {type(self)}')
+
 
 class ObservableDict(ObservableCollection, dict):
 

+ 13 - 0
tests/test_observables.py

@@ -1,4 +1,5 @@
 import asyncio
+import copy
 import sys
 
 from nicegui import ui
@@ -153,3 +154,15 @@ def test_setting_change_handler():
     data.on_change(increment_counter)
     data.append(2)
     assert count == 1
+
+
+def test_copy():
+    a = ObservableList([[1, 2, 3], [4, 5, 6]])
+    b = copy.copy(a)
+    c = copy.deepcopy(a)
+    a.append([7, 8, 9])
+    a[0][0] = 0
+
+    assert a == [[0, 2, 3], [4, 5, 6], [7, 8, 9]]
+    assert b == [[0, 2, 3], [4, 5, 6]]
+    assert c == [[1, 2, 3], [4, 5, 6]]

+ 15 - 0
tests/test_storage.py

@@ -1,4 +1,5 @@
 import asyncio
+import copy
 from pathlib import Path
 
 import httpx
@@ -263,3 +264,17 @@ def test_clear_client_storage(screen: Screen):
         assert app.storage.client == {}
 
     screen.open('/')
+
+
+def test_deepcopy(screen: Screen):
+    # https://github.com/zauberzeug/nicegui/issues/3023
+    @ui.page('/')
+    def page():
+        app.storage.general['a'] = {'b': 0}
+        copy.deepcopy(app.storage.general['a'])
+        ui.label('Loaded')
+
+    screen.open('/')
+    screen.should_contain('Loaded')
+    screen.wait(0.5)
+    assert Path('.nicegui', 'storage-general.json').read_text('utf-8') == '{"a":{"b":0}}'