瀏覽代碼

ConnectionModal and ConnectionBanner cleanup (#1379)

Masen Furer 1 年之前
父節點
當前提交
6b481ecfc3

+ 1 - 6
reflex/.templates/jinja/web/pages/index.js.jinja2

@@ -14,7 +14,7 @@ export default function Component() {
   const focusRef = useRef();
   const focusRef = useRef();
   
   
   // Main event loop.
   // Main event loop.
-  const [Event, notConnected] = useContext(EventLoopContext)
+  const [Event, connectError] = useContext(EventLoopContext)
 
 
   // Set focus to the specified element.
   // Set focus to the specified element.
   useEffect(() => {
   useEffect(() => {
@@ -37,12 +37,7 @@ export default function Component() {
   {% endfor %}
   {% endfor %}
 
 
   return (
   return (
-  <Fragment>
-      {%- if err_comp -%}
-            {{ utils.render(err_comp, indent_width=1) }}
-       {%- endif -%}
     {{utils.render(render, indent_width=0)}}
     {{utils.render(render, indent_width=0)}}
-    </Fragment>
   )
   )
 }
 }
 {% endblock %}
 {% endblock %}

+ 2 - 2
reflex/.templates/web/pages/_app.js

@@ -15,12 +15,12 @@ const GlobalStyles = css`
 `;
 `;
 
 
 function EventLoopProvider({ children }) {
 function EventLoopProvider({ children }) {
-  const [state, Event, notConnected] = useEventLoop(
+  const [state, Event, connectError] = useEventLoop(
     initialState,
     initialState,
     initialEvents,
     initialEvents,
   )
   )
   return (
   return (
-    <EventLoopContext.Provider value={[Event, notConnected]}>
+    <EventLoopContext.Provider value={[Event, connectError]}>
       <StateContext.Provider value={state}>
       <StateContext.Provider value={state}>
         {children}
         {children}
       </StateContext.Provider>
       </StateContext.Provider>

+ 9 - 9
reflex/.templates/web/utils/state.js

@@ -250,14 +250,14 @@ export const processEvent = async (
  * @param socket The socket object to connect.
  * @param socket The socket object to connect.
  * @param dispatch The function to queue state update
  * @param dispatch The function to queue state update
  * @param transports The transports to use.
  * @param transports The transports to use.
- * @param setNotConnected The function to update connection state.
+ * @param setConnectError The function to update connection error value.
  * @param initial_events Array of events to seed the queue after connecting.
  * @param initial_events Array of events to seed the queue after connecting.
  */
  */
 export const connect = async (
 export const connect = async (
   socket,
   socket,
   dispatch,
   dispatch,
   transports,
   transports,
-  setNotConnected,
+  setConnectError,
   initial_events = [],
   initial_events = [],
 ) => {
 ) => {
   // Get backend URL object from the endpoint.
   // Get backend URL object from the endpoint.
@@ -272,11 +272,11 @@ export const connect = async (
   // Once the socket is open, hydrate the page.
   // Once the socket is open, hydrate the page.
   socket.current.on("connect", () => {
   socket.current.on("connect", () => {
     queueEvents(initial_events, socket)
     queueEvents(initial_events, socket)
-    setNotConnected(false)
+    setConnectError(null)
   });
   });
 
 
   socket.current.on('connect_error', (error) => {
   socket.current.on('connect_error', (error) => {
-    setNotConnected(true)
+    setConnectError(error)
   });
   });
 
 
   // On each received message, queue the updates and events.
   // On each received message, queue the updates and events.
@@ -357,10 +357,10 @@ export const E = (name, payload = {}, handler = null) => {
  * @param initial_state The initial page state.
  * @param initial_state The initial page state.
  * @param initial_events Array of events to seed the queue after connecting.
  * @param initial_events Array of events to seed the queue after connecting.
  *
  *
- * @returns [state, Event, notConnected] -
+ * @returns [state, Event, connectError] -
  *   state is a reactive dict,
  *   state is a reactive dict,
  *   Event is used to queue an event, and
  *   Event is used to queue an event, and
- *   notConnected is a reactive boolean indicating whether the websocket is connected.
+ *   connectError is a reactive js error from the websocket connection (or null if connected).
  */
  */
 export const useEventLoop = (
 export const useEventLoop = (
   initial_state = {},
   initial_state = {},
@@ -369,7 +369,7 @@ export const useEventLoop = (
   const socket = useRef(null)
   const socket = useRef(null)
   const router = useRouter()
   const router = useRouter()
   const [state, dispatch] = useReducer(applyDelta, initial_state)
   const [state, dispatch] = useReducer(applyDelta, initial_state)
-  const [notConnected, setNotConnected] = useState(false)
+  const [connectError, setConnectError] = useState(null)
   
   
   // Function to add new events to the event queue.
   // Function to add new events to the event queue.
   const Event = (events, _e) => {
   const Event = (events, _e) => {
@@ -386,7 +386,7 @@ export const useEventLoop = (
 
 
     // Initialize the websocket connection.
     // Initialize the websocket connection.
     if (!socket.current) {
     if (!socket.current) {
-      connect(socket, dispatch, ['websocket', 'polling'], setNotConnected, initial_events)
+      connect(socket, dispatch, ['websocket', 'polling'], setConnectError, initial_events)
     }
     }
     (async () => {
     (async () => {
       // Process all outstanding events.
       // Process all outstanding events.
@@ -395,7 +395,7 @@ export const useEventLoop = (
       }
       }
     })()
     })()
   })
   })
-  return [state, Event, notConnected]
+  return [state, Event, connectError]
 }
 }
 
 
 /***
 /***

+ 43 - 20
reflex/app.py

@@ -1,4 +1,5 @@
 """The main Reflex app."""
 """The main Reflex app."""
+from __future__ import annotations
 
 
 import asyncio
 import asyncio
 import inspect
 import inspect
@@ -29,6 +30,7 @@ from reflex.admin import AdminDash
 from reflex.base import Base
 from reflex.base import Base
 from reflex.compiler import compiler
 from reflex.compiler import compiler
 from reflex.compiler import utils as compiler_utils
 from reflex.compiler import utils as compiler_utils
+from reflex.components import connection_modal
 from reflex.components.component import Component, ComponentStyle
 from reflex.components.component import Component, ComponentStyle
 from reflex.components.layout.fragment import Fragment
 from reflex.components.layout.fragment import Fragment
 from reflex.config import get_config
 from reflex.config import get_config
@@ -88,12 +90,12 @@ class App(Base):
     # Admin dashboard
     # Admin dashboard
     admin_dash: Optional[AdminDash] = None
     admin_dash: Optional[AdminDash] = None
 
 
-    # The component to render if there is a connection error to the server.
-    connect_error_component: Optional[Component] = None
-
     # The async server name space
     # The async server name space
     event_namespace: Optional[AsyncNamespace] = None
     event_namespace: Optional[AsyncNamespace] = None
 
 
+    # A component that is present on every page.
+    overlay_component: Optional[Union[Component, ComponentCallable]] = connection_modal
+
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         """Initialize the app.
         """Initialize the app.
 
 
@@ -106,6 +108,10 @@ class App(Base):
                         Also, if there are multiple client subclasses of rx.State(Subclasses of rx.State should consist
                         Also, if there are multiple client subclasses of rx.State(Subclasses of rx.State should consist
                         of the DefaultState and the client app state).
                         of the DefaultState and the client app state).
         """
         """
+        if "connect_error_component" in kwargs:
+            raise ValueError(
+                "`connect_error_component` is deprecated, use `overlay_component` instead"
+            )
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
         state_subclasses = State.__subclasses__()
         state_subclasses = State.__subclasses__()
         inferred_state = state_subclasses[-1]
         inferred_state = state_subclasses[-1]
@@ -269,6 +275,31 @@ class App(Base):
         else:
         else:
             self.middleware.insert(index, middleware)
             self.middleware.insert(index, middleware)
 
 
+    @staticmethod
+    def _generate_component(component: Component | ComponentCallable) -> Component:
+        """Generate a component from a callable.
+
+        Args:
+            component: The component function to call or Component to return as-is.
+
+        Returns:
+            The generated component.
+
+        Raises:
+            TypeError: When an invalid component function is passed.
+        """
+        try:
+            return component if isinstance(component, Component) else component()
+        except TypeError as e:
+            message = str(e)
+            if "BaseVar" in message or "ComputedVar" in message:
+                raise TypeError(
+                    "You may be trying to use an invalid Python function on a state var. "
+                    "When referencing a var inside your render code, only limited var operations are supported. "
+                    "See the var operation docs here: https://reflex.dev/docs/state/vars/#var-operations"
+                ) from e
+            raise e
+
     def add_page(
     def add_page(
         self,
         self,
         component: Union[Component, ComponentCallable],
         component: Union[Component, ComponentCallable],
@@ -296,9 +327,6 @@ class App(Base):
             on_load: The event handler(s) that will be called each time the page load.
             on_load: The event handler(s) that will be called each time the page load.
             meta: The metadata of the page.
             meta: The metadata of the page.
             script_tags: List of script tags to be added to component
             script_tags: List of script tags to be added to component
-
-        Raises:
-            TypeError: If an invalid var operation is used.
         """
         """
         # If the route is not set, get it from the callable.
         # If the route is not set, get it from the callable.
         if route is None:
         if route is None:
@@ -314,20 +342,16 @@ class App(Base):
         self.state.setup_dynamic_args(get_route_args(route))
         self.state.setup_dynamic_args(get_route_args(route))
 
 
         # Generate the component if it is a callable.
         # Generate the component if it is a callable.
-        try:
-            component = component if isinstance(component, Component) else component()
-        except TypeError as e:
-            message = str(e)
-            if "BaseVar" in message or "ComputedVar" in message:
-                raise TypeError(
-                    "You may be trying to use an invalid Python function on a state var. "
-                    "When referencing a var inside your render code, only limited var operations are supported. "
-                    "See the var operation docs here: https://reflex.dev/docs/state/vars/#var-operations"
-                ) from e
-            raise e
+        component = self._generate_component(component)
 
 
-        # Wrap the component in a fragment.
-        component = Fragment.create(component)
+        # Wrap the component in a fragment with optional overlay.
+        if self.overlay_component is not None:
+            component = Fragment.create(
+                self._generate_component(self.overlay_component),
+                component,
+            )
+        else:
+            component = Fragment.create(component)
 
 
         # Add meta information to the component.
         # Add meta information to the component.
         compiler_utils.add_meta(
         compiler_utils.add_meta(
@@ -497,7 +521,6 @@ class App(Base):
                             route,
                             route,
                             component,
                             component,
                             self.state,
                             self.state,
-                            self.connect_error_component,
                         ),
                         ),
                     )
                     )
                 )
                 )

+ 1 - 10
reflex/compiler/compiler.py

@@ -89,14 +89,12 @@ def _compile_contexts(state: Type[State]) -> str:
 def _compile_page(
 def _compile_page(
     component: Component,
     component: Component,
     state: Type[State],
     state: Type[State],
-    connect_error_component,
 ) -> str:
 ) -> str:
     """Compile the component given the app state.
     """Compile the component given the app state.
 
 
     Args:
     Args:
         component: The component to compile.
         component: The component to compile.
         state: The app state.
         state: The app state.
-        connect_error_component: The component to render on sever connection error.
 
 
     Returns:
     Returns:
         The compiled component.
         The compiled component.
@@ -113,7 +111,6 @@ def _compile_page(
         state_name=state.get_name(),
         state_name=state.get_name(),
         hooks=component.get_hooks(),
         hooks=component.get_hooks(),
         render=component.render(),
         render=component.render(),
-        err_comp=connect_error_component.render() if connect_error_component else None,
     )
     )
 
 
 
 
@@ -221,7 +218,6 @@ def compile_page(
     path: str,
     path: str,
     component: Component,
     component: Component,
     state: Type[State],
     state: Type[State],
-    connect_error_component: Component,
 ) -> Tuple[str, str]:
 ) -> Tuple[str, str]:
     """Compile a single page.
     """Compile a single page.
 
 
@@ -229,7 +225,6 @@ def compile_page(
         path: The path to compile the page to.
         path: The path to compile the page to.
         component: The component to compile.
         component: The component to compile.
         state: The app state.
         state: The app state.
-        connect_error_component: The component to render on sever connection error.
 
 
     Returns:
     Returns:
         The path and code of the compiled page.
         The path and code of the compiled page.
@@ -238,11 +233,7 @@ def compile_page(
     output_path = utils.get_page_path(path)
     output_path = utils.get_page_path(path)
 
 
     # Add the style to the component.
     # Add the style to the component.
-    code = _compile_page(
-        component,
-        state,
-        connect_error_component,
-    )
+    code = _compile_page(component, state)
     return output_path, code
     return output_path, code
 
 
 
 

+ 1 - 0
reflex/components/__init__.py

@@ -31,6 +31,7 @@ badge = Badge.create
 code = Code.create
 code = Code.create
 code_block = CodeBlock.create
 code_block = CodeBlock.create
 connection_banner = ConnectionBanner.create
 connection_banner = ConnectionBanner.create
+connection_modal = ConnectionModal.create
 data_table = DataTable.create
 data_table = DataTable.create
 divider = Divider.create
 divider = Divider.create
 list = List.create
 list = List.create

+ 1 - 1
reflex/components/overlay/__init__.py

@@ -8,7 +8,7 @@ from .alertdialog import (
     AlertDialogHeader,
     AlertDialogHeader,
     AlertDialogOverlay,
     AlertDialogOverlay,
 )
 )
-from .banner import ConnectionBanner
+from .banner import ConnectionBanner, ConnectionModal
 from .drawer import (
 from .drawer import (
     Drawer,
     Drawer,
     DrawerBody,
     DrawerBody,

+ 54 - 2
reflex/components/overlay/banner.py

@@ -1,11 +1,41 @@
 """Banner components."""
 """Banner components."""
+from __future__ import annotations
+
 from typing import Optional
 from typing import Optional
 
 
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.components.layout import Box, Cond, Fragment
 from reflex.components.layout import Box, Cond, Fragment
+from reflex.components.overlay.modal import Modal
 from reflex.components.typography import Text
 from reflex.components.typography import Text
 from reflex.vars import Var
 from reflex.vars import Var
 
 
+connection_error = Var.create_safe(
+    value="(connectError !== null) ? connectError.message : ''",
+    is_local=False,
+    is_string=False,
+)
+has_connection_error = Var.create_safe(
+    value="connectError !== null",
+    is_string=False,
+)
+has_connection_error.type_ = bool
+
+
+def default_connection_error() -> list[str | Var]:
+    """Get the default connection error message.
+
+    Returns:
+        The default connection error message.
+    """
+    from reflex.config import get_config
+
+    return [
+        "Cannot connect to server: ",
+        connection_error,
+        ". Check if server is reachable at ",
+        get_config().api_url or "<API_URL not set>",
+    ]
+
 
 
 class ConnectionBanner(Cond):
 class ConnectionBanner(Cond):
     """A connection banner component."""
     """A connection banner component."""
@@ -23,11 +53,33 @@ class ConnectionBanner(Cond):
         if not comp:
         if not comp:
             comp = Box.create(
             comp = Box.create(
                 Text.create(
                 Text.create(
-                    "cannot connect to server. Check if server is reachable",
+                    *default_connection_error(),
                     bg="red",
                     bg="red",
                     color="white",
                     color="white",
                 ),
                 ),
                 textAlign="center",
                 textAlign="center",
             )
             )
 
 
-        return super().create(Var.create("notConnected"), comp, Fragment.create())  # type: ignore
+        return super().create(has_connection_error, comp, Fragment.create())  # type: ignore
+
+
+class ConnectionModal(Modal):
+    """A connection status modal window."""
+
+    @classmethod
+    def create(cls, comp: Optional[Component] = None) -> Component:
+        """Create a connection banner component.
+
+        Args:
+            comp: The component to render when there's a server connection error.
+
+        Returns:
+            The connection banner component.
+        """
+        if not comp:
+            comp = Text.create(*default_connection_error())
+        return super().create(
+            header="Connection Error",
+            body=comp,
+            is_open=has_connection_error,
+        )