Bläddra i källkod

[REF-843] Automatically update api_url and deploy_url (#1954)

Masen Furer 1 år sedan
förälder
incheckning
684912e33b

+ 21 - 2
reflex/.templates/web/utils/state.js

@@ -67,6 +67,24 @@ export const getToken = () => {
   return token;
 };
 
+/**
+ * Get the URL for the websocket connection
+ * @returns The websocket URL object.
+ */
+export const getEventURL = () => {
+  // Get backend URL object from the endpoint.
+  const endpoint = new URL(EVENTURL);
+  if (endpoint.hostname === "localhost") {
+    // If the backend URL references localhost, and the frontend is not on localhost,
+    // then use the frontend host.
+    const frontend_hostname = window.location.hostname;
+    if (frontend_hostname !== "localhost") {
+      endpoint.hostname = frontend_hostname;
+    }
+  }
+  return endpoint
+}
+
 /**
  * Apply a delta to the state.
  * @param state The state to apply the delta to.
@@ -289,9 +307,10 @@ export const connect = async (
   client_storage = {},
 ) => {
   // Get backend URL object from the endpoint.
-  const endpoint = new URL(EVENTURL);
+  const endpoint = getEventURL()
+
   // Create the socket.
-  socket.current = io(EVENTURL, {
+  socket.current = io(endpoint.href, {
     path: endpoint["pathname"],
     transports: transports,
     autoUnref: false,

+ 23 - 5
reflex/components/overlay/banner.py

@@ -3,11 +3,13 @@ from __future__ import annotations
 
 from typing import Optional
 
+from reflex.components.base.bare import Bare
 from reflex.components.component import Component
 from reflex.components.layout import Box, Cond
 from reflex.components.overlay.modal import Modal
 from reflex.components.typography import Text
-from reflex.vars import Var
+from reflex.utils import imports
+from reflex.vars import ImportVar, Var
 
 connection_error: Var = Var.create_safe(
     value="(connectError !== null) ? connectError.message : ''",
@@ -21,19 +23,35 @@ has_connection_error: Var = Var.create_safe(
 has_connection_error.type_ = bool
 
 
-def default_connection_error() -> list[str | Var]:
+class WebsocketTargetURL(Bare):
+    """A component that renders the websocket target URL."""
+
+    def _get_imports(self) -> imports.ImportDict:
+        return {
+            "/utils/state.js": {ImportVar(tag="getEventURL")},
+        }
+
+    @classmethod
+    def create(cls) -> Component:
+        """Create a websocket target URL component.
+
+        Returns:
+            The websocket target URL component.
+        """
+        return super().create(contents="{getEventURL().href}")
+
+
+def default_connection_error() -> list[str | Var | Component]:
     """Get the default connection error message.
 
     Returns:
         The default connection error message.
     """
-    from reflex.config import get_config
-
     return [
         "Cannot connect to server: ",
         connection_error,
         ". Check if server is reachable at ",
-        get_config().api_url or "<API_URL not set>",
+        WebsocketTargetURL.create(),
     ]
 
 

+ 49 - 4
reflex/config.py

@@ -6,7 +6,9 @@ import importlib
 import os
 import sys
 import urllib.parse
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Set
+
+import pydantic
 
 from reflex import constants
 from reflex.base import Base
@@ -191,6 +193,9 @@ class Config(Base):
     # The username.
     username: Optional[str] = None
 
+    # Attributes that were explicitly set by the user.
+    _non_default_attributes: Set[str] = pydantic.PrivateAttr(set())
+
     def __init__(self, *args, **kwargs):
         """Initialize the config values.
 
@@ -204,7 +209,14 @@ class Config(Base):
         self.check_deprecated_values(**kwargs)
 
         # Update the config from environment variables.
-        self.update_from_env()
+        env_kwargs = self.update_from_env()
+        for key, env_value in env_kwargs.items():
+            setattr(self, key, env_value)
+
+        # Update default URLs if ports were set
+        kwargs.update(env_kwargs)
+        self._non_default_attributes.update(kwargs)
+        self._replace_defaults(**kwargs)
 
     @staticmethod
     def check_deprecated_values(**kwargs):
@@ -227,13 +239,16 @@ class Config(Base):
                 "env_path is deprecated - use environment variables instead"
             )
 
-    def update_from_env(self):
+    def update_from_env(self) -> dict[str, Any]:
         """Update the config from environment variables.
 
+        Returns:
+            The updated config values.
 
         Raises:
             ValueError: If an environment variable is set to an invalid type.
         """
+        updated_values = {}
         # Iterate over the fields.
         for key, field in self.__fields__.items():
             # The env var name is the key in uppercase.
@@ -260,7 +275,9 @@ class Config(Base):
                     raise
 
                 # Set the value.
-                setattr(self, key, env_var)
+                updated_values[key] = env_var
+
+        return updated_values
 
     def get_event_namespace(self) -> str | None:
         """Get the websocket event namespace.
@@ -274,6 +291,34 @@ class Config(Base):
         event_url = constants.Endpoint.EVENT.get_url()
         return urllib.parse.urlsplit(event_url).path
 
+    def _replace_defaults(self, **kwargs):
+        """Replace formatted defaults when the caller provides updates.
+
+        Args:
+            **kwargs: The kwargs passed to the config or from the env.
+        """
+        if "api_url" not in self._non_default_attributes and "backend_port" in kwargs:
+            self.api_url = f"http://localhost:{kwargs['backend_port']}"
+
+        if (
+            "deploy_url" not in self._non_default_attributes
+            and "frontend_port" in kwargs
+        ):
+            self.deploy_url = f"http://localhost:{kwargs['frontend_port']}"
+
+    def _set_persistent(self, **kwargs):
+        """Set values in this config and in the environment so they persist into subprocess.
+
+        Args:
+            **kwargs: The kwargs passed to the config.
+        """
+        for key, value in kwargs.items():
+            if value is not None:
+                os.environ[key.upper()] = str(value)
+            setattr(self, key, value)
+        self._non_default_attributes.update(kwargs)
+        self._replace_defaults(**kwargs)
+
 
 def get_config(reload: bool = False) -> Config:
     """Get the app config.

+ 1 - 0
reflex/config.pyi

@@ -97,5 +97,6 @@ class Config(Base):
     def check_deprecated_values(**kwargs) -> None: ...
     def update_from_env(self) -> None: ...
     def get_event_namespace(self) -> str | None: ...
+    def _set_persistent(self, **kwargs) -> None: ...
 
 def get_config(reload: bool = ...) -> Config: ...

+ 6 - 0
reflex/reflex.py

@@ -140,6 +140,12 @@ def run(
     if backend and processes.is_process_on_port(backend_port):
         backend_port = processes.change_or_terminate_port(backend_port, "backend")
 
+    # Apply the new ports to the config.
+    if frontend_port != str(config.frontend_port):
+        config._set_persistent(frontend_port=frontend_port)
+    if backend_port != str(config.backend_port):
+        config._set_persistent(backend_port=backend_port)
+
     console.rule("[bold]Starting Reflex App")
 
     if frontend:

+ 92 - 0
tests/test_config.py

@@ -108,3 +108,95 @@ def test_event_namespace(mocker, kwargs, expected):
     config = reflex.config.get_config()
     assert conf == config
     assert config.get_event_namespace() == expected
+
+
+DEFAULT_CONFIG = rx.Config(app_name="a")
+
+
+@pytest.mark.parametrize(
+    ("config_kwargs", "env_vars", "set_persistent_vars", "exp_config_values"),
+    [
+        (
+            {},
+            {},
+            {},
+            {
+                "api_url": DEFAULT_CONFIG.api_url,
+                "backend_port": DEFAULT_CONFIG.backend_port,
+                "deploy_url": DEFAULT_CONFIG.deploy_url,
+                "frontend_port": DEFAULT_CONFIG.frontend_port,
+            },
+        ),
+        # Ports set in config kwargs
+        (
+            {"backend_port": 8001, "frontend_port": 3001},
+            {},
+            {},
+            {
+                "api_url": "http://localhost:8001",
+                "backend_port": 8001,
+                "deploy_url": "http://localhost:3001",
+                "frontend_port": 3001,
+            },
+        ),
+        # Ports set in environment take precendence
+        (
+            {"backend_port": 8001, "frontend_port": 3001},
+            {"BACKEND_PORT": 8002},
+            {},
+            {
+                "api_url": "http://localhost:8002",
+                "backend_port": 8002,
+                "deploy_url": "http://localhost:3001",
+                "frontend_port": 3001,
+            },
+        ),
+        # Ports set on the command line take precendence
+        (
+            {"backend_port": 8001, "frontend_port": 3001},
+            {"BACKEND_PORT": 8002},
+            {"frontend_port": "3005"},
+            {
+                "api_url": "http://localhost:8002",
+                "backend_port": 8002,
+                "deploy_url": "http://localhost:3005",
+                "frontend_port": 3005,
+            },
+        ),
+        # api_url / deploy_url already set should not be overridden
+        (
+            {"api_url": "http://foo.bar:8900", "deploy_url": "http://foo.bar:3001"},
+            {"BACKEND_PORT": 8002},
+            {"frontend_port": "3005"},
+            {
+                "api_url": "http://foo.bar:8900",
+                "backend_port": 8002,
+                "deploy_url": "http://foo.bar:3001",
+                "frontend_port": 3005,
+            },
+        ),
+    ],
+)
+def test_replace_defaults(
+    monkeypatch,
+    config_kwargs,
+    env_vars,
+    set_persistent_vars,
+    exp_config_values,
+):
+    """Test that the config replaces defaults with values from the environment.
+
+    Args:
+        monkeypatch: The pytest monkeypatch object.
+        config_kwargs: The config kwargs.
+        env_vars: The environment variables.
+        set_persistent_vars: The values passed to config._set_persistent variables.
+        exp_config_values: The expected config values.
+    """
+    mock_os_env = os.environ.copy()
+    monkeypatch.setattr(reflex.config.os, "environ", mock_os_env)  # type: ignore
+    mock_os_env.update({k: str(v) for k, v in env_vars.items()})
+    c = rx.Config(app_name="a", **config_kwargs)
+    c._set_persistent(**set_persistent_vars)
+    for key, value in exp_config_values.items():
+        assert getattr(c, key) == value