test_lifespan.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. """Test cases for the Starlette lifespan integration."""
  2. import functools
  3. from collections.abc import Generator
  4. import pytest
  5. from selenium.webdriver.common.by import By
  6. from reflex.testing import AppHarness
  7. from .utils import SessionStorage
  8. def LifespanApp(
  9. mount_cached_fastapi: bool = False, mount_api_transformer: bool = False
  10. ) -> None:
  11. """App with lifespan tasks and context.
  12. Args:
  13. mount_cached_fastapi: Whether to mount the cached FastAPI app.
  14. mount_api_transformer: Whether to mount the API transformer.
  15. """
  16. import asyncio
  17. from contextlib import asynccontextmanager
  18. import reflex as rx
  19. lifespan_task_global = 0
  20. lifespan_context_global = 0
  21. @asynccontextmanager
  22. async def lifespan_context(app, inc: int = 1):
  23. global lifespan_context_global
  24. print(f"Lifespan context entered: {app}.")
  25. lifespan_context_global += inc # pyright: ignore[reportUnboundVariable]
  26. try:
  27. yield
  28. finally:
  29. print("Lifespan context exited.")
  30. lifespan_context_global += inc
  31. async def lifespan_task(inc: int = 1):
  32. global lifespan_task_global
  33. print("Lifespan global started.")
  34. try:
  35. while True:
  36. lifespan_task_global += inc # pyright: ignore[reportUnboundVariable, reportPossiblyUnboundVariable]
  37. await asyncio.sleep(0.1)
  38. except asyncio.CancelledError as ce:
  39. print(f"Lifespan global cancelled: {ce}.")
  40. lifespan_task_global = 0
  41. class LifespanState(rx.State):
  42. interval: int = 100
  43. @rx.var(cache=False)
  44. def task_global(self) -> int:
  45. return lifespan_task_global
  46. @rx.var(cache=False)
  47. def context_global(self) -> int:
  48. return lifespan_context_global
  49. @rx.event
  50. def tick(self, date):
  51. pass
  52. def index():
  53. return rx.vstack(
  54. rx.text(LifespanState.task_global, id="task_global"),
  55. rx.text(LifespanState.context_global, id="context_global"),
  56. rx.button(
  57. rx.moment(
  58. interval=LifespanState.interval, on_change=LifespanState.tick
  59. ),
  60. on_click=LifespanState.set_interval( # pyright: ignore [reportAttributeAccessIssue]
  61. rx.cond(LifespanState.interval, 0, 100)
  62. ),
  63. id="toggle-tick",
  64. ),
  65. )
  66. from fastapi import FastAPI
  67. app = rx.App(api_transformer=FastAPI() if mount_api_transformer else None)
  68. if mount_cached_fastapi:
  69. assert app.api is not None
  70. app.register_lifespan_task(lifespan_task)
  71. app.register_lifespan_task(lifespan_context, inc=2)
  72. app.add_page(index)
  73. @pytest.fixture(
  74. params=[False, True], ids=["no_api_transformer", "mount_api_transformer"]
  75. )
  76. def mount_api_transformer(request: pytest.FixtureRequest) -> bool:
  77. """Whether to use api_transformer in the app.
  78. Args:
  79. request: pytest fixture request object
  80. Returns:
  81. bool: Whether to use api_transformer
  82. """
  83. return request.param
  84. @pytest.fixture(params=[False, True], ids=["no_fastapi", "mount_cached_fastapi"])
  85. def mount_cached_fastapi(request: pytest.FixtureRequest) -> bool:
  86. """Whether to use cached FastAPI in the app (app.api).
  87. Args:
  88. request: pytest fixture request object
  89. Returns:
  90. Whether to use cached FastAPI
  91. """
  92. return request.param
  93. @pytest.fixture()
  94. def lifespan_app(
  95. tmp_path, mount_api_transformer: bool, mount_cached_fastapi: bool
  96. ) -> Generator[AppHarness, None, None]:
  97. """Start LifespanApp app at tmp_path via AppHarness.
  98. Args:
  99. tmp_path: pytest tmp_path fixture
  100. mount_api_transformer: Whether to mount the API transformer.
  101. mount_cached_fastapi: Whether to mount the cached FastAPI app.
  102. Yields:
  103. running AppHarness instance
  104. """
  105. with AppHarness.create(
  106. root=tmp_path,
  107. app_source=functools.partial(
  108. LifespanApp,
  109. mount_cached_fastapi=mount_cached_fastapi,
  110. mount_api_transformer=mount_api_transformer,
  111. ),
  112. app_name=f"lifespanapp_fastapi{mount_cached_fastapi}_transformer{mount_api_transformer}",
  113. ) as harness:
  114. yield harness
  115. @pytest.mark.asyncio
  116. async def test_lifespan(lifespan_app: AppHarness):
  117. """Test the lifespan integration.
  118. Args:
  119. lifespan_app: harness for LifespanApp app
  120. """
  121. assert lifespan_app.app_module is not None, "app module is not found"
  122. assert lifespan_app.app_instance is not None, "app is not running"
  123. driver = lifespan_app.frontend()
  124. ss = SessionStorage(driver)
  125. assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found"
  126. context_global = driver.find_element(By.ID, "context_global")
  127. task_global = driver.find_element(By.ID, "task_global")
  128. assert lifespan_app.poll_for_content(context_global, exp_not_equal="0") == "2"
  129. assert lifespan_app.app_module.lifespan_context_global == 2
  130. original_task_global_text = task_global.text
  131. original_task_global_value = int(original_task_global_text)
  132. lifespan_app.poll_for_content(task_global, exp_not_equal=original_task_global_text)
  133. driver.find_element(By.ID, "toggle-tick").click() # avoid teardown errors
  134. assert lifespan_app.app_module.lifespan_task_global > original_task_global_value
  135. assert int(task_global.text) > original_task_global_value
  136. # Kill the backend
  137. assert lifespan_app.backend is not None
  138. lifespan_app.backend.should_exit = True
  139. if lifespan_app.backend_thread is not None:
  140. lifespan_app.backend_thread.join()
  141. # Check that the lifespan tasks have been cancelled
  142. assert lifespan_app.app_module.lifespan_task_global == 0
  143. assert lifespan_app.app_module.lifespan_context_global == 4