test_lifespan.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. """Test cases for the FastAPI lifespan integration."""
  2. from typing import Generator
  3. import pytest
  4. from selenium.webdriver.common.by import By
  5. from reflex.testing import AppHarness
  6. from .utils import SessionStorage
  7. def LifespanApp():
  8. """App with lifespan tasks and context."""
  9. import asyncio
  10. from contextlib import asynccontextmanager
  11. import reflex as rx
  12. lifespan_task_global = 0
  13. lifespan_context_global = 0
  14. @asynccontextmanager
  15. async def lifespan_context(app, inc: int = 1):
  16. global lifespan_context_global
  17. print(f"Lifespan context entered: {app}.")
  18. lifespan_context_global += inc # pyright: ignore[reportUnboundVariable]
  19. try:
  20. yield
  21. finally:
  22. print("Lifespan context exited.")
  23. lifespan_context_global += inc
  24. async def lifespan_task(inc: int = 1):
  25. global lifespan_task_global
  26. print("Lifespan global started.")
  27. try:
  28. while True:
  29. lifespan_task_global += inc # pyright: ignore[reportUnboundVariable]
  30. await asyncio.sleep(0.1)
  31. except asyncio.CancelledError as ce:
  32. print(f"Lifespan global cancelled: {ce}.")
  33. lifespan_task_global = 0
  34. class LifespanState(rx.State):
  35. @rx.var
  36. def task_global(self) -> int:
  37. return lifespan_task_global
  38. @rx.var
  39. def context_global(self) -> int:
  40. return lifespan_context_global
  41. def tick(self, date):
  42. pass
  43. def index():
  44. return rx.vstack(
  45. rx.text(LifespanState.task_global, id="task_global"),
  46. rx.text(LifespanState.context_global, id="context_global"),
  47. rx.moment(interval=100, on_change=LifespanState.tick),
  48. )
  49. app = rx.App()
  50. app.register_lifespan_task(lifespan_task)
  51. app.register_lifespan_task(lifespan_context, inc=2)
  52. app.add_page(index)
  53. @pytest.fixture()
  54. def lifespan_app(tmp_path) -> Generator[AppHarness, None, None]:
  55. """Start LifespanApp app at tmp_path via AppHarness.
  56. Args:
  57. tmp_path: pytest tmp_path fixture
  58. Yields:
  59. running AppHarness instance
  60. """
  61. with AppHarness.create(
  62. root=tmp_path,
  63. app_source=LifespanApp, # type: ignore
  64. ) as harness:
  65. yield harness
  66. @pytest.mark.asyncio
  67. async def test_lifespan(lifespan_app: AppHarness):
  68. """Test the lifespan integration.
  69. Args:
  70. lifespan_app: harness for LifespanApp app
  71. """
  72. assert lifespan_app.app_module is not None, "app module is not found"
  73. assert lifespan_app.app_instance is not None, "app is not running"
  74. driver = lifespan_app.frontend()
  75. ss = SessionStorage(driver)
  76. assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found"
  77. context_global = driver.find_element(By.ID, "context_global")
  78. task_global = driver.find_element(By.ID, "task_global")
  79. assert context_global.text == "2"
  80. assert lifespan_app.app_module.lifespan_context_global == 2 # type: ignore
  81. original_task_global_text = task_global.text
  82. original_task_global_value = int(original_task_global_text)
  83. lifespan_app.poll_for_content(task_global, exp_not_equal=original_task_global_text)
  84. assert lifespan_app.app_module.lifespan_task_global > original_task_global_value # type: ignore
  85. assert int(task_global.text) > original_task_global_value
  86. # Kill the backend
  87. assert lifespan_app.backend is not None
  88. lifespan_app.backend.should_exit = True
  89. if lifespan_app.backend_thread is not None:
  90. lifespan_app.backend_thread.join()
  91. # Check that the lifespan tasks have been cancelled
  92. assert lifespan_app.app_module.lifespan_task_global == 0
  93. assert lifespan_app.app_module.lifespan_context_global == 4