test_lifespan.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. """Test cases for the FastAPI lifespan integration."""
  2. from typing import Generator
  3. import pytest
  4. from selenium.webdriver.common.by import By
  5. from reflex.testing import AppHarness
  6. from .utils import SessionStorage
  7. def LifespanApp():
  8. """App with lifespan tasks and context."""
  9. import asyncio
  10. from contextlib import asynccontextmanager
  11. import reflex as rx
  12. def create_tasks():
  13. lifespan_task_global = 0
  14. lifespan_context_global = 0
  15. def lifespan_context_global_getter():
  16. return lifespan_context_global
  17. def lifespan_task_global_getter():
  18. return lifespan_task_global
  19. @asynccontextmanager
  20. async def lifespan_context(app, inc: int = 1):
  21. nonlocal lifespan_context_global
  22. print(f"Lifespan context entered: {app}.")
  23. lifespan_context_global += inc
  24. try:
  25. yield
  26. finally:
  27. print("Lifespan context exited.")
  28. lifespan_context_global += inc
  29. async def lifespan_task(inc: int = 1):
  30. nonlocal lifespan_task_global
  31. print("Lifespan global started.")
  32. try:
  33. while True:
  34. lifespan_task_global += inc
  35. await asyncio.sleep(0.1)
  36. except asyncio.CancelledError as ce:
  37. print(f"Lifespan global cancelled: {ce}.")
  38. lifespan_task_global = 0
  39. class LifespanState(rx.State):
  40. interval: int = 100
  41. @rx.var(cache=False)
  42. def task_global(self) -> int:
  43. return lifespan_task_global
  44. @rx.var(cache=False)
  45. def context_global(self) -> int:
  46. return lifespan_context_global
  47. @rx.event
  48. def tick(self, date):
  49. pass
  50. return (
  51. lifespan_task,
  52. lifespan_context,
  53. LifespanState,
  54. lifespan_task_global_getter,
  55. lifespan_context_global_getter,
  56. )
  57. (
  58. lifespan_task,
  59. lifespan_context,
  60. LifespanState,
  61. lifespan_task_global_getter,
  62. lifespan_context_global_getter,
  63. ) = create_tasks()
  64. def index():
  65. return rx.vstack(
  66. rx.text(LifespanState.task_global, id="task_global"),
  67. rx.text(LifespanState.context_global, id="context_global"),
  68. rx.button(
  69. rx.moment(
  70. interval=LifespanState.interval, on_change=LifespanState.tick
  71. ),
  72. on_click=LifespanState.set_interval( # type: ignore
  73. rx.cond(LifespanState.interval, 0, 100)
  74. ),
  75. id="toggle-tick",
  76. ),
  77. )
  78. app = rx.App()
  79. app.register_lifespan_task(lifespan_task)
  80. app.register_lifespan_task(lifespan_context, inc=2)
  81. app.add_page(index)
  82. @pytest.fixture()
  83. def lifespan_app(tmp_path) -> Generator[AppHarness, None, None]:
  84. """Start LifespanApp app at tmp_path via AppHarness.
  85. Args:
  86. tmp_path: pytest tmp_path fixture
  87. Yields:
  88. running AppHarness instance
  89. """
  90. with AppHarness.create(
  91. root=tmp_path,
  92. app_source=LifespanApp,
  93. ) as harness:
  94. yield harness
  95. @pytest.mark.asyncio
  96. async def test_lifespan(lifespan_app: AppHarness):
  97. """Test the lifespan integration.
  98. Args:
  99. lifespan_app: harness for LifespanApp app
  100. """
  101. assert lifespan_app.app_module is not None, "app module is not found"
  102. assert lifespan_app.app_instance is not None, "app is not running"
  103. driver = lifespan_app.frontend()
  104. ss = SessionStorage(driver)
  105. assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found"
  106. context_global = driver.find_element(By.ID, "context_global")
  107. task_global = driver.find_element(By.ID, "task_global")
  108. assert context_global.text == "2"
  109. assert lifespan_app.app_module.lifespan_context_global_getter() == 2 # type: ignore
  110. original_task_global_text = task_global.text
  111. original_task_global_value = int(original_task_global_text)
  112. lifespan_app.poll_for_content(task_global, exp_not_equal=original_task_global_text)
  113. driver.find_element(By.ID, "toggle-tick").click() # avoid teardown errors
  114. assert (
  115. lifespan_app.app_module.lifespan_task_global_getter()
  116. > original_task_global_value
  117. ) # type: ignore
  118. assert int(task_global.text) > original_task_global_value
  119. # Kill the backend
  120. assert lifespan_app.backend is not None
  121. lifespan_app.backend.should_exit = True
  122. if lifespan_app.backend_thread is not None:
  123. lifespan_app.backend_thread.join()
  124. # Check that the lifespan tasks have been cancelled
  125. assert lifespan_app.app_module.lifespan_task_global_getter() == 0
  126. assert lifespan_app.app_module.lifespan_context_global_getter() == 4