test_connection_banner.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. """Test case for displaying the connection banner when the websocket drops."""
  2. import functools
  3. from collections.abc import Generator
  4. import pytest
  5. from selenium.common.exceptions import NoSuchElementException
  6. from selenium.webdriver.common.by import By
  7. from reflex import constants
  8. from reflex.config import environment
  9. from reflex.testing import AppHarness, WebDriver
  10. from .utils import SessionStorage
  11. def ConnectionBanner():
  12. """App with a connection banner."""
  13. import asyncio
  14. import reflex as rx
  15. class State(rx.State):
  16. foo: int = 0
  17. @rx.event
  18. async def delay(self):
  19. await asyncio.sleep(5)
  20. def index():
  21. return rx.vstack(
  22. rx.text("Hello World"),
  23. rx.input(value=State.foo, read_only=True, id="counter"),
  24. rx.button(
  25. "Increment",
  26. id="increment",
  27. on_click=State.set_foo(State.foo + 1), # pyright: ignore [reportAttributeAccessIssue]
  28. ),
  29. rx.button("Delay", id="delay", on_click=State.delay),
  30. )
  31. app = rx.App()
  32. app.add_page(index)
  33. @pytest.fixture(
  34. params=[constants.CompileContext.RUN, constants.CompileContext.DEPLOY],
  35. ids=["compile_context_run", "compile_context_deploy"],
  36. )
  37. def simulate_compile_context(request) -> constants.CompileContext:
  38. """Fixture to simulate reflex cloud deployment.
  39. Args:
  40. request: pytest request fixture.
  41. Returns:
  42. The context to run the app with.
  43. """
  44. return request.param
  45. @pytest.fixture()
  46. def connection_banner(
  47. tmp_path,
  48. simulate_compile_context: constants.CompileContext,
  49. ) -> Generator[AppHarness, None, None]:
  50. """Start ConnectionBanner app at tmp_path via AppHarness.
  51. Args:
  52. tmp_path: pytest tmp_path fixture
  53. simulate_compile_context: Which context to run the app with.
  54. Yields:
  55. running AppHarness instance
  56. """
  57. environment.REFLEX_COMPILE_CONTEXT.set(simulate_compile_context)
  58. with AppHarness.create(
  59. root=tmp_path,
  60. app_source=functools.partial(ConnectionBanner),
  61. app_name=(
  62. "connection_banner_reflex_cloud"
  63. if simulate_compile_context == constants.CompileContext.DEPLOY
  64. else "connection_banner"
  65. ),
  66. ) as harness:
  67. yield harness
  68. CONNECTION_ERROR_XPATH = "//*[ contains(text(), 'Cannot connect to server') ]"
  69. def has_error_modal(driver: WebDriver) -> bool:
  70. """Check if the connection error modal is displayed.
  71. Args:
  72. driver: Selenium webdriver instance.
  73. Returns:
  74. True if the modal is displayed, False otherwise.
  75. """
  76. try:
  77. driver.find_element(By.XPATH, CONNECTION_ERROR_XPATH)
  78. except NoSuchElementException:
  79. return False
  80. else:
  81. return True
  82. def has_cloud_banner(driver: WebDriver) -> bool:
  83. """Check if the cloud banner is displayed.
  84. Args:
  85. driver: Selenium webdriver instance.
  86. Returns:
  87. True if the banner is displayed, False otherwise.
  88. """
  89. try:
  90. driver.find_element(By.XPATH, "//*[ contains(text(), 'This app is paused') ]")
  91. except NoSuchElementException:
  92. return False
  93. else:
  94. return True
  95. def _assert_token(connection_banner, driver):
  96. """Poll for backend to be up.
  97. Args:
  98. connection_banner: AppHarness instance.
  99. driver: Selenium webdriver instance.
  100. """
  101. ss = SessionStorage(driver)
  102. assert connection_banner._poll_for(lambda: ss.get("token") is not None), (
  103. "token not found"
  104. )
  105. @pytest.mark.asyncio
  106. async def test_connection_banner(connection_banner: AppHarness):
  107. """Test that the connection banner is displayed when the websocket drops.
  108. Args:
  109. connection_banner: AppHarness instance.
  110. """
  111. assert connection_banner.app_instance is not None
  112. assert connection_banner.backend is not None
  113. driver = connection_banner.frontend()
  114. _assert_token(connection_banner, driver)
  115. assert connection_banner._poll_for(lambda: not has_error_modal(driver))
  116. delay_button = driver.find_element(By.ID, "delay")
  117. increment_button = driver.find_element(By.ID, "increment")
  118. counter_element = driver.find_element(By.ID, "counter")
  119. # Increment the counter
  120. increment_button.click()
  121. assert connection_banner.poll_for_value(counter_element, exp_not_equal="0") == "1"
  122. # Start an long event before killing the backend, to mark event_processing=true
  123. delay_button.click()
  124. # Get the backend port
  125. backend_port = connection_banner._poll_for_servers().getsockname()[1]
  126. # Kill the backend
  127. connection_banner.backend.should_exit = True
  128. if connection_banner.backend_thread is not None:
  129. connection_banner.backend_thread.join()
  130. # Error modal should now be displayed
  131. assert connection_banner._poll_for(lambda: has_error_modal(driver))
  132. # Increment the counter with backend down
  133. increment_button.click()
  134. assert connection_banner.poll_for_value(counter_element, exp_not_equal="0") == "1"
  135. # Bring the backend back up
  136. connection_banner._start_backend(port=backend_port)
  137. # Create a new StateManager to avoid async loop affinity issues w/ redis
  138. await connection_banner._reset_backend_state_manager()
  139. # Banner should be gone now
  140. assert connection_banner._poll_for(lambda: not has_error_modal(driver))
  141. # Count should have incremented after coming back up
  142. assert connection_banner.poll_for_value(counter_element, exp_not_equal="1") == "2"
  143. @pytest.mark.asyncio
  144. async def test_cloud_banner(
  145. connection_banner: AppHarness, simulate_compile_context: constants.CompileContext
  146. ):
  147. """Test that the connection banner is displayed when the websocket drops.
  148. Args:
  149. connection_banner: AppHarness instance.
  150. simulate_compile_context: Which context to set for the app.
  151. """
  152. assert connection_banner.app_instance is not None
  153. assert connection_banner.backend is not None
  154. driver = connection_banner.frontend()
  155. driver.add_cookie({"name": "backend-enabled", "value": "truly"})
  156. driver.refresh()
  157. _assert_token(connection_banner, driver)
  158. assert connection_banner._poll_for(lambda: not has_cloud_banner(driver))
  159. driver.add_cookie({"name": "backend-enabled", "value": "false"})
  160. driver.refresh()
  161. if simulate_compile_context == constants.CompileContext.DEPLOY:
  162. assert connection_banner._poll_for(lambda: has_cloud_banner(driver))
  163. else:
  164. _assert_token(connection_banner, driver)
  165. assert connection_banner._poll_for(lambda: not has_cloud_banner(driver))
  166. driver.delete_cookie("backend-enabled")
  167. driver.refresh()
  168. _assert_token(connection_banner, driver)
  169. assert connection_banner._poll_for(lambda: not has_cloud_banner(driver))