소스 검색

ConnectionModal and ConnectionBanner cleanup (#1379)

Masen Furer 1 년 전
부모
커밋
6b481ecfc3

+ 1 - 6
reflex/.templates/jinja/web/pages/index.js.jinja2

@@ -14,7 +14,7 @@ export default function Component() {
   const focusRef = useRef();
   
   // Main event loop.
-  const [Event, notConnected] = useContext(EventLoopContext)
+  const [Event, connectError] = useContext(EventLoopContext)
 
   // Set focus to the specified element.
   useEffect(() => {
@@ -37,12 +37,7 @@ export default function Component() {
   {% endfor %}
 
   return (
-  <Fragment>
-      {%- if err_comp -%}
-            {{ utils.render(err_comp, indent_width=1) }}
-       {%- endif -%}
     {{utils.render(render, indent_width=0)}}
-    </Fragment>
   )
 }
 {% endblock %}

+ 2 - 2
reflex/.templates/web/pages/_app.js

@@ -15,12 +15,12 @@ const GlobalStyles = css`
 `;
 
 function EventLoopProvider({ children }) {
-  const [state, Event, notConnected] = useEventLoop(
+  const [state, Event, connectError] = useEventLoop(
     initialState,
     initialEvents,
   )
   return (
-    <EventLoopContext.Provider value={[Event, notConnected]}>
+    <EventLoopContext.Provider value={[Event, connectError]}>
       <StateContext.Provider value={state}>
         {children}
       </StateContext.Provider>

+ 9 - 9
reflex/.templates/web/utils/state.js

@@ -250,14 +250,14 @@ export const processEvent = async (
  * @param socket The socket object to connect.
  * @param dispatch The function to queue state update
  * @param transports The transports to use.
- * @param setNotConnected The function to update connection state.
+ * @param setConnectError The function to update connection error value.
  * @param initial_events Array of events to seed the queue after connecting.
  */
 export const connect = async (
   socket,
   dispatch,
   transports,
-  setNotConnected,
+  setConnectError,
   initial_events = [],
 ) => {
   // Get backend URL object from the endpoint.
@@ -272,11 +272,11 @@ export const connect = async (
   // Once the socket is open, hydrate the page.
   socket.current.on("connect", () => {
     queueEvents(initial_events, socket)
-    setNotConnected(false)
+    setConnectError(null)
   });
 
   socket.current.on('connect_error', (error) => {
-    setNotConnected(true)
+    setConnectError(error)
   });
 
   // On each received message, queue the updates and events.
@@ -357,10 +357,10 @@ export const E = (name, payload = {}, handler = null) => {
  * @param initial_state The initial page state.
  * @param initial_events Array of events to seed the queue after connecting.
  *
- * @returns [state, Event, notConnected] -
+ * @returns [state, Event, connectError] -
  *   state is a reactive dict,
  *   Event is used to queue an event, and
- *   notConnected is a reactive boolean indicating whether the websocket is connected.
+ *   connectError is a reactive js error from the websocket connection (or null if connected).
  */
 export const useEventLoop = (
   initial_state = {},
@@ -369,7 +369,7 @@ export const useEventLoop = (
   const socket = useRef(null)
   const router = useRouter()
   const [state, dispatch] = useReducer(applyDelta, initial_state)
-  const [notConnected, setNotConnected] = useState(false)
+  const [connectError, setConnectError] = useState(null)
   
   // Function to add new events to the event queue.
   const Event = (events, _e) => {
@@ -386,7 +386,7 @@ export const useEventLoop = (
 
     // Initialize the websocket connection.
     if (!socket.current) {
-      connect(socket, dispatch, ['websocket', 'polling'], setNotConnected, initial_events)
+      connect(socket, dispatch, ['websocket', 'polling'], setConnectError, initial_events)
     }
     (async () => {
       // Process all outstanding events.
@@ -395,7 +395,7 @@ export const useEventLoop = (
       }
     })()
   })
-  return [state, Event, notConnected]
+  return [state, Event, connectError]
 }
 
 /***

+ 43 - 20
reflex/app.py

@@ -1,4 +1,5 @@
 """The main Reflex app."""
+from __future__ import annotations
 
 import asyncio
 import inspect
@@ -29,6 +30,7 @@ from reflex.admin import AdminDash
 from reflex.base import Base
 from reflex.compiler import compiler
 from reflex.compiler import utils as compiler_utils
+from reflex.components import connection_modal
 from reflex.components.component import Component, ComponentStyle
 from reflex.components.layout.fragment import Fragment
 from reflex.config import get_config
@@ -88,12 +90,12 @@ class App(Base):
     # Admin dashboard
     admin_dash: Optional[AdminDash] = None
 
-    # The component to render if there is a connection error to the server.
-    connect_error_component: Optional[Component] = None
-
     # The async server name space
     event_namespace: Optional[AsyncNamespace] = None
 
+    # A component that is present on every page.
+    overlay_component: Optional[Union[Component, ComponentCallable]] = connection_modal
+
     def __init__(self, *args, **kwargs):
         """Initialize the app.
 
@@ -106,6 +108,10 @@ class App(Base):
                         Also, if there are multiple client subclasses of rx.State(Subclasses of rx.State should consist
                         of the DefaultState and the client app state).
         """
+        if "connect_error_component" in kwargs:
+            raise ValueError(
+                "`connect_error_component` is deprecated, use `overlay_component` instead"
+            )
         super().__init__(*args, **kwargs)
         state_subclasses = State.__subclasses__()
         inferred_state = state_subclasses[-1]
@@ -269,6 +275,31 @@ class App(Base):
         else:
             self.middleware.insert(index, middleware)
 
+    @staticmethod
+    def _generate_component(component: Component | ComponentCallable) -> Component:
+        """Generate a component from a callable.
+
+        Args:
+            component: The component function to call or Component to return as-is.
+
+        Returns:
+            The generated component.
+
+        Raises:
+            TypeError: When an invalid component function is passed.
+        """
+        try:
+            return component if isinstance(component, Component) else component()
+        except TypeError as e:
+            message = str(e)
+            if "BaseVar" in message or "ComputedVar" in message:
+                raise TypeError(
+                    "You may be trying to use an invalid Python function on a state var. "
+                    "When referencing a var inside your render code, only limited var operations are supported. "
+                    "See the var operation docs here: https://reflex.dev/docs/state/vars/#var-operations"
+                ) from e
+            raise e
+
     def add_page(
         self,
         component: Union[Component, ComponentCallable],
@@ -296,9 +327,6 @@ class App(Base):
             on_load: The event handler(s) that will be called each time the page load.
             meta: The metadata of the page.
             script_tags: List of script tags to be added to component
-
-        Raises:
-            TypeError: If an invalid var operation is used.
         """
         # If the route is not set, get it from the callable.
         if route is None:
@@ -314,20 +342,16 @@ class App(Base):
         self.state.setup_dynamic_args(get_route_args(route))
 
         # Generate the component if it is a callable.
-        try:
-            component = component if isinstance(component, Component) else component()
-        except TypeError as e:
-            message = str(e)
-            if "BaseVar" in message or "ComputedVar" in message:
-                raise TypeError(
-                    "You may be trying to use an invalid Python function on a state var. "
-                    "When referencing a var inside your render code, only limited var operations are supported. "
-                    "See the var operation docs here: https://reflex.dev/docs/state/vars/#var-operations"
-                ) from e
-            raise e
+        component = self._generate_component(component)
 
-        # Wrap the component in a fragment.
-        component = Fragment.create(component)
+        # Wrap the component in a fragment with optional overlay.
+        if self.overlay_component is not None:
+            component = Fragment.create(
+                self._generate_component(self.overlay_component),
+                component,
+            )
+        else:
+            component = Fragment.create(component)
 
         # Add meta information to the component.
         compiler_utils.add_meta(
@@ -497,7 +521,6 @@ class App(Base):
                             route,
                             component,
                             self.state,
-                            self.connect_error_component,
                         ),
                     )
                 )

+ 1 - 10
reflex/compiler/compiler.py

@@ -89,14 +89,12 @@ def _compile_contexts(state: Type[State]) -> str:
 def _compile_page(
     component: Component,
     state: Type[State],
-    connect_error_component,
 ) -> str:
     """Compile the component given the app state.
 
     Args:
         component: The component to compile.
         state: The app state.
-        connect_error_component: The component to render on sever connection error.
 
     Returns:
         The compiled component.
@@ -113,7 +111,6 @@ def _compile_page(
         state_name=state.get_name(),
         hooks=component.get_hooks(),
         render=component.render(),
-        err_comp=connect_error_component.render() if connect_error_component else None,
     )
 
 
@@ -221,7 +218,6 @@ def compile_page(
     path: str,
     component: Component,
     state: Type[State],
-    connect_error_component: Component,
 ) -> Tuple[str, str]:
     """Compile a single page.
 
@@ -229,7 +225,6 @@ def compile_page(
         path: The path to compile the page to.
         component: The component to compile.
         state: The app state.
-        connect_error_component: The component to render on sever connection error.
 
     Returns:
         The path and code of the compiled page.
@@ -238,11 +233,7 @@ def compile_page(
     output_path = utils.get_page_path(path)
 
     # Add the style to the component.
-    code = _compile_page(
-        component,
-        state,
-        connect_error_component,
-    )
+    code = _compile_page(component, state)
     return output_path, code
 
 

+ 1 - 0
reflex/components/__init__.py

@@ -31,6 +31,7 @@ badge = Badge.create
 code = Code.create
 code_block = CodeBlock.create
 connection_banner = ConnectionBanner.create
+connection_modal = ConnectionModal.create
 data_table = DataTable.create
 divider = Divider.create
 list = List.create

+ 1 - 1
reflex/components/overlay/__init__.py

@@ -8,7 +8,7 @@ from .alertdialog import (
     AlertDialogHeader,
     AlertDialogOverlay,
 )
-from .banner import ConnectionBanner
+from .banner import ConnectionBanner, ConnectionModal
 from .drawer import (
     Drawer,
     DrawerBody,

+ 54 - 2
reflex/components/overlay/banner.py

@@ -1,11 +1,41 @@
 """Banner components."""
+from __future__ import annotations
+
 from typing import Optional
 
 from reflex.components.component import Component
 from reflex.components.layout import Box, Cond, Fragment
+from reflex.components.overlay.modal import Modal
 from reflex.components.typography import Text
 from reflex.vars import Var
 
+connection_error = Var.create_safe(
+    value="(connectError !== null) ? connectError.message : ''",
+    is_local=False,
+    is_string=False,
+)
+has_connection_error = Var.create_safe(
+    value="connectError !== null",
+    is_string=False,
+)
+has_connection_error.type_ = bool
+
+
+def default_connection_error() -> list[str | Var]:
+    """Get the default connection error message.
+
+    Returns:
+        The default connection error message.
+    """
+    from reflex.config import get_config
+
+    return [
+        "Cannot connect to server: ",
+        connection_error,
+        ". Check if server is reachable at ",
+        get_config().api_url or "<API_URL not set>",
+    ]
+
 
 class ConnectionBanner(Cond):
     """A connection banner component."""
@@ -23,11 +53,33 @@ class ConnectionBanner(Cond):
         if not comp:
             comp = Box.create(
                 Text.create(
-                    "cannot connect to server. Check if server is reachable",
+                    *default_connection_error(),
                     bg="red",
                     color="white",
                 ),
                 textAlign="center",
             )
 
-        return super().create(Var.create("notConnected"), comp, Fragment.create())  # type: ignore
+        return super().create(has_connection_error, comp, Fragment.create())  # type: ignore
+
+
+class ConnectionModal(Modal):
+    """A connection status modal window."""
+
+    @classmethod
+    def create(cls, comp: Optional[Component] = None) -> Component:
+        """Create a connection banner component.
+
+        Args:
+            comp: The component to render when there's a server connection error.
+
+        Returns:
+            The connection banner component.
+        """
+        if not comp:
+            comp = Text.create(*default_connection_error())
+        return super().create(
+            header="Connection Error",
+            body=comp,
+            is_open=has_connection_error,
+        )