فهرست منبع

Merge branch 'main' into add-validation-to-function-vars

Khaleel Al-Adhami 4 ماه پیش
والد
کامیت
f257122934

+ 7 - 0
reflex/.templates/web/utils/state.js

@@ -408,6 +408,13 @@ export const connect = async (
   });
   });
   // Ensure undefined fields in events are sent as null instead of removed
   // Ensure undefined fields in events are sent as null instead of removed
   socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v);
   socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v);
+  socket.current.io.decoder.tryParse = (str) => {
+    try {
+      return JSON5.parse(str);
+    } catch (e) {
+      return false;
+    }
+  };
 
 
   function checkVisibility() {
   function checkVisibility() {
     if (document.visibilityState === "visible") {
     if (document.visibilityState === "visible") {

+ 17 - 0
reflex/components/dynamic.py

@@ -136,6 +136,23 @@ def load_dynamic_serializer():
 
 
         module_code_lines.insert(0, "const React = window.__reflex.react;")
         module_code_lines.insert(0, "const React = window.__reflex.react;")
 
 
+        function_line = next(
+            index
+            for index, line in enumerate(module_code_lines)
+            if line.startswith("export default function")
+        )
+
+        module_code_lines = [
+            line
+            for _, line in sorted(
+                enumerate(module_code_lines),
+                key=lambda x: (
+                    not (x[1].startswith("import ") and x[0] < function_line),
+                    x[0],
+                ),
+            )
+        ]
+
         return "\n".join(
         return "\n".join(
             [
             [
                 "//__reflex_evaluate",
                 "//__reflex_evaluate",

+ 3 - 0
reflex/config.py

@@ -567,6 +567,9 @@ class EnvironmentVariables:
     # The maximum size of the reflex state in kilobytes.
     # The maximum size of the reflex state in kilobytes.
     REFLEX_STATE_SIZE_LIMIT: EnvVar[int] = env_var(1000)
     REFLEX_STATE_SIZE_LIMIT: EnvVar[int] = env_var(1000)
 
 
+    # Whether to use the turbopack bundler.
+    REFLEX_USE_TURBOPACK: EnvVar[bool] = env_var(True)
+
 
 
 environment = EnvironmentVariables()
 environment = EnvironmentVariables()
 
 

+ 1 - 1
reflex/constants/installer.py

@@ -182,7 +182,7 @@ class PackageJson(SimpleNamespace):
         "@emotion/react": "11.13.3",
         "@emotion/react": "11.13.3",
         "axios": "1.7.7",
         "axios": "1.7.7",
         "json5": "2.2.3",
         "json5": "2.2.3",
-        "next": "14.2.16",
+        "next": "15.1.4",
         "next-sitemap": "4.2.3",
         "next-sitemap": "4.2.3",
         "next-themes": "0.4.3",
         "next-themes": "0.4.3",
         "react": "18.3.1",
         "react": "18.3.1",

+ 4 - 2
reflex/reflex.py

@@ -519,7 +519,9 @@ def deploy(
     if prerequisites.needs_reinit(frontend=True):
     if prerequisites.needs_reinit(frontend=True):
         _init(name=config.app_name, loglevel=loglevel)
         _init(name=config.app_name, loglevel=loglevel)
     prerequisites.check_latest_package_version(constants.ReflexHostingCLI.MODULE_NAME)
     prerequisites.check_latest_package_version(constants.ReflexHostingCLI.MODULE_NAME)
-
+    extra: dict[str, str] = (
+        {"config_path": config_path} if config_path is not None else {}
+    )
     hosting_cli.deploy(
     hosting_cli.deploy(
         app_name=app_name,
         app_name=app_name,
         export_fn=lambda zip_dest_dir,
         export_fn=lambda zip_dest_dir,
@@ -545,7 +547,7 @@ def deploy(
         loglevel=type(loglevel).INFO,  # type: ignore
         loglevel=type(loglevel).INFO,  # type: ignore
         token=token,
         token=token,
         project=project,
         project=project,
-        config_path=config_path,
+        **extra,
     )
     )
 
 
 
 

+ 26 - 5
reflex/state.py

@@ -104,6 +104,7 @@ from reflex.utils.exceptions import (
     LockExpiredError,
     LockExpiredError,
     ReflexRuntimeError,
     ReflexRuntimeError,
     SetUndefinedStateVarError,
     SetUndefinedStateVarError,
+    StateMismatchError,
     StateSchemaMismatchError,
     StateSchemaMismatchError,
     StateSerializationError,
     StateSerializationError,
     StateTooLargeError,
     StateTooLargeError,
@@ -1543,7 +1544,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         # Return the direct parent of target_state_cls for subsequent linking.
         # Return the direct parent of target_state_cls for subsequent linking.
         return parent_state
         return parent_state
 
 
-    def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState:
+    def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE:
         """Get a state instance from the cache.
         """Get a state instance from the cache.
 
 
         Args:
         Args:
@@ -1551,11 +1552,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
 
         Returns:
         Returns:
             The instance of state_cls associated with this state's client_token.
             The instance of state_cls associated with this state's client_token.
+
+        Raises:
+            StateMismatchError: If the state instance is not of the expected type.
         """
         """
         root_state = self._get_root_state()
         root_state = self._get_root_state()
-        return root_state.get_substate(state_cls.get_full_name().split("."))
+        substate = root_state.get_substate(state_cls.get_full_name().split("."))
+        if not isinstance(substate, state_cls):
+            raise StateMismatchError(
+                f"Searched for state {state_cls.get_full_name()} but found {substate}."
+            )
+        return substate
 
 
-    async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
+    async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
         """Get a state instance from redis.
         """Get a state instance from redis.
 
 
         Args:
         Args:
@@ -1566,6 +1575,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
 
         Raises:
         Raises:
             RuntimeError: If redis is not used in this backend process.
             RuntimeError: If redis is not used in this backend process.
+            StateMismatchError: If the state instance is not of the expected type.
         """
         """
         # Fetch all missing parent states from redis.
         # Fetch all missing parent states from redis.
         parent_state_of_state_cls = await self._populate_parent_states(state_cls)
         parent_state_of_state_cls = await self._populate_parent_states(state_cls)
@@ -1577,14 +1587,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
                 f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
                 f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
                 "(All states should already be available -- this is likely a bug).",
                 "(All states should already be available -- this is likely a bug).",
             )
             )
-        return await state_manager.get_state(
+
+        state_in_redis = await state_manager.get_state(
             token=_substate_key(self.router.session.client_token, state_cls),
             token=_substate_key(self.router.session.client_token, state_cls),
             top_level=False,
             top_level=False,
             get_substates=True,
             get_substates=True,
             parent_state=parent_state_of_state_cls,
             parent_state=parent_state_of_state_cls,
         )
         )
 
 
-    async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
+        if not isinstance(state_in_redis, state_cls):
+            raise StateMismatchError(
+                f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
+            )
+
+        return state_in_redis
+
+    async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE:
         """Get an instance of the state associated with this token.
         """Get an instance of the state associated with this token.
 
 
         Allows for arbitrary access to sibling states from within an event handler.
         Allows for arbitrary access to sibling states from within an event handler.
@@ -2316,6 +2334,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         return state
         return state
 
 
 
 
+T_STATE = TypeVar("T_STATE", bound=BaseState)
+
+
 class State(BaseState):
 class State(BaseState):
     """The app Base State."""
     """The app Base State."""
 
 

+ 47 - 3
reflex/utils/console.py

@@ -2,6 +2,11 @@
 
 
 from __future__ import annotations
 from __future__ import annotations
 
 
+import inspect
+import shutil
+from pathlib import Path
+from types import FrameType
+
 from rich.console import Console
 from rich.console import Console
 from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
 from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
 from rich.prompt import Prompt
 from rich.prompt import Prompt
@@ -188,6 +193,33 @@ def warn(msg: str, dedupe: bool = False, **kwargs):
         print(f"[orange1]Warning: {msg}[/orange1]", **kwargs)
         print(f"[orange1]Warning: {msg}[/orange1]", **kwargs)
 
 
 
 
+def _get_first_non_framework_frame() -> FrameType | None:
+    import click
+    import typer
+    import typing_extensions
+
+    import reflex as rx
+
+    # Exclude utility modules that should never be the source of deprecated reflex usage.
+    exclude_modules = [click, rx, typer, typing_extensions]
+    exclude_roots = [
+        p.parent.resolve()
+        if (p := Path(m.__file__)).name == "__init__.py"
+        else p.resolve()
+        for m in exclude_modules
+    ]
+    # Specifically exclude the reflex cli module.
+    if reflex_bin := shutil.which(b"reflex"):
+        exclude_roots.append(Path(reflex_bin.decode()))
+
+    frame = inspect.currentframe()
+    while frame := frame and frame.f_back:
+        frame_path = Path(inspect.getfile(frame)).resolve()
+        if not any(frame_path.is_relative_to(root) for root in exclude_roots):
+            break
+    return frame
+
+
 def deprecate(
 def deprecate(
     feature_name: str,
     feature_name: str,
     reason: str,
     reason: str,
@@ -206,15 +238,27 @@ def deprecate(
         dedupe: If True, suppress multiple console logs of deprecation message.
         dedupe: If True, suppress multiple console logs of deprecation message.
         kwargs: Keyword arguments to pass to the print function.
         kwargs: Keyword arguments to pass to the print function.
     """
     """
-    if feature_name not in _EMITTED_DEPRECATION_WARNINGS:
+    dedupe_key = feature_name
+    loc = ""
+
+    # See if we can find where the deprecation exists in "user code"
+    origin_frame = _get_first_non_framework_frame()
+    if origin_frame is not None:
+        filename = Path(origin_frame.f_code.co_filename)
+        if filename.is_relative_to(Path.cwd()):
+            filename = filename.relative_to(Path.cwd())
+        loc = f"{filename}:{origin_frame.f_lineno}"
+        dedupe_key = f"{dedupe_key} {loc}"
+
+    if dedupe_key not in _EMITTED_DEPRECATION_WARNINGS:
         msg = (
         msg = (
             f"{feature_name} has been deprecated in version {deprecation_version} {reason.rstrip('.')}. It will be completely "
             f"{feature_name} has been deprecated in version {deprecation_version} {reason.rstrip('.')}. It will be completely "
-            f"removed in {removal_version}"
+            f"removed in {removal_version}. ({loc})"
         )
         )
         if _LOG_LEVEL <= LogLevel.WARNING:
         if _LOG_LEVEL <= LogLevel.WARNING:
             print(f"[yellow]DeprecationWarning: {msg}[/yellow]", **kwargs)
             print(f"[yellow]DeprecationWarning: {msg}[/yellow]", **kwargs)
         if dedupe:
         if dedupe:
-            _EMITTED_DEPRECATION_WARNINGS.add(feature_name)
+            _EMITTED_DEPRECATION_WARNINGS.add(dedupe_key)
 
 
 
 
 def error(msg: str, dedupe: bool = False, **kwargs):
 def error(msg: str, dedupe: bool = False, **kwargs):

+ 4 - 0
reflex/utils/exceptions.py

@@ -163,6 +163,10 @@ class StateSerializationError(ReflexError):
     """Raised when the state cannot be serialized."""
     """Raised when the state cannot be serialized."""
 
 
 
 
+class StateMismatchError(ReflexError, ValueError):
+    """Raised when the state retrieved does not match the expected state."""
+
+
 class SystemPackageMissingError(ReflexError):
 class SystemPackageMissingError(ReflexError):
     """Raised when a system package is missing."""
     """Raised when a system package is missing."""
 
 

+ 5 - 1
reflex/utils/prerequisites.py

@@ -610,10 +610,14 @@ def initialize_web_directory():
     init_reflex_json(project_hash=project_hash)
     init_reflex_json(project_hash=project_hash)
 
 
 
 
+def _turbopack_flag() -> str:
+    return " --turbopack" if environment.REFLEX_USE_TURBOPACK.get() else ""
+
+
 def _compile_package_json():
 def _compile_package_json():
     return templates.PACKAGE_JSON.render(
     return templates.PACKAGE_JSON.render(
         scripts={
         scripts={
-            "dev": constants.PackageJson.Commands.DEV,
+            "dev": constants.PackageJson.Commands.DEV + _turbopack_flag(),
             "export": constants.PackageJson.Commands.EXPORT,
             "export": constants.PackageJson.Commands.EXPORT,
             "export_sitemap": constants.PackageJson.Commands.EXPORT_SITEMAP,
             "export_sitemap": constants.PackageJson.Commands.EXPORT_SITEMAP,
             "prod": constants.PackageJson.Commands.PROD,
             "prod": constants.PackageJson.Commands.PROD,

+ 18 - 11
reflex/utils/processes.py

@@ -17,6 +17,7 @@ import typer
 from redis.exceptions import RedisError
 from redis.exceptions import RedisError
 
 
 from reflex import constants
 from reflex import constants
+from reflex.config import environment
 from reflex.utils import console, path_ops, prerequisites
 from reflex.utils import console, path_ops, prerequisites
 
 
 
 
@@ -156,24 +157,30 @@ def new_process(args, run: bool = False, show_logs: bool = False, **kwargs):
     Raises:
     Raises:
         Exit: When attempting to run a command with a None value.
         Exit: When attempting to run a command with a None value.
     """
     """
-    node_bin_path = str(path_ops.get_node_bin_path())
-    if not node_bin_path and not prerequisites.CURRENTLY_INSTALLING_NODE:
-        console.warn(
-            "The path to the Node binary could not be found. Please ensure that Node is properly "
-            "installed and added to your system's PATH environment variable or try running "
-            "`reflex init` again."
-        )
+    # Check for invalid command first.
     if None in args:
     if None in args:
         console.error(f"Invalid command: {args}")
         console.error(f"Invalid command: {args}")
         raise typer.Exit(1)
         raise typer.Exit(1)
-    # Add the node bin path to the PATH environment variable.
+
+    path_env: str = os.environ.get("PATH", "")
+
+    # Add node_bin_path to the PATH environment variable.
+    if not environment.REFLEX_BACKEND_ONLY.get():
+        node_bin_path = str(path_ops.get_node_bin_path())
+        if not node_bin_path and not prerequisites.CURRENTLY_INSTALLING_NODE:
+            console.warn(
+                "The path to the Node binary could not be found. Please ensure that Node is properly "
+                "installed and added to your system's PATH environment variable or try running "
+                "`reflex init` again."
+            )
+        path_env = os.pathsep.join([node_bin_path, path_env])
+
     env: dict[str, str] = {
     env: dict[str, str] = {
         **os.environ,
         **os.environ,
-        "PATH": os.pathsep.join(
-            [node_bin_path if node_bin_path else "", os.environ["PATH"]]
-        ),  # type: ignore
+        "PATH": path_env,
         **kwargs.pop("env", {}),
         **kwargs.pop("env", {}),
     }
     }
+
     kwargs = {
     kwargs = {
         "env": env,
         "env": env,
         "stderr": None if show_logs else subprocess.STDOUT,
         "stderr": None if show_logs else subprocess.STDOUT,

+ 2 - 2
reflex/vars/base.py

@@ -626,7 +626,7 @@ class Var(Generic[VAR_TYPE]):
         if _var_is_local is not None:
         if _var_is_local is not None:
             console.deprecate(
             console.deprecate(
                 feature_name="_var_is_local",
                 feature_name="_var_is_local",
-                reason="The _var_is_local argument is not supported for Var."
+                reason="The _var_is_local argument is not supported for Var. "
                 "If you want to create a Var from a raw Javascript expression, use the constructor directly",
                 "If you want to create a Var from a raw Javascript expression, use the constructor directly",
                 deprecation_version="0.6.0",
                 deprecation_version="0.6.0",
                 removal_version="0.7.0",
                 removal_version="0.7.0",
@@ -634,7 +634,7 @@ class Var(Generic[VAR_TYPE]):
         if _var_is_string is not None:
         if _var_is_string is not None:
             console.deprecate(
             console.deprecate(
                 feature_name="_var_is_string",
                 feature_name="_var_is_string",
-                reason="The _var_is_string argument is not supported for Var."
+                reason="The _var_is_string argument is not supported for Var. "
                 "If you want to create a Var from a raw Javascript expression, use the constructor directly",
                 "If you want to create a Var from a raw Javascript expression, use the constructor directly",
                 deprecation_version="0.6.0",
                 deprecation_version="0.6.0",
                 removal_version="0.7.0",
                 removal_version="0.7.0",

+ 6 - 17
reflex/vars/number.py

@@ -12,7 +12,6 @@ from typing import TYPE_CHECKING, Any, Callable, NoReturn, TypeVar, Union, overl
 from reflex.constants.base import Dirs
 from reflex.constants.base import Dirs
 from reflex.utils.exceptions import PrimitiveUnserializableToJSON, VarTypeError
 from reflex.utils.exceptions import PrimitiveUnserializableToJSON, VarTypeError
 from reflex.utils.imports import ImportDict, ImportVar
 from reflex.utils.imports import ImportDict, ImportVar
-from reflex.utils.types import is_optional
 
 
 from .base import (
 from .base import (
     CustomVarOperationReturn,
     CustomVarOperationReturn,
@@ -349,7 +348,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
         """
         """
         if not isinstance(other, NUMBER_TYPES):
         if not isinstance(other, NUMBER_TYPES):
             raise_unsupported_operand_types("<", (type(self), type(other)))
             raise_unsupported_operand_types("<", (type(self), type(other)))
-        return less_than_operation(self, +other).guess_type()
+        return less_than_operation(+self, +other).guess_type()
 
 
     def __le__(self, other: number_types) -> BooleanVar:
     def __le__(self, other: number_types) -> BooleanVar:
         """Less than or equal comparison.
         """Less than or equal comparison.
@@ -362,7 +361,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
         """
         """
         if not isinstance(other, NUMBER_TYPES):
         if not isinstance(other, NUMBER_TYPES):
             raise_unsupported_operand_types("<=", (type(self), type(other)))
             raise_unsupported_operand_types("<=", (type(self), type(other)))
-        return less_than_or_equal_operation(self, +other).guess_type()
+        return less_than_or_equal_operation(+self, +other).guess_type()
 
 
     def __eq__(self, other: Any) -> BooleanVar:
     def __eq__(self, other: Any) -> BooleanVar:
         """Equal comparison.
         """Equal comparison.
@@ -374,7 +373,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
             The result of the comparison.
             The result of the comparison.
         """
         """
         if isinstance(other, NUMBER_TYPES):
         if isinstance(other, NUMBER_TYPES):
-            return equal_operation(self, +other).guess_type()
+            return equal_operation(+self, +other).guess_type()
         return equal_operation(self, other).guess_type()
         return equal_operation(self, other).guess_type()
 
 
     def __ne__(self, other: Any) -> BooleanVar:
     def __ne__(self, other: Any) -> BooleanVar:
@@ -387,7 +386,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
             The result of the comparison.
             The result of the comparison.
         """
         """
         if isinstance(other, NUMBER_TYPES):
         if isinstance(other, NUMBER_TYPES):
-            return not_equal_operation(self, +other).guess_type()
+            return not_equal_operation(+self, +other).guess_type()
         return not_equal_operation(self, other).guess_type()
         return not_equal_operation(self, other).guess_type()
 
 
     def __gt__(self, other: number_types) -> BooleanVar:
     def __gt__(self, other: number_types) -> BooleanVar:
@@ -401,7 +400,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
         """
         """
         if not isinstance(other, NUMBER_TYPES):
         if not isinstance(other, NUMBER_TYPES):
             raise_unsupported_operand_types(">", (type(self), type(other)))
             raise_unsupported_operand_types(">", (type(self), type(other)))
-        return greater_than_operation(self, +other).guess_type()
+        return greater_than_operation(+self, +other).guess_type()
 
 
     def __ge__(self, other: number_types) -> BooleanVar:
     def __ge__(self, other: number_types) -> BooleanVar:
         """Greater than or equal comparison.
         """Greater than or equal comparison.
@@ -414,17 +413,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
         """
         """
         if not isinstance(other, NUMBER_TYPES):
         if not isinstance(other, NUMBER_TYPES):
             raise_unsupported_operand_types(">=", (type(self), type(other)))
             raise_unsupported_operand_types(">=", (type(self), type(other)))
-        return greater_than_or_equal_operation(self, +other).guess_type()
-
-    def bool(self) -> BooleanVar:
-        """Boolean conversion.
-
-        Returns:
-            The boolean value of the number.
-        """
-        if is_optional(self._var_type):
-            return boolify((self != None) & (self != 0)).guess_type()  # noqa: E711
-        return self != 0
+        return greater_than_or_equal_operation(+self, +other).guess_type()
 
 
     def _is_strict_float(self) -> bool:
     def _is_strict_float(self) -> bool:
         """Check if the number is a float.
         """Check if the number is a float.

+ 14 - 0
tests/integration/test_computed_vars.py

@@ -58,6 +58,11 @@ def ComputedVars():
         def depends_on_count3(self) -> int:
         def depends_on_count3(self) -> int:
             return self.count
             return self.count
 
 
+        # special floats should be properly decoded on the frontend
+        @rx.var(cache=True, initial_value=[])
+        def special_floats(self) -> list[float]:
+            return [42.9, float("nan"), float("inf"), float("-inf")]
+
         @rx.event
         @rx.event
         def increment(self):
         def increment(self):
             self.count += 1
             self.count += 1
@@ -103,6 +108,11 @@ def ComputedVars():
                     State.depends_on_count3,
                     State.depends_on_count3,
                     id="depends_on_count3",
                     id="depends_on_count3",
                 ),
                 ),
+                rx.text("special_floats:"),
+                rx.text(
+                    State.special_floats.join(", "),
+                    id="special_floats",
+                ),
             ),
             ),
         )
         )
 
 
@@ -224,6 +234,10 @@ async def test_computed_vars(
     assert depends_on_count3
     assert depends_on_count3
     assert depends_on_count3.text == "0"
     assert depends_on_count3.text == "0"
 
 
+    special_floats = driver.find_element(By.ID, "special_floats")
+    assert special_floats
+    assert special_floats.text == "42.9, NaN, Infinity, -Infinity"
+
     increment = driver.find_element(By.ID, "increment")
     increment = driver.find_element(By.ID, "increment")
     assert increment.is_enabled()
     assert increment.is_enabled()