Explorar el Código

[ENG-3749] type safe vars (#4066)

* type safe vars

* fix behavior for dict and list
Khaleel Al-Adhami hace 7 meses
padre
commit
8663d4f978
Se han modificado 5 ficheros con 77 adiciones y 3 borrados
  1. 1 1
      reflex/__init__.py
  2. 2 0
      reflex/__init__.pyi
  3. 7 2
      reflex/state.py
  4. 2 0
      reflex/vars/__init__.py
  5. 65 0
      reflex/vars/base.py

+ 1 - 1
reflex/__init__.py

@@ -331,7 +331,7 @@ _MAPPING: dict = {
     "style": ["Style", "toggle_color_mode"],
     "style": ["Style", "toggle_color_mode"],
     "utils.imports": ["ImportVar"],
     "utils.imports": ["ImportVar"],
     "utils.serializers": ["serializer"],
     "utils.serializers": ["serializer"],
-    "vars": ["Var"],
+    "vars": ["Var", "field", "Field"],
 }
 }
 
 
 _SUBMODULES: set[str] = {
 _SUBMODULES: set[str] = {

+ 2 - 0
reflex/__init__.pyi

@@ -189,7 +189,9 @@ from .style import Style as Style
 from .style import toggle_color_mode as toggle_color_mode
 from .style import toggle_color_mode as toggle_color_mode
 from .utils.imports import ImportVar as ImportVar
 from .utils.imports import ImportVar as ImportVar
 from .utils.serializers import serializer as serializer
 from .utils.serializers import serializer as serializer
+from .vars import Field as Field
 from .vars import Var as Var
 from .vars import Var as Var
+from .vars import field as field
 
 
 del compat
 del compat
 RADIX_THEMES_MAPPING: dict
 RADIX_THEMES_MAPPING: dict

+ 7 - 2
reflex/state.py

@@ -32,6 +32,7 @@ from typing import (
     Type,
     Type,
     Union,
     Union,
     cast,
     cast,
+    get_args,
     get_type_hints,
     get_type_hints,
 )
 )
 
 
@@ -81,7 +82,7 @@ from reflex.utils.exceptions import (
 )
 )
 from reflex.utils.exec import is_testing_env
 from reflex.utils.exec import is_testing_env
 from reflex.utils.serializers import serializer
 from reflex.utils.serializers import serializer
-from reflex.utils.types import override
+from reflex.utils.types import get_origin, override
 from reflex.vars import VarData
 from reflex.vars import VarData
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
@@ -240,12 +241,16 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
     Returns:
     Returns:
         The Var instance.
         The Var instance.
     """
     """
+    from reflex.vars import Field
+
     field_name = format.format_state_name(cls.get_full_name()) + "." + f.name
     field_name = format.format_state_name(cls.get_full_name()) + "." + f.name
 
 
     return dispatch(
     return dispatch(
         field_name=field_name,
         field_name=field_name,
         var_data=VarData.from_state(cls, f.name),
         var_data=VarData.from_state(cls, f.name),
-        result_var_type=f.outer_type_,
+        result_var_type=f.outer_type_
+        if get_origin(f.outer_type_) is not Field
+        else get_args(f.outer_type_)[0],
     )
     )
 
 
 
 

+ 2 - 0
reflex/vars/__init__.py

@@ -1,8 +1,10 @@
 """Immutable-Based Var System."""
 """Immutable-Based Var System."""
 
 
+from .base import Field as Field
 from .base import LiteralVar as LiteralVar
 from .base import LiteralVar as LiteralVar
 from .base import Var as Var
 from .base import Var as Var
 from .base import VarData as VarData
 from .base import VarData as VarData
+from .base import field as field
 from .base import get_unique_variable_name as get_unique_variable_name
 from .base import get_unique_variable_name as get_unique_variable_name
 from .base import get_uuid_string_var as get_uuid_string_var
 from .base import get_uuid_string_var as get_uuid_string_var
 from .base import var_operation as var_operation
 from .base import var_operation as var_operation

+ 65 - 0
reflex/vars/base.py

@@ -2832,3 +2832,68 @@ def dispatch(
         _var_data=var_data,
         _var_data=var_data,
         _var_type=result_var_type,
         _var_type=result_var_type,
     ).guess_type()
     ).guess_type()
+
+
+V = TypeVar("V")
+
+
+class Field(Generic[T]):
+    """Shadow class for Var to allow for type hinting in the IDE."""
+
+    def __set__(self, instance, value: T):
+        """Set the Var.
+
+        Args:
+            instance: The instance of the class setting the Var.
+            value: The value to set the Var to.
+        """
+
+    @overload
+    def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ...
+
+    @overload
+    def __get__(self: Field[int], instance: None, owner) -> NumberVar: ...
+
+    @overload
+    def __get__(self: Field[str], instance: None, owner) -> StringVar: ...
+
+    @overload
+    def __get__(self: Field[None], instance: None, owner) -> NoneVar: ...
+
+    @overload
+    def __get__(
+        self: Field[List[V]] | Field[Set[V]] | Field[Tuple[V, ...]],
+        instance: None,
+        owner,
+    ) -> ArrayVar[List[V]]: ...
+
+    @overload
+    def __get__(
+        self: Field[Dict[str, V]], instance: None, owner
+    ) -> ObjectVar[Dict[str, V]]: ...
+
+    @overload
+    def __get__(self, instance: None, owner) -> Var[T]: ...
+
+    @overload
+    def __get__(self, instance, owner) -> T: ...
+
+    def __get__(self, instance, owner):  # type: ignore
+        """Get the Var.
+
+        Args:
+            instance: The instance of the class accessing the Var.
+            owner: The class that the Var is attached to.
+        """
+
+
+def field(value: T) -> Field[T]:
+    """Create a Field with a value.
+
+    Args:
+        value: The value of the Field.
+
+    Returns:
+        The Field.
+    """
+    return value  # type: ignore