|
@@ -4,6 +4,7 @@ from __future__ import annotations
|
|
|
import asyncio
|
|
|
import contextlib
|
|
|
import dataclasses
|
|
|
+import functools
|
|
|
import inspect
|
|
|
import os
|
|
|
import pathlib
|
|
@@ -20,6 +21,7 @@ import types
|
|
|
from http.server import SimpleHTTPRequestHandler
|
|
|
from typing import (
|
|
|
TYPE_CHECKING,
|
|
|
+ Any,
|
|
|
AsyncIterator,
|
|
|
Callable,
|
|
|
Coroutine,
|
|
@@ -135,6 +137,8 @@ class AppHarness:
|
|
|
if app_name is None:
|
|
|
if app_source is None:
|
|
|
app_name = root.name.lower()
|
|
|
+ elif isinstance(app_source, functools.partial):
|
|
|
+ app_name = app_source.func.__name__.lower()
|
|
|
else:
|
|
|
app_name = app_source.__name__.lower()
|
|
|
return cls(
|
|
@@ -144,13 +148,54 @@ class AppHarness:
|
|
|
app_module_path=root / app_name / f"{app_name}.py",
|
|
|
)
|
|
|
|
|
|
+ def _get_globals_from_signature(self, func: Any) -> dict[str, Any]:
|
|
|
+ """Get the globals from a function or module object.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ func: function or module object
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ dict of globals
|
|
|
+ """
|
|
|
+ overrides = {}
|
|
|
+ glbs = {}
|
|
|
+ if not callable(func):
|
|
|
+ return glbs
|
|
|
+ if isinstance(func, functools.partial):
|
|
|
+ overrides = func.keywords
|
|
|
+ func = func.func
|
|
|
+ for param in inspect.signature(func).parameters.values():
|
|
|
+ if param.default is not inspect.Parameter.empty:
|
|
|
+ glbs[param.name] = param.default
|
|
|
+ glbs.update(overrides)
|
|
|
+ return glbs
|
|
|
+
|
|
|
+ def _get_source_from_func(self, func: Any) -> str:
|
|
|
+ """Get the source from a function or module object.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ func: function or module object
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ source code
|
|
|
+ """
|
|
|
+ source = inspect.getsource(func)
|
|
|
+ source = re.sub(r"^\s*def\s+\w+\s*\(.*?\):", "", source, flags=re.DOTALL)
|
|
|
+ return textwrap.dedent(source)
|
|
|
+
|
|
|
def _initialize_app(self):
|
|
|
os.environ["TELEMETRY_ENABLED"] = "" # disable telemetry reporting for tests
|
|
|
self.app_path.mkdir(parents=True, exist_ok=True)
|
|
|
if self.app_source is not None:
|
|
|
+ app_globals = self._get_globals_from_signature(self.app_source)
|
|
|
+ if isinstance(self.app_source, functools.partial):
|
|
|
+ self.app_source = self.app_source.func # type: ignore
|
|
|
# get the source from a function or module object
|
|
|
- source_code = textwrap.dedent(
|
|
|
- "".join(inspect.getsource(self.app_source).splitlines(True)[1:]),
|
|
|
+ source_code = "\n".join(
|
|
|
+ [
|
|
|
+ "\n".join(f"{k} = {v!r}" for k, v in app_globals.items()),
|
|
|
+ self._get_source_from_func(self.app_source),
|
|
|
+ ]
|
|
|
)
|
|
|
with chdir(self.app_path):
|
|
|
reflex.reflex._init(
|
|
@@ -167,11 +212,11 @@ class AppHarness:
|
|
|
# self.app_module.app.
|
|
|
self.app_module = reflex.utils.prerequisites.get_compiled_app(reload=True)
|
|
|
self.app_instance = self.app_module.app
|
|
|
- if isinstance(self.app_instance.state_manager, StateManagerRedis):
|
|
|
+ if isinstance(self.app_instance._state_manager, StateManagerRedis):
|
|
|
# Create our own redis connection for testing.
|
|
|
self.state_manager = StateManagerRedis.create(self.app_instance.state)
|
|
|
else:
|
|
|
- self.state_manager = self.app_instance.state_manager
|
|
|
+ self.state_manager = self.app_instance._state_manager
|
|
|
|
|
|
def _get_backend_shutdown_handler(self):
|
|
|
if self.backend is None:
|
|
@@ -181,10 +226,13 @@ class AppHarness:
|
|
|
|
|
|
async def _shutdown_redis(*args, **kwargs) -> None:
|
|
|
# ensure redis is closed before event loop
|
|
|
- if self.app_instance is not None and isinstance(
|
|
|
- self.app_instance.state_manager, StateManagerRedis
|
|
|
- ):
|
|
|
- await self.app_instance.state_manager.close()
|
|
|
+ try:
|
|
|
+ if self.app_instance is not None and isinstance(
|
|
|
+ self.app_instance.state_manager, StateManagerRedis
|
|
|
+ ):
|
|
|
+ await self.app_instance.state_manager.close()
|
|
|
+ except ValueError:
|
|
|
+ pass
|
|
|
await original_shutdown(*args, **kwargs)
|
|
|
|
|
|
return _shutdown_redis
|