test_app.py 37 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247
  1. from __future__ import annotations
  2. import io
  3. import os.path
  4. import sys
  5. import unittest.mock
  6. import uuid
  7. from pathlib import Path
  8. from typing import Generator, List, Tuple, Type
  9. if sys.version_info.major >= 3 and sys.version_info.minor > 7:
  10. from unittest.mock import AsyncMock # type: ignore
  11. else:
  12. # python 3.7 doesn't ship with unittest.mock
  13. from asynctest import CoroutineMock as AsyncMock
  14. import pytest
  15. import sqlmodel
  16. from fastapi import UploadFile
  17. from starlette_admin.auth import AuthProvider
  18. from starlette_admin.contrib.sqla.admin import Admin
  19. from starlette_admin.contrib.sqla.view import ModelView
  20. import reflex.components.radix.themes as rdxt
  21. from reflex import AdminDash, constants
  22. from reflex.app import (
  23. App,
  24. ComponentCallable,
  25. DefaultState,
  26. default_overlay_component,
  27. process,
  28. upload,
  29. )
  30. from reflex.components import Box, Component, Cond, Fragment, Text
  31. from reflex.event import Event, get_hydrate_event
  32. from reflex.middleware import HydrateMiddleware
  33. from reflex.model import Model
  34. from reflex.state import RouterData, State, StateManagerRedis, StateUpdate
  35. from reflex.style import Style
  36. from reflex.utils import format
  37. from reflex.vars import ComputedVar
  38. from .conftest import chdir
  39. from .states import (
  40. ChildFileUploadState,
  41. FileStateBase1,
  42. FileUploadState,
  43. GenState,
  44. GrandChildFileUploadState,
  45. )
  46. @pytest.fixture
  47. def index_page():
  48. """An index page.
  49. Returns:
  50. The index page.
  51. """
  52. def index():
  53. return Box.create("Index")
  54. return index
  55. @pytest.fixture
  56. def about_page():
  57. """An about page.
  58. Returns:
  59. The about page.
  60. """
  61. def about():
  62. return Box.create("About")
  63. return about
  64. class ATestState(State):
  65. """A simple state for testing."""
  66. var: int
  67. @pytest.fixture()
  68. def test_state() -> Type[State]:
  69. """A default state.
  70. Returns:
  71. A default state.
  72. """
  73. return ATestState
  74. @pytest.fixture()
  75. def redundant_test_state() -> Type[State]:
  76. """A default state.
  77. Returns:
  78. A default state.
  79. """
  80. class RedundantTestState(State):
  81. var: int
  82. return RedundantTestState
  83. @pytest.fixture(scope="session")
  84. def test_model() -> Type[Model]:
  85. """A default model.
  86. Returns:
  87. A default model.
  88. """
  89. class TestModel(Model, table=True): # type: ignore
  90. pass
  91. return TestModel
  92. @pytest.fixture(scope="session")
  93. def test_model_auth() -> Type[Model]:
  94. """A default model.
  95. Returns:
  96. A default model.
  97. """
  98. class TestModelAuth(Model, table=True): # type: ignore
  99. """A test model with auth."""
  100. pass
  101. return TestModelAuth
  102. @pytest.fixture()
  103. def test_get_engine():
  104. """A default database engine.
  105. Returns:
  106. A default database engine.
  107. """
  108. enable_admin = True
  109. url = "sqlite:///test.db"
  110. return sqlmodel.create_engine(
  111. url,
  112. echo=False,
  113. connect_args={"check_same_thread": False} if enable_admin else {},
  114. )
  115. @pytest.fixture()
  116. def test_custom_auth_admin() -> Type[AuthProvider]:
  117. """A default auth provider.
  118. Returns:
  119. A default default auth provider.
  120. """
  121. class TestAuthProvider(AuthProvider):
  122. """A test auth provider."""
  123. login_path: str = "/login"
  124. logout_path: str = "/logout"
  125. def login(self):
  126. """Login."""
  127. pass
  128. def is_authenticated(self):
  129. """Is authenticated."""
  130. pass
  131. def get_admin_user(self):
  132. """Get admin user."""
  133. pass
  134. def logout(self):
  135. """Logout."""
  136. pass
  137. return TestAuthProvider
  138. def test_default_app(app: App):
  139. """Test creating an app with no args.
  140. Args:
  141. app: The app to test.
  142. """
  143. assert app.state() == DefaultState()
  144. assert app.middleware == [HydrateMiddleware()]
  145. assert app.style == Style()
  146. assert app.admin_dash is None
  147. def test_multiple_states_error(monkeypatch, test_state, redundant_test_state):
  148. """Test that an error is thrown when multiple classes subclass rx.State.
  149. Args:
  150. monkeypatch: Pytest monkeypatch object.
  151. test_state: A test state subclassing rx.State.
  152. redundant_test_state: Another test state subclassing rx.State.
  153. """
  154. monkeypatch.delenv(constants.PYTEST_CURRENT_TEST)
  155. with pytest.raises(ValueError):
  156. App()
  157. def test_add_page_default_route(app: App, index_page, about_page):
  158. """Test adding a page to an app.
  159. Args:
  160. app: The app to test.
  161. index_page: The index page.
  162. about_page: The about page.
  163. """
  164. assert app.pages == {}
  165. app.add_page(index_page)
  166. assert set(app.pages.keys()) == {"index"}
  167. app.add_page(about_page)
  168. assert set(app.pages.keys()) == {"index", "about"}
  169. def test_add_page_set_route(app: App, index_page, windows_platform: bool):
  170. """Test adding a page to an app.
  171. Args:
  172. app: The app to test.
  173. index_page: The index page.
  174. windows_platform: Whether the system is windows.
  175. """
  176. route = "test" if windows_platform else "/test"
  177. assert app.pages == {}
  178. app.add_page(index_page, route=route)
  179. assert set(app.pages.keys()) == {"test"}
  180. def test_add_page_set_route_dynamic(app: App, index_page, windows_platform: bool):
  181. """Test adding a page with dynamic route variable to an app.
  182. Args:
  183. app: The app to test.
  184. index_page: The index page.
  185. windows_platform: Whether the system is windows.
  186. """
  187. route = "/test/[dynamic]"
  188. if windows_platform:
  189. route.lstrip("/").replace("/", "\\")
  190. assert app.pages == {}
  191. app.add_page(index_page, route=route)
  192. assert set(app.pages.keys()) == {"test/[dynamic]"}
  193. assert "dynamic" in app.state.computed_vars
  194. assert app.state.computed_vars["dynamic"]._deps(objclass=DefaultState) == {
  195. constants.ROUTER
  196. }
  197. assert constants.ROUTER in app.state().computed_var_dependencies
  198. def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool):
  199. """Test adding a page to an app.
  200. Args:
  201. app: The app to test.
  202. index_page: The index page.
  203. windows_platform: Whether the system is windows.
  204. """
  205. route = "test\\nested" if windows_platform else "/test/nested"
  206. assert app.pages == {}
  207. app.add_page(index_page, route=route)
  208. assert set(app.pages.keys()) == {route.strip(os.path.sep)}
  209. def test_initialize_with_admin_dashboard(test_model):
  210. """Test setting the admin dashboard of an app.
  211. Args:
  212. test_model: The default model.
  213. """
  214. app = App(admin_dash=AdminDash(models=[test_model]))
  215. assert app.admin_dash is not None
  216. assert len(app.admin_dash.models) > 0
  217. assert app.admin_dash.models[0] == test_model
  218. def test_initialize_with_custom_admin_dashboard(
  219. test_get_engine,
  220. test_custom_auth_admin,
  221. test_model_auth,
  222. ):
  223. """Test setting the custom admin dashboard of an app.
  224. Args:
  225. test_get_engine: The default database engine.
  226. test_model_auth: The default model for an auth admin dashboard.
  227. test_custom_auth_admin: The custom auth provider.
  228. """
  229. custom_admin = Admin(engine=test_get_engine, auth_provider=test_custom_auth_admin)
  230. app = App(admin_dash=AdminDash(models=[test_model_auth], admin=custom_admin))
  231. assert app.admin_dash is not None
  232. assert app.admin_dash.admin is not None
  233. assert len(app.admin_dash.models) > 0
  234. assert app.admin_dash.models[0] == test_model_auth
  235. assert app.admin_dash.admin.auth_provider == test_custom_auth_admin
  236. def test_initialize_admin_dashboard_with_view_overrides(test_model):
  237. """Test setting the admin dashboard of an app with view class overriden.
  238. Args:
  239. test_model: The default model.
  240. """
  241. class TestModelView(ModelView):
  242. pass
  243. app = App(
  244. admin_dash=AdminDash(
  245. models=[test_model], view_overrides={test_model: TestModelView}
  246. )
  247. )
  248. assert app.admin_dash is not None
  249. assert app.admin_dash.models == [test_model]
  250. assert app.admin_dash.view_overrides[test_model] == TestModelView
  251. @pytest.mark.asyncio
  252. async def test_initialize_with_state(test_state: Type[ATestState], token: str):
  253. """Test setting the state of an app.
  254. Args:
  255. test_state: The default state.
  256. token: a Token.
  257. """
  258. app = App(state=test_state)
  259. assert app.state == test_state
  260. # Get a state for a given token.
  261. state = await app.state_manager.get_state(token)
  262. assert isinstance(state, test_state)
  263. assert state.var == 0 # type: ignore
  264. if isinstance(app.state_manager, StateManagerRedis):
  265. await app.state_manager.redis.close()
  266. @pytest.mark.asyncio
  267. async def test_set_and_get_state(test_state):
  268. """Test setting and getting the state of an app with different tokens.
  269. Args:
  270. test_state: The default state.
  271. """
  272. app = App(state=test_state)
  273. # Create two tokens.
  274. token1 = str(uuid.uuid4())
  275. token2 = str(uuid.uuid4())
  276. # Get the default state for each token.
  277. state1 = await app.state_manager.get_state(token1)
  278. state2 = await app.state_manager.get_state(token2)
  279. assert state1.var == 0 # type: ignore
  280. assert state2.var == 0 # type: ignore
  281. # Set the vars to different values.
  282. state1.var = 1
  283. state2.var = 2
  284. await app.state_manager.set_state(token1, state1)
  285. await app.state_manager.set_state(token2, state2)
  286. # Get the states again and check the values.
  287. state1 = await app.state_manager.get_state(token1)
  288. state2 = await app.state_manager.get_state(token2)
  289. assert state1.var == 1 # type: ignore
  290. assert state2.var == 2 # type: ignore
  291. if isinstance(app.state_manager, StateManagerRedis):
  292. await app.state_manager.redis.close()
  293. @pytest.mark.asyncio
  294. async def test_dynamic_var_event(test_state: Type[ATestState], token: str):
  295. """Test that the default handler of a dynamic generated var
  296. works as expected.
  297. Args:
  298. test_state: State Fixture.
  299. token: a Token.
  300. """
  301. state = test_state() # type: ignore
  302. state.add_var("int_val", int, 0)
  303. result = await state._process(
  304. Event(
  305. token=token,
  306. name=f"{test_state.get_name()}.set_int_val",
  307. router_data={"pathname": "/", "query": {}},
  308. payload={"value": 50},
  309. )
  310. ).__anext__()
  311. assert result.delta == {test_state.get_name(): {"int_val": 50}}
  312. @pytest.mark.asyncio
  313. @pytest.mark.parametrize(
  314. "event_tuples",
  315. [
  316. pytest.param(
  317. [
  318. (
  319. "list_mutation_test_state.make_friend",
  320. {
  321. "list_mutation_test_state": {
  322. "plain_friends": ["Tommy", "another-fd"]
  323. }
  324. },
  325. ),
  326. (
  327. "list_mutation_test_state.change_first_friend",
  328. {
  329. "list_mutation_test_state": {
  330. "plain_friends": ["Jenny", "another-fd"]
  331. }
  332. },
  333. ),
  334. ],
  335. id="append then __setitem__",
  336. ),
  337. pytest.param(
  338. [
  339. (
  340. "list_mutation_test_state.unfriend_first_friend",
  341. {"list_mutation_test_state": {"plain_friends": []}},
  342. ),
  343. (
  344. "list_mutation_test_state.make_friend",
  345. {"list_mutation_test_state": {"plain_friends": ["another-fd"]}},
  346. ),
  347. ],
  348. id="delitem then append",
  349. ),
  350. pytest.param(
  351. [
  352. (
  353. "list_mutation_test_state.make_friends_with_colleagues",
  354. {
  355. "list_mutation_test_state": {
  356. "plain_friends": ["Tommy", "Peter", "Jimmy"]
  357. }
  358. },
  359. ),
  360. (
  361. "list_mutation_test_state.remove_tommy",
  362. {"list_mutation_test_state": {"plain_friends": ["Peter", "Jimmy"]}},
  363. ),
  364. (
  365. "list_mutation_test_state.remove_last_friend",
  366. {"list_mutation_test_state": {"plain_friends": ["Peter"]}},
  367. ),
  368. (
  369. "list_mutation_test_state.unfriend_all_friends",
  370. {"list_mutation_test_state": {"plain_friends": []}},
  371. ),
  372. ],
  373. id="extend, remove, pop, clear",
  374. ),
  375. pytest.param(
  376. [
  377. (
  378. "list_mutation_test_state.add_jimmy_to_second_group",
  379. {
  380. "list_mutation_test_state": {
  381. "friends_in_nested_list": [["Tommy"], ["Jenny", "Jimmy"]]
  382. }
  383. },
  384. ),
  385. (
  386. "list_mutation_test_state.remove_first_person_from_first_group",
  387. {
  388. "list_mutation_test_state": {
  389. "friends_in_nested_list": [[], ["Jenny", "Jimmy"]]
  390. }
  391. },
  392. ),
  393. (
  394. "list_mutation_test_state.remove_first_group",
  395. {
  396. "list_mutation_test_state": {
  397. "friends_in_nested_list": [["Jenny", "Jimmy"]]
  398. }
  399. },
  400. ),
  401. ],
  402. id="nested list",
  403. ),
  404. pytest.param(
  405. [
  406. (
  407. "list_mutation_test_state.add_jimmy_to_tommy_friends",
  408. {
  409. "list_mutation_test_state": {
  410. "friends_in_dict": {"Tommy": ["Jenny", "Jimmy"]}
  411. }
  412. },
  413. ),
  414. (
  415. "list_mutation_test_state.remove_jenny_from_tommy",
  416. {
  417. "list_mutation_test_state": {
  418. "friends_in_dict": {"Tommy": ["Jimmy"]}
  419. }
  420. },
  421. ),
  422. (
  423. "list_mutation_test_state.tommy_has_no_fds",
  424. {"list_mutation_test_state": {"friends_in_dict": {"Tommy": []}}},
  425. ),
  426. ],
  427. id="list in dict",
  428. ),
  429. ],
  430. )
  431. async def test_list_mutation_detection__plain_list(
  432. event_tuples: List[Tuple[str, List[str]]],
  433. list_mutation_state: State,
  434. token: str,
  435. ):
  436. """Test list mutation detection
  437. when reassignment is not explicitly included in the logic.
  438. Args:
  439. event_tuples: From parametrization.
  440. list_mutation_state: A state with list mutation features.
  441. token: a Token.
  442. """
  443. for event_name, expected_delta in event_tuples:
  444. result = await list_mutation_state._process(
  445. Event(
  446. token=token,
  447. name=event_name,
  448. router_data={"pathname": "/", "query": {}},
  449. payload={},
  450. )
  451. ).__anext__()
  452. assert result.delta == expected_delta
  453. @pytest.mark.asyncio
  454. @pytest.mark.parametrize(
  455. "event_tuples",
  456. [
  457. pytest.param(
  458. [
  459. (
  460. "dict_mutation_test_state.add_age",
  461. {
  462. "dict_mutation_test_state": {
  463. "details": {"name": "Tommy", "age": 20}
  464. }
  465. },
  466. ),
  467. (
  468. "dict_mutation_test_state.change_name",
  469. {
  470. "dict_mutation_test_state": {
  471. "details": {"name": "Jenny", "age": 20}
  472. }
  473. },
  474. ),
  475. (
  476. "dict_mutation_test_state.remove_last_detail",
  477. {"dict_mutation_test_state": {"details": {"name": "Jenny"}}},
  478. ),
  479. ],
  480. id="update then __setitem__",
  481. ),
  482. pytest.param(
  483. [
  484. (
  485. "dict_mutation_test_state.clear_details",
  486. {"dict_mutation_test_state": {"details": {}}},
  487. ),
  488. (
  489. "dict_mutation_test_state.add_age",
  490. {"dict_mutation_test_state": {"details": {"age": 20}}},
  491. ),
  492. ],
  493. id="delitem then update",
  494. ),
  495. pytest.param(
  496. [
  497. (
  498. "dict_mutation_test_state.add_age",
  499. {
  500. "dict_mutation_test_state": {
  501. "details": {"name": "Tommy", "age": 20}
  502. }
  503. },
  504. ),
  505. (
  506. "dict_mutation_test_state.remove_name",
  507. {"dict_mutation_test_state": {"details": {"age": 20}}},
  508. ),
  509. (
  510. "dict_mutation_test_state.pop_out_age",
  511. {"dict_mutation_test_state": {"details": {}}},
  512. ),
  513. ],
  514. id="add, remove, pop",
  515. ),
  516. pytest.param(
  517. [
  518. (
  519. "dict_mutation_test_state.remove_home_address",
  520. {
  521. "dict_mutation_test_state": {
  522. "address": [{}, {"work": "work address"}]
  523. }
  524. },
  525. ),
  526. (
  527. "dict_mutation_test_state.add_street_to_home_address",
  528. {
  529. "dict_mutation_test_state": {
  530. "address": [
  531. {"street": "street address"},
  532. {"work": "work address"},
  533. ]
  534. }
  535. },
  536. ),
  537. ],
  538. id="dict in list",
  539. ),
  540. pytest.param(
  541. [
  542. (
  543. "dict_mutation_test_state.change_friend_name",
  544. {
  545. "dict_mutation_test_state": {
  546. "friend_in_nested_dict": {
  547. "name": "Nikhil",
  548. "friend": {"name": "Tommy"},
  549. }
  550. }
  551. },
  552. ),
  553. (
  554. "dict_mutation_test_state.add_friend_age",
  555. {
  556. "dict_mutation_test_state": {
  557. "friend_in_nested_dict": {
  558. "name": "Nikhil",
  559. "friend": {"name": "Tommy", "age": 30},
  560. }
  561. }
  562. },
  563. ),
  564. (
  565. "dict_mutation_test_state.remove_friend",
  566. {
  567. "dict_mutation_test_state": {
  568. "friend_in_nested_dict": {"name": "Nikhil"}
  569. }
  570. },
  571. ),
  572. ],
  573. id="nested dict",
  574. ),
  575. ],
  576. )
  577. async def test_dict_mutation_detection__plain_list(
  578. event_tuples: List[Tuple[str, List[str]]],
  579. dict_mutation_state: State,
  580. token: str,
  581. ):
  582. """Test dict mutation detection
  583. when reassignment is not explicitly included in the logic.
  584. Args:
  585. event_tuples: From parametrization.
  586. dict_mutation_state: A state with dict mutation features.
  587. token: a Token.
  588. """
  589. for event_name, expected_delta in event_tuples:
  590. result = await dict_mutation_state._process(
  591. Event(
  592. token=token,
  593. name=event_name,
  594. router_data={"pathname": "/", "query": {}},
  595. payload={},
  596. )
  597. ).__anext__()
  598. assert result.delta == expected_delta
  599. @pytest.mark.asyncio
  600. @pytest.mark.parametrize(
  601. ("state", "delta"),
  602. [
  603. (
  604. FileUploadState,
  605. {"file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}},
  606. ),
  607. (
  608. ChildFileUploadState,
  609. {
  610. "file_state_base1.child_file_upload_state": {
  611. "img_list": ["image1.jpg", "image2.jpg"]
  612. }
  613. },
  614. ),
  615. (
  616. GrandChildFileUploadState,
  617. {
  618. "file_state_base1.file_state_base2.grand_child_file_upload_state": {
  619. "img_list": ["image1.jpg", "image2.jpg"]
  620. }
  621. },
  622. ),
  623. ],
  624. )
  625. async def test_upload_file(tmp_path, state, delta, token: str):
  626. """Test that file upload works correctly.
  627. Args:
  628. tmp_path: Temporary path.
  629. state: The state class.
  630. delta: Expected delta
  631. token: a Token.
  632. """
  633. state._tmp_path = tmp_path
  634. app = App(state=state if state is FileUploadState else FileStateBase1)
  635. app.event_namespace.emit = AsyncMock() # type: ignore
  636. current_state = await app.state_manager.get_state(token)
  637. data = b"This is binary data"
  638. # Create a binary IO object and write data to it
  639. bio = io.BytesIO()
  640. bio.write(data)
  641. if state is FileUploadState:
  642. handler_prefix = f"{token}:{state.get_name()}"
  643. else:
  644. handler_prefix = f"{token}:{state.get_full_name().partition('.')[2]}"
  645. file1 = UploadFile(
  646. filename=f"{handler_prefix}.multi_handle_upload:True:image1.jpg",
  647. file=bio,
  648. )
  649. file2 = UploadFile(
  650. filename=f"{handler_prefix}.multi_handle_upload:True:image2.jpg",
  651. file=bio,
  652. )
  653. upload_fn = upload(app)
  654. await upload_fn([file1, file2])
  655. state_update = StateUpdate(delta=delta, events=[], final=True)
  656. app.event_namespace.emit.assert_called_with( # type: ignore
  657. "event", state_update.json(), to=current_state.get_sid()
  658. )
  659. current_state = await app.state_manager.get_state(token)
  660. state_dict = current_state.dict()
  661. for substate in state.get_full_name().split(".")[1:]:
  662. state_dict = state_dict[substate]
  663. assert state_dict["img_list"] == [
  664. "image1.jpg",
  665. "image2.jpg",
  666. ]
  667. if isinstance(app.state_manager, StateManagerRedis):
  668. await app.state_manager.redis.close()
  669. @pytest.mark.asyncio
  670. @pytest.mark.parametrize(
  671. "state",
  672. [FileUploadState, ChildFileUploadState, GrandChildFileUploadState],
  673. )
  674. async def test_upload_file_without_annotation(state, tmp_path, token):
  675. """Test that an error is thrown when there's no param annotated with rx.UploadFile or List[UploadFile].
  676. Args:
  677. state: The state class.
  678. tmp_path: Temporary path.
  679. token: a Token.
  680. """
  681. data = b"This is binary data"
  682. # Create a binary IO object and write data to it
  683. bio = io.BytesIO()
  684. bio.write(data)
  685. state._tmp_path = tmp_path
  686. app = App(state=state if state is FileUploadState else FileStateBase1)
  687. if state is FileUploadState:
  688. state_name = state.get_name()
  689. else:
  690. state_name = state.get_full_name().partition(".")[2]
  691. handler_prefix = f"{token}:{state_name}"
  692. file1 = UploadFile(
  693. filename=f"{handler_prefix}.handle_upload2:True:image1.jpg",
  694. file=bio,
  695. )
  696. file2 = UploadFile(
  697. filename=f"{handler_prefix}.handle_upload2:True:image2.jpg",
  698. file=bio,
  699. )
  700. fn = upload(app)
  701. with pytest.raises(ValueError) as err:
  702. await fn([file1, file2])
  703. assert (
  704. err.value.args[0]
  705. == f"`{state_name}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
  706. )
  707. if isinstance(app.state_manager, StateManagerRedis):
  708. await app.state_manager.redis.close()
  709. class DynamicState(State):
  710. """State class for testing dynamic route var.
  711. This is defined at module level because event handlers cannot be addressed
  712. correctly when the class is defined as a local.
  713. There are several counters:
  714. * loaded: counts how many times `on_load` was triggered by the hydrate middleware
  715. * counter: counts how many times `on_counter` was triggered by a non-navigational event
  716. -> these events should NOT trigger reload or recalculation of router_data dependent vars
  717. * side_effect_counter: counts how many times a computed var was
  718. recalculated when the dynamic route var was dirty
  719. """
  720. loaded: int = 0
  721. counter: int = 0
  722. # side_effect_counter: int = 0
  723. def on_load(self):
  724. """Event handler for page on_load, should trigger for all navigation events."""
  725. self.loaded = self.loaded + 1
  726. def on_counter(self):
  727. """Increment the counter var."""
  728. self.counter = self.counter + 1
  729. @ComputedVar
  730. def comp_dynamic(self) -> str:
  731. """A computed var that depends on the dynamic var.
  732. Returns:
  733. same as self.dynamic
  734. """
  735. # self.side_effect_counter = self.side_effect_counter + 1
  736. return self.dynamic
  737. @pytest.mark.asyncio
  738. async def test_dynamic_route_var_route_change_completed_on_load(
  739. index_page,
  740. windows_platform: bool,
  741. token: str,
  742. ):
  743. """Create app with dynamic route var, and simulate navigation.
  744. on_load should fire, allowing any additional vars to be updated before the
  745. initial page hydrate.
  746. Args:
  747. index_page: The index page.
  748. windows_platform: Whether the system is windows.
  749. token: a Token.
  750. """
  751. arg_name = "dynamic"
  752. route = f"/test/[{arg_name}]"
  753. if windows_platform:
  754. route.lstrip("/").replace("/", "\\")
  755. app = App(state=DynamicState)
  756. assert arg_name not in app.state.vars
  757. app.add_page(index_page, route=route, on_load=DynamicState.on_load) # type: ignore
  758. assert arg_name in app.state.vars
  759. assert arg_name in app.state.computed_vars
  760. assert app.state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
  761. constants.ROUTER
  762. }
  763. assert constants.ROUTER in app.state().computed_var_dependencies
  764. sid = "mock_sid"
  765. client_ip = "127.0.0.1"
  766. state = await app.state_manager.get_state(token)
  767. assert state.dynamic == ""
  768. exp_vals = ["foo", "foobar", "baz"]
  769. def _event(name, val, **kwargs):
  770. return Event(
  771. token=kwargs.pop("token", token),
  772. name=name,
  773. router_data=kwargs.pop(
  774. "router_data", {"pathname": route, "query": {arg_name: val}}
  775. ),
  776. payload=kwargs.pop("payload", {}),
  777. **kwargs,
  778. )
  779. def _dynamic_state_event(name, val, **kwargs):
  780. return _event(
  781. name=format.format_event_handler(getattr(DynamicState, name)), # type: ignore
  782. val=val,
  783. **kwargs,
  784. )
  785. prev_exp_val = ""
  786. for exp_index, exp_val in enumerate(exp_vals):
  787. hydrate_event = _event(name=get_hydrate_event(state), val=exp_val)
  788. exp_router_data = {
  789. "headers": {},
  790. "ip": client_ip,
  791. "sid": sid,
  792. "token": token,
  793. **hydrate_event.router_data,
  794. }
  795. exp_router = RouterData(exp_router_data)
  796. process_coro = process(
  797. app,
  798. event=hydrate_event,
  799. sid=sid,
  800. headers={},
  801. client_ip=client_ip,
  802. )
  803. update = await process_coro.__anext__() # type: ignore
  804. # route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
  805. assert update == StateUpdate(
  806. delta={
  807. state.get_name(): {
  808. arg_name: exp_val,
  809. f"comp_{arg_name}": exp_val,
  810. constants.CompileVars.IS_HYDRATED: False,
  811. "loaded": exp_index,
  812. "counter": exp_index,
  813. "router": exp_router,
  814. # "side_effect_counter": exp_index,
  815. }
  816. },
  817. events=[
  818. _dynamic_state_event(
  819. name="on_load",
  820. val=exp_val,
  821. router_data=exp_router_data,
  822. ),
  823. _dynamic_state_event(
  824. name="set_is_hydrated",
  825. payload={"value": True},
  826. val=exp_val,
  827. router_data=exp_router_data,
  828. ),
  829. ],
  830. )
  831. if isinstance(app.state_manager, StateManagerRedis):
  832. # When redis is used, the state is not updated until the processing is complete
  833. state = await app.state_manager.get_state(token)
  834. assert state.dynamic == prev_exp_val
  835. # complete the processing
  836. with pytest.raises(StopAsyncIteration):
  837. await process_coro.__anext__() # type: ignore
  838. # check that router data was written to the state_manager store
  839. state = await app.state_manager.get_state(token)
  840. assert state.dynamic == exp_val
  841. process_coro = process(
  842. app,
  843. event=_dynamic_state_event(name="on_load", val=exp_val),
  844. sid=sid,
  845. headers={},
  846. client_ip=client_ip,
  847. )
  848. on_load_update = await process_coro.__anext__() # type: ignore
  849. assert on_load_update == StateUpdate(
  850. delta={
  851. state.get_name(): {
  852. # These computed vars _shouldn't_ be here, because they didn't change
  853. arg_name: exp_val,
  854. f"comp_{arg_name}": exp_val,
  855. "loaded": exp_index + 1,
  856. },
  857. },
  858. events=[],
  859. )
  860. # complete the processing
  861. with pytest.raises(StopAsyncIteration):
  862. await process_coro.__anext__() # type: ignore
  863. process_coro = process(
  864. app,
  865. event=_dynamic_state_event(
  866. name="set_is_hydrated", payload={"value": True}, val=exp_val
  867. ),
  868. sid=sid,
  869. headers={},
  870. client_ip=client_ip,
  871. )
  872. on_set_is_hydrated_update = await process_coro.__anext__() # type: ignore
  873. assert on_set_is_hydrated_update == StateUpdate(
  874. delta={
  875. state.get_name(): {
  876. # These computed vars _shouldn't_ be here, because they didn't change
  877. arg_name: exp_val,
  878. f"comp_{arg_name}": exp_val,
  879. "is_hydrated": True,
  880. },
  881. },
  882. events=[],
  883. )
  884. # complete the processing
  885. with pytest.raises(StopAsyncIteration):
  886. await process_coro.__anext__() # type: ignore
  887. # a simple state update event should NOT trigger on_load or route var side effects
  888. process_coro = process(
  889. app,
  890. event=_dynamic_state_event(name="on_counter", val=exp_val),
  891. sid=sid,
  892. headers={},
  893. client_ip=client_ip,
  894. )
  895. update = await process_coro.__anext__() # type: ignore
  896. assert update == StateUpdate(
  897. delta={
  898. state.get_name(): {
  899. # These computed vars _shouldn't_ be here, because they didn't change
  900. f"comp_{arg_name}": exp_val,
  901. arg_name: exp_val,
  902. "counter": exp_index + 1,
  903. }
  904. },
  905. events=[],
  906. )
  907. # complete the processing
  908. with pytest.raises(StopAsyncIteration):
  909. await process_coro.__anext__() # type: ignore
  910. prev_exp_val = exp_val
  911. state = await app.state_manager.get_state(token)
  912. assert state.loaded == len(exp_vals)
  913. assert state.counter == len(exp_vals)
  914. # print(f"Expected {exp_vals} rendering side effects, got {state.side_effect_counter}")
  915. # assert state.side_effect_counter == len(exp_vals)
  916. if isinstance(app.state_manager, StateManagerRedis):
  917. await app.state_manager.redis.close()
  918. @pytest.mark.asyncio
  919. async def test_process_events(mocker, token: str):
  920. """Test that an event is processed properly and that it is postprocessed
  921. n+1 times. Also check that the processing flag of the last stateupdate is set to
  922. False.
  923. Args:
  924. mocker: mocker object.
  925. token: a Token.
  926. """
  927. router_data = {
  928. "pathname": "/",
  929. "query": {},
  930. "token": token,
  931. "sid": "mock_sid",
  932. "headers": {},
  933. "ip": "127.0.0.1",
  934. }
  935. app = App(state=GenState)
  936. mocker.patch.object(app, "postprocess", AsyncMock())
  937. event = Event(
  938. token=token, name="gen_state.go", payload={"c": 5}, router_data=router_data
  939. )
  940. async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"): # type: ignore
  941. pass
  942. assert (await app.state_manager.get_state(token)).value == 5
  943. assert app.postprocess.call_count == 6
  944. if isinstance(app.state_manager, StateManagerRedis):
  945. await app.state_manager.redis.close()
  946. @pytest.mark.parametrize(
  947. ("state", "overlay_component", "exp_page_child"),
  948. [
  949. (DefaultState, default_overlay_component, None),
  950. (DefaultState, None, None),
  951. (DefaultState, Text.create("foo"), Text),
  952. (State, default_overlay_component, Fragment),
  953. (State, None, None),
  954. (State, Text.create("foo"), Text),
  955. (State, lambda: Text.create("foo"), Text),
  956. ],
  957. )
  958. def test_overlay_component(
  959. state: State | None,
  960. overlay_component: Component | ComponentCallable | None,
  961. exp_page_child: Type[Component] | None,
  962. ):
  963. """Test that the overlay component is set correctly.
  964. Args:
  965. state: The state class to pass to App.
  966. overlay_component: The overlay_component to pass to App.
  967. exp_page_child: The type of the expected child in the page fragment.
  968. """
  969. app = App(state=state, overlay_component=overlay_component)
  970. if exp_page_child is None:
  971. assert app.overlay_component is None
  972. elif isinstance(exp_page_child, Fragment):
  973. assert app.overlay_component is not None
  974. generated_component = app._generate_component(app.overlay_component) # type: ignore
  975. assert isinstance(generated_component, Fragment)
  976. assert isinstance(
  977. generated_component.children[0],
  978. Cond, # ConnectionModal is a Cond under the hood
  979. )
  980. else:
  981. assert app.overlay_component is not None
  982. assert isinstance(
  983. app._generate_component(app.overlay_component), # type: ignore
  984. exp_page_child,
  985. )
  986. app.add_page(Box.create("Index"), route="/test")
  987. page = app.pages["test"]
  988. if exp_page_child is not None:
  989. assert len(page.children) == 3
  990. children_types = (type(child) for child in page.children)
  991. assert exp_page_child in children_types
  992. else:
  993. assert len(page.children) == 2
  994. @pytest.fixture
  995. def compilable_app(tmp_path) -> Generator[tuple[App, Path], None, None]:
  996. """Fixture for an app that can be compiled.
  997. Args:
  998. tmp_path: Temporary path.
  999. Yields:
  1000. Tuple containing (app instance, Path to ".web" directory)
  1001. The working directory is set to the app dir (parent of .web),
  1002. allowing app.compile() to be called.
  1003. """
  1004. app_path = tmp_path / "app"
  1005. web_dir = app_path / ".web"
  1006. web_dir.mkdir(parents=True)
  1007. (web_dir / "package.json").touch()
  1008. app = App()
  1009. app.get_frontend_packages = unittest.mock.Mock()
  1010. with chdir(app_path):
  1011. yield app, web_dir
  1012. def test_app_wrap_compile_theme(compilable_app):
  1013. """Test that the radix theme component wraps the app.
  1014. Args:
  1015. compilable_app: compilable_app fixture.
  1016. """
  1017. app, web_dir = compilable_app
  1018. app.theme = rdxt.theme(accent_color="plum")
  1019. app.compile()
  1020. app_js_contents = (web_dir / "pages" / "_app.js").read_text()
  1021. app_js_lines = [
  1022. line.strip() for line in app_js_contents.splitlines() if line.strip()
  1023. ]
  1024. assert (
  1025. "function AppWrap({children}) {"
  1026. "return ("
  1027. "<RadixThemesTheme accentColor={`plum`}>"
  1028. "{children}"
  1029. "</RadixThemesTheme>"
  1030. ")"
  1031. "}"
  1032. ) in "".join(app_js_lines)
  1033. def test_app_wrap_priority(compilable_app):
  1034. """Test that the app wrap components are wrapped in the correct order.
  1035. Args:
  1036. compilable_app: compilable_app fixture.
  1037. """
  1038. app, web_dir = compilable_app
  1039. class Fragment1(Component):
  1040. tag = "Fragment1"
  1041. def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]:
  1042. return {(99, "Box"): Box.create()}
  1043. class Fragment2(Component):
  1044. tag = "Fragment2"
  1045. def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]:
  1046. return {(50, "Text"): Text.create()}
  1047. class Fragment3(Component):
  1048. tag = "Fragment3"
  1049. def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]:
  1050. return {(10, "Fragment2"): Fragment2.create()}
  1051. def page():
  1052. return Fragment1.create(Fragment3.create())
  1053. app.add_page(page)
  1054. app.compile()
  1055. app_js_contents = (web_dir / "pages" / "_app.js").read_text()
  1056. app_js_lines = [
  1057. line.strip() for line in app_js_contents.splitlines() if line.strip()
  1058. ]
  1059. assert (
  1060. "function AppWrap({children}) {"
  1061. "return ("
  1062. "<Box>"
  1063. "<ChakraProvider theme={extendTheme(theme)}>"
  1064. "<Global styles={GlobalStyles}/>"
  1065. "<ChakraColorModeProvider>"
  1066. "<Text>"
  1067. "<Fragment2>"
  1068. "{children}"
  1069. "</Fragment2>"
  1070. "</Text>"
  1071. "</ChakraColorModeProvider>"
  1072. "</ChakraProvider>"
  1073. "</Box>"
  1074. ")"
  1075. "}"
  1076. ) in "".join(app_js_lines)