import functools from typing import Dict, List import pytest from plotly.graph_objects import Figure import reflex as rx from reflex.base import Base from reflex.constants import IS_HYDRATED, RouteVar from reflex.event import Event, EventHandler from reflex.state import State from reflex.utils import format from reflex.vars import BaseVar, ComputedVar class Object(Base): """A test object fixture.""" prop1: int = 42 prop2: str = "hello" class TestState(State): """A test state.""" # Set this class as not test one __test__ = False num1: int num2: float = 3.14 key: str map_key: str = "a" array: List[float] = [1, 2, 3.14] mapping: Dict[str, List[int]] = {"a": [1, 2, 3], "b": [4, 5, 6]} obj: Object = Object() complex: Dict[int, Object] = {1: Object(), 2: Object()} fig: Figure = Figure() @ComputedVar def sum(self) -> float: """Dynamically sum the numbers. Returns: The sum of the numbers. """ return self.num1 + self.num2 @ComputedVar def upper(self) -> str: """Uppercase the key. Returns: The uppercased key. """ return self.key.upper() def do_something(self): """Do something.""" pass class ChildState(TestState): """A child state fixture.""" value: str count: int = 23 def change_both(self, value: str, count: int): """Change both the value and count. Args: value: The new value. count: The new count. """ self.value = value.upper() self.count = count * 2 class ChildState2(TestState): """A child state fixture.""" value: str class GrandchildState(ChildState): """A grandchild state fixture.""" value2: str def do_nothing(self): """Do something.""" pass @pytest.fixture def test_state() -> TestState: """A state. Returns: A test state. """ return TestState() # type: ignore @pytest.fixture def child_state(test_state) -> ChildState: """A child state. Args: test_state: A test state. Returns: A test child state. """ child_state = test_state.get_substate(["child_state"]) assert child_state is not None return child_state @pytest.fixture def child_state2(test_state) -> ChildState2: """A second child state. Args: test_state: A test state. Returns: A second test child state. """ child_state2 = test_state.get_substate(["child_state2"]) assert child_state2 is not None return child_state2 @pytest.fixture def grandchild_state(child_state) -> GrandchildState: """A state. Args: child_state: A child state. Returns: A test state. """ grandchild_state = child_state.get_substate(["grandchild_state"]) assert grandchild_state is not None return grandchild_state def test_base_class_vars(test_state): """Test that the class vars are set correctly. Args: test_state: A state. """ fields = test_state.get_fields() cls = type(test_state) for field in fields: if field in test_state.get_skip_vars(): continue prop = getattr(cls, field) assert isinstance(prop, BaseVar) assert prop.name == field assert cls.num1.type_ == int assert cls.num2.type_ == float assert cls.key.type_ == str def test_computed_class_var(test_state): """Test that the class computed vars are set correctly. Args: test_state: A state. """ cls = type(test_state) vars = [(prop.name, prop.type_) for prop in cls.computed_vars.values()] assert ("sum", float) in vars assert ("upper", str) in vars def test_class_vars(test_state): """Test that the class vars are set correctly. Args: test_state: A state. """ cls = type(test_state) assert set(cls.vars.keys()) == { IS_HYDRATED, # added by hydrate_middleware to all State "num1", "num2", "key", "map_key", "array", "mapping", "obj", "complex", "sum", "upper", "fig", } def test_event_handlers(test_state): """Test that event handler is set correctly. Args: test_state: A state. """ expected = { "do_something", "set_array", "set_complex", "set_fig", "set_key", "set_mapping", "set_num1", "set_num2", "set_obj", } cls = type(test_state) assert set(cls.event_handlers.keys()).intersection(expected) == expected def test_default_value(test_state): """Test that the default value of a var is correct. Args: test_state: A state. """ assert test_state.num1 == 0 assert test_state.num2 == 3.14 assert test_state.key == "" assert test_state.sum == 3.14 assert test_state.upper == "" def test_computed_vars(test_state): """Test that the computed var is computed correctly. Args: test_state: A state. """ test_state.num1 = 1 test_state.num2 = 4 assert test_state.sum == 5 test_state.key = "hello world" assert test_state.upper == "HELLO WORLD" def test_dict(test_state): """Test that the dict representation of a state is correct. Args: test_state: A state. """ substates = {"child_state", "child_state2"} assert set(test_state.dict().keys()) == set(test_state.vars.keys()) | substates assert ( set(test_state.dict(include_computed=False).keys()) == set(test_state.base_vars) | substates ) def test_default_setters(test_state): """Test that we can set default values. Args: test_state: A state. """ for prop_name in test_state.base_vars: # Each base var should have a default setter. assert hasattr(test_state, f"set_{prop_name}") def test_class_indexing_with_vars(): """Test that we can index into a state var with another var.""" prop = TestState.array[TestState.num1] assert str(prop) == "{test_state.array.at(test_state.num1)}" prop = TestState.mapping["a"][TestState.num1] assert str(prop) == '{test_state.mapping["a"].at(test_state.num1)}' prop = TestState.mapping[TestState.map_key] assert str(prop) == "{test_state.mapping[test_state.map_key]}" def test_class_attributes(): """Test that we can get class attributes.""" prop = TestState.obj.prop1 assert str(prop) == "{test_state.obj.prop1}" prop = TestState.complex[1].prop1 assert str(prop) == "{test_state.complex[1].prop1}" def test_get_parent_state(): """Test getting the parent state.""" assert TestState.get_parent_state() is None assert ChildState.get_parent_state() == TestState assert ChildState2.get_parent_state() == TestState assert GrandchildState.get_parent_state() == ChildState def test_get_substates(): """Test getting the substates.""" assert TestState.get_substates() == {ChildState, ChildState2} assert ChildState.get_substates() == {GrandchildState} assert ChildState2.get_substates() == set() assert GrandchildState.get_substates() == set() def test_get_name(): """Test getting the name of a state.""" assert TestState.get_name() == "test_state" assert ChildState.get_name() == "child_state" assert ChildState2.get_name() == "child_state2" assert GrandchildState.get_name() == "grandchild_state" def test_get_full_name(): """Test getting the full name.""" assert TestState.get_full_name() == "test_state" assert ChildState.get_full_name() == "test_state.child_state" assert ChildState2.get_full_name() == "test_state.child_state2" assert GrandchildState.get_full_name() == "test_state.child_state.grandchild_state" def test_get_class_substate(): """Test getting the substate of a class.""" assert TestState.get_class_substate(("child_state",)) == ChildState assert TestState.get_class_substate(("child_state2",)) == ChildState2 assert ChildState.get_class_substate(("grandchild_state",)) == GrandchildState assert ( TestState.get_class_substate(("child_state", "grandchild_state")) == GrandchildState ) with pytest.raises(ValueError): TestState.get_class_substate(("invalid_child",)) with pytest.raises(ValueError): TestState.get_class_substate( ( "child_state", "invalid_child", ) ) def test_get_class_var(): """Test getting the var of a class.""" assert TestState.get_class_var(("num1",)) == TestState.num1 assert TestState.get_class_var(("num2",)) == TestState.num2 assert ChildState.get_class_var(("value",)) == ChildState.value assert GrandchildState.get_class_var(("value2",)) == GrandchildState.value2 assert TestState.get_class_var(("child_state", "value")) == ChildState.value assert ( TestState.get_class_var(("child_state", "grandchild_state", "value2")) == GrandchildState.value2 ) assert ( ChildState.get_class_var(("grandchild_state", "value2")) == GrandchildState.value2 ) with pytest.raises(ValueError): TestState.get_class_var(("invalid_var",)) with pytest.raises(ValueError): TestState.get_class_var( ( "child_state", "invalid_var", ) ) def test_set_class_var(): """Test setting the var of a class.""" with pytest.raises(AttributeError): TestState.num3 # type: ignore TestState._set_var(BaseVar(name="num3", type_=int).set_state(TestState)) var = TestState.num3 # type: ignore assert var.name == "num3" assert var.type_ == int assert var.state == TestState.get_full_name() def test_set_parent_and_substates(test_state, child_state, grandchild_state): """Test setting the parent and substates. Args: test_state: A state. child_state: A child state. grandchild_state: A grandchild state. """ assert len(test_state.substates) == 2 assert set(test_state.substates) == {"child_state", "child_state2"} assert child_state.parent_state == test_state assert len(child_state.substates) == 1 assert set(child_state.substates) == {"grandchild_state"} assert grandchild_state.parent_state == child_state assert len(grandchild_state.substates) == 0 def test_get_child_attribute(test_state, child_state, child_state2, grandchild_state): """Test getting the attribute of a state. Args: test_state: A state. child_state: A child state. child_state2: A child state. grandchild_state: A grandchild state. """ assert test_state.num1 == 0 assert child_state.value == "" assert child_state2.value == "" assert child_state.count == 23 assert grandchild_state.value2 == "" with pytest.raises(AttributeError): test_state.invalid with pytest.raises(AttributeError): test_state.child_state.invalid with pytest.raises(AttributeError): test_state.child_state.grandchild_state.invalid def test_set_child_attribute(test_state, child_state, grandchild_state): """Test setting the attribute of a state. Args: test_state: A state. child_state: A child state. grandchild_state: A grandchild state. """ test_state.num1 = 10 assert test_state.num1 == 10 assert child_state.num1 == 10 assert grandchild_state.num1 == 10 grandchild_state.num1 = 5 assert test_state.num1 == 5 assert child_state.num1 == 5 assert grandchild_state.num1 == 5 child_state.value = "test" assert child_state.value == "test" assert grandchild_state.value == "test" grandchild_state.value = "test2" assert child_state.value == "test2" assert grandchild_state.value == "test2" grandchild_state.value2 = "test3" assert grandchild_state.value2 == "test3" def test_get_substate(test_state, child_state, child_state2, grandchild_state): """Test getting the substate of a state. Args: test_state: A state. child_state: A child state. child_state2: A child state. grandchild_state: A grandchild state. """ assert test_state.get_substate(("child_state",)) == child_state assert test_state.get_substate(("child_state2",)) == child_state2 assert ( test_state.get_substate(("child_state", "grandchild_state")) == grandchild_state ) assert child_state.get_substate(("grandchild_state",)) == grandchild_state with pytest.raises(ValueError): test_state.get_substate(("invalid",)) with pytest.raises(ValueError): test_state.get_substate(("child_state", "invalid")) with pytest.raises(ValueError): test_state.get_substate(("child_state", "grandchild_state", "invalid")) def test_set_dirty_var(test_state): """Test changing state vars marks the value as dirty. Args: test_state: A state. """ # Initially there should be no dirty vars. assert test_state.dirty_vars == set() # Setting a var should mark it as dirty. test_state.num1 = 1 assert test_state.dirty_vars == {"num1", "sum"} # Setting another var should mark it as dirty. test_state.num2 = 2 assert test_state.dirty_vars == {"num1", "num2", "sum"} # Cleaning the state should remove all dirty vars. test_state.clean() assert test_state.dirty_vars == set() def test_set_dirty_substate(test_state, child_state, child_state2, grandchild_state): """Test changing substate vars marks the value as dirty. Args: test_state: A state. child_state: A child state. child_state2: A child state. grandchild_state: A grandchild state. """ # Initially there should be no dirty vars. assert test_state.dirty_vars == set() assert child_state.dirty_vars == set() assert child_state2.dirty_vars == set() assert grandchild_state.dirty_vars == set() # Setting a var should mark it as dirty. child_state.value = "test" assert child_state.dirty_vars == {"value"} assert test_state.dirty_substates == {"child_state"} assert child_state.dirty_substates == set() # Cleaning the parent state should remove the dirty substate. test_state.clean() assert test_state.dirty_substates == set() assert child_state.dirty_vars == set() # Setting a var on the grandchild should bubble up. grandchild_state.value2 = "test2" assert child_state.dirty_substates == {"grandchild_state"} assert test_state.dirty_substates == {"child_state"} # Cleaning the middle state should keep the parent state dirty. child_state.clean() assert test_state.dirty_substates == {"child_state"} assert child_state.dirty_substates == set() assert grandchild_state.dirty_vars == set() def test_reset(test_state, child_state): """Test resetting the state. Args: test_state: A state. child_state: A child state. """ # Set some values. test_state.num1 = 1 test_state.num2 = 2 child_state.value = "test" # Reset the state. test_state.reset() # The values should be reset. assert test_state.num1 == 0 assert test_state.num2 == 3.14 assert child_state.value == "" # The dirty vars should be reset. assert test_state.dirty_vars == set() assert child_state.dirty_vars == set() # The dirty substates should be reset. assert test_state.dirty_substates == set() @pytest.mark.asyncio async def test_process_event_simple(test_state): """Test processing an event. Args: test_state: A state. """ assert test_state.num1 == 0 event = Event(token="t", name="set_num1", payload={"value": 69}) update = await test_state._process(event).__anext__() # The event should update the value. assert test_state.num1 == 69 # The delta should contain the changes, including computed vars. # assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}} assert update.delta == {"test_state": {"num1": 69, "sum": 72.14, "upper": ""}} assert update.events == [] @pytest.mark.asyncio async def test_process_event_substate(test_state, child_state, grandchild_state): """Test processing an event on a substate. Args: test_state: A state. child_state: A child state. grandchild_state: A grandchild state. """ # Events should bubble down to the substate. assert child_state.value == "" assert child_state.count == 23 event = Event( token="t", name="child_state.change_both", payload={"value": "hi", "count": 12} ) update = await test_state._process(event).__anext__() assert child_state.value == "HI" assert child_state.count == 24 assert update.delta == { "test_state": {"sum": 3.14, "upper": ""}, "test_state.child_state": {"value": "HI", "count": 24}, } test_state.clean() # Test with the granchild state. assert grandchild_state.value2 == "" event = Event( token="t", name="child_state.grandchild_state.set_value2", payload={"value": "new"}, ) update = await test_state._process(event).__anext__() assert grandchild_state.value2 == "new" assert update.delta == { "test_state": {"sum": 3.14, "upper": ""}, "test_state.child_state.grandchild_state": {"value2": "new"}, } @pytest.mark.asyncio async def test_process_event_generator(gen_state): """Test event handlers that generate multiple updates. Args: gen_state: A state. """ gen_state = gen_state() event = Event( token="t", name="go", payload={"c": 5}, ) gen = gen_state._process(event) count = 0 async for update in gen: count += 1 if count == 6: assert update.delta == {} assert update.final else: assert gen_state.value == count assert update.delta == { "gen_state": {"value": count}, } assert not update.final assert count == 6 def test_format_event_handler(): """Test formatting an event handler.""" assert ( format.format_event_handler(TestState.do_something) == "test_state.do_something" # type: ignore ) assert ( format.format_event_handler(ChildState.change_both) # type: ignore == "test_state.child_state.change_both" ) assert ( format.format_event_handler(GrandchildState.do_nothing) # type: ignore == "test_state.child_state.grandchild_state.do_nothing" ) def test_get_token(test_state, mocker, router_data): """Test that the token obtained from the router_data is correct. Args: test_state: The test state. mocker: Pytest Mocker object. router_data: The router data fixture. """ mocker.patch.object(test_state, "router_data", router_data) assert test_state.get_token() == "b181904c-3953-4a79-dc18-ae9518c22f05" def test_get_sid(test_state, mocker, router_data): """Test getting session id. Args: test_state: A state. mocker: Pytest Mocker object. router_data: The router data fixture. """ mocker.patch.object(test_state, "router_data", router_data) assert test_state.get_sid() == "9fpxSzPb9aFMb4wFAAAH" def test_get_headers(test_state, mocker, router_data, router_data_headers): """Test getting client headers. Args: test_state: A state. mocker: Pytest Mocker object. router_data: The router data fixture. router_data_headers: The expected headers. """ mocker.patch.object(test_state, "router_data", router_data) assert test_state.get_headers() == router_data_headers def test_get_client_ip(test_state, mocker, router_data): """Test getting client IP. Args: test_state: A state. mocker: Pytest Mocker object. router_data: The router data fixture. """ mocker.patch.object(test_state, "router_data", router_data) assert test_state.get_client_ip() == "127.0.0.1" def test_get_cookies(test_state, mocker, router_data): """Test getting client cookies. Args: test_state: A state. mocker: Pytest Mocker object. router_data: The router data fixture. """ mocker.patch.object(test_state, "router_data", router_data) assert test_state.get_cookies() == { "csrftoken": "mocktoken", "name": "reflex", "list_cookies": ["some", "random", "cookies"], "dict_cookies": {"name": "reflex"}, "val": True, } def test_get_current_page(test_state): assert test_state.get_current_page() == "" route = "mypage/subpage" test_state.router_data = {RouteVar.PATH: route} assert test_state.get_current_page() == route def test_get_query_params(test_state): assert test_state.get_query_params() == {} params = {"p1": "a", "p2": "b"} test_state.router_data = {RouteVar.QUERY: params} assert test_state.get_query_params() == params def test_add_var(test_state): test_state.add_var("dynamic_int", int, 42) assert test_state.dynamic_int == 42 test_state.add_var("dynamic_list", List[int], [5, 10]) assert test_state.dynamic_list == [5, 10] assert test_state.dynamic_list == [5, 10] # how to test that one? # test_state.dynamic_list.append(15) # assert test_state.dynamic_list == [5, 10, 15] test_state.add_var("dynamic_dict", Dict[str, int], {"k1": 5, "k2": 10}) assert test_state.dynamic_dict == {"k1": 5, "k2": 10} assert test_state.dynamic_dict == {"k1": 5, "k2": 10} def test_add_var_default_handlers(test_state): test_state.add_var("rand_int", int, 10) assert "set_rand_int" in test_state.event_handlers assert isinstance(test_state.event_handlers["set_rand_int"], EventHandler) class InterdependentState(State): """A state with 3 vars and 3 computed vars. x: a variable that no computed var depends on v1: a varable that one computed var directly depeneds on _v2: a backend variable that one computed var directly depends on v1x2: a computed var that depends on v1 v2x2: a computed var that depends on backend var _v2 v1x2x2: a computed var that depends on computed var v1x2 """ x: int = 0 v1: int = 0 _v2: int = 1 @rx.cached_var def v1x2(self) -> int: """Depends on var v1. Returns: Var v1 multiplied by 2 """ return self.v1 * 2 @rx.cached_var def v2x2(self) -> int: """Depends on backend var _v2. Returns: backend var _v2 multiplied by 2 """ return self._v2 * 2 @rx.cached_var def v1x2x2(self) -> int: """Depends on ComputedVar v1x2. Returns: ComputedVar v1x2 multiplied by 2 """ return self.v1x2 * 2 @pytest.fixture def interdependent_state() -> State: """A state with varying dependency between vars. Returns: instance of InterdependentState """ s = InterdependentState() s.dict() # prime initial relationships by accessing all ComputedVars return s def test_not_dirty_computed_var_from_var(interdependent_state): """Set Var that no ComputedVar depends on, expect no recalculation. Args: interdependent_state: A state with varying Var dependencies. """ interdependent_state.x = 5 assert interdependent_state.get_delta() == { interdependent_state.get_full_name(): {"x": 5}, } def test_dirty_computed_var_from_var(interdependent_state): """Set Var that ComputedVar depends on, expect recalculation. The other ComputedVar depends on the changed ComputedVar and should also be recalculated. No other ComputedVars should be recalculated. Args: interdependent_state: A state with varying Var dependencies. """ interdependent_state.v1 = 1 assert interdependent_state.get_delta() == { interdependent_state.get_full_name(): {"v1": 1, "v1x2": 2, "v1x2x2": 4}, } def test_dirty_computed_var_from_backend_var(interdependent_state): """Set backend var that ComputedVar depends on, expect recalculation. Args: interdependent_state: A state with varying Var dependencies. """ interdependent_state._v2 = 2 assert interdependent_state.get_delta() == { interdependent_state.get_full_name(): {"v2x2": 4}, } def test_per_state_backend_var(interdependent_state): """Set backend var on one instance, expect no affect in other instances. Args: interdependent_state: A state with varying Var dependencies. """ s2 = InterdependentState() assert s2._v2 == interdependent_state._v2 interdependent_state._v2 = 2 assert s2._v2 != interdependent_state._v2 s3 = InterdependentState() assert s3._v2 != interdependent_state._v2 # both s2 and s3 should still have the default value assert s2._v2 == s3._v2 # changing s2._v2 should not affect others s2._v2 = 4 assert s2._v2 != interdependent_state._v2 assert s2._v2 != s3._v2 def test_child_state(): """Test that the child state computed vars can reference parent state vars.""" class MainState(State): v: int = 2 class ChildState(MainState): @ComputedVar def rendered_var(self): return self.v ms = MainState() cs = ms.substates[ChildState.get_name()] assert ms.v == 2 assert cs.v == 2 assert cs.rendered_var == 2 def test_conditional_computed_vars(): """Test that computed vars can have conditionals.""" class MainState(State): flag: bool = False t1: str = "a" t2: str = "b" @ComputedVar def rendered_var(self) -> str: if self.flag: return self.t1 return self.t2 ms = MainState() # Initially there are no dirty computed vars. assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"} assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"} assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"} assert ms.computed_vars["rendered_var"].deps(objclass=MainState) == { "flag", "t1", "t2", } def test_event_handlers_convert_to_fns(test_state, child_state): """Test that when the state is initialized, event handlers are converted to fns. Args: test_state: A state with event handlers. child_state: A child state with event handlers. """ # The class instances should be event handlers. assert isinstance(TestState.do_something, EventHandler) assert isinstance(ChildState.change_both, EventHandler) # The object instances should be fns. test_state.do_something() child_state.change_both(value="goose", count=9) assert child_state.value == "GOOSE" assert child_state.count == 18 def test_event_handlers_call_other_handlers(): """Test that event handlers can call other event handlers.""" class MainState(State): v: int = 0 def set_v(self, v: int): self.v = v def set_v2(self, v: int): self.set_v(v) ms = MainState() ms.set_v2(1) assert ms.v == 1 def test_computed_var_cached(): """Test that a ComputedVar doesn't recalculate when accessed.""" comp_v_calls = 0 class ComputedState(State): v: int = 0 @rx.cached_var def comp_v(self) -> int: nonlocal comp_v_calls comp_v_calls += 1 return self.v cs = ComputedState() assert cs.dict()["v"] == 0 assert comp_v_calls == 1 assert cs.dict()["comp_v"] == 0 assert comp_v_calls == 1 assert cs.comp_v == 0 assert comp_v_calls == 1 cs.v = 1 assert comp_v_calls == 1 assert cs.comp_v == 1 assert comp_v_calls == 2 def test_computed_var_cached_depends_on_non_cached(): """Test that a cached_var is recalculated if it depends on non-cached ComputedVar.""" class ComputedState(State): v: int = 0 @rx.var def no_cache_v(self) -> int: return self.v @rx.cached_var def dep_v(self) -> int: return self.no_cache_v @rx.cached_var def comp_v(self) -> int: return self.v cs = ComputedState() assert cs.dirty_vars == set() assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}} cs.clean() assert cs.dirty_vars == set() assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}} cs.clean() assert cs.dirty_vars == set() cs.v = 1 assert cs.dirty_vars == {"v", "comp_v", "dep_v", "no_cache_v"} assert cs.get_delta() == { cs.get_name(): {"v": 1, "no_cache_v": 1, "dep_v": 1, "comp_v": 1} } cs.clean() assert cs.dirty_vars == set() assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}} cs.clean() assert cs.dirty_vars == set() assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}} cs.clean() assert cs.dirty_vars == set() def test_computed_var_depends_on_parent_non_cached(): """Child state cached_var that depends on parent state un cached var is always recalculated.""" counter = 0 class ParentState(State): @rx.var def no_cache_v(self) -> int: nonlocal counter counter += 1 return counter class ChildState(ParentState): @rx.cached_var def dep_v(self) -> int: return self.no_cache_v ps = ParentState() cs = ps.substates[ChildState.get_name()] assert ps.dirty_vars == set() assert cs.dirty_vars == set() assert ps.dict() == { cs.get_name(): {"dep_v": 2}, "no_cache_v": 1, IS_HYDRATED: False, } assert ps.dict() == { cs.get_name(): {"dep_v": 4}, "no_cache_v": 3, IS_HYDRATED: False, } assert ps.dict() == { cs.get_name(): {"dep_v": 6}, "no_cache_v": 5, IS_HYDRATED: False, } assert counter == 6 @pytest.mark.parametrize("use_partial", [True, False]) def test_cached_var_depends_on_event_handler(use_partial: bool): """A cached_var that calls an event handler calculates deps correctly. Args: use_partial: if true, replace the EventHandler with functools.partial """ counter = 0 class HandlerState(State): x: int = 42 def handler(self): self.x = self.x + 1 @rx.cached_var def cached_x_side_effect(self) -> int: self.handler() nonlocal counter counter += 1 return counter if use_partial: HandlerState.handler = functools.partial(HandlerState.handler.fn) assert isinstance(HandlerState.handler, functools.partial) else: assert isinstance(HandlerState.handler, EventHandler) s = HandlerState() assert "cached_x_side_effect" in s.computed_var_dependencies["x"] assert s.cached_x_side_effect == 1 assert s.x == 43 s.handler() assert s.cached_x_side_effect == 2 assert s.x == 45 def test_backend_method(): """A method with leading underscore should be callable from event handler.""" class BackendMethodState(State): def _be_method(self): return True def handler(self): assert self._be_method() bms = BackendMethodState() bms.handler() assert bms._be_method()