Browse Source

App: only render default overlay_component when DefaultState is not used (#1744)

Masen Furer 1 year ago
parent
commit
2e014422f5
4 changed files with 174 additions and 5 deletions
  1. 90 0
      integration/test_connection_banner.py
  2. 19 1
      reflex/app.py
  3. 2 2
      reflex/testing.py
  4. 63 2
      tests/test_app.py

+ 90 - 0
integration/test_connection_banner.py

@@ -0,0 +1,90 @@
+"""Test case for displaying the connection banner when the websocket drops."""
+
+from typing import Generator
+
+import pytest
+from selenium.common.exceptions import NoSuchElementException
+from selenium.webdriver.common.by import By
+
+from reflex.testing import AppHarness, WebDriver
+
+
+def ConnectionBanner():
+    """App with a connection banner."""
+    import reflex as rx
+
+    class State(rx.State):
+        foo: int = 0
+
+    def index():
+        return rx.text("Hello World")
+
+    app = rx.App(state=State)
+    app.add_page(index)
+    app.compile()
+
+
+@pytest.fixture()
+def connection_banner(tmp_path) -> Generator[AppHarness, None, None]:
+    """Start ConnectionBanner app at tmp_path via AppHarness.
+
+    Args:
+        tmp_path: pytest tmp_path fixture
+
+    Yields:
+        running AppHarness instance
+    """
+    with AppHarness.create(
+        root=tmp_path,
+        app_source=ConnectionBanner,  # type: ignore
+    ) as harness:
+        yield harness
+
+
+CONNECTION_ERROR_XPATH = "//*[ text() = 'Connection Error' ]"
+
+
+def has_error_modal(driver: WebDriver) -> bool:
+    """Check if the connection error modal is displayed.
+
+    Args:
+        driver: Selenium webdriver instance.
+
+    Returns:
+        True if the modal is displayed, False otherwise.
+    """
+    try:
+        driver.find_element(By.XPATH, CONNECTION_ERROR_XPATH)
+        return True
+    except NoSuchElementException:
+        return False
+
+
+def test_connection_banner(connection_banner: AppHarness):
+    """Test that the connection banner is displayed when the websocket drops.
+
+    Args:
+        connection_banner: AppHarness instance.
+    """
+    assert connection_banner.app_instance is not None
+    assert connection_banner.backend is not None
+    driver = connection_banner.frontend()
+
+    connection_banner._poll_for(lambda: not has_error_modal(driver))
+
+    # Get the backend port
+    backend_port = connection_banner._poll_for_servers().getsockname()[1]
+
+    # Kill the backend
+    connection_banner.backend.should_exit = True
+    if connection_banner.backend_thread is not None:
+        connection_banner.backend_thread.join()
+
+    # Error modal should now be displayed
+    connection_banner._poll_for(lambda: has_error_modal(driver))
+
+    # Bring the backend back up
+    connection_banner._start_backend(port=backend_port)
+
+    # Banner should be gone now
+    connection_banner._poll_for(lambda: not has_error_modal(driver))

+ 19 - 1
reflex/app.py

@@ -57,6 +57,15 @@ ComponentCallable = Callable[[], Component]
 Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
 
 
+def default_overlay_component() -> Component:
+    """Default overlay_component attribute for App.
+
+    Returns:
+        The default overlay_component, which is a connection_modal.
+    """
+    return connection_modal()
+
+
 class App(Base):
     """A Reflex application."""
 
@@ -97,7 +106,9 @@ class App(Base):
     event_namespace: Optional[AsyncNamespace] = None
 
     # A component that is present on every page.
-    overlay_component: Optional[Union[Component, ComponentCallable]] = connection_modal
+    overlay_component: Optional[
+        Union[Component, ComponentCallable]
+    ] = default_overlay_component
 
     def __init__(self, *args, **kwargs):
         """Initialize the app.
@@ -179,6 +190,13 @@ class App(Base):
         # Set up the admin dash.
         self.setup_admin_dash()
 
+        # If a State is not used and no overlay_component is specified, do not render the connection modal
+        if (
+            self.state is DefaultState
+            and self.overlay_component is default_overlay_component
+        ):
+            self.overlay_component = None
+
     def __repr__(self) -> str:
         """Get the string representation of the app.
 

+ 2 - 2
reflex/testing.py

@@ -163,14 +163,14 @@ class AppHarness:
             self.app_module = reflex.utils.prerequisites.get_app(reload=True)
         self.app_instance = self.app_module.app
 
-    def _start_backend(self):
+    def _start_backend(self, port=0):
         if self.app_instance is None:
             raise RuntimeError("App was not initialized.")
         self.backend = uvicorn.Server(
             uvicorn.Config(
                 app=self.app_instance.api,
                 host="127.0.0.1",
-                port=0,
+                port=port,
             )
         )
         self.backend_thread = threading.Thread(target=self.backend.run)

+ 63 - 2
tests/test_app.py

@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import io
 import os.path
 import sys
@@ -16,8 +18,15 @@ from starlette_admin.contrib.sqla.admin import Admin
 from starlette_admin.contrib.sqla.view import ModelView
 
 from reflex import AdminDash, constants
-from reflex.app import App, DefaultState, process, upload
-from reflex.components import Box
+from reflex.app import (
+    App,
+    ComponentCallable,
+    DefaultState,
+    default_overlay_component,
+    process,
+    upload,
+)
+from reflex.components import Box, Component, Cond, Fragment, Text
 from reflex.event import Event, get_hydrate_event
 from reflex.middleware import HydrateMiddleware
 from reflex.model import Model
@@ -945,3 +954,55 @@ async def test_process_events(gen_state, mocker):
 
     assert app.state_manager.get_state("token").value == 5
     assert app.postprocess.call_count == 6
+
+
+@pytest.mark.parametrize(
+    ("state", "overlay_component", "exp_page_child"),
+    [
+        (DefaultState, default_overlay_component, None),
+        (DefaultState, None, None),
+        (DefaultState, Text.create("foo"), Text),
+        (State, default_overlay_component, Fragment),
+        (State, None, None),
+        (State, Text.create("foo"), Text),
+        (State, lambda: Text.create("foo"), Text),
+    ],
+)
+def test_overlay_component(
+    state: State | None,
+    overlay_component: Component | ComponentCallable | None,
+    exp_page_child: Type[Component] | None,
+):
+    """Test that the overlay component is set correctly.
+
+    Args:
+        state: The state class to pass to App.
+        overlay_component: The overlay_component to pass to App.
+        exp_page_child: The type of the expected child in the page fragment.
+    """
+    app = App(state=state, overlay_component=overlay_component)
+    if exp_page_child is None:
+        assert app.overlay_component is None
+    elif isinstance(exp_page_child, Fragment):
+        assert app.overlay_component is not None
+        generated_component = app._generate_component(app.overlay_component)
+        assert isinstance(generated_component, Fragment)
+        assert isinstance(
+            generated_component.children[0],
+            Cond,  # ConnectionModal is a Cond under the hood
+        )
+    else:
+        assert app.overlay_component is not None
+        assert isinstance(
+            app._generate_component(app.overlay_component),
+            exp_page_child,
+        )
+
+    app.add_page(Box.create("Index"), route="/test")
+    page = app.pages["test"]
+    if exp_page_child is not None:
+        assert len(page.children) == 3
+        children_types = (type(child) for child in page.children)
+        assert exp_page_child in children_types
+    else:
+        assert len(page.children) == 2