mock_state.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # Copyright 2021-2024 Avaiga Private Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  4. # the License. You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  9. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  10. # specific language governing permissions and limitations under the License.
  11. import typing as t
  12. from .. import Gui, State
  13. from ..utils import _MapDict
  14. class MockState(State):
  15. """A Mock implementation for `State`.
  16. TODO
  17. example of use:
  18. ```py
  19. def test_callback():
  20. ms = MockState(Gui(""), a = 1)
  21. on_action(ms) # function to test
  22. assert ms.a == 2
  23. ```
  24. """
  25. __VARS = "vars"
  26. def __init__(self, gui: Gui, **kwargs) -> None:
  27. super().__setattr__(MockState.__VARS, {k: _MapDict(v) if isinstance(v, dict) else v for k, v in kwargs.items()})
  28. self._gui = gui
  29. super().__init__()
  30. def get_gui(self) -> Gui:
  31. return self._gui
  32. def __getattribute__(self, name: str) -> t.Any:
  33. if (attr := t.cast(dict, super().__getattribute__(MockState.__VARS)).get(name, None)) is not None:
  34. return attr
  35. try:
  36. return super().__getattribute__(name)
  37. except Exception:
  38. return None
  39. def __setattr__(self, name: str, value: t.Any) -> None:
  40. t.cast(dict, super().__getattribute__(MockState.__VARS))[name] = (
  41. _MapDict(value) if isinstance(value, dict) else value
  42. )
  43. def __getitem__(self, key: str):
  44. return self
  45. def __enter__(self):
  46. return self
  47. def __exit__(self, exc_type, exc_value, traceback):
  48. return True
  49. def broadcast(self, name: str, value: t.Any):
  50. pass