|
@@ -91,6 +91,12 @@ def default_overlay_component() -> Component:
|
|
|
return Fragment.create(connection_pulser(), connection_modal())
|
|
|
|
|
|
|
|
|
+class OverlayFragment(Fragment):
|
|
|
+ """Alias for Fragment, used to wrap the overlay_component."""
|
|
|
+
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
class App(Base):
|
|
|
"""A Reflex application."""
|
|
|
|
|
@@ -159,7 +165,7 @@ class App(Base):
|
|
|
|
|
|
Raises:
|
|
|
ValueError: If the event namespace is not provided in the config.
|
|
|
- Also, if there are multiple client subclasses of rx.State(Subclasses of rx.State should consist
|
|
|
+ Also, if there are multiple client subclasses of rx.BaseState(Subclasses of rx.BaseState should consist
|
|
|
of the DefaultState and the client app state).
|
|
|
"""
|
|
|
if "connect_error_component" in kwargs:
|
|
@@ -167,12 +173,12 @@ class App(Base):
|
|
|
"`connect_error_component` is deprecated, use `overlay_component` instead"
|
|
|
)
|
|
|
super().__init__(*args, **kwargs)
|
|
|
- state_subclasses = BaseState.__subclasses__()
|
|
|
+ base_state_subclasses = BaseState.__subclasses__()
|
|
|
|
|
|
# Special case to allow test cases have multiple subclasses of rx.BaseState.
|
|
|
if not is_testing_env():
|
|
|
# Only one Base State class is allowed.
|
|
|
- if len(state_subclasses) > 1:
|
|
|
+ if len(base_state_subclasses) > 1:
|
|
|
raise ValueError(
|
|
|
"rx.BaseState cannot be subclassed multiple times. use rx.State instead"
|
|
|
)
|
|
@@ -184,12 +190,6 @@ class App(Base):
|
|
|
deprecation_version="0.3.5",
|
|
|
removal_version="0.5.0",
|
|
|
)
|
|
|
- # 2 substates are built-in and not considered when determining if app is stateless.
|
|
|
- if len(State.class_subclasses) > 2:
|
|
|
- self.state = State
|
|
|
- # Get the config
|
|
|
- config = get_config()
|
|
|
-
|
|
|
# Add middleware.
|
|
|
self.middleware.append(HydrateMiddleware())
|
|
|
|
|
@@ -198,45 +198,60 @@ class App(Base):
|
|
|
self.add_cors()
|
|
|
self.add_default_endpoints()
|
|
|
|
|
|
- if self.state:
|
|
|
- # Set up the state manager.
|
|
|
- self._state_manager = StateManager.create(state=self.state)
|
|
|
-
|
|
|
- # Set up the Socket.IO AsyncServer.
|
|
|
- self.sio = AsyncServer(
|
|
|
- async_mode="asgi",
|
|
|
- cors_allowed_origins=(
|
|
|
- "*"
|
|
|
- if config.cors_allowed_origins == ["*"]
|
|
|
- else config.cors_allowed_origins
|
|
|
- ),
|
|
|
- cors_credentials=True,
|
|
|
- max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
|
|
|
- ping_interval=constants.Ping.INTERVAL,
|
|
|
- ping_timeout=constants.Ping.TIMEOUT,
|
|
|
- )
|
|
|
+ self.setup_state()
|
|
|
|
|
|
- # Create the socket app. Note event endpoint constant replaces the default 'socket.io' path.
|
|
|
- self.socket_app = ASGIApp(self.sio, socketio_path="")
|
|
|
- namespace = config.get_event_namespace()
|
|
|
+ # Set up the admin dash.
|
|
|
+ self.setup_admin_dash()
|
|
|
|
|
|
- if not namespace:
|
|
|
- raise ValueError("event namespace must be provided in the config.")
|
|
|
+ def enable_state(self) -> None:
|
|
|
+ """Enable state for the app."""
|
|
|
+ if not self.state:
|
|
|
+ self.state = State
|
|
|
+ self.setup_state()
|
|
|
|
|
|
- # Create the event namespace and attach the main app. Not related to any paths.
|
|
|
- self.event_namespace = EventNamespace(namespace, self)
|
|
|
+ def setup_state(self) -> None:
|
|
|
+ """Set up the state for the app.
|
|
|
|
|
|
- # Register the event namespace with the socket.
|
|
|
- self.sio.register_namespace(self.event_namespace)
|
|
|
- # Mount the socket app with the API.
|
|
|
- self.api.mount(str(constants.Endpoint.EVENT), self.socket_app)
|
|
|
+ Raises:
|
|
|
+ ValueError: If the event namespace is not provided in the config.
|
|
|
+ If the state has not been enabled.
|
|
|
+ """
|
|
|
+ if not self.state:
|
|
|
+ return
|
|
|
|
|
|
- # Set up the admin dash.
|
|
|
- self.setup_admin_dash()
|
|
|
+ config = get_config()
|
|
|
|
|
|
- # If a State is not used and no overlay_component is specified, do not render the connection modal
|
|
|
- if self.state is None and self.overlay_component is default_overlay_component:
|
|
|
- self.overlay_component = None
|
|
|
+ # Set up the state manager.
|
|
|
+ self._state_manager = StateManager.create(state=self.state)
|
|
|
+
|
|
|
+ # Set up the Socket.IO AsyncServer.
|
|
|
+ self.sio = AsyncServer(
|
|
|
+ async_mode="asgi",
|
|
|
+ cors_allowed_origins=(
|
|
|
+ "*"
|
|
|
+ if config.cors_allowed_origins == ["*"]
|
|
|
+ else config.cors_allowed_origins
|
|
|
+ ),
|
|
|
+ cors_credentials=True,
|
|
|
+ max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
|
|
|
+ ping_interval=constants.Ping.INTERVAL,
|
|
|
+ ping_timeout=constants.Ping.TIMEOUT,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Create the socket app. Note event endpoint constant replaces the default 'socket.io' path.
|
|
|
+ self.socket_app = ASGIApp(self.sio, socketio_path="")
|
|
|
+ namespace = config.get_event_namespace()
|
|
|
+
|
|
|
+ if not namespace:
|
|
|
+ raise ValueError("event namespace must be provided in the config.")
|
|
|
+
|
|
|
+ # Create the event namespace and attach the main app. Not related to any paths.
|
|
|
+ self.event_namespace = EventNamespace(namespace, self)
|
|
|
+
|
|
|
+ # Register the event namespace with the socket.
|
|
|
+ self.sio.register_namespace(self.event_namespace)
|
|
|
+ # Mount the socket app with the API.
|
|
|
+ self.api.mount(str(constants.Endpoint.EVENT), self.socket_app)
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
"""Get the string representation of the app.
|
|
@@ -430,21 +445,24 @@ class App(Base):
|
|
|
# Check if the route given is valid
|
|
|
verify_route_validity(route)
|
|
|
|
|
|
- # Apply dynamic args to the route.
|
|
|
- if self.state:
|
|
|
- self.state.setup_dynamic_args(get_route_args(route))
|
|
|
+ # Setup dynamic args for the route.
|
|
|
+ # this state assignment is only required for tests using the deprecated state kwarg for App
|
|
|
+ state = self.state if self.state else State
|
|
|
+ state.setup_dynamic_args(get_route_args(route))
|
|
|
|
|
|
# Generate the component if it is a callable.
|
|
|
component = self._generate_component(component)
|
|
|
|
|
|
- # Wrap the component in a fragment with optional overlay.
|
|
|
- if self.overlay_component is not None:
|
|
|
- component = Fragment.create(
|
|
|
- self._generate_component(self.overlay_component),
|
|
|
- component,
|
|
|
- )
|
|
|
- else:
|
|
|
- component = Fragment.create(component)
|
|
|
+ if self.state is None:
|
|
|
+ for var in component._get_vars(include_children=True):
|
|
|
+ if not var._var_data:
|
|
|
+ continue
|
|
|
+ if not var._var_data.state:
|
|
|
+ continue
|
|
|
+ self.enable_state()
|
|
|
+ break
|
|
|
+
|
|
|
+ component = OverlayFragment.create(component)
|
|
|
|
|
|
# Add meta information to the component.
|
|
|
compiler_utils.add_meta(
|
|
@@ -649,6 +667,28 @@ class App(Base):
|
|
|
# By default, compile the app.
|
|
|
return True
|
|
|
|
|
|
+ def _add_overlay_to_component(self, component: Component) -> Component:
|
|
|
+ if self.overlay_component is None:
|
|
|
+ return component
|
|
|
+
|
|
|
+ children = component.children
|
|
|
+ overlay_component = self._generate_component(self.overlay_component)
|
|
|
+
|
|
|
+ if children[0] == overlay_component:
|
|
|
+ return component
|
|
|
+
|
|
|
+ # recreate OverlayFragment with overlay_component as first child
|
|
|
+ component = OverlayFragment.create(overlay_component, *children)
|
|
|
+
|
|
|
+ return component
|
|
|
+
|
|
|
+ def _setup_overlay_component(self):
|
|
|
+ """If a State is not used and no overlay_component is specified, do not render the connection modal."""
|
|
|
+ if self.state is None and self.overlay_component is default_overlay_component:
|
|
|
+ self.overlay_component = None
|
|
|
+ for k, component in self.pages.items():
|
|
|
+ self.pages[k] = self._add_overlay_to_component(component)
|
|
|
+
|
|
|
def compile(self):
|
|
|
"""compile_() is the new function for performing compilation.
|
|
|
Reflex framework will call it automatically as needed.
|
|
@@ -682,6 +722,8 @@ class App(Base):
|
|
|
if not self._should_compile():
|
|
|
return
|
|
|
|
|
|
+ self._setup_overlay_component()
|
|
|
+
|
|
|
# Create a progress bar.
|
|
|
progress = Progress(
|
|
|
*Progress.get_default_columns()[:-1],
|