test_lifespan.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. """Test cases for the FastAPI lifespan integration."""
  2. from typing import Generator
  3. import pytest
  4. from selenium.webdriver.common.by import By
  5. from reflex.testing import AppHarness
  6. from .utils import SessionStorage
  7. def LifespanApp():
  8. """App with lifespan tasks and context."""
  9. import asyncio
  10. from contextlib import asynccontextmanager
  11. import reflex as rx
  12. lifespan_task_global = 0
  13. lifespan_context_global = 0
  14. @asynccontextmanager
  15. async def lifespan_context(app, inc: int = 1):
  16. global lifespan_context_global
  17. print(f"Lifespan context entered: {app}.")
  18. lifespan_context_global += inc # pyright: ignore[reportUnboundVariable]
  19. try:
  20. yield
  21. finally:
  22. print("Lifespan context exited.")
  23. lifespan_context_global += inc
  24. async def lifespan_task(inc: int = 1):
  25. global lifespan_task_global
  26. print("Lifespan global started.")
  27. try:
  28. while True:
  29. lifespan_task_global += inc # pyright: ignore[reportUnboundVariable]
  30. await asyncio.sleep(0.1)
  31. except asyncio.CancelledError as ce:
  32. print(f"Lifespan global cancelled: {ce}.")
  33. lifespan_task_global = 0
  34. class LifespanState(rx.State):
  35. @rx.var
  36. def task_global(self) -> int:
  37. return lifespan_task_global
  38. @rx.var
  39. def context_global(self) -> int:
  40. return lifespan_context_global
  41. @rx.event
  42. def tick(self, date):
  43. pass
  44. def index():
  45. return rx.vstack(
  46. rx.text(LifespanState.task_global, id="task_global"),
  47. rx.text(LifespanState.context_global, id="context_global"),
  48. rx.moment(interval=100, on_change=LifespanState.tick),
  49. )
  50. app = rx.App()
  51. app.register_lifespan_task(lifespan_task)
  52. app.register_lifespan_task(lifespan_context, inc=2)
  53. app.add_page(index)
  54. @pytest.fixture()
  55. def lifespan_app(tmp_path) -> Generator[AppHarness, None, None]:
  56. """Start LifespanApp app at tmp_path via AppHarness.
  57. Args:
  58. tmp_path: pytest tmp_path fixture
  59. Yields:
  60. running AppHarness instance
  61. """
  62. with AppHarness.create(
  63. root=tmp_path,
  64. app_source=LifespanApp,
  65. ) as harness:
  66. yield harness
  67. @pytest.mark.asyncio
  68. async def test_lifespan(lifespan_app: AppHarness):
  69. """Test the lifespan integration.
  70. Args:
  71. lifespan_app: harness for LifespanApp app
  72. """
  73. assert lifespan_app.app_module is not None, "app module is not found"
  74. assert lifespan_app.app_instance is not None, "app is not running"
  75. driver = lifespan_app.frontend()
  76. ss = SessionStorage(driver)
  77. assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found"
  78. context_global = driver.find_element(By.ID, "context_global")
  79. task_global = driver.find_element(By.ID, "task_global")
  80. assert context_global.text == "2"
  81. assert lifespan_app.app_module.lifespan_context_global == 2 # type: ignore
  82. original_task_global_text = task_global.text
  83. original_task_global_value = int(original_task_global_text)
  84. lifespan_app.poll_for_content(task_global, exp_not_equal=original_task_global_text)
  85. assert lifespan_app.app_module.lifespan_task_global > original_task_global_value # type: ignore
  86. assert int(task_global.text) > original_task_global_value
  87. # Kill the backend
  88. assert lifespan_app.backend is not None
  89. lifespan_app.backend.should_exit = True
  90. if lifespan_app.backend_thread is not None:
  91. lifespan_app.backend_thread.join()
  92. # Check that the lifespan tasks have been cancelled
  93. assert lifespan_app.app_module.lifespan_task_global == 0
  94. assert lifespan_app.app_module.lifespan_context_global == 4