Selaa lähdekoodia

unionize base var fields types (#4153)

* unionize base var fields types

* add tests

* fix union types for vars (#4152)

* remove 3.11 special casing

* special case on version

* fix old versions of python

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
Khaleel Al-Adhami 7 kuukautta sitten
vanhempi
säilyke
b1d449897a
4 muutettua tiedostoa jossa 66 lisäystä ja 27 poistoa
  1. 23 5
      reflex/utils/types.py
  2. 1 21
      reflex/vars/base.py
  3. 3 1
      reflex/vars/object.py
  4. 39 0
      tests/units/test_var.py

+ 23 - 5
reflex/utils/types.py

@@ -182,6 +182,26 @@ def is_generic_alias(cls: GenericType) -> bool:
     return isinstance(cls, GenericAliasTypes)
 
 
+def unionize(*args: GenericType) -> Type:
+    """Unionize the types.
+
+    Args:
+        args: The types to unionize.
+
+    Returns:
+        The unionized types.
+    """
+    if not args:
+        return Any
+    if len(args) == 1:
+        return args[0]
+    # We are bisecting the args list here to avoid hitting the recursion limit
+    # In Python versions >= 3.11, we can simply do `return Union[*args]`
+    midpoint = len(args) // 2
+    first_half, second_half = args[:midpoint], args[midpoint:]
+    return Union[unionize(*first_half), unionize(*second_half)]
+
+
 def is_none(cls: GenericType) -> bool:
     """Check if a class is None.
 
@@ -358,11 +378,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
             return type_
     elif is_union(cls):
         # Check in each arg of the annotation.
-        for arg in get_args(cls):
-            type_ = get_attribute_access_type(arg, name)
-            if type_ is not None:
-                # Return the first attribute type that is accessible.
-                return type_
+        return unionize(
+            *(get_attribute_access_type(arg, name) for arg in get_args(cls))
+        )
     elif isinstance(cls, type):
         # Bare class
         if sys.version_info >= (3, 10):

+ 1 - 21
reflex/vars/base.py

@@ -56,7 +56,7 @@ from reflex.utils.imports import (
     ParsedImportDict,
     parse_imports,
 )
-from reflex.utils.types import GenericType, Self, get_origin, has_args
+from reflex.utils.types import GenericType, Self, get_origin, has_args, unionize
 
 if TYPE_CHECKING:
     from reflex.state import BaseState
@@ -1237,26 +1237,6 @@ def var_operation(
     return wrapper
 
 
-def unionize(*args: Type) -> Type:
-    """Unionize the types.
-
-    Args:
-        args: The types to unionize.
-
-    Returns:
-        The unionized types.
-    """
-    if not args:
-        return Any
-    if len(args) == 1:
-        return args[0]
-    # We are bisecting the args list here to avoid hitting the recursion limit
-    # In Python versions >= 3.11, we can simply do `return Union[*args]`
-    midpoint = len(args) // 2
-    first_half, second_half = args[:midpoint], args[midpoint:]
-    return Union[unionize(*first_half), unionize(*second_half)]
-
-
 def figure_out_type(value: Any) -> types.GenericType:
     """Figure out the type of the value.
 

+ 3 - 1
reflex/vars/object.py

@@ -262,7 +262,9 @@ class ObjectVar(Var[OBJECT_TYPE]):
             var_type = get_args(var_type)[0]
 
         fixed_type = var_type if isclass(var_type) else get_origin(var_type)
-        if isclass(fixed_type) and not issubclass(fixed_type, dict):
+        if (isclass(fixed_type) and not issubclass(fixed_type, dict)) or (
+            fixed_type in types.UnionTypes
+        ):
             attribute_type = get_attribute_access_type(var_type, name)
             if attribute_type is None:
                 raise VarAttributeError(

+ 39 - 0
tests/units/test_var.py

@@ -1,5 +1,6 @@
 import json
 import math
+import sys
 import typing
 from typing import Dict, List, Optional, Set, Tuple, Union, cast
 
@@ -398,6 +399,44 @@ def test_list_tuple_contains(var, expected):
     assert str(var.contains(other_var)) == f"{expected}.includes(other)"
 
 
+class Foo(rx.Base):
+    """Foo class."""
+
+    bar: int
+    baz: str
+
+
+class Bar(rx.Base):
+    """Bar class."""
+
+    bar: str
+    baz: str
+    foo: int
+
+
+@pytest.mark.parametrize(
+    ("var", "var_type"),
+    (
+        [
+            (Var(_js_expr="", _var_type=Foo | Bar).guess_type(), Foo | Bar),
+            (Var(_js_expr="", _var_type=Foo | Bar).guess_type().bar, Union[int, str]),
+        ]
+        if sys.version_info >= (3, 10)
+        else []
+    )
+    + [
+        (Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type(), Union[Foo, Bar]),
+        (Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().baz, str),
+        (
+            Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().foo,
+            Union[int, None],
+        ),
+    ],
+)
+def test_var_types(var, var_type):
+    assert var._var_type == var_type
+
+
 @pytest.mark.parametrize(
     "var, expected",
     [