dep_tracking.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. """Collection of base classes."""
  2. from __future__ import annotations
  3. import contextlib
  4. import dataclasses
  5. import dis
  6. import enum
  7. import inspect
  8. from types import CellType, CodeType, FunctionType
  9. from typing import TYPE_CHECKING, Any, ClassVar, cast
  10. from reflex.utils.exceptions import VarValueError
  11. if TYPE_CHECKING:
  12. from reflex.state import BaseState
  13. from .base import Var
  14. CellEmpty = object()
  15. def get_cell_value(cell: CellType) -> Any:
  16. """Get the value of a cell object.
  17. Args:
  18. cell: The cell object to get the value from. (func.__closure__ objects)
  19. Returns:
  20. The value from the cell or CellEmpty if a ValueError is raised.
  21. """
  22. try:
  23. return cell.cell_contents
  24. except ValueError:
  25. return CellEmpty
  26. class ScanStatus(enum.Enum):
  27. """State of the dis instruction scanning loop."""
  28. SCANNING = enum.auto()
  29. GETTING_ATTR = enum.auto()
  30. GETTING_STATE = enum.auto()
  31. GETTING_VAR = enum.auto()
  32. @dataclasses.dataclass
  33. class DependencyTracker:
  34. """State machine for identifying state attributes that are accessed by a function."""
  35. func: FunctionType | CodeType = dataclasses.field()
  36. state_cls: type[BaseState] = dataclasses.field()
  37. dependencies: dict[str, set[str]] = dataclasses.field(default_factory=dict)
  38. scan_status: ScanStatus = dataclasses.field(default=ScanStatus.SCANNING)
  39. top_of_stack: str | None = dataclasses.field(default=None)
  40. tracked_locals: dict[str, type[BaseState]] = dataclasses.field(default_factory=dict)
  41. _getting_state_class: type[BaseState] | None = dataclasses.field(default=None)
  42. _getting_var_instructions: list[dis.Instruction] = dataclasses.field(
  43. default_factory=list
  44. )
  45. INVALID_NAMES: ClassVar[list[str]] = ["parent_state", "substates", "get_substate"]
  46. def __post_init__(self):
  47. """After initializing, populate the dependencies dict."""
  48. with contextlib.suppress(AttributeError):
  49. # unbox functools.partial
  50. self.func = cast(FunctionType, self.func.func) # pyright: ignore[reportAttributeAccessIssue]
  51. with contextlib.suppress(AttributeError):
  52. # unbox EventHandler
  53. self.func = cast(FunctionType, self.func.fn) # pyright: ignore[reportAttributeAccessIssue]
  54. if isinstance(self.func, FunctionType):
  55. with contextlib.suppress(AttributeError, IndexError):
  56. # the first argument to the function is the name of "self" arg
  57. self.tracked_locals[self.func.__code__.co_varnames[0]] = self.state_cls
  58. self._populate_dependencies()
  59. def _merge_deps(self, tracker: DependencyTracker) -> None:
  60. """Merge dependencies from another DependencyTracker.
  61. Args:
  62. tracker: The DependencyTracker to merge dependencies from.
  63. """
  64. for state_name, dep_name in tracker.dependencies.items():
  65. self.dependencies.setdefault(state_name, set()).update(dep_name)
  66. def load_attr_or_method(self, instruction: dis.Instruction) -> None:
  67. """Handle loading an attribute or method from the object on top of the stack.
  68. This method directly tracks attributes and recursively merges
  69. dependencies from analyzing the dependencies of any methods called.
  70. Args:
  71. instruction: The dis instruction to process.
  72. Raises:
  73. VarValueError: if the attribute is an disallowed name.
  74. """
  75. from .base import ComputedVar
  76. if instruction.argval in self.INVALID_NAMES:
  77. raise VarValueError(
  78. f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`."
  79. )
  80. if instruction.argval == "get_state":
  81. # Special case: arbitrary state access requested.
  82. self.scan_status = ScanStatus.GETTING_STATE
  83. return
  84. if instruction.argval == "get_var_value":
  85. # Special case: arbitrary var access requested.
  86. self.scan_status = ScanStatus.GETTING_VAR
  87. return
  88. # Reset status back to SCANNING after attribute is accessed.
  89. self.scan_status = ScanStatus.SCANNING
  90. if not self.top_of_stack:
  91. return
  92. target_state = self.tracked_locals[self.top_of_stack]
  93. try:
  94. ref_obj = getattr(target_state, instruction.argval)
  95. except AttributeError:
  96. # Not found on this state class, maybe it is a dynamic attribute that will be picked up later.
  97. ref_obj = None
  98. if isinstance(ref_obj, property) and not isinstance(ref_obj, ComputedVar):
  99. # recurse into property fget functions
  100. ref_obj = ref_obj.fget
  101. if callable(ref_obj):
  102. # recurse into callable attributes
  103. self._merge_deps(
  104. type(self)(func=cast(FunctionType, ref_obj), state_cls=target_state)
  105. )
  106. elif (
  107. instruction.argval in target_state.backend_vars
  108. or instruction.argval in target_state.vars
  109. ):
  110. # var access
  111. self.dependencies.setdefault(target_state.get_full_name(), set()).add(
  112. instruction.argval
  113. )
  114. def _get_globals(self) -> dict[str, Any]:
  115. """Get the globals of the function.
  116. Returns:
  117. The var names and values in the globals of the function.
  118. """
  119. if isinstance(self.func, CodeType):
  120. return {}
  121. return self.func.__globals__ # pyright: ignore[reportAttributeAccessIssue]
  122. def _get_closure(self) -> dict[str, Any]:
  123. """Get the closure of the function, with unbound values omitted.
  124. Returns:
  125. The var names and values in the closure of the function.
  126. """
  127. if isinstance(self.func, CodeType):
  128. return {}
  129. return {
  130. var_name: get_cell_value(cell)
  131. for var_name, cell in zip(
  132. self.func.__code__.co_freevars, # pyright: ignore[reportAttributeAccessIssue]
  133. self.func.__closure__ or (),
  134. strict=False,
  135. )
  136. if get_cell_value(cell) is not CellEmpty
  137. }
  138. def handle_getting_state(self, instruction: dis.Instruction) -> None:
  139. """Handle bytecode analysis when `get_state` was called in the function.
  140. If the wrapped function is getting an arbitrary state and saving it to a
  141. local variable, this method associates the local variable name with the
  142. state class in self.tracked_locals.
  143. When an attribute/method is accessed on a tracked local, it will be
  144. associated with this state.
  145. Args:
  146. instruction: The dis instruction to process.
  147. Raises:
  148. VarValueError: if the state class cannot be determined from the instruction.
  149. """
  150. from reflex.state import BaseState
  151. if instruction.opname == "LOAD_FAST":
  152. raise VarValueError(
  153. f"Dependency detection cannot identify get_state class from local var {instruction.argval}."
  154. )
  155. if isinstance(self.func, CodeType):
  156. raise VarValueError(
  157. "Dependency detection cannot identify get_state class from a code object."
  158. )
  159. if instruction.opname == "LOAD_GLOBAL":
  160. # Special case: referencing state class from global scope.
  161. try:
  162. self._getting_state_class = self._get_globals()[instruction.argval]
  163. except (ValueError, KeyError) as ve:
  164. raise VarValueError(
  165. f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, not found in globals."
  166. ) from ve
  167. elif instruction.opname == "LOAD_DEREF":
  168. # Special case: referencing state class from closure.
  169. try:
  170. self._getting_state_class = self._get_closure()[instruction.argval]
  171. except (ValueError, KeyError) as ve:
  172. raise VarValueError(
  173. f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?"
  174. ) from ve
  175. elif instruction.opname == "STORE_FAST":
  176. # Storing the result of get_state in a local variable.
  177. if not isinstance(self._getting_state_class, type) or not issubclass(
  178. self._getting_state_class, BaseState
  179. ):
  180. raise VarValueError(
  181. f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`."
  182. )
  183. self.tracked_locals[instruction.argval] = self._getting_state_class
  184. self.scan_status = ScanStatus.SCANNING
  185. self._getting_state_class = None
  186. def _eval_var(self) -> Var:
  187. """Evaluate instructions from the wrapped function to get the Var object.
  188. Returns:
  189. The Var object.
  190. Raises:
  191. VarValueError: if the source code for the var cannot be determined.
  192. """
  193. # Get the original source code and eval it to get the Var.
  194. module = inspect.getmodule(self.func)
  195. positions0 = self._getting_var_instructions[0].positions
  196. positions1 = self._getting_var_instructions[-1].positions
  197. if module is None or positions0 is None or positions1 is None:
  198. raise VarValueError(
  199. f"Cannot determine the source code for the var in {self.func!r}."
  200. )
  201. start_line = positions0.lineno
  202. start_column = positions0.col_offset
  203. end_line = positions1.end_lineno
  204. end_column = positions1.end_col_offset
  205. if (
  206. start_line is None
  207. or start_column is None
  208. or end_line is None
  209. or end_column is None
  210. ):
  211. raise VarValueError(
  212. f"Cannot determine the source code for the var in {self.func!r}."
  213. )
  214. source = inspect.getsource(module).splitlines(True)[start_line - 1 : end_line]
  215. # Create a python source string snippet.
  216. if len(source) > 1:
  217. snipped_source = "".join(
  218. [
  219. *source[0][start_column:],
  220. *(source[1:-2] if len(source) > 2 else []),
  221. *source[-1][: end_column - 1],
  222. ]
  223. )
  224. else:
  225. snipped_source = source[0][start_column : end_column - 1]
  226. # Evaluate the string in the context of the function's globals and closure.
  227. return eval(f"({snipped_source})", self._get_globals(), self._get_closure())
  228. def handle_getting_var(self, instruction: dis.Instruction) -> None:
  229. """Handle bytecode analysis when `get_var_value` was called in the function.
  230. This only really works if the expression passed to `get_var_value` is
  231. evaluable in the function's global scope or closure, so getting the var
  232. value from a var saved in a local variable or in the class instance is
  233. not possible.
  234. Args:
  235. instruction: The dis instruction to process.
  236. Raises:
  237. VarValueError: if the source code for the var cannot be determined.
  238. """
  239. if instruction.opname == "CALL" and self._getting_var_instructions:
  240. if self._getting_var_instructions:
  241. the_var = self._eval_var()
  242. the_var_data = the_var._get_all_var_data()
  243. if the_var_data is None:
  244. raise VarValueError(
  245. f"Cannot determine the source code for the var in {self.func!r}."
  246. )
  247. self.dependencies.setdefault(the_var_data.state, set()).add(
  248. the_var_data.field_name
  249. )
  250. self._getting_var_instructions.clear()
  251. self.scan_status = ScanStatus.SCANNING
  252. else:
  253. self._getting_var_instructions.append(instruction)
  254. def _populate_dependencies(self) -> None:
  255. """Update self.dependencies based on the disassembly of self.func.
  256. Save references to attributes accessed on "self" or other fetched states.
  257. Recursively called when the function makes a method call on "self" or
  258. define comprehensions or nested functions that may reference "self".
  259. """
  260. for instruction in dis.get_instructions(self.func):
  261. if self.scan_status == ScanStatus.GETTING_STATE:
  262. self.handle_getting_state(instruction)
  263. elif self.scan_status == ScanStatus.GETTING_VAR:
  264. self.handle_getting_var(instruction)
  265. elif (
  266. instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
  267. and instruction.argval in self.tracked_locals
  268. ):
  269. # bytecode loaded the class instance to the top of stack, next load instruction
  270. # is referencing an attribute on self
  271. self.top_of_stack = instruction.argval
  272. self.scan_status = ScanStatus.GETTING_ATTR
  273. elif self.scan_status == ScanStatus.GETTING_ATTR and instruction.opname in (
  274. "LOAD_ATTR",
  275. "LOAD_METHOD",
  276. ):
  277. self.load_attr_or_method(instruction)
  278. self.top_of_stack = None
  279. elif instruction.opname == "LOAD_CONST" and isinstance(
  280. instruction.argval, CodeType
  281. ):
  282. # recurse into nested functions / comprehensions, which can reference
  283. # instance attributes from the outer scope
  284. self._merge_deps(
  285. type(self)(
  286. func=instruction.argval,
  287. state_cls=self.state_cls,
  288. tracked_locals=self.tracked_locals,
  289. )
  290. )