test_connection_banner.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. """Test case for displaying the connection banner when the websocket drops."""
  2. from typing import Generator
  3. import pytest
  4. from selenium.common.exceptions import NoSuchElementException
  5. from selenium.webdriver.common.by import By
  6. from reflex.testing import AppHarness, WebDriver
  7. from .utils import SessionStorage
  8. def ConnectionBanner():
  9. """App with a connection banner."""
  10. import asyncio
  11. import reflex as rx
  12. class State(rx.State):
  13. foo: int = 0
  14. @rx.event
  15. async def delay(self):
  16. await asyncio.sleep(5)
  17. def index():
  18. return rx.vstack(
  19. rx.text("Hello World"),
  20. rx.input(value=State.foo, read_only=True, id="counter"),
  21. rx.button(
  22. "Increment",
  23. id="increment",
  24. on_click=State.set_foo(State.foo + 1), # type: ignore
  25. ),
  26. rx.button("Delay", id="delay", on_click=State.delay),
  27. )
  28. app = rx.App(_state=rx.State)
  29. app.add_page(index)
  30. @pytest.fixture()
  31. def connection_banner(tmp_path) -> Generator[AppHarness, None, None]:
  32. """Start ConnectionBanner app at tmp_path via AppHarness.
  33. Args:
  34. tmp_path: pytest tmp_path fixture
  35. Yields:
  36. running AppHarness instance
  37. """
  38. with AppHarness.create(
  39. root=tmp_path,
  40. app_source=ConnectionBanner,
  41. ) as harness:
  42. yield harness
  43. CONNECTION_ERROR_XPATH = "//*[ contains(text(), 'Cannot connect to server') ]"
  44. def has_error_modal(driver: WebDriver) -> bool:
  45. """Check if the connection error modal is displayed.
  46. Args:
  47. driver: Selenium webdriver instance.
  48. Returns:
  49. True if the modal is displayed, False otherwise.
  50. """
  51. try:
  52. driver.find_element(By.XPATH, CONNECTION_ERROR_XPATH)
  53. except NoSuchElementException:
  54. return False
  55. else:
  56. return True
  57. @pytest.mark.asyncio
  58. async def test_connection_banner(connection_banner: AppHarness):
  59. """Test that the connection banner is displayed when the websocket drops.
  60. Args:
  61. connection_banner: AppHarness instance.
  62. """
  63. assert connection_banner.app_instance is not None
  64. assert connection_banner.backend is not None
  65. driver = connection_banner.frontend()
  66. ss = SessionStorage(driver)
  67. assert connection_banner._poll_for(
  68. lambda: ss.get("token") is not None
  69. ), "token not found"
  70. assert connection_banner._poll_for(lambda: not has_error_modal(driver))
  71. delay_button = driver.find_element(By.ID, "delay")
  72. increment_button = driver.find_element(By.ID, "increment")
  73. counter_element = driver.find_element(By.ID, "counter")
  74. # Increment the counter
  75. increment_button.click()
  76. assert connection_banner.poll_for_value(counter_element, exp_not_equal="0") == "1"
  77. # Start an long event before killing the backend, to mark event_processing=true
  78. delay_button.click()
  79. # Get the backend port
  80. backend_port = connection_banner._poll_for_servers().getsockname()[1]
  81. # Kill the backend
  82. connection_banner.backend.should_exit = True
  83. if connection_banner.backend_thread is not None:
  84. connection_banner.backend_thread.join()
  85. # Error modal should now be displayed
  86. assert connection_banner._poll_for(lambda: has_error_modal(driver))
  87. # Increment the counter with backend down
  88. increment_button.click()
  89. assert connection_banner.poll_for_value(counter_element, exp_not_equal="0") == "1"
  90. # Bring the backend back up
  91. connection_banner._start_backend(port=backend_port)
  92. # Create a new StateManager to avoid async loop affinity issues w/ redis
  93. await connection_banner._reset_backend_state_manager()
  94. # Banner should be gone now
  95. assert connection_banner._poll_for(lambda: not has_error_modal(driver))
  96. # Count should have incremented after coming back up
  97. assert connection_banner.poll_for_value(counter_element, exp_not_equal="1") == "2"