|
@@ -20,6 +20,7 @@ from typing import (
|
|
Any,
|
|
Any,
|
|
Callable,
|
|
Callable,
|
|
Dict,
|
|
Dict,
|
|
|
|
+ FrozenSet,
|
|
Generic,
|
|
Generic,
|
|
Iterable,
|
|
Iterable,
|
|
List,
|
|
List,
|
|
@@ -72,6 +73,7 @@ if TYPE_CHECKING:
|
|
|
|
|
|
|
|
|
|
VAR_TYPE = TypeVar("VAR_TYPE", covariant=True)
|
|
VAR_TYPE = TypeVar("VAR_TYPE", covariant=True)
|
|
|
|
+OTHER_VAR_TYPE = TypeVar("OTHER_VAR_TYPE")
|
|
|
|
|
|
warnings.filterwarnings("ignore", message="fields may not start with an underscore")
|
|
warnings.filterwarnings("ignore", message="fields may not start with an underscore")
|
|
|
|
|
|
@@ -119,6 +121,17 @@ class Var(Generic[VAR_TYPE]):
|
|
"""
|
|
"""
|
|
return self._js_expr
|
|
return self._js_expr
|
|
|
|
|
|
|
|
+ @property
|
|
|
|
+ def _var_field_name(self) -> str:
|
|
|
|
+ """The name of the field.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ The name of the field.
|
|
|
|
+ """
|
|
|
|
+ var_data = self._get_all_var_data()
|
|
|
|
+ field_name = var_data.field_name if var_data else None
|
|
|
|
+ return field_name or self._js_expr
|
|
|
|
+
|
|
@property
|
|
@property
|
|
@deprecated("Use `_js_expr` instead.")
|
|
@deprecated("Use `_js_expr` instead.")
|
|
def _var_name_unwrapped(self) -> str:
|
|
def _var_name_unwrapped(self) -> str:
|
|
@@ -181,7 +194,19 @@ class Var(Generic[VAR_TYPE]):
|
|
and self._get_all_var_data() == other._get_all_var_data()
|
|
and self._get_all_var_data() == other._get_all_var_data()
|
|
)
|
|
)
|
|
|
|
|
|
- def _replace(self, merge_var_data=None, **kwargs: Any):
|
|
|
|
|
|
+ @overload
|
|
|
|
+ def _replace(
|
|
|
|
+ self, _var_type: Type[OTHER_VAR_TYPE], merge_var_data=None, **kwargs: Any
|
|
|
|
+ ) -> Var[OTHER_VAR_TYPE]: ...
|
|
|
|
+
|
|
|
|
+ @overload
|
|
|
|
+ def _replace(
|
|
|
|
+ self, _var_type: GenericType | None = None, merge_var_data=None, **kwargs: Any
|
|
|
|
+ ) -> Self: ...
|
|
|
|
+
|
|
|
|
+ def _replace(
|
|
|
|
+ self, _var_type: GenericType | None = None, merge_var_data=None, **kwargs: Any
|
|
|
|
+ ) -> Self | Var:
|
|
"""Make a copy of this Var with updated fields.
|
|
"""Make a copy of this Var with updated fields.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
@@ -205,14 +230,20 @@ class Var(Generic[VAR_TYPE]):
|
|
"The _var_full_name_needs_state_prefix argument is not supported for Var."
|
|
"The _var_full_name_needs_state_prefix argument is not supported for Var."
|
|
)
|
|
)
|
|
|
|
|
|
- return dataclasses.replace(
|
|
|
|
|
|
+ value_with_replaced = dataclasses.replace(
|
|
self,
|
|
self,
|
|
|
|
+ _var_type=_var_type or self._var_type,
|
|
_var_data=VarData.merge(
|
|
_var_data=VarData.merge(
|
|
kwargs.get("_var_data", self._var_data), merge_var_data
|
|
kwargs.get("_var_data", self._var_data), merge_var_data
|
|
),
|
|
),
|
|
**kwargs,
|
|
**kwargs,
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ if (js_expr := kwargs.get("_js_expr", None)) is not None:
|
|
|
|
+ object.__setattr__(value_with_replaced, "_js_expr", js_expr)
|
|
|
|
+
|
|
|
|
+ return value_with_replaced
|
|
|
|
+
|
|
@classmethod
|
|
@classmethod
|
|
def create(
|
|
def create(
|
|
cls,
|
|
cls,
|
|
@@ -566,8 +597,7 @@ class Var(Generic[VAR_TYPE]):
|
|
Returns:
|
|
Returns:
|
|
The name of the setter function.
|
|
The name of the setter function.
|
|
"""
|
|
"""
|
|
- var_name_parts = self._js_expr.split(".")
|
|
|
|
- setter = constants.SETTER_PREFIX + var_name_parts[-1]
|
|
|
|
|
|
+ setter = constants.SETTER_PREFIX + self._var_field_name
|
|
var_data = self._get_all_var_data()
|
|
var_data = self._get_all_var_data()
|
|
if var_data is None:
|
|
if var_data is None:
|
|
return setter
|
|
return setter
|
|
@@ -581,7 +611,7 @@ class Var(Generic[VAR_TYPE]):
|
|
Returns:
|
|
Returns:
|
|
A function that that creates a setter for the var.
|
|
A function that that creates a setter for the var.
|
|
"""
|
|
"""
|
|
- actual_name = self._js_expr.split(".")[-1]
|
|
|
|
|
|
+ actual_name = self._var_field_name
|
|
|
|
|
|
def setter(state: BaseState, value: Any):
|
|
def setter(state: BaseState, value: Any):
|
|
"""Get the setter for the var.
|
|
"""Get the setter for the var.
|
|
@@ -623,7 +653,9 @@ class Var(Generic[VAR_TYPE]):
|
|
return StateOperation.create(
|
|
return StateOperation.create(
|
|
formatted_state_name,
|
|
formatted_state_name,
|
|
self,
|
|
self,
|
|
- _var_data=VarData.merge(VarData.from_state(state), self._var_data),
|
|
|
|
|
|
+ _var_data=VarData.merge(
|
|
|
|
+ VarData.from_state(state, self._js_expr), self._var_data
|
|
|
|
+ ),
|
|
).guess_type()
|
|
).guess_type()
|
|
|
|
|
|
def __eq__(self, other: Var | Any) -> BooleanVar:
|
|
def __eq__(self, other: Var | Any) -> BooleanVar:
|
|
@@ -1706,12 +1738,18 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|
while self._js_expr in state_where_defined.inherited_vars:
|
|
while self._js_expr in state_where_defined.inherited_vars:
|
|
state_where_defined = state_where_defined.get_parent_state()
|
|
state_where_defined = state_where_defined.get_parent_state()
|
|
|
|
|
|
- return self._replace(
|
|
|
|
- _js_expr=format_state_name(state_where_defined.get_full_name())
|
|
|
|
|
|
+ field_name = (
|
|
|
|
+ format_state_name(state_where_defined.get_full_name())
|
|
+ "."
|
|
+ "."
|
|
- + self._js_expr,
|
|
|
|
- merge_var_data=VarData.from_state(state_where_defined),
|
|
|
|
- ).guess_type()
|
|
|
|
|
|
+ + self._js_expr
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ return dispatch(
|
|
|
|
+ field_name,
|
|
|
|
+ var_data=VarData.from_state(state_where_defined, self._js_expr),
|
|
|
|
+ result_var_type=self._var_type,
|
|
|
|
+ existing_var=self,
|
|
|
|
+ )
|
|
|
|
|
|
if not self._cache:
|
|
if not self._cache:
|
|
return self.fget(instance)
|
|
return self.fget(instance)
|
|
@@ -2339,6 +2377,9 @@ class VarData:
|
|
# The name of the enclosing state.
|
|
# The name of the enclosing state.
|
|
state: str = dataclasses.field(default="")
|
|
state: str = dataclasses.field(default="")
|
|
|
|
|
|
|
|
+ # The name of the field in the state.
|
|
|
|
+ field_name: str = dataclasses.field(default="")
|
|
|
|
+
|
|
# Imports needed to render this var
|
|
# Imports needed to render this var
|
|
imports: ImmutableParsedImportDict = dataclasses.field(default_factory=tuple)
|
|
imports: ImmutableParsedImportDict = dataclasses.field(default_factory=tuple)
|
|
|
|
|
|
@@ -2348,6 +2389,7 @@ class VarData:
|
|
def __init__(
|
|
def __init__(
|
|
self,
|
|
self,
|
|
state: str = "",
|
|
state: str = "",
|
|
|
|
+ field_name: str = "",
|
|
imports: ImportDict | ParsedImportDict | None = None,
|
|
imports: ImportDict | ParsedImportDict | None = None,
|
|
hooks: dict[str, None] | None = None,
|
|
hooks: dict[str, None] | None = None,
|
|
):
|
|
):
|
|
@@ -2355,6 +2397,7 @@ class VarData:
|
|
|
|
|
|
Args:
|
|
Args:
|
|
state: The name of the enclosing state.
|
|
state: The name of the enclosing state.
|
|
|
|
+ field_name: The name of the field in the state.
|
|
imports: Imports needed to render this var.
|
|
imports: Imports needed to render this var.
|
|
hooks: Hooks that need to be present in the component to render this var.
|
|
hooks: Hooks that need to be present in the component to render this var.
|
|
"""
|
|
"""
|
|
@@ -2364,6 +2407,7 @@ class VarData:
|
|
)
|
|
)
|
|
)
|
|
)
|
|
object.__setattr__(self, "state", state)
|
|
object.__setattr__(self, "state", state)
|
|
|
|
+ object.__setattr__(self, "field_name", field_name)
|
|
object.__setattr__(self, "imports", immutable_imports)
|
|
object.__setattr__(self, "imports", immutable_imports)
|
|
object.__setattr__(self, "hooks", tuple(hooks or {}))
|
|
object.__setattr__(self, "hooks", tuple(hooks or {}))
|
|
|
|
|
|
@@ -2386,12 +2430,14 @@ class VarData:
|
|
The merged var data object.
|
|
The merged var data object.
|
|
"""
|
|
"""
|
|
state = ""
|
|
state = ""
|
|
|
|
+ field_name = ""
|
|
_imports = {}
|
|
_imports = {}
|
|
hooks = {}
|
|
hooks = {}
|
|
for var_data in others:
|
|
for var_data in others:
|
|
if var_data is None:
|
|
if var_data is None:
|
|
continue
|
|
continue
|
|
state = state or var_data.state
|
|
state = state or var_data.state
|
|
|
|
+ field_name = field_name or var_data.field_name
|
|
_imports = imports.merge_imports(_imports, var_data.imports)
|
|
_imports = imports.merge_imports(_imports, var_data.imports)
|
|
hooks.update(
|
|
hooks.update(
|
|
var_data.hooks
|
|
var_data.hooks
|
|
@@ -2399,9 +2445,10 @@ class VarData:
|
|
else {k: None for k in var_data.hooks}
|
|
else {k: None for k in var_data.hooks}
|
|
)
|
|
)
|
|
|
|
|
|
- if state or _imports or hooks:
|
|
|
|
|
|
+ if state or _imports or hooks or field_name:
|
|
return VarData(
|
|
return VarData(
|
|
state=state,
|
|
state=state,
|
|
|
|
+ field_name=field_name,
|
|
imports=_imports,
|
|
imports=_imports,
|
|
hooks=hooks,
|
|
hooks=hooks,
|
|
)
|
|
)
|
|
@@ -2413,38 +2460,15 @@ class VarData:
|
|
Returns:
|
|
Returns:
|
|
True if any field is set to a non-default value.
|
|
True if any field is set to a non-default value.
|
|
"""
|
|
"""
|
|
- return bool(self.state or self.imports or self.hooks)
|
|
|
|
-
|
|
|
|
- def __eq__(self, other: Any) -> bool:
|
|
|
|
- """Check if two var data objects are equal.
|
|
|
|
-
|
|
|
|
- Args:
|
|
|
|
- other: The other var data object to compare.
|
|
|
|
-
|
|
|
|
- Returns:
|
|
|
|
- True if all fields are equal and collapsed imports are equal.
|
|
|
|
- """
|
|
|
|
- if not isinstance(other, VarData):
|
|
|
|
- return False
|
|
|
|
-
|
|
|
|
- # Don't compare interpolations - that's added in by the decoder, and
|
|
|
|
- # not part of the vardata itself.
|
|
|
|
- return (
|
|
|
|
- self.state == other.state
|
|
|
|
- and self.hooks
|
|
|
|
- == (
|
|
|
|
- other.hooks if isinstance(other, VarData) else tuple(other.hooks.keys())
|
|
|
|
- )
|
|
|
|
- and imports.collapse_imports(self.imports)
|
|
|
|
- == imports.collapse_imports(other.imports)
|
|
|
|
- )
|
|
|
|
|
|
+ return bool(self.state or self.imports or self.hooks or self.field_name)
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
- def from_state(cls, state: Type[BaseState] | str) -> VarData:
|
|
|
|
|
|
+ def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData:
|
|
"""Set the state of the var.
|
|
"""Set the state of the var.
|
|
|
|
|
|
Args:
|
|
Args:
|
|
state: The state to set or the full name of the state.
|
|
state: The state to set or the full name of the state.
|
|
|
|
+ field_name: The name of the field in the state. Optional.
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
The var with the set state.
|
|
The var with the set state.
|
|
@@ -2452,8 +2476,9 @@ class VarData:
|
|
from reflex.utils import format
|
|
from reflex.utils import format
|
|
|
|
|
|
state_name = state if isinstance(state, str) else state.get_full_name()
|
|
state_name = state if isinstance(state, str) else state.get_full_name()
|
|
- new_var_data = VarData(
|
|
|
|
|
|
+ return VarData(
|
|
state=state_name,
|
|
state=state_name,
|
|
|
|
+ field_name=field_name,
|
|
hooks={
|
|
hooks={
|
|
"const {0} = useContext(StateContexts.{0})".format(
|
|
"const {0} = useContext(StateContexts.{0})".format(
|
|
format.format_state_name(state_name)
|
|
format.format_state_name(state_name)
|
|
@@ -2464,7 +2489,6 @@ class VarData:
|
|
"react": [ImportVar(tag="useContext")],
|
|
"react": [ImportVar(tag="useContext")],
|
|
},
|
|
},
|
|
)
|
|
)
|
|
- return new_var_data
|
|
|
|
|
|
|
|
|
|
|
|
def _decode_var_immutable(value: str) -> tuple[VarData | None, str]:
|
|
def _decode_var_immutable(value: str) -> tuple[VarData | None, str]:
|
|
@@ -2561,3 +2585,238 @@ REPLACED_NAMES = {
|
|
"set_state": "_var_set_state",
|
|
"set_state": "_var_set_state",
|
|
"deps": "_deps",
|
|
"deps": "_deps",
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+dispatchers: Dict[GenericType, Callable[[Var], Var]] = {}
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def transform(fn: Callable[[Var], Var]) -> Callable[[Var], Var]:
|
|
|
|
+ """Register a function to transform a Var.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ fn: The function to register.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ The decorator.
|
|
|
|
+
|
|
|
|
+ Raises:
|
|
|
|
+ TypeError: If the return type of the function is not a Var.
|
|
|
|
+ TypeError: If the Var return type does not have a generic type.
|
|
|
|
+ ValueError: If a function for the generic type is already registered.
|
|
|
|
+ """
|
|
|
|
+ return_type = fn.__annotations__["return"]
|
|
|
|
+
|
|
|
|
+ origin = get_origin(return_type)
|
|
|
|
+
|
|
|
|
+ if origin is not Var:
|
|
|
|
+ raise TypeError(
|
|
|
|
+ f"Expected return type of {fn.__name__} to be a Var, got {origin}."
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ generic_args = get_args(return_type)
|
|
|
|
+
|
|
|
|
+ if not generic_args:
|
|
|
|
+ raise TypeError(
|
|
|
|
+ f"Expected Var return type of {fn.__name__} to have a generic type."
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ generic_type = get_origin(generic_args[0]) or generic_args[0]
|
|
|
|
+
|
|
|
|
+ if generic_type in dispatchers:
|
|
|
|
+ raise ValueError(f"Function for {generic_type} already registered.")
|
|
|
|
+
|
|
|
|
+ dispatchers[generic_type] = fn
|
|
|
|
+
|
|
|
|
+ return fn
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def generic_type_to_actual_type_map(
|
|
|
|
+ generic_type: GenericType, actual_type: GenericType
|
|
|
|
+) -> Dict[TypeVar, GenericType]:
|
|
|
|
+ """Map the generic type to the actual type.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ generic_type: The generic type.
|
|
|
|
+ actual_type: The actual type.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ The mapping of type variables to actual types.
|
|
|
|
+
|
|
|
|
+ Raises:
|
|
|
|
+ TypeError: If the generic type and actual type do not match.
|
|
|
|
+ TypeError: If the number of generic arguments and actual arguments do not match.
|
|
|
|
+ """
|
|
|
|
+ generic_origin = get_origin(generic_type) or generic_type
|
|
|
|
+ actual_origin = get_origin(actual_type) or actual_type
|
|
|
|
+
|
|
|
|
+ if generic_origin is not actual_origin:
|
|
|
|
+ if isinstance(generic_origin, TypeVar):
|
|
|
|
+ return {generic_origin: actual_origin}
|
|
|
|
+ raise TypeError(
|
|
|
|
+ f"Type mismatch: expected {generic_origin}, got {actual_origin}."
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ generic_args = get_args(generic_type)
|
|
|
|
+ actual_args = get_args(actual_type)
|
|
|
|
+
|
|
|
|
+ if len(generic_args) != len(actual_args):
|
|
|
|
+ raise TypeError(
|
|
|
|
+ f"Number of generic arguments mismatch: expected {len(generic_args)}, got {len(actual_args)}."
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # call recursively for nested generic types and merge the results
|
|
|
|
+ return {
|
|
|
|
+ k: v
|
|
|
|
+ for generic_arg, actual_arg in zip(generic_args, actual_args)
|
|
|
|
+ for k, v in generic_type_to_actual_type_map(generic_arg, actual_arg).items()
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def resolve_generic_type_with_mapping(
|
|
|
|
+ generic_type: GenericType, type_mapping: Dict[TypeVar, GenericType]
|
|
|
|
+):
|
|
|
|
+ """Resolve a generic type with a type mapping.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ generic_type: The generic type.
|
|
|
|
+ type_mapping: The type mapping.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ The resolved generic type.
|
|
|
|
+ """
|
|
|
|
+ if isinstance(generic_type, TypeVar):
|
|
|
|
+ return type_mapping.get(generic_type, generic_type)
|
|
|
|
+
|
|
|
|
+ generic_origin = get_origin(generic_type) or generic_type
|
|
|
|
+
|
|
|
|
+ generic_args = get_args(generic_type)
|
|
|
|
+
|
|
|
|
+ if not generic_args:
|
|
|
|
+ return generic_type
|
|
|
|
+
|
|
|
|
+ mapping_for_older_python = {
|
|
|
|
+ list: List,
|
|
|
|
+ set: Set,
|
|
|
|
+ dict: Dict,
|
|
|
|
+ tuple: Tuple,
|
|
|
|
+ frozenset: FrozenSet,
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return mapping_for_older_python.get(generic_origin, generic_origin)[
|
|
|
|
+ tuple(
|
|
|
|
+ resolve_generic_type_with_mapping(arg, type_mapping) for arg in generic_args
|
|
|
|
+ )
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def resolve_arg_type_from_return_type(
|
|
|
|
+ arg_type: GenericType, return_type: GenericType, actual_return_type: GenericType
|
|
|
|
+) -> GenericType:
|
|
|
|
+ """Resolve the argument type from the return type.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ arg_type: The argument type.
|
|
|
|
+ return_type: The return type.
|
|
|
|
+ actual_return_type: The requested return type.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ The argument type without the generics that are resolved.
|
|
|
|
+ """
|
|
|
|
+ return resolve_generic_type_with_mapping(
|
|
|
|
+ arg_type, generic_type_to_actual_type_map(return_type, actual_return_type)
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def dispatch(
|
|
|
|
+ field_name: str,
|
|
|
|
+ var_data: VarData,
|
|
|
|
+ result_var_type: GenericType,
|
|
|
|
+ existing_var: Var | None = None,
|
|
|
|
+) -> Var:
|
|
|
|
+ """Dispatch a Var to the appropriate transformation function.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ field_name: The name of the field.
|
|
|
|
+ var_data: The VarData associated with the Var.
|
|
|
|
+ result_var_type: The type of the Var.
|
|
|
|
+ existing_var: The existing Var to transform. Optional.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ The transformed Var.
|
|
|
|
+
|
|
|
|
+ Raises:
|
|
|
|
+ TypeError: If the return type of the function is not a Var.
|
|
|
|
+ TypeError: If the Var return type does not have a generic type.
|
|
|
|
+ TypeError: If the first argument of the function is not a Var.
|
|
|
|
+ TypeError: If the first argument of the function does not have a generic type
|
|
|
|
+ """
|
|
|
|
+ result_origin_var_type = get_origin(result_var_type) or result_var_type
|
|
|
|
+
|
|
|
|
+ if result_origin_var_type in dispatchers:
|
|
|
|
+ fn = dispatchers[result_origin_var_type]
|
|
|
|
+ fn_first_arg_type = list(inspect.signature(fn).parameters.values())[
|
|
|
|
+ 0
|
|
|
|
+ ].annotation
|
|
|
|
+
|
|
|
|
+ fn_return = inspect.signature(fn).return_annotation
|
|
|
|
+
|
|
|
|
+ fn_return_origin = get_origin(fn_return) or fn_return
|
|
|
|
+
|
|
|
|
+ if fn_return_origin is not Var:
|
|
|
|
+ raise TypeError(
|
|
|
|
+ f"Expected return type of {fn.__name__} to be a Var, got {fn_return}."
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ fn_return_generic_args = get_args(fn_return)
|
|
|
|
+
|
|
|
|
+ if not fn_return_generic_args:
|
|
|
|
+ raise TypeError(f"Expected generic type of {fn_return} to be a type.")
|
|
|
|
+
|
|
|
|
+ arg_origin = get_origin(fn_first_arg_type) or fn_first_arg_type
|
|
|
|
+
|
|
|
|
+ if arg_origin is not Var:
|
|
|
|
+ raise TypeError(
|
|
|
|
+ f"Expected first argument of {fn.__name__} to be a Var, got {fn_first_arg_type}."
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ arg_generic_args = get_args(fn_first_arg_type)
|
|
|
|
+
|
|
|
|
+ if not arg_generic_args:
|
|
|
|
+ raise TypeError(
|
|
|
|
+ f"Expected generic type of {fn_first_arg_type} to be a type."
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ arg_type = arg_generic_args[0]
|
|
|
|
+ fn_return_type = fn_return_generic_args[0]
|
|
|
|
+
|
|
|
|
+ var = (
|
|
|
|
+ Var(
|
|
|
|
+ field_name,
|
|
|
|
+ _var_data=var_data,
|
|
|
|
+ _var_type=resolve_arg_type_from_return_type(
|
|
|
|
+ arg_type, fn_return_type, result_var_type
|
|
|
|
+ ),
|
|
|
|
+ ).guess_type()
|
|
|
|
+ if existing_var is None
|
|
|
|
+ else existing_var._replace(
|
|
|
|
+ _var_type=resolve_arg_type_from_return_type(
|
|
|
|
+ arg_type, fn_return_type, result_var_type
|
|
|
|
+ ),
|
|
|
|
+ _var_data=var_data,
|
|
|
|
+ _js_expr=field_name,
|
|
|
|
+ ).guess_type()
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ return fn(var)
|
|
|
|
+
|
|
|
|
+ if existing_var is not None:
|
|
|
|
+ return existing_var._replace(
|
|
|
|
+ _js_expr=field_name,
|
|
|
|
+ _var_data=var_data,
|
|
|
|
+ _var_type=result_var_type,
|
|
|
|
+ ).guess_type()
|
|
|
|
+ return Var(
|
|
|
|
+ field_name,
|
|
|
|
+ _var_data=var_data,
|
|
|
|
+ _var_type=result_var_type,
|
|
|
|
+ ).guess_type()
|