Просмотр исходного кода

rx.event(background=True) (#4263)

* event background True

* fix typo

* fix overloads

* forgor

* remove extra parens

* more forgor
Khaleel Al-Adhami 6 месяцев назад
Родитель
Сommit
4260a0cfc3

+ 1 - 1
reflex/app.py

@@ -1389,7 +1389,7 @@ def upload(app: App):
         if isinstance(func, EventHandler):
         if isinstance(func, EventHandler):
             if func.is_background:
             if func.is_background:
                 raise UploadTypeError(
                 raise UploadTypeError(
-                    f"@rx.background is not supported for upload handler `{handler}`.",
+                    f"@rx.event(background=True) is not supported for upload handler `{handler}`.",
                 )
                 )
             func = func.fn
             func = func.fn
         if isinstance(func, functools.partial):
         if isinstance(func, functools.partial):

+ 96 - 22
reflex/event.py

@@ -83,7 +83,7 @@ class Event:
 BACKGROUND_TASK_MARKER = "_reflex_background_task"
 BACKGROUND_TASK_MARKER = "_reflex_background_task"
 
 
 
 
-def background(fn):
+def background(fn, *, __internal_reflex_call: bool = False):
     """Decorator to mark event handler as running in the background.
     """Decorator to mark event handler as running in the background.
 
 
     Args:
     Args:
@@ -96,6 +96,13 @@ def background(fn):
     Raises:
     Raises:
         TypeError: If the function is not a coroutine function or async generator.
         TypeError: If the function is not a coroutine function or async generator.
     """
     """
+    if not __internal_reflex_call:
+        console.deprecate(
+            "background-decorator",
+            "Use `rx.event(background=True)` instead.",
+            "0.6.5",
+            "0.7.0",
+        )
     if not inspect.iscoroutinefunction(fn) and not inspect.isasyncgenfunction(fn):
     if not inspect.iscoroutinefunction(fn) and not inspect.isasyncgenfunction(fn):
         raise TypeError("Background task must be async function or generator.")
         raise TypeError("Background task must be async function or generator.")
     setattr(fn, BACKGROUND_TASK_MARKER, True)
     setattr(fn, BACKGROUND_TASK_MARKER, True)
@@ -1457,6 +1464,8 @@ V3 = TypeVar("V3")
 V4 = TypeVar("V4")
 V4 = TypeVar("V4")
 V5 = TypeVar("V5")
 V5 = TypeVar("V5")
 
 
+background_event_decorator = background
+
 if sys.version_info >= (3, 10):
 if sys.version_info >= (3, 10):
     from typing import Concatenate
     from typing import Concatenate
 
 
@@ -1557,32 +1566,12 @@ if sys.version_info >= (3, 10):
 
 
             return partial(self.func, instance)  # type: ignore
             return partial(self.func, instance)  # type: ignore
 
 
-    def event_handler(func: Callable[Concatenate[Any, P], T]) -> EventCallback[P, T]:
-        """Wrap a function to be used as an event.
 
 
-        Args:
-            func: The function to wrap.
-
-        Returns:
-            The wrapped function.
-        """
-        return func  # type: ignore
 else:
 else:
 
 
     class EventCallback(Generic[P, T]):
     class EventCallback(Generic[P, T]):
         """A descriptor that wraps a function to be used as an event."""
         """A descriptor that wraps a function to be used as an event."""
 
 
-    def event_handler(func: Callable[P, T]) -> Callable[P, T]:
-        """Wrap a function to be used as an event.
-
-        Args:
-            func: The function to wrap.
-
-        Returns:
-            The wrapped function.
-        """
-        return func
-
 
 
 G = ParamSpec("G")
 G = ParamSpec("G")
 
 
@@ -1608,8 +1597,93 @@ class EventNamespace(types.SimpleNamespace):
     EventChainVar = EventChainVar
     EventChainVar = EventChainVar
     LiteralEventChainVar = LiteralEventChainVar
     LiteralEventChainVar = LiteralEventChainVar
     EventType = EventType
     EventType = EventType
+    EventCallback = EventCallback
+
+    if sys.version_info >= (3, 10):
+
+        @overload
+        @staticmethod
+        def __call__(
+            func: None = None, *, background: bool | None = None
+        ) -> Callable[[Callable[Concatenate[Any, P], T]], EventCallback[P, T]]: ...
+
+        @overload
+        @staticmethod
+        def __call__(
+            func: Callable[Concatenate[Any, P], T],
+            *,
+            background: bool | None = None,
+        ) -> EventCallback[P, T]: ...
+
+        @staticmethod
+        def __call__(
+            func: Callable[Concatenate[Any, P], T] | None = None,
+            *,
+            background: bool | None = None,
+        ) -> Union[
+            EventCallback[P, T],
+            Callable[[Callable[Concatenate[Any, P], T]], EventCallback[P, T]],
+        ]:
+            """Wrap a function to be used as an event.
+
+            Args:
+                func: The function to wrap.
+                background: Whether the event should be run in the background. Defaults to False.
+
+            Returns:
+                The wrapped function.
+            """
+
+            def wrapper(func: Callable[Concatenate[Any, P], T]) -> EventCallback[P, T]:
+                if background is True:
+                    return background_event_decorator(func, __internal_reflex_call=True)  # type: ignore
+                return func  # type: ignore
+
+            if func is not None:
+                return wrapper(func)
+            return wrapper
+    else:
+
+        @overload
+        @staticmethod
+        def __call__(
+            func: None = None, *, background: bool | None = None
+        ) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
+
+        @overload
+        @staticmethod
+        def __call__(
+            func: Callable[P, T], *, background: bool | None = None
+        ) -> Callable[P, T]: ...
+
+        @staticmethod
+        def __call__(
+            func: Callable[P, T] | None = None,
+            *,
+            background: bool | None = None,
+        ) -> Union[
+            Callable[P, T],
+            Callable[[Callable[P, T]], Callable[P, T]],
+        ]:
+            """Wrap a function to be used as an event.
+
+            Args:
+                func: The function to wrap.
+                background: Whether the event should be run in the background. Defaults to False.
+
+            Returns:
+                The wrapped function.
+            """
+
+            def wrapper(func: Callable[P, T]) -> Callable[P, T]:
+                if background is True:
+                    return background_event_decorator(func, __internal_reflex_call=True)  # type: ignore
+                return func  # type: ignore
+
+            if func is not None:
+                return wrapper(func)
+            return wrapper
 
 
-    __call__ = staticmethod(event_handler)
     get_event = staticmethod(get_event)
     get_event = staticmethod(get_event)
     get_hydrate_event = staticmethod(get_hydrate_event)
     get_hydrate_event = staticmethod(get_hydrate_event)
     fix_events = staticmethod(fix_events)
     fix_events = staticmethod(fix_events)

+ 1 - 1
reflex/experimental/misc.py

@@ -7,7 +7,7 @@ from typing import Any
 async def run_in_thread(func) -> Any:
 async def run_in_thread(func) -> Any:
     """Run a function in a separate thread.
     """Run a function in a separate thread.
 
 
-    To not block the UI event queue, run_in_thread must be inside inside a rx.background() decorated method.
+    To not block the UI event queue, run_in_thread must be inside inside a rx.event(background=True) decorated method.
 
 
     Args:
     Args:
         func (callable): The non-async function to run.
         func (callable): The non-async function to run.

+ 2 - 2
reflex/state.py

@@ -2346,7 +2346,7 @@ class StateProxy(wrapt.ObjectProxy):
         class State(rx.State):
         class State(rx.State):
             counter: int = 0
             counter: int = 0
 
 
-            @rx.background
+            @rx.event(background=True)
             async def bg_increment(self):
             async def bg_increment(self):
                 await asyncio.sleep(1)
                 await asyncio.sleep(1)
                 async with self:
                 async with self:
@@ -3248,7 +3248,7 @@ class StateManagerRedis(StateManager):
             raise LockExpiredError(
             raise LockExpiredError(
                 f"Lock expired for token {token} while processing. Consider increasing "
                 f"Lock expired for token {token} while processing. Consider increasing "
                 f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
                 f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
-                "or use `@rx.background` decorator for long-running tasks."
+                "or use `@rx.event(background=True)` decorator for long-running tasks."
             )
             )
         client_token, substate_name = _split_substate_key(token)
         client_token, substate_name = _split_substate_key(token)
         # If the substate name on the token doesn't match the instance name, it cannot have a parent.
         # If the substate name on the token doesn't match the instance name, it cannot have a parent.

+ 9 - 16
tests/integration/test_background_task.py

@@ -1,4 +1,4 @@
-"""Test @rx.background task functionality."""
+"""Test @rx.event(background=True) task functionality."""
 
 
 from typing import Generator
 from typing import Generator
 
 
@@ -22,8 +22,7 @@ def BackgroundTask():
         _task_id: int = 0
         _task_id: int = 0
         iterations: int = 10
         iterations: int = 10
 
 
-        @rx.background
-        @rx.event
+        @rx.event(background=True)
         async def handle_event(self):
         async def handle_event(self):
             async with self:
             async with self:
                 self._task_id += 1
                 self._task_id += 1
@@ -32,8 +31,7 @@ def BackgroundTask():
                     self.counter += 1
                     self.counter += 1
                 await asyncio.sleep(0.005)
                 await asyncio.sleep(0.005)
 
 
-        @rx.background
-        @rx.event
+        @rx.event(background=True)
         async def handle_event_yield_only(self):
         async def handle_event_yield_only(self):
             async with self:
             async with self:
                 self._task_id += 1
                 self._task_id += 1
@@ -48,7 +46,7 @@ def BackgroundTask():
         def increment(self):
         def increment(self):
             self.counter += 1
             self.counter += 1
 
 
-        @rx.background
+        @rx.event(background=True)
         async def increment_arbitrary(self, amount: int):
         async def increment_arbitrary(self, amount: int):
             async with self:
             async with self:
                 self.counter += int(amount)
                 self.counter += int(amount)
@@ -61,8 +59,7 @@ def BackgroundTask():
         async def blocking_pause(self):
         async def blocking_pause(self):
             await asyncio.sleep(0.02)
             await asyncio.sleep(0.02)
 
 
-        @rx.background
-        @rx.event
+        @rx.event(background=True)
         async def non_blocking_pause(self):
         async def non_blocking_pause(self):
             await asyncio.sleep(0.02)
             await asyncio.sleep(0.02)
 
 
@@ -74,15 +71,13 @@ def BackgroundTask():
                     self.counter += 1
                     self.counter += 1
                 await asyncio.sleep(0.005)
                 await asyncio.sleep(0.005)
 
 
-        @rx.background
-        @rx.event
+        @rx.event(background=True)
         async def handle_racy_event(self):
         async def handle_racy_event(self):
             await asyncio.gather(
             await asyncio.gather(
                 self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
                 self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
             )
             )
 
 
-        @rx.background
-        @rx.event
+        @rx.event(background=True)
         async def nested_async_with_self(self):
         async def nested_async_with_self(self):
             async with self:
             async with self:
                 self.counter += 1
                 self.counter += 1
@@ -94,8 +89,7 @@ def BackgroundTask():
             third_state = await self.get_state(ThirdState)
             third_state = await self.get_state(ThirdState)
             await third_state._triple_count()
             await third_state._triple_count()
 
 
-        @rx.background
-        @rx.event
+        @rx.event(background=True)
         async def yield_in_async_with_self(self):
         async def yield_in_async_with_self(self):
             async with self:
             async with self:
                 self.counter += 1
                 self.counter += 1
@@ -103,8 +97,7 @@ def BackgroundTask():
                 self.counter += 1
                 self.counter += 1
 
 
     class OtherState(rx.State):
     class OtherState(rx.State):
-        @rx.background
-        @rx.event
+        @rx.event(background=True)
         async def get_other_state(self):
         async def get_other_state(self):
             async with self:
             async with self:
                 state = await self.get_state(State)
                 state = await self.get_state(State)

+ 3 - 3
tests/units/states/upload.py

@@ -71,7 +71,7 @@ class FileUploadState(State):
             assert file.filename is not None
             assert file.filename is not None
             self.img_list.append(file.filename)
             self.img_list.append(file.filename)
 
 
-    @rx.background
+    @rx.event(background=True)
     async def bg_upload(self, files: List[rx.UploadFile]):
     async def bg_upload(self, files: List[rx.UploadFile]):
         """Background task cannot be upload handler.
         """Background task cannot be upload handler.
 
 
@@ -119,7 +119,7 @@ class ChildFileUploadState(FileStateBase1):
             assert file.filename is not None
             assert file.filename is not None
             self.img_list.append(file.filename)
             self.img_list.append(file.filename)
 
 
-    @rx.background
+    @rx.event(background=True)
     async def bg_upload(self, files: List[rx.UploadFile]):
     async def bg_upload(self, files: List[rx.UploadFile]):
         """Background task cannot be upload handler.
         """Background task cannot be upload handler.
 
 
@@ -167,7 +167,7 @@ class GrandChildFileUploadState(FileStateBase2):
             assert file.filename is not None
             assert file.filename is not None
             self.img_list.append(file.filename)
             self.img_list.append(file.filename)
 
 
-    @rx.background
+    @rx.event(background=True)
     async def bg_upload(self, files: List[rx.UploadFile]):
     async def bg_upload(self, files: List[rx.UploadFile]):
         """Background task cannot be upload handler.
         """Background task cannot be upload handler.
 
 

+ 1 - 1
tests/units/test_app.py

@@ -874,7 +874,7 @@ async def test_upload_file_background(state, tmp_path, token):
         await fn(request_mock, [file_mock])
         await fn(request_mock, [file_mock])
     assert (
     assert (
         err.value.args[0]
         err.value.args[0]
-        == f"@rx.background is not supported for upload handler `{state.get_full_name()}.bg_upload`."
+        == f"@rx.event(background=True) is not supported for upload handler `{state.get_full_name()}.bg_upload`."
     )
     )
 
 
     if isinstance(app.state_manager, StateManagerRedis):
     if isinstance(app.state_manager, StateManagerRedis):

+ 3 - 3
tests/units/test_state.py

@@ -1965,7 +1965,7 @@ class BackgroundTaskState(BaseState):
         """
         """
         return self.order
         return self.order
 
 
-    @rx.background
+    @rx.event(background=True)
     async def background_task(self):
     async def background_task(self):
         """A background task that updates the state."""
         """A background task that updates the state."""
         async with self:
         async with self:
@@ -2002,7 +2002,7 @@ class BackgroundTaskState(BaseState):
             self.other()  # direct calling event handlers works in context
             self.other()  # direct calling event handlers works in context
             self._private_method()
             self._private_method()
 
 
-    @rx.background
+    @rx.event(background=True)
     async def background_task_reset(self):
     async def background_task_reset(self):
         """A background task that resets the state."""
         """A background task that resets the state."""
         with pytest.raises(ImmutableStateError):
         with pytest.raises(ImmutableStateError):
@@ -2016,7 +2016,7 @@ class BackgroundTaskState(BaseState):
         async with self:
         async with self:
             self.order.append("reset")
             self.order.append("reset")
 
 
-    @rx.background
+    @rx.event(background=True)
     async def background_task_generator(self):
     async def background_task_generator(self):
         """A background task generator that does nothing.
         """A background task generator that does nothing.