Pārlūkot izejas kodu

add backend disabled dialog (#4715)

* add backend disabled dialog

* pyi that guy

* pyi the other guy

* extend test_connection_banner to also test the cloud banner

* oops, need asyncio _inside_ the app

* Update reflex/components/core/banner.py

Co-authored-by: Masen Furer <m_github@0x26.net>

* use universal cookies

* fix pre-commit

* revert universal cookie 🍪

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
Khaleel Al-Adhami 3 mēneši atpakaļ
vecāks
revīzija
335816cbf7

+ 15 - 6
reflex/.templates/web/utils/state.js

@@ -106,6 +106,18 @@ export const getBackendURL = (url_str) => {
   return endpoint;
 };
 
+/**
+ * Check if the backend is disabled.
+ *
+ * @returns True if the backend is disabled, false otherwise.
+ */
+export const isBackendDisabled = () => {
+  const cookie = document.cookie
+    .split("; ")
+    .find((row) => row.startsWith("backend-enabled="));
+  return cookie !== undefined && cookie.split("=")[1] == "false";
+};
+
 /**
  * Determine if any event in the event queue is stateful.
  *
@@ -301,10 +313,7 @@ export const applyEvent = async (event, socket) => {
 
   // Send the event to the server.
   if (socket) {
-    socket.emit(
-      "event",
-      event,
-    );
+    socket.emit("event", event);
     return true;
   }
 
@@ -497,7 +506,7 @@ export const uploadFiles = async (
     return false;
   }
 
-  const upload_ref_name = `__upload_controllers_${upload_id}`
+  const upload_ref_name = `__upload_controllers_${upload_id}`;
 
   if (refs[upload_ref_name]) {
     console.log("Upload already in progress for ", upload_id);
@@ -815,7 +824,7 @@ export const useEventLoop = (
       return;
     }
     // only use websockets if state is present
-    if (Object.keys(initialState).length > 1) {
+    if (Object.keys(initialState).length > 1 && !isBackendDisabled()) {
       // Initialize the websocket connection.
       if (!socket.current) {
         connect(

+ 8 - 1
reflex/app.py

@@ -59,7 +59,11 @@ from reflex.components.component import (
     ComponentStyle,
     evaluate_style_namespaces,
 )
-from reflex.components.core.banner import connection_pulser, connection_toaster
+from reflex.components.core.banner import (
+    backend_disabled,
+    connection_pulser,
+    connection_toaster,
+)
 from reflex.components.core.breakpoints import set_breakpoints
 from reflex.components.core.client_side_routing import (
     Default404Page,
@@ -158,9 +162,12 @@ def default_overlay_component() -> Component:
     Returns:
         The default overlay_component, which is a connection_modal.
     """
+    config = get_config()
+
     return Fragment.create(
         connection_pulser(),
         connection_toaster(),
+        *([backend_disabled()] if config.is_reflex_cloud else []),
         *codespaces.codespaces_auto_redirect(),
     )
 

+ 79 - 0
reflex/components/core/banner.py

@@ -4,8 +4,10 @@ from __future__ import annotations
 
 from typing import Optional
 
+from reflex import constants
 from reflex.components.component import Component
 from reflex.components.core.cond import cond
+from reflex.components.datadisplay.logo import svg_logo
 from reflex.components.el.elements.typography import Div
 from reflex.components.lucide.icon import Icon
 from reflex.components.radix.themes.components.dialog import (
@@ -293,7 +295,84 @@ class ConnectionPulser(Div):
         )
 
 
+class BackendDisabled(Div):
+    """A component that displays a message when the backend is disabled."""
+
+    @classmethod
+    def create(cls, **props) -> Component:
+        """Create a backend disabled component.
+
+        Args:
+            **props: The properties of the component.
+
+        Returns:
+            The backend disabled component.
+        """
+        import reflex as rx
+
+        is_backend_disabled = Var(
+            "backendDisabled",
+            _var_type=bool,
+            _var_data=VarData(
+                hooks={
+                    "const [backendDisabled, setBackendDisabled] = useState(false);": None,
+                    "useEffect(() => { setBackendDisabled(isBackendDisabled()); }, []);": None,
+                },
+                imports={
+                    f"$/{constants.Dirs.STATE_PATH}": [
+                        ImportVar(tag="isBackendDisabled")
+                    ],
+                },
+            ),
+        )
+
+        return super().create(
+            rx.cond(
+                is_backend_disabled,
+                rx.box(
+                    rx.box(
+                        rx.card(
+                            rx.vstack(
+                                svg_logo(),
+                                rx.text(
+                                    "You ran out of compute credits.",
+                                ),
+                                rx.callout(
+                                    rx.fragment(
+                                        "Please upgrade your plan or raise your compute credits at ",
+                                        rx.link(
+                                            "Reflex Cloud.",
+                                            href="https://cloud.reflex.dev/",
+                                        ),
+                                    ),
+                                    width="100%",
+                                    icon="info",
+                                    variant="surface",
+                                ),
+                            ),
+                            font_size="20px",
+                            font_family='"Inter", "Helvetica", "Arial", sans-serif',
+                            variant="classic",
+                        ),
+                        position="fixed",
+                        top="50%",
+                        left="50%",
+                        transform="translate(-50%, -50%)",
+                        width="40ch",
+                        max_width="90vw",
+                    ),
+                    position="fixed",
+                    z_index=9999,
+                    backdrop_filter="grayscale(1) blur(5px)",
+                    width="100dvw",
+                    height="100dvh",
+                ),
+            )
+        )
+
+
 connection_banner = ConnectionBanner.create
 connection_modal = ConnectionModal.create
 connection_toaster = ConnectionToaster.create
 connection_pulser = ConnectionPulser.create
+backend_disabled = BackendDisabled.create

+ 86 - 0
reflex/components/core/banner.pyi

@@ -350,7 +350,93 @@ class ConnectionPulser(Div):
         """
         ...
 
+class BackendDisabled(Div):
+    @overload
+    @classmethod
+    def create(  # type: ignore
+        cls,
+        *children,
+        access_key: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
+        auto_capitalize: Optional[
+            Union[Var[Union[bool, int, str]], bool, int, str]
+        ] = None,
+        content_editable: Optional[
+            Union[Var[Union[bool, int, str]], bool, int, str]
+        ] = None,
+        context_menu: Optional[
+            Union[Var[Union[bool, int, str]], bool, int, str]
+        ] = None,
+        dir: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
+        draggable: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
+        enter_key_hint: Optional[
+            Union[Var[Union[bool, int, str]], bool, int, str]
+        ] = None,
+        hidden: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
+        input_mode: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
+        item_prop: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
+        lang: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
+        role: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
+        slot: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
+        spell_check: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
+        tab_index: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
+        title: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
+        style: Optional[Style] = None,
+        key: Optional[Any] = None,
+        id: Optional[Any] = None,
+        class_name: Optional[Any] = None,
+        autofocus: Optional[bool] = None,
+        custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None,
+        on_blur: Optional[EventType[[], BASE_STATE]] = None,
+        on_click: Optional[EventType[[], BASE_STATE]] = None,
+        on_context_menu: Optional[EventType[[], BASE_STATE]] = None,
+        on_double_click: Optional[EventType[[], BASE_STATE]] = None,
+        on_focus: Optional[EventType[[], BASE_STATE]] = None,
+        on_mount: Optional[EventType[[], BASE_STATE]] = None,
+        on_mouse_down: Optional[EventType[[], BASE_STATE]] = None,
+        on_mouse_enter: Optional[EventType[[], BASE_STATE]] = None,
+        on_mouse_leave: Optional[EventType[[], BASE_STATE]] = None,
+        on_mouse_move: Optional[EventType[[], BASE_STATE]] = None,
+        on_mouse_out: Optional[EventType[[], BASE_STATE]] = None,
+        on_mouse_over: Optional[EventType[[], BASE_STATE]] = None,
+        on_mouse_up: Optional[EventType[[], BASE_STATE]] = None,
+        on_scroll: Optional[EventType[[], BASE_STATE]] = None,
+        on_unmount: Optional[EventType[[], BASE_STATE]] = None,
+        **props,
+    ) -> "BackendDisabled":
+        """Create a backend disabled component.
+
+        Args:
+            access_key: Provides a hint for generating a keyboard shortcut for the current element.
+            auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user.
+            content_editable: Indicates whether the element's content is editable.
+            context_menu: Defines the ID of a <menu> element which will serve as the element's context menu.
+            dir: Defines the text direction. Allowed values are ltr (Left-To-Right) or rtl (Right-To-Left)
+            draggable: Defines whether the element can be dragged.
+            enter_key_hint: Hints what media types the media element is able to play.
+            hidden: Defines whether the element is hidden.
+            input_mode: Defines the type of the element.
+            item_prop: Defines the name of the element for metadata purposes.
+            lang: Defines the language used in the element.
+            role: Defines the role of the element.
+            slot: Assigns a slot in a shadow DOM shadow tree to an element.
+            spell_check: Defines whether the element may be checked for spelling errors.
+            tab_index: Defines the position of the current element in the tabbing order.
+            title: Defines a tooltip for the element.
+            style: The style of the component.
+            key: A unique key for the component.
+            id: The id for the component.
+            class_name: The class name for the component.
+            autofocus: Whether the component should take the focus once the page is loaded
+            custom_attrs: custom attribute
+            **props: The properties of the component.
+
+        Returns:
+            The backend disabled component.
+        """
+        ...
+
 connection_banner = ConnectionBanner.create
 connection_modal = ConnectionModal.create
 connection_toaster = ConnectionToaster.create
 connection_pulser = ConnectionPulser.create
+backend_disabled = BackendDisabled.create

+ 1 - 1
reflex/components/radix/themes/components/card.py

@@ -20,7 +20,7 @@ class Card(elements.Div, RadixThemesComponent):
     # Card size: "1" - "5"
     size: Var[Responsive[Literal["1", "2", "3", "4", "5"],]]
 
-    # Variant of Card: "solid" | "soft" | "outline" | "ghost"
+    # Variant of Card: "surface" | "classic" | "ghost"
     variant: Var[Literal["surface", "classic", "ghost"]]
 
 

+ 1 - 1
reflex/components/radix/themes/components/card.pyi

@@ -94,7 +94,7 @@ class Card(elements.Div, RadixThemesComponent):
             *children: Child components.
             as_child: Change the default rendered element for the one passed as a child, merging their props and behavior.
             size: Card size: "1" - "5"
-            variant: Variant of Card: "solid" | "soft" | "outline" | "ghost"
+            variant: Variant of Card: "surface" | "classic" | "ghost"
             access_key: Provides a hint for generating a keyboard shortcut for the current element.
             auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user.
             content_editable: Indicates whether the element's content is editable.

+ 3 - 0
reflex/config.py

@@ -703,6 +703,9 @@ class Config(Base):
     # Path to file containing key-values pairs to override in the environment; Dotenv format.
     env_file: Optional[str] = None
 
+    # Whether the app is running in the reflex cloud environment.
+    is_reflex_cloud: bool = False
+
     def __init__(self, *args, **kwargs):
         """Initialize the config values.
 

+ 102 - 8
tests/integration/test_connection_banner.py

@@ -1,5 +1,6 @@
 """Test case for displaying the connection banner when the websocket drops."""
 
+import functools
 from typing import Generator
 
 import pytest
@@ -11,12 +12,19 @@ from reflex.testing import AppHarness, WebDriver
 from .utils import SessionStorage
 
 
-def ConnectionBanner():
-    """App with a connection banner."""
+def ConnectionBanner(is_reflex_cloud: bool = False):
+    """App with a connection banner.
+
+    Args:
+        is_reflex_cloud: The value for config.is_reflex_cloud.
+    """
     import asyncio
 
     import reflex as rx
 
+    # Simulate reflex cloud deploy
+    rx.config.get_config().is_reflex_cloud = is_reflex_cloud
+
     class State(rx.State):
         foo: int = 0
 
@@ -40,19 +48,43 @@ def ConnectionBanner():
     app.add_page(index)
 
 
+@pytest.fixture(
+    params=[False, True], ids=["reflex_cloud_disabled", "reflex_cloud_enabled"]
+)
+def simulate_is_reflex_cloud(request) -> bool:
+    """Fixture to simulate reflex cloud deployment.
+
+    Args:
+        request: pytest request fixture.
+
+    Returns:
+        True if reflex cloud is enabled, False otherwise.
+    """
+    return request.param
+
+
 @pytest.fixture()
-def connection_banner(tmp_path) -> Generator[AppHarness, None, None]:
+def connection_banner(
+    tmp_path,
+    simulate_is_reflex_cloud: bool,
+) -> Generator[AppHarness, None, None]:
     """Start ConnectionBanner app at tmp_path via AppHarness.
 
     Args:
         tmp_path: pytest tmp_path fixture
+        simulate_is_reflex_cloud: Whether is_reflex_cloud is set for the app.
 
     Yields:
         running AppHarness instance
     """
     with AppHarness.create(
         root=tmp_path,
-        app_source=ConnectionBanner,
+        app_source=functools.partial(
+            ConnectionBanner, is_reflex_cloud=simulate_is_reflex_cloud
+        ),
+        app_name="connection_banner_reflex_cloud"
+        if simulate_is_reflex_cloud
+        else "connection_banner",
     ) as harness:
         yield harness
 
@@ -77,6 +109,38 @@ def has_error_modal(driver: WebDriver) -> bool:
         return True
 
 
+def has_cloud_banner(driver: WebDriver) -> bool:
+    """Check if the cloud banner is displayed.
+
+    Args:
+        driver: Selenium webdriver instance.
+
+    Returns:
+        True if the banner is displayed, False otherwise.
+    """
+    try:
+        driver.find_element(
+            By.XPATH, "//*[ contains(text(), 'You ran out of compute credits.') ]"
+        )
+    except NoSuchElementException:
+        return False
+    else:
+        return True
+
+
+def _assert_token(connection_banner, driver):
+    """Poll for backend to be up.
+
+    Args:
+        connection_banner: AppHarness instance.
+        driver: Selenium webdriver instance.
+    """
+    ss = SessionStorage(driver)
+    assert connection_banner._poll_for(
+        lambda: ss.get("token") is not None
+    ), "token not found"
+
+
 @pytest.mark.asyncio
 async def test_connection_banner(connection_banner: AppHarness):
     """Test that the connection banner is displayed when the websocket drops.
@@ -88,10 +152,7 @@ async def test_connection_banner(connection_banner: AppHarness):
     assert connection_banner.backend is not None
     driver = connection_banner.frontend()
 
-    ss = SessionStorage(driver)
-    assert connection_banner._poll_for(
-        lambda: ss.get("token") is not None
-    ), "token not found"
+    _assert_token(connection_banner, driver)
 
     assert connection_banner._poll_for(lambda: not has_error_modal(driver))
 
@@ -132,3 +193,36 @@ async def test_connection_banner(connection_banner: AppHarness):
 
     # Count should have incremented after coming back up
     assert connection_banner.poll_for_value(counter_element, exp_not_equal="1") == "2"
+
+
+@pytest.mark.asyncio
+async def test_cloud_banner(
+    connection_banner: AppHarness, simulate_is_reflex_cloud: bool
+):
+    """Test that the connection banner is displayed when the websocket drops.
+
+    Args:
+        connection_banner: AppHarness instance.
+        simulate_is_reflex_cloud: Whether is_reflex_cloud is set for the app.
+    """
+    assert connection_banner.app_instance is not None
+    assert connection_banner.backend is not None
+    driver = connection_banner.frontend()
+
+    driver.add_cookie({"name": "backend-enabled", "value": "truly"})
+    driver.refresh()
+    _assert_token(connection_banner, driver)
+    assert connection_banner._poll_for(lambda: not has_cloud_banner(driver))
+
+    driver.add_cookie({"name": "backend-enabled", "value": "false"})
+    driver.refresh()
+    if simulate_is_reflex_cloud:
+        assert connection_banner._poll_for(lambda: has_cloud_banner(driver))
+    else:
+        _assert_token(connection_banner, driver)
+        assert connection_banner._poll_for(lambda: not has_cloud_banner(driver))
+
+    driver.delete_cookie("backend-enabled")
+    driver.refresh()
+    _assert_token(connection_banner, driver)
+    assert connection_banner._poll_for(lambda: not has_cloud_banner(driver))