test_connection_banner.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. """Test case for displaying the connection banner when the websocket drops."""
  2. import functools
  3. from typing import Generator
  4. import pytest
  5. from selenium.common.exceptions import NoSuchElementException
  6. from selenium.webdriver.common.by import By
  7. from reflex.testing import AppHarness, WebDriver
  8. from .utils import SessionStorage
  9. def ConnectionBanner(is_reflex_cloud: bool = False):
  10. """App with a connection banner.
  11. Args:
  12. is_reflex_cloud: The value for config.is_reflex_cloud.
  13. """
  14. import asyncio
  15. import reflex as rx
  16. # Simulate reflex cloud deploy
  17. rx.config.get_config().is_reflex_cloud = is_reflex_cloud
  18. class State(rx.State):
  19. foo: int = 0
  20. @rx.event
  21. async def delay(self):
  22. await asyncio.sleep(5)
  23. def index():
  24. return rx.vstack(
  25. rx.text("Hello World"),
  26. rx.input(value=State.foo, read_only=True, id="counter"),
  27. rx.button(
  28. "Increment",
  29. id="increment",
  30. on_click=State.set_foo(State.foo + 1), # pyright: ignore [reportAttributeAccessIssue]
  31. ),
  32. rx.button("Delay", id="delay", on_click=State.delay),
  33. )
  34. app = rx.App(_state=rx.State)
  35. app.add_page(index)
  36. @pytest.fixture(
  37. params=[False, True], ids=["reflex_cloud_disabled", "reflex_cloud_enabled"]
  38. )
  39. def simulate_is_reflex_cloud(request) -> bool:
  40. """Fixture to simulate reflex cloud deployment.
  41. Args:
  42. request: pytest request fixture.
  43. Returns:
  44. True if reflex cloud is enabled, False otherwise.
  45. """
  46. return request.param
  47. @pytest.fixture()
  48. def connection_banner(
  49. tmp_path,
  50. simulate_is_reflex_cloud: bool,
  51. ) -> Generator[AppHarness, None, None]:
  52. """Start ConnectionBanner app at tmp_path via AppHarness.
  53. Args:
  54. tmp_path: pytest tmp_path fixture
  55. simulate_is_reflex_cloud: Whether is_reflex_cloud is set for the app.
  56. Yields:
  57. running AppHarness instance
  58. """
  59. with AppHarness.create(
  60. root=tmp_path,
  61. app_source=functools.partial(
  62. ConnectionBanner, is_reflex_cloud=simulate_is_reflex_cloud
  63. ),
  64. app_name="connection_banner_reflex_cloud"
  65. if simulate_is_reflex_cloud
  66. else "connection_banner",
  67. ) as harness:
  68. yield harness
  69. CONNECTION_ERROR_XPATH = "//*[ contains(text(), 'Cannot connect to server') ]"
  70. def has_error_modal(driver: WebDriver) -> bool:
  71. """Check if the connection error modal is displayed.
  72. Args:
  73. driver: Selenium webdriver instance.
  74. Returns:
  75. True if the modal is displayed, False otherwise.
  76. """
  77. try:
  78. driver.find_element(By.XPATH, CONNECTION_ERROR_XPATH)
  79. except NoSuchElementException:
  80. return False
  81. else:
  82. return True
  83. def has_cloud_banner(driver: WebDriver) -> bool:
  84. """Check if the cloud banner is displayed.
  85. Args:
  86. driver: Selenium webdriver instance.
  87. Returns:
  88. True if the banner is displayed, False otherwise.
  89. """
  90. try:
  91. driver.find_element(
  92. By.XPATH, "//*[ contains(text(), 'You ran out of compute credits.') ]"
  93. )
  94. except NoSuchElementException:
  95. return False
  96. else:
  97. return True
  98. def _assert_token(connection_banner, driver):
  99. """Poll for backend to be up.
  100. Args:
  101. connection_banner: AppHarness instance.
  102. driver: Selenium webdriver instance.
  103. """
  104. ss = SessionStorage(driver)
  105. assert connection_banner._poll_for(lambda: ss.get("token") is not None), (
  106. "token not found"
  107. )
  108. @pytest.mark.asyncio
  109. async def test_connection_banner(connection_banner: AppHarness):
  110. """Test that the connection banner is displayed when the websocket drops.
  111. Args:
  112. connection_banner: AppHarness instance.
  113. """
  114. assert connection_banner.app_instance is not None
  115. assert connection_banner.backend is not None
  116. driver = connection_banner.frontend()
  117. _assert_token(connection_banner, driver)
  118. assert connection_banner._poll_for(lambda: not has_error_modal(driver))
  119. delay_button = driver.find_element(By.ID, "delay")
  120. increment_button = driver.find_element(By.ID, "increment")
  121. counter_element = driver.find_element(By.ID, "counter")
  122. # Increment the counter
  123. increment_button.click()
  124. assert connection_banner.poll_for_value(counter_element, exp_not_equal="0") == "1"
  125. # Start an long event before killing the backend, to mark event_processing=true
  126. delay_button.click()
  127. # Get the backend port
  128. backend_port = connection_banner._poll_for_servers().getsockname()[1]
  129. # Kill the backend
  130. connection_banner.backend.should_exit = True
  131. if connection_banner.backend_thread is not None:
  132. connection_banner.backend_thread.join()
  133. # Error modal should now be displayed
  134. assert connection_banner._poll_for(lambda: has_error_modal(driver))
  135. # Increment the counter with backend down
  136. increment_button.click()
  137. assert connection_banner.poll_for_value(counter_element, exp_not_equal="0") == "1"
  138. # Bring the backend back up
  139. connection_banner._start_backend(port=backend_port)
  140. # Create a new StateManager to avoid async loop affinity issues w/ redis
  141. await connection_banner._reset_backend_state_manager()
  142. # Banner should be gone now
  143. assert connection_banner._poll_for(lambda: not has_error_modal(driver))
  144. # Count should have incremented after coming back up
  145. assert connection_banner.poll_for_value(counter_element, exp_not_equal="1") == "2"
  146. @pytest.mark.asyncio
  147. async def test_cloud_banner(
  148. connection_banner: AppHarness, simulate_is_reflex_cloud: bool
  149. ):
  150. """Test that the connection banner is displayed when the websocket drops.
  151. Args:
  152. connection_banner: AppHarness instance.
  153. simulate_is_reflex_cloud: Whether is_reflex_cloud is set for the app.
  154. """
  155. assert connection_banner.app_instance is not None
  156. assert connection_banner.backend is not None
  157. driver = connection_banner.frontend()
  158. driver.add_cookie({"name": "backend-enabled", "value": "truly"})
  159. driver.refresh()
  160. _assert_token(connection_banner, driver)
  161. assert connection_banner._poll_for(lambda: not has_cloud_banner(driver))
  162. driver.add_cookie({"name": "backend-enabled", "value": "false"})
  163. driver.refresh()
  164. if simulate_is_reflex_cloud:
  165. assert connection_banner._poll_for(lambda: has_cloud_banner(driver))
  166. else:
  167. _assert_token(connection_banner, driver)
  168. assert connection_banner._poll_for(lambda: not has_cloud_banner(driver))
  169. driver.delete_cookie("backend-enabled")
  170. driver.refresh()
  171. _assert_token(connection_banner, driver)
  172. assert connection_banner._poll_for(lambda: not has_cloud_banner(driver))