dep_tracking.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. """Collection of base classes."""
  2. from __future__ import annotations
  3. import contextlib
  4. import dataclasses
  5. import dis
  6. import enum
  7. import inspect
  8. from types import CellType, CodeType, FunctionType
  9. from typing import TYPE_CHECKING, Any, ClassVar, cast
  10. from reflex.utils.exceptions import VarValueError
  11. if TYPE_CHECKING:
  12. from reflex.state import BaseState
  13. from .base import Var
  14. CellEmpty = object()
  15. def get_cell_value(cell: CellType) -> Any:
  16. """Get the value of a cell object.
  17. Args:
  18. cell: The cell object to get the value from. (func.__closure__ objects)
  19. Returns:
  20. The value from the cell or CellEmpty if a ValueError is raised.
  21. """
  22. try:
  23. return cell.cell_contents
  24. except ValueError:
  25. return CellEmpty
  26. class ScanStatus(enum.Enum):
  27. """State of the dis instruction scanning loop."""
  28. SCANNING = enum.auto()
  29. GETTING_ATTR = enum.auto()
  30. GETTING_STATE = enum.auto()
  31. GETTING_VAR = enum.auto()
  32. @dataclasses.dataclass
  33. class DependencyTracker:
  34. """State machine for identifying state attributes that are accessed by a function."""
  35. func: FunctionType | CodeType = dataclasses.field()
  36. state_cls: type[BaseState] = dataclasses.field()
  37. dependencies: dict[str, set[str]] = dataclasses.field(default_factory=dict)
  38. scan_status: ScanStatus = dataclasses.field(default=ScanStatus.SCANNING)
  39. top_of_stack: str | None = dataclasses.field(default=None)
  40. tracked_locals: dict[str, type[BaseState]] = dataclasses.field(default_factory=dict)
  41. _getting_state_class: type[BaseState] | None = dataclasses.field(default=None)
  42. _getting_var_instructions: list[dis.Instruction] = dataclasses.field(
  43. default_factory=list
  44. )
  45. INVALID_NAMES: ClassVar[list[str]] = ["parent_state", "substates", "get_substate"]
  46. def __post_init__(self):
  47. """After initializing, populate the dependencies dict."""
  48. with contextlib.suppress(AttributeError):
  49. # unbox functools.partial
  50. self.func = cast(FunctionType, self.func.func) # pyright: ignore[reportAttributeAccessIssue]
  51. with contextlib.suppress(AttributeError):
  52. # unbox EventHandler
  53. self.func = cast(FunctionType, self.func.fn) # pyright: ignore[reportAttributeAccessIssue,reportFunctionMemberAccess]
  54. if isinstance(self.func, FunctionType):
  55. with contextlib.suppress(AttributeError, IndexError):
  56. # the first argument to the function is the name of "self" arg
  57. self.tracked_locals[self.func.__code__.co_varnames[0]] = self.state_cls
  58. self._populate_dependencies()
  59. def _merge_deps(self, tracker: DependencyTracker) -> None:
  60. """Merge dependencies from another DependencyTracker.
  61. Args:
  62. tracker: The DependencyTracker to merge dependencies from.
  63. """
  64. for state_name, dep_name in tracker.dependencies.items():
  65. self.dependencies.setdefault(state_name, set()).update(dep_name)
  66. def load_attr_or_method(self, instruction: dis.Instruction) -> None:
  67. """Handle loading an attribute or method from the object on top of the stack.
  68. This method directly tracks attributes and recursively merges
  69. dependencies from analyzing the dependencies of any methods called.
  70. Args:
  71. instruction: The dis instruction to process.
  72. Raises:
  73. VarValueError: if the attribute is an disallowed name.
  74. """
  75. from .base import ComputedVar
  76. if instruction.argval in self.INVALID_NAMES:
  77. msg = f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`."
  78. raise VarValueError(msg)
  79. if instruction.argval == "get_state":
  80. # Special case: arbitrary state access requested.
  81. self.scan_status = ScanStatus.GETTING_STATE
  82. return
  83. if instruction.argval == "get_var_value":
  84. # Special case: arbitrary var access requested.
  85. self.scan_status = ScanStatus.GETTING_VAR
  86. return
  87. # Reset status back to SCANNING after attribute is accessed.
  88. self.scan_status = ScanStatus.SCANNING
  89. if not self.top_of_stack:
  90. return
  91. target_state = self.tracked_locals[self.top_of_stack]
  92. try:
  93. ref_obj = getattr(target_state, instruction.argval)
  94. except AttributeError:
  95. # Not found on this state class, maybe it is a dynamic attribute that will be picked up later.
  96. ref_obj = None
  97. if isinstance(ref_obj, property) and not isinstance(ref_obj, ComputedVar):
  98. # recurse into property fget functions
  99. ref_obj = ref_obj.fget
  100. if callable(ref_obj):
  101. # recurse into callable attributes
  102. self._merge_deps(
  103. type(self)(func=cast(FunctionType, ref_obj), state_cls=target_state)
  104. )
  105. elif (
  106. instruction.argval in target_state.backend_vars
  107. or instruction.argval in target_state.vars
  108. ):
  109. # var access
  110. self.dependencies.setdefault(target_state.get_full_name(), set()).add(
  111. instruction.argval
  112. )
  113. def _get_globals(self) -> dict[str, Any]:
  114. """Get the globals of the function.
  115. Returns:
  116. The var names and values in the globals of the function.
  117. """
  118. if isinstance(self.func, CodeType):
  119. return {}
  120. return self.func.__globals__ # pyright: ignore[reportAttributeAccessIssue]
  121. def _get_closure(self) -> dict[str, Any]:
  122. """Get the closure of the function, with unbound values omitted.
  123. Returns:
  124. The var names and values in the closure of the function.
  125. """
  126. if isinstance(self.func, CodeType):
  127. return {}
  128. return {
  129. var_name: get_cell_value(cell)
  130. for var_name, cell in zip(
  131. self.func.__code__.co_freevars, # pyright: ignore[reportAttributeAccessIssue]
  132. self.func.__closure__ or (),
  133. strict=False,
  134. )
  135. if get_cell_value(cell) is not CellEmpty
  136. }
  137. def handle_getting_state(self, instruction: dis.Instruction) -> None:
  138. """Handle bytecode analysis when `get_state` was called in the function.
  139. If the wrapped function is getting an arbitrary state and saving it to a
  140. local variable, this method associates the local variable name with the
  141. state class in self.tracked_locals.
  142. When an attribute/method is accessed on a tracked local, it will be
  143. associated with this state.
  144. Args:
  145. instruction: The dis instruction to process.
  146. Raises:
  147. VarValueError: if the state class cannot be determined from the instruction.
  148. """
  149. from reflex.state import BaseState
  150. if instruction.opname == "LOAD_FAST":
  151. msg = f"Dependency detection cannot identify get_state class from local var {instruction.argval}."
  152. raise VarValueError(msg)
  153. if isinstance(self.func, CodeType):
  154. msg = "Dependency detection cannot identify get_state class from a code object."
  155. raise VarValueError(msg)
  156. if instruction.opname == "LOAD_GLOBAL":
  157. # Special case: referencing state class from global scope.
  158. try:
  159. self._getting_state_class = self._get_globals()[instruction.argval]
  160. except (ValueError, KeyError) as ve:
  161. msg = f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, not found in globals."
  162. raise VarValueError(msg) from ve
  163. elif instruction.opname == "LOAD_DEREF":
  164. # Special case: referencing state class from closure.
  165. try:
  166. self._getting_state_class = self._get_closure()[instruction.argval]
  167. except (ValueError, KeyError) as ve:
  168. msg = f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?"
  169. raise VarValueError(msg) from ve
  170. elif instruction.opname == "STORE_FAST":
  171. # Storing the result of get_state in a local variable.
  172. if not isinstance(self._getting_state_class, type) or not issubclass(
  173. self._getting_state_class, BaseState
  174. ):
  175. msg = f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`."
  176. raise VarValueError(msg)
  177. self.tracked_locals[instruction.argval] = self._getting_state_class
  178. self.scan_status = ScanStatus.SCANNING
  179. self._getting_state_class = None
  180. def _eval_var(self) -> Var:
  181. """Evaluate instructions from the wrapped function to get the Var object.
  182. Returns:
  183. The Var object.
  184. Raises:
  185. VarValueError: if the source code for the var cannot be determined.
  186. """
  187. # Get the original source code and eval it to get the Var.
  188. module = inspect.getmodule(self.func)
  189. positions0 = self._getting_var_instructions[0].positions
  190. positions1 = self._getting_var_instructions[-1].positions
  191. if module is None or positions0 is None or positions1 is None:
  192. msg = f"Cannot determine the source code for the var in {self.func!r}."
  193. raise VarValueError(msg)
  194. start_line = positions0.lineno
  195. start_column = positions0.col_offset
  196. end_line = positions1.end_lineno
  197. end_column = positions1.end_col_offset
  198. if (
  199. start_line is None
  200. or start_column is None
  201. or end_line is None
  202. or end_column is None
  203. ):
  204. msg = f"Cannot determine the source code for the var in {self.func!r}."
  205. raise VarValueError(msg)
  206. source = inspect.getsource(module).splitlines(True)[start_line - 1 : end_line]
  207. # Create a python source string snippet.
  208. if len(source) > 1:
  209. snipped_source = "".join(
  210. [
  211. *source[0][start_column:],
  212. *(source[1:-2] if len(source) > 2 else []),
  213. *source[-1][: end_column - 1],
  214. ]
  215. )
  216. else:
  217. snipped_source = source[0][start_column : end_column - 1]
  218. # Evaluate the string in the context of the function's globals and closure.
  219. return eval(f"({snipped_source})", self._get_globals(), self._get_closure())
  220. def handle_getting_var(self, instruction: dis.Instruction) -> None:
  221. """Handle bytecode analysis when `get_var_value` was called in the function.
  222. This only really works if the expression passed to `get_var_value` is
  223. evaluable in the function's global scope or closure, so getting the var
  224. value from a var saved in a local variable or in the class instance is
  225. not possible.
  226. Args:
  227. instruction: The dis instruction to process.
  228. Raises:
  229. VarValueError: if the source code for the var cannot be determined.
  230. """
  231. if instruction.opname == "CALL" and self._getting_var_instructions:
  232. if self._getting_var_instructions:
  233. the_var = self._eval_var()
  234. the_var_data = the_var._get_all_var_data()
  235. if the_var_data is None:
  236. msg = f"Cannot determine the source code for the var in {self.func!r}."
  237. raise VarValueError(msg)
  238. self.dependencies.setdefault(the_var_data.state, set()).add(
  239. the_var_data.field_name
  240. )
  241. self._getting_var_instructions.clear()
  242. self.scan_status = ScanStatus.SCANNING
  243. else:
  244. self._getting_var_instructions.append(instruction)
  245. def _populate_dependencies(self) -> None:
  246. """Update self.dependencies based on the disassembly of self.func.
  247. Save references to attributes accessed on "self" or other fetched states.
  248. Recursively called when the function makes a method call on "self" or
  249. define comprehensions or nested functions that may reference "self".
  250. """
  251. for instruction in dis.get_instructions(self.func):
  252. if self.scan_status == ScanStatus.GETTING_STATE:
  253. self.handle_getting_state(instruction)
  254. elif self.scan_status == ScanStatus.GETTING_VAR:
  255. self.handle_getting_var(instruction)
  256. elif (
  257. instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
  258. and instruction.argval in self.tracked_locals
  259. ):
  260. # bytecode loaded the class instance to the top of stack, next load instruction
  261. # is referencing an attribute on self
  262. self.top_of_stack = instruction.argval
  263. self.scan_status = ScanStatus.GETTING_ATTR
  264. elif self.scan_status == ScanStatus.GETTING_ATTR and instruction.opname in (
  265. "LOAD_ATTR",
  266. "LOAD_METHOD",
  267. ):
  268. self.load_attr_or_method(instruction)
  269. self.top_of_stack = None
  270. elif instruction.opname == "LOAD_CONST" and isinstance(
  271. instruction.argval, CodeType
  272. ):
  273. # recurse into nested functions / comprehensions, which can reference
  274. # instance attributes from the outer scope
  275. self._merge_deps(
  276. type(self)(
  277. func=instruction.argval,
  278. state_cls=self.state_cls,
  279. tracked_locals=self.tracked_locals,
  280. )
  281. )