123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- """Immutable function vars."""
- from __future__ import annotations
- import dataclasses
- import sys
- from functools import cached_property
- from typing import Any, Callable, Optional, Tuple, Union
- from reflex.experimental.vars.base import ImmutableVar, LiteralVar
- from reflex.vars import ImmutableVarData, Var, VarData
- class FunctionVar(ImmutableVar):
- """Base class for immutable function vars."""
- def __call__(self, *args: Var | Any) -> ArgsFunctionOperation:
- """Call the function with the given arguments.
- Args:
- *args: The arguments to call the function with.
- Returns:
- The function call operation.
- """
- return ArgsFunctionOperation(
- ("...args",),
- VarOperationCall(self, *args, ImmutableVar.create_safe("...args")),
- )
- def call(self, *args: Var | Any) -> VarOperationCall:
- """Call the function with the given arguments.
- Args:
- *args: The arguments to call the function with.
- Returns:
- The function call operation.
- """
- return VarOperationCall(self, *args)
- class FunctionStringVar(FunctionVar):
- """Base class for immutable function vars from a string."""
- def __init__(self, func: str, _var_data: VarData | None = None) -> None:
- """Initialize the function var.
- Args:
- func: The function to call.
- _var_data: Additional hooks and imports associated with the Var.
- """
- super(FunctionVar, self).__init__(
- _var_name=func,
- _var_type=Callable,
- _var_data=ImmutableVarData.merge(_var_data),
- )
- @dataclasses.dataclass(
- eq=False,
- frozen=True,
- **{"slots": True} if sys.version_info >= (3, 10) else {},
- )
- class VarOperationCall(ImmutableVar):
- """Base class for immutable vars that are the result of a function call."""
- _func: Optional[FunctionVar] = dataclasses.field(default=None)
- _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple)
- def __init__(
- self, func: FunctionVar, *args: Var | Any, _var_data: VarData | None = None
- ):
- """Initialize the function call var.
- Args:
- func: The function to call.
- *args: The arguments to call the function with.
- _var_data: Additional hooks and imports associated with the Var.
- """
- super(VarOperationCall, self).__init__(
- _var_name="",
- _var_type=Any,
- _var_data=ImmutableVarData.merge(_var_data),
- )
- object.__setattr__(self, "_func", func)
- object.__setattr__(self, "_args", args)
- object.__delattr__(self, "_var_name")
- def __getattr__(self, name):
- """Get an attribute of the var.
- Args:
- name: The name of the attribute.
- Returns:
- The attribute of the var.
- """
- if name == "_var_name":
- return self._cached_var_name
- return super(type(self), self).__getattr__(name)
- @cached_property
- def _cached_var_name(self) -> str:
- """The name of the var.
- Returns:
- The name of the var.
- """
- return f"({str(self._func)}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))"
- @cached_property
- def _cached_get_all_var_data(self) -> ImmutableVarData | None:
- """Get all VarData associated with the Var.
- Returns:
- The VarData of the components and all of its children.
- """
- return ImmutableVarData.merge(
- self._func._get_all_var_data() if self._func is not None else None,
- *[var._get_all_var_data() for var in self._args],
- self._var_data,
- )
- def _get_all_var_data(self) -> ImmutableVarData | None:
- """Wrapper method for cached property.
- Returns:
- The VarData of the components and all of its children.
- """
- return self._cached_get_all_var_data
- def __post_init__(self):
- """Post-initialize the var."""
- pass
- @dataclasses.dataclass(
- eq=False,
- frozen=True,
- **{"slots": True} if sys.version_info >= (3, 10) else {},
- )
- class ArgsFunctionOperation(FunctionVar):
- """Base class for immutable function defined via arguments and return expression."""
- _args_names: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
- _return_expr: Union[Var, Any] = dataclasses.field(default=None)
- def __init__(
- self,
- args_names: Tuple[str, ...],
- return_expr: Var | Any,
- _var_data: VarData | None = None,
- ) -> None:
- """Initialize the function with arguments var.
- Args:
- args_names: The names of the arguments.
- return_expr: The return expression of the function.
- _var_data: Additional hooks and imports associated with the Var.
- """
- super(ArgsFunctionOperation, self).__init__(
- _var_name=f"",
- _var_type=Callable,
- _var_data=ImmutableVarData.merge(_var_data),
- )
- object.__setattr__(self, "_args_names", args_names)
- object.__setattr__(self, "_return_expr", return_expr)
- object.__delattr__(self, "_var_name")
- def __getattr__(self, name):
- """Get an attribute of the var.
- Args:
- name: The name of the attribute.
- Returns:
- The attribute of the var.
- """
- if name == "_var_name":
- return self._cached_var_name
- return super(type(self), self).__getattr__(name)
- @cached_property
- def _cached_var_name(self) -> str:
- """The name of the var.
- Returns:
- The name of the var.
- """
- return f"(({', '.join(self._args_names)}) => ({str(LiteralVar.create(self._return_expr))}))"
- @cached_property
- def _cached_get_all_var_data(self) -> ImmutableVarData | None:
- """Get all VarData associated with the Var.
- Returns:
- The VarData of the components and all of its children.
- """
- return ImmutableVarData.merge(
- self._return_expr._get_all_var_data(),
- self._var_data,
- )
- def _get_all_var_data(self) -> ImmutableVarData | None:
- """Wrapper method for cached property.
- Returns:
- The VarData of the components and all of its children.
- """
- return self._cached_get_all_var_data
- def __post_init__(self):
- """Post-initialize the var."""
|