Procházet zdrojové kódy

[REF-843] Automatically update api_url and deploy_url (#1954)

Masen Furer před 1 rokem
rodič
revize
684912e33b

+ 21 - 2
reflex/.templates/web/utils/state.js

@@ -67,6 +67,24 @@ export const getToken = () => {
   return token;
   return token;
 };
 };
 
 
+/**
+ * Get the URL for the websocket connection
+ * @returns The websocket URL object.
+ */
+export const getEventURL = () => {
+  // Get backend URL object from the endpoint.
+  const endpoint = new URL(EVENTURL);
+  if (endpoint.hostname === "localhost") {
+    // If the backend URL references localhost, and the frontend is not on localhost,
+    // then use the frontend host.
+    const frontend_hostname = window.location.hostname;
+    if (frontend_hostname !== "localhost") {
+      endpoint.hostname = frontend_hostname;
+    }
+  }
+  return endpoint
+}
+
 /**
 /**
  * Apply a delta to the state.
  * Apply a delta to the state.
  * @param state The state to apply the delta to.
  * @param state The state to apply the delta to.
@@ -289,9 +307,10 @@ export const connect = async (
   client_storage = {},
   client_storage = {},
 ) => {
 ) => {
   // Get backend URL object from the endpoint.
   // Get backend URL object from the endpoint.
-  const endpoint = new URL(EVENTURL);
+  const endpoint = getEventURL()
+
   // Create the socket.
   // Create the socket.
-  socket.current = io(EVENTURL, {
+  socket.current = io(endpoint.href, {
     path: endpoint["pathname"],
     path: endpoint["pathname"],
     transports: transports,
     transports: transports,
     autoUnref: false,
     autoUnref: false,

+ 23 - 5
reflex/components/overlay/banner.py

@@ -3,11 +3,13 @@ from __future__ import annotations
 
 
 from typing import Optional
 from typing import Optional
 
 
+from reflex.components.base.bare import Bare
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.components.layout import Box, Cond
 from reflex.components.layout import Box, Cond
 from reflex.components.overlay.modal import Modal
 from reflex.components.overlay.modal import Modal
 from reflex.components.typography import Text
 from reflex.components.typography import Text
-from reflex.vars import Var
+from reflex.utils import imports
+from reflex.vars import ImportVar, Var
 
 
 connection_error: Var = Var.create_safe(
 connection_error: Var = Var.create_safe(
     value="(connectError !== null) ? connectError.message : ''",
     value="(connectError !== null) ? connectError.message : ''",
@@ -21,19 +23,35 @@ has_connection_error: Var = Var.create_safe(
 has_connection_error.type_ = bool
 has_connection_error.type_ = bool
 
 
 
 
-def default_connection_error() -> list[str | Var]:
+class WebsocketTargetURL(Bare):
+    """A component that renders the websocket target URL."""
+
+    def _get_imports(self) -> imports.ImportDict:
+        return {
+            "/utils/state.js": {ImportVar(tag="getEventURL")},
+        }
+
+    @classmethod
+    def create(cls) -> Component:
+        """Create a websocket target URL component.
+
+        Returns:
+            The websocket target URL component.
+        """
+        return super().create(contents="{getEventURL().href}")
+
+
+def default_connection_error() -> list[str | Var | Component]:
     """Get the default connection error message.
     """Get the default connection error message.
 
 
     Returns:
     Returns:
         The default connection error message.
         The default connection error message.
     """
     """
-    from reflex.config import get_config
-
     return [
     return [
         "Cannot connect to server: ",
         "Cannot connect to server: ",
         connection_error,
         connection_error,
         ". Check if server is reachable at ",
         ". Check if server is reachable at ",
-        get_config().api_url or "<API_URL not set>",
+        WebsocketTargetURL.create(),
     ]
     ]
 
 
 
 

+ 49 - 4
reflex/config.py

@@ -6,7 +6,9 @@ import importlib
 import os
 import os
 import sys
 import sys
 import urllib.parse
 import urllib.parse
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Set
+
+import pydantic
 
 
 from reflex import constants
 from reflex import constants
 from reflex.base import Base
 from reflex.base import Base
@@ -191,6 +193,9 @@ class Config(Base):
     # The username.
     # The username.
     username: Optional[str] = None
     username: Optional[str] = None
 
 
+    # Attributes that were explicitly set by the user.
+    _non_default_attributes: Set[str] = pydantic.PrivateAttr(set())
+
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         """Initialize the config values.
         """Initialize the config values.
 
 
@@ -204,7 +209,14 @@ class Config(Base):
         self.check_deprecated_values(**kwargs)
         self.check_deprecated_values(**kwargs)
 
 
         # Update the config from environment variables.
         # Update the config from environment variables.
-        self.update_from_env()
+        env_kwargs = self.update_from_env()
+        for key, env_value in env_kwargs.items():
+            setattr(self, key, env_value)
+
+        # Update default URLs if ports were set
+        kwargs.update(env_kwargs)
+        self._non_default_attributes.update(kwargs)
+        self._replace_defaults(**kwargs)
 
 
     @staticmethod
     @staticmethod
     def check_deprecated_values(**kwargs):
     def check_deprecated_values(**kwargs):
@@ -227,13 +239,16 @@ class Config(Base):
                 "env_path is deprecated - use environment variables instead"
                 "env_path is deprecated - use environment variables instead"
             )
             )
 
 
-    def update_from_env(self):
+    def update_from_env(self) -> dict[str, Any]:
         """Update the config from environment variables.
         """Update the config from environment variables.
 
 
+        Returns:
+            The updated config values.
 
 
         Raises:
         Raises:
             ValueError: If an environment variable is set to an invalid type.
             ValueError: If an environment variable is set to an invalid type.
         """
         """
+        updated_values = {}
         # Iterate over the fields.
         # Iterate over the fields.
         for key, field in self.__fields__.items():
         for key, field in self.__fields__.items():
             # The env var name is the key in uppercase.
             # The env var name is the key in uppercase.
@@ -260,7 +275,9 @@ class Config(Base):
                     raise
                     raise
 
 
                 # Set the value.
                 # Set the value.
-                setattr(self, key, env_var)
+                updated_values[key] = env_var
+
+        return updated_values
 
 
     def get_event_namespace(self) -> str | None:
     def get_event_namespace(self) -> str | None:
         """Get the websocket event namespace.
         """Get the websocket event namespace.
@@ -274,6 +291,34 @@ class Config(Base):
         event_url = constants.Endpoint.EVENT.get_url()
         event_url = constants.Endpoint.EVENT.get_url()
         return urllib.parse.urlsplit(event_url).path
         return urllib.parse.urlsplit(event_url).path
 
 
+    def _replace_defaults(self, **kwargs):
+        """Replace formatted defaults when the caller provides updates.
+
+        Args:
+            **kwargs: The kwargs passed to the config or from the env.
+        """
+        if "api_url" not in self._non_default_attributes and "backend_port" in kwargs:
+            self.api_url = f"http://localhost:{kwargs['backend_port']}"
+
+        if (
+            "deploy_url" not in self._non_default_attributes
+            and "frontend_port" in kwargs
+        ):
+            self.deploy_url = f"http://localhost:{kwargs['frontend_port']}"
+
+    def _set_persistent(self, **kwargs):
+        """Set values in this config and in the environment so they persist into subprocess.
+
+        Args:
+            **kwargs: The kwargs passed to the config.
+        """
+        for key, value in kwargs.items():
+            if value is not None:
+                os.environ[key.upper()] = str(value)
+            setattr(self, key, value)
+        self._non_default_attributes.update(kwargs)
+        self._replace_defaults(**kwargs)
+
 
 
 def get_config(reload: bool = False) -> Config:
 def get_config(reload: bool = False) -> Config:
     """Get the app config.
     """Get the app config.

+ 1 - 0
reflex/config.pyi

@@ -97,5 +97,6 @@ class Config(Base):
     def check_deprecated_values(**kwargs) -> None: ...
     def check_deprecated_values(**kwargs) -> None: ...
     def update_from_env(self) -> None: ...
     def update_from_env(self) -> None: ...
     def get_event_namespace(self) -> str | None: ...
     def get_event_namespace(self) -> str | None: ...
+    def _set_persistent(self, **kwargs) -> None: ...
 
 
 def get_config(reload: bool = ...) -> Config: ...
 def get_config(reload: bool = ...) -> Config: ...

+ 6 - 0
reflex/reflex.py

@@ -140,6 +140,12 @@ def run(
     if backend and processes.is_process_on_port(backend_port):
     if backend and processes.is_process_on_port(backend_port):
         backend_port = processes.change_or_terminate_port(backend_port, "backend")
         backend_port = processes.change_or_terminate_port(backend_port, "backend")
 
 
+    # Apply the new ports to the config.
+    if frontend_port != str(config.frontend_port):
+        config._set_persistent(frontend_port=frontend_port)
+    if backend_port != str(config.backend_port):
+        config._set_persistent(backend_port=backend_port)
+
     console.rule("[bold]Starting Reflex App")
     console.rule("[bold]Starting Reflex App")
 
 
     if frontend:
     if frontend:

+ 92 - 0
tests/test_config.py

@@ -108,3 +108,95 @@ def test_event_namespace(mocker, kwargs, expected):
     config = reflex.config.get_config()
     config = reflex.config.get_config()
     assert conf == config
     assert conf == config
     assert config.get_event_namespace() == expected
     assert config.get_event_namespace() == expected
+
+
+DEFAULT_CONFIG = rx.Config(app_name="a")
+
+
+@pytest.mark.parametrize(
+    ("config_kwargs", "env_vars", "set_persistent_vars", "exp_config_values"),
+    [
+        (
+            {},
+            {},
+            {},
+            {
+                "api_url": DEFAULT_CONFIG.api_url,
+                "backend_port": DEFAULT_CONFIG.backend_port,
+                "deploy_url": DEFAULT_CONFIG.deploy_url,
+                "frontend_port": DEFAULT_CONFIG.frontend_port,
+            },
+        ),
+        # Ports set in config kwargs
+        (
+            {"backend_port": 8001, "frontend_port": 3001},
+            {},
+            {},
+            {
+                "api_url": "http://localhost:8001",
+                "backend_port": 8001,
+                "deploy_url": "http://localhost:3001",
+                "frontend_port": 3001,
+            },
+        ),
+        # Ports set in environment take precendence
+        (
+            {"backend_port": 8001, "frontend_port": 3001},
+            {"BACKEND_PORT": 8002},
+            {},
+            {
+                "api_url": "http://localhost:8002",
+                "backend_port": 8002,
+                "deploy_url": "http://localhost:3001",
+                "frontend_port": 3001,
+            },
+        ),
+        # Ports set on the command line take precendence
+        (
+            {"backend_port": 8001, "frontend_port": 3001},
+            {"BACKEND_PORT": 8002},
+            {"frontend_port": "3005"},
+            {
+                "api_url": "http://localhost:8002",
+                "backend_port": 8002,
+                "deploy_url": "http://localhost:3005",
+                "frontend_port": 3005,
+            },
+        ),
+        # api_url / deploy_url already set should not be overridden
+        (
+            {"api_url": "http://foo.bar:8900", "deploy_url": "http://foo.bar:3001"},
+            {"BACKEND_PORT": 8002},
+            {"frontend_port": "3005"},
+            {
+                "api_url": "http://foo.bar:8900",
+                "backend_port": 8002,
+                "deploy_url": "http://foo.bar:3001",
+                "frontend_port": 3005,
+            },
+        ),
+    ],
+)
+def test_replace_defaults(
+    monkeypatch,
+    config_kwargs,
+    env_vars,
+    set_persistent_vars,
+    exp_config_values,
+):
+    """Test that the config replaces defaults with values from the environment.
+
+    Args:
+        monkeypatch: The pytest monkeypatch object.
+        config_kwargs: The config kwargs.
+        env_vars: The environment variables.
+        set_persistent_vars: The values passed to config._set_persistent variables.
+        exp_config_values: The expected config values.
+    """
+    mock_os_env = os.environ.copy()
+    monkeypatch.setattr(reflex.config.os, "environ", mock_os_env)  # type: ignore
+    mock_os_env.update({k: str(v) for k, v in env_vars.items()})
+    c = rx.Config(app_name="a", **config_kwargs)
+    c._set_persistent(**set_persistent_vars)
+    for key, value in exp_config_values.items():
+        assert getattr(c, key) == value