소스 검색

move overlays to _app.js (#4794)

* move overlays to _app.js

* fix unit tests

* fix dynamic imports app

* fix unit cases once again

* clear custom compoent cache between app harness tests
Khaleel Al-Adhami 3 달 전
부모
커밋
d545ee3f0b
6개의 변경된 파일76개의 추가작업 그리고 47개의 파일을 삭제
  1. 59 33
      reflex/app.py
  2. 1 0
      reflex/compiler/compiler.py
  3. 3 1
      reflex/components/component.py
  4. 2 0
      reflex/testing.py
  5. 6 9
      reflex/utils/exec.py
  6. 5 4
      tests/units/test_app.py

+ 59 - 33
reflex/app.py

@@ -164,11 +164,11 @@ def default_backend_exception_handler(exception: Exception) -> EventSpec:
         return window_alert("\n".join(error_message))
         return window_alert("\n".join(error_message))
 
 
 
 
-def default_overlay_component() -> Component:
-    """Default overlay_component attribute for App.
+def extra_overlay_function() -> Optional[Component]:
+    """Extra overlay function to add to the overlay component.
 
 
     Returns:
     Returns:
-        The default overlay_component, which is a connection_modal.
+        The extra overlay function.
     """
     """
     config = get_config()
     config = get_config()
 
 
@@ -178,7 +178,8 @@ def default_overlay_component() -> Component:
         module, _, function_name = extra_config.rpartition(".")
         module, _, function_name = extra_config.rpartition(".")
         try:
         try:
             module = __import__(module)
             module = __import__(module)
-            config_overlay = getattr(module, function_name)()
+            config_overlay = Fragment.create(getattr(module, function_name)())
+            config_overlay._get_all_imports()
         except Exception as e:
         except Exception as e:
             from reflex.compiler.utils import save_error
             from reflex.compiler.utils import save_error
 
 
@@ -188,13 +189,27 @@ def default_overlay_component() -> Component:
                 f"Error loading extra_overlay_function {extra_config}. Error saved to {log_path}"
                 f"Error loading extra_overlay_function {extra_config}. Error saved to {log_path}"
             )
             )
 
 
-    return Fragment.create(
-        connection_pulser(),
-        connection_toaster(),
-        *([config_overlay] if config_overlay else []),
-        *([backend_disabled()] if config.is_reflex_cloud else []),
-        *codespaces.codespaces_auto_redirect(),
-    )
+    return config_overlay
+
+
+def default_overlay_component() -> Component:
+    """Default overlay_component attribute for App.
+
+    Returns:
+        The default overlay_component, which is a connection_modal.
+    """
+    config = get_config()
+    from reflex.components.component import memo
+
+    def default_overlay_components():
+        return Fragment.create(
+            connection_pulser(),
+            connection_toaster(),
+            *([backend_disabled()] if config.is_reflex_cloud else []),
+            *codespaces.codespaces_auto_redirect(),
+        )
+
+    return Fragment.create(memo(default_overlay_components)())
 
 
 
 
 def default_error_boundary(*children: Component) -> Component:
 def default_error_boundary(*children: Component) -> Component:
@@ -266,11 +281,26 @@ class App(MiddlewareMixin, LifespanMixin):
 
 
     # A component that is present on every page (defaults to the Connection Error banner).
     # A component that is present on every page (defaults to the Connection Error banner).
     overlay_component: Optional[Union[Component, ComponentCallable]] = (
     overlay_component: Optional[Union[Component, ComponentCallable]] = (
-        dataclasses.field(default_factory=default_overlay_component)
+        dataclasses.field(default=None)
     )
     )
 
 
     # Error boundary component to wrap the app with.
     # Error boundary component to wrap the app with.
-    error_boundary: Optional[ComponentCallable] = default_error_boundary
+    error_boundary: Optional[ComponentCallable] = dataclasses.field(default=None)
+
+    # App wraps to be applied to the whole app. Expected to be a dictionary of (order, name) to a function that takes whether the state is enabled and optionally returns a component.
+    app_wraps: Dict[tuple[int, str], Callable[[bool], Optional[Component]]] = (
+        dataclasses.field(
+            default_factory=lambda: {
+                (55, "ErrorBoundary"): (
+                    lambda stateful: default_error_boundary() if stateful else None
+                ),
+                (5, "Overlay"): (
+                    lambda stateful: default_overlay_component() if stateful else None
+                ),
+                (4, "ExtraOverlay"): lambda stateful: extra_overlay_function(),
+            }
+        )
+    )
 
 
     # Components to add to the head of every page.
     # Components to add to the head of every page.
     head_components: List[Component] = dataclasses.field(default_factory=list)
     head_components: List[Component] = dataclasses.field(default_factory=list)
@@ -880,25 +910,6 @@ class App(MiddlewareMixin, LifespanMixin):
         for k, component in self._pages.items():
         for k, component in self._pages.items():
             self._pages[k] = self._add_overlay_to_component(component)
             self._pages[k] = self._add_overlay_to_component(component)
 
 
-    def _add_error_boundary_to_component(self, component: Component) -> Component:
-        if self.error_boundary is None:
-            return component
-
-        component = self.error_boundary(*component.children)
-
-        return component
-
-    def _setup_error_boundary(self):
-        """If a State is not used and no error_boundary is specified, do not render the error boundary."""
-        if self._state is None and self.error_boundary is default_error_boundary:
-            self.error_boundary = None
-
-        for k, component in self._pages.items():
-            # Skip the 404 page
-            if k == constants.Page404.SLUG:
-                continue
-            self._pages[k] = self._add_error_boundary_to_component(component)
-
     def _setup_sticky_badge(self):
     def _setup_sticky_badge(self):
         """Add the sticky badge to the app."""
         """Add the sticky badge to the app."""
         for k, component in self._pages.items():
         for k, component in self._pages.items():
@@ -1039,7 +1050,6 @@ class App(MiddlewareMixin, LifespanMixin):
 
 
         self._validate_var_dependencies()
         self._validate_var_dependencies()
         self._setup_overlay_component()
         self._setup_overlay_component()
-        self._setup_error_boundary()
         if is_prod_mode() and config.show_built_with_reflex:
         if is_prod_mode() and config.show_built_with_reflex:
             self._setup_sticky_badge()
             self._setup_sticky_badge()
 
 
@@ -1066,6 +1076,22 @@ class App(MiddlewareMixin, LifespanMixin):
             # Add the custom components from the page to the set.
             # Add the custom components from the page to the set.
             custom_components |= component._get_all_custom_components()
             custom_components |= component._get_all_custom_components()
 
 
+        # Add the app wraps to the app.
+        for key, app_wrap in self.app_wraps.items():
+            component = app_wrap(self._state is not None)
+            if component is not None:
+                app_wrappers[key] = component
+                custom_components |= component._get_all_custom_components()
+
+        if self.error_boundary:
+            console.deprecate(
+                feature_name="App.error_boundary",
+                reason="Use app_wraps instead.",
+                deprecation_version="0.7.1",
+                removal_version="0.8.0",
+            )
+            app_wrappers[(55, "ErrorBoundary")] = self.error_boundary()
+
         # Perform auto-memoization of stateful components.
         # Perform auto-memoization of stateful components.
         with console.timing("Auto-memoize StatefulComponents"):
         with console.timing("Auto-memoize StatefulComponents"):
             (
             (

+ 1 - 0
reflex/compiler/compiler.py

@@ -78,6 +78,7 @@ def _compile_app(app_root: Component) -> str:
         hooks=app_root._get_all_hooks(),
         hooks=app_root._get_all_hooks(),
         window_libraries=window_libraries,
         window_libraries=window_libraries,
         render=app_root.render(),
         render=app_root.render(),
+        dynamic_imports=app_root._get_all_dynamic_imports(),
     )
     )
 
 
 
 

+ 3 - 1
reflex/components/component.py

@@ -23,6 +23,8 @@ from typing import (
     Union,
     Union,
 )
 )
 
 
+from typing_extensions import Self
+
 import reflex.state
 import reflex.state
 from reflex.base import Base
 from reflex.base import Base
 from reflex.compiler.templates import STATEFUL_COMPONENT
 from reflex.compiler.templates import STATEFUL_COMPONENT
@@ -685,7 +687,7 @@ class Component(BaseComponent, ABC):
         }
         }
 
 
     @classmethod
     @classmethod
-    def create(cls, *children, **props) -> Component:
+    def create(cls, *children, **props) -> Self:
         """Create the component.
         """Create the component.
 
 
         Args:
         Args:

+ 2 - 0
reflex/testing.py

@@ -43,6 +43,7 @@ import reflex.utils.exec
 import reflex.utils.format
 import reflex.utils.format
 import reflex.utils.prerequisites
 import reflex.utils.prerequisites
 import reflex.utils.processes
 import reflex.utils.processes
+from reflex.components.component import CustomComponent
 from reflex.config import environment
 from reflex.config import environment
 from reflex.state import (
 from reflex.state import (
     BaseState,
     BaseState,
@@ -254,6 +255,7 @@ class AppHarness:
         # disable telemetry reporting for tests
         # disable telemetry reporting for tests
 
 
         os.environ["TELEMETRY_ENABLED"] = "false"
         os.environ["TELEMETRY_ENABLED"] = "false"
+        CustomComponent.create().get_component.cache_clear()
         self.app_path.mkdir(parents=True, exist_ok=True)
         self.app_path.mkdir(parents=True, exist_ok=True)
         if self.app_source is not None:
         if self.app_source is not None:
             app_globals = self._get_globals_from_signature(self.app_source)
             app_globals = self._get_globals_from_signature(self.app_source)

+ 6 - 9
reflex/utils/exec.py

@@ -254,15 +254,12 @@ def get_reload_paths() -> Sequence[Path]:
     if config.app_module is not None and config.app_module.__file__:
     if config.app_module is not None and config.app_module.__file__:
         module_path = Path(config.app_module.__file__).resolve().parent
         module_path = Path(config.app_module.__file__).resolve().parent
 
 
-        while module_path.parent.name:
-            if any(
-                sibling_file.name == "__init__.py"
-                for sibling_file in module_path.parent.iterdir()
-            ):
-                # go up a level to find dir without `__init__.py`
-                module_path = module_path.parent
-            else:
-                break
+        while module_path.parent.name and any(
+            sibling_file.name == "__init__.py"
+            for sibling_file in module_path.parent.iterdir()
+        ):
+            # go up a level to find dir without `__init__.py`
+            module_path = module_path.parent
 
 
         reload_paths = [module_path]
         reload_paths = [module_path]
 
 

+ 5 - 4
tests/units/test_app.py

@@ -1299,6 +1299,7 @@ def test_app_wrap_compile_theme(
     app_js_lines = [
     app_js_lines = [
         line.strip() for line in app_js_contents.splitlines() if line.strip()
         line.strip() for line in app_js_contents.splitlines() if line.strip()
     ]
     ]
+    lines = "".join(app_js_lines)
     assert (
     assert (
         "function AppWrap({children}) {"
         "function AppWrap({children}) {"
         "return ("
         "return ("
@@ -1313,7 +1314,7 @@ def test_app_wrap_compile_theme(
         + ("</StrictMode>" if react_strict_mode else "")
         + ("</StrictMode>" if react_strict_mode else "")
         + ")"
         + ")"
         "}"
         "}"
-    ) in "".join(app_js_lines)
+    ) in lines
 
 
 
 
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
@@ -1362,6 +1363,7 @@ def test_app_wrap_priority(
     app_js_lines = [
     app_js_lines = [
         line.strip() for line in app_js_contents.splitlines() if line.strip()
         line.strip() for line in app_js_contents.splitlines() if line.strip()
     ]
     ]
+    lines = "".join(app_js_lines)
     assert (
     assert (
         "function AppWrap({children}) {"
         "function AppWrap({children}) {"
         "return (" + ("<StrictMode>" if react_strict_mode else "") + "<RadixThemesBox>"
         "return (" + ("<StrictMode>" if react_strict_mode else "") + "<RadixThemesBox>"
@@ -1374,9 +1376,8 @@ def test_app_wrap_priority(
         "</Fragment2>"
         "</Fragment2>"
         "</RadixThemesColorModeProvider>"
         "</RadixThemesColorModeProvider>"
         "</RadixThemesText>"
         "</RadixThemesText>"
-        "</RadixThemesBox>" + ("</StrictMode>" if react_strict_mode else "") + ")"
-        "}"
-    ) in "".join(app_js_lines)
+        "</RadixThemesBox>" + ("</StrictMode>" if react_strict_mode else "")
+    ) in lines
 
 
 
 
 def test_app_state_determination():
 def test_app_state_determination():