Переглянути джерело

Support callback for rx.call_script (#2045)

Masen Furer 1 рік тому
батько
коміт
f66c6c3361

+ 341 - 0
integration/test_call_script.py

@@ -0,0 +1,341 @@
+"""Integration tests for client side storage."""
+from __future__ import annotations
+
+from typing import Generator
+
+import pytest
+from selenium.webdriver.common.by import By
+from selenium.webdriver.remote.webdriver import WebDriver
+
+from reflex.testing import AppHarness
+
+
+def CallScript():
+    """A test app for browser javascript integration."""
+    import reflex as rx
+
+    inline_scripts = """
+    let inline_counter = 0
+    function inline1() {
+        inline_counter += 1
+        return "inline1"
+    }
+    function inline2() {
+        inline_counter += 1
+        console.log("inline2")
+    }
+    function inline3() {
+        inline_counter += 1
+        return {inline3: 42, a: [1, 2, 3], s: 'js', o: {a: 1, b: 2}}
+    }
+    """
+
+    external_scripts = inline_scripts.replace("inline", "external")
+
+    class CallScriptState(rx.State):
+        results: list[str | dict | list | None] = []
+        inline_counter: int = 0
+        external_counter: int = 0
+
+        def call_script_callback(self, result):
+            self.results.append(result)
+
+        def call_script_callback_other_arg(self, result, other_arg):
+            self.results.append([other_arg, result])
+
+        def call_scripts_inline_yield(self):
+            yield rx.call_script("inline1()")
+            yield rx.call_script("inline2()")
+            yield rx.call_script("inline3()")
+
+        def call_script_inline_return(self):
+            return rx.call_script("inline2()")
+
+        def call_scripts_inline_yield_callback(self):
+            yield rx.call_script(
+                "inline1()", callback=CallScriptState.call_script_callback
+            )
+            yield rx.call_script(
+                "inline2()", callback=CallScriptState.call_script_callback
+            )
+            yield rx.call_script(
+                "inline3()", callback=CallScriptState.call_script_callback
+            )
+
+        def call_script_inline_return_callback(self):
+            return rx.call_script(
+                "inline3()", callback=CallScriptState.call_script_callback
+            )
+
+        def call_script_inline_return_lambda(self):
+            return rx.call_script(
+                "inline2()",
+                callback=lambda result: CallScriptState.call_script_callback_other_arg(  # type: ignore
+                    result, "lambda"
+                ),
+            )
+
+        def get_inline_counter(self):
+            return rx.call_script(
+                "inline_counter",
+                callback=CallScriptState.set_inline_counter,  # type: ignore
+            )
+
+        def call_scripts_external_yield(self):
+            yield rx.call_script("external1()")
+            yield rx.call_script("external2()")
+            yield rx.call_script("external3()")
+
+        def call_script_external_return(self):
+            return rx.call_script("external2()")
+
+        def call_scripts_external_yield_callback(self):
+            yield rx.call_script(
+                "external1()", callback=CallScriptState.call_script_callback
+            )
+            yield rx.call_script(
+                "external2()", callback=CallScriptState.call_script_callback
+            )
+            yield rx.call_script(
+                "external3()", callback=CallScriptState.call_script_callback
+            )
+
+        def call_script_external_return_callback(self):
+            return rx.call_script(
+                "external3()", callback=CallScriptState.call_script_callback
+            )
+
+        def call_script_external_return_lambda(self):
+            return rx.call_script(
+                "external2()",
+                callback=lambda result: CallScriptState.call_script_callback_other_arg(  # type: ignore
+                    result, "lambda"
+                ),
+            )
+
+        def get_external_counter(self):
+            return rx.call_script(
+                "external_counter",
+                callback=CallScriptState.set_external_counter,  # type: ignore
+            )
+
+        def reset_(self):
+            yield rx.call_script("inline_counter = 0; external_counter = 0")
+            self.reset()
+
+    app = rx.App(state=CallScriptState)
+    with open("assets/external.js", "w") as f:
+        f.write(external_scripts)
+
+    @app.add_page
+    def index():
+        return rx.vstack(
+            rx.input(
+                value=CallScriptState.router.session.client_token,
+                is_read_only=True,
+                id="token",
+            ),
+            rx.input(
+                value=CallScriptState.inline_counter.to(str),  # type: ignore
+                id="inline_counter",
+                is_read_only=True,
+            ),
+            rx.input(
+                value=CallScriptState.external_counter.to(str),  # type: ignore
+                id="external_counter",
+                is_read_only=True,
+            ),
+            rx.text_area(
+                value=CallScriptState.results.to_string(),  # type: ignore
+                id="results",
+                is_read_only=True,
+            ),
+            rx.script(inline_scripts),
+            rx.script(src="/external.js"),
+            rx.button(
+                "call_scripts_inline_yield",
+                on_click=CallScriptState.call_scripts_inline_yield,
+                id="inline_yield",
+            ),
+            rx.button(
+                "call_script_inline_return",
+                on_click=CallScriptState.call_script_inline_return,
+                id="inline_return",
+            ),
+            rx.button(
+                "call_scripts_inline_yield_callback",
+                on_click=CallScriptState.call_scripts_inline_yield_callback,
+                id="inline_yield_callback",
+            ),
+            rx.button(
+                "call_script_inline_return_callback",
+                on_click=CallScriptState.call_script_inline_return_callback,
+                id="inline_return_callback",
+            ),
+            rx.button(
+                "call_script_inline_return_lambda",
+                on_click=CallScriptState.call_script_inline_return_lambda,
+                id="inline_return_lambda",
+            ),
+            rx.button(
+                "call_scripts_external_yield",
+                on_click=CallScriptState.call_scripts_external_yield,
+                id="external_yield",
+            ),
+            rx.button(
+                "call_script_external_return",
+                on_click=CallScriptState.call_script_external_return,
+                id="external_return",
+            ),
+            rx.button(
+                "call_scripts_external_yield_callback",
+                on_click=CallScriptState.call_scripts_external_yield_callback,
+                id="external_yield_callback",
+            ),
+            rx.button(
+                "call_script_external_return_callback",
+                on_click=CallScriptState.call_script_external_return_callback,
+                id="external_return_callback",
+            ),
+            rx.button(
+                "call_script_external_return_lambda",
+                on_click=CallScriptState.call_script_external_return_lambda,
+                id="external_return_lambda",
+            ),
+            rx.button(
+                "Update Inline Counter",
+                on_click=CallScriptState.get_inline_counter,
+                id="update_inline_counter",
+            ),
+            rx.button(
+                "Update External Counter",
+                on_click=CallScriptState.get_external_counter,
+                id="update_external_counter",
+            ),
+            rx.button("Reset", id="reset", on_click=CallScriptState.reset_),
+        )
+
+    app.compile()
+
+
+@pytest.fixture(scope="session")
+def call_script(tmp_path_factory) -> Generator[AppHarness, None, None]:
+    """Start CallScript app at tmp_path via AppHarness.
+
+    Args:
+        tmp_path_factory: pytest tmp_path_factory fixture
+
+    Yields:
+        running AppHarness instance
+    """
+    with AppHarness.create(
+        root=tmp_path_factory.mktemp("call_script"),
+        app_source=CallScript,  # type: ignore
+    ) as harness:
+        yield harness
+
+
+@pytest.fixture
+def driver(call_script: AppHarness) -> Generator[WebDriver, None, None]:
+    """Get an instance of the browser open to the call_script app.
+
+    Args:
+        call_script: harness for CallScript app
+
+    Yields:
+        WebDriver instance.
+    """
+    assert call_script.app_instance is not None, "app is not running"
+    driver = call_script.frontend()
+    try:
+        yield driver
+    finally:
+        driver.quit()
+
+
+def assert_token(call_script: AppHarness, driver: WebDriver) -> str:
+    """Get the token associated with backend state.
+
+    Args:
+        call_script: harness for CallScript app.
+        driver: WebDriver instance.
+
+    Returns:
+        The token visible in the driver browser.
+    """
+    assert call_script.app_instance is not None
+    token_input = driver.find_element(By.ID, "token")
+    assert token_input
+
+    # wait for the backend connection to send the token
+    token = call_script.poll_for_value(token_input)
+    assert token is not None
+
+    return token
+
+
+@pytest.mark.parametrize("script", ["inline", "external"])
+def test_call_script(
+    call_script: AppHarness,
+    driver: WebDriver,
+    script: str,
+):
+    """Test calling javascript functions from python.
+
+    Args:
+        call_script: harness for CallScript app.
+        driver: WebDriver instance.
+        script: The type of script to test.
+    """
+    assert_token(call_script, driver)
+    reset_button = driver.find_element(By.ID, "reset")
+    update_counter_button = driver.find_element(By.ID, f"update_{script}_counter")
+    counter = driver.find_element(By.ID, f"{script}_counter")
+    results = driver.find_element(By.ID, "results")
+    yield_button = driver.find_element(By.ID, f"{script}_yield")
+    return_button = driver.find_element(By.ID, f"{script}_return")
+    yield_callback_button = driver.find_element(By.ID, f"{script}_yield_callback")
+    return_callback_button = driver.find_element(By.ID, f"{script}_return_callback")
+    return_lambda_button = driver.find_element(By.ID, f"{script}_return_lambda")
+
+    yield_button.click()
+    update_counter_button.click()
+    assert call_script.poll_for_value(counter, exp_not_equal="0") == "3"
+    reset_button.click()
+    assert call_script.poll_for_value(counter, exp_not_equal="3") == "0"
+    return_button.click()
+    update_counter_button.click()
+    assert call_script.poll_for_value(counter, exp_not_equal="0") == "1"
+    reset_button.click()
+    assert call_script.poll_for_value(counter, exp_not_equal="1") == "0"
+
+    yield_callback_button.click()
+    update_counter_button.click()
+    assert call_script.poll_for_value(counter, exp_not_equal="0") == "3"
+    assert call_script.poll_for_value(
+        results, exp_not_equal="[]"
+    ) == '["%s1",null,{"%s3":42,"a":[1,2,3],"s":"js","o":{"a":1,"b":2}}]' % (
+        script,
+        script,
+    )
+    reset_button.click()
+    assert call_script.poll_for_value(counter, exp_not_equal="3") == "0"
+
+    return_callback_button.click()
+    update_counter_button.click()
+    assert call_script.poll_for_value(counter, exp_not_equal="0") == "1"
+    assert (
+        call_script.poll_for_value(results, exp_not_equal="[]")
+        == '[{"%s3":42,"a":[1,2,3],"s":"js","o":{"a":1,"b":2}}]' % script
+    )
+    reset_button.click()
+    assert call_script.poll_for_value(counter, exp_not_equal="1") == "0"
+
+    return_lambda_button.click()
+    update_counter_button.click()
+    assert call_script.poll_for_value(counter, exp_not_equal="0") == "1"
+    assert (
+        call_script.poll_for_value(results, exp_not_equal="[]") == '[["lambda",null]]'
+    )
+    reset_button.click()
+    assert call_script.poll_for_value(counter, exp_not_equal="1") == "0"

+ 9 - 3
reflex/.templates/web/utils/state.js

@@ -198,7 +198,10 @@ export const applyEvent = async (event, socket) => {
 
   if (event.name == "_call_script") {
     try {
-      eval(event.payload.javascript_code);
+      const eval_result = eval(event.payload.javascript_code);
+      if (event.payload.callback) {
+        eval(event.payload.callback)(eval_result)
+      }
     } catch (e) {
       console.log("_call_script", e);
     }
@@ -213,7 +216,7 @@ export const applyEvent = async (event, socket) => {
 
   // Send the event to the server.
   if (socket) {
-    socket.emit("event", JSON.stringify(event));
+    socket.emit("event", JSON.stringify(event, (k, v) => v === undefined ? null : v));
     return true;
   }
 
@@ -407,7 +410,10 @@ export const hydrateClientStorage = (client_storage) => {
     for (const state_key in client_storage.cookies) {
       const cookie_options = client_storage.cookies[state_key]
       const cookie_name = cookie_options.name || state_key
-      client_storage_values.cookies[state_key] = cookies.get(cookie_name)
+      const cookie_value = cookies.get(cookie_name)
+      if (cookie_value !== undefined) {
+        client_storage_values.cookies[state_key] = cookies.get(cookie_name)
+      }
     }
   }
   if (client_storage.local_storage && (typeof window !== 'undefined')) {

+ 34 - 1
reflex/event.py

@@ -2,6 +2,7 @@
 from __future__ import annotations
 
 import inspect
+from types import FunctionType
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -408,19 +409,51 @@ def download(url: str, filename: Optional[str] = None) -> EventSpec:
     )
 
 
-def call_script(javascript_code: str) -> EventSpec:
+def _callback_arg_spec(eval_result):
+    """ArgSpec for call_script callback function.
+
+    Args:
+        eval_result: The result of the javascript execution.
+
+    Returns:
+        Args for the callback function
+    """
+    return [eval_result]
+
+
+def call_script(
+    javascript_code: str,
+    callback: EventHandler | Callable | None = None,
+) -> EventSpec:
     """Create an event handler that executes arbitrary javascript code.
 
     Args:
         javascript_code: The code to execute.
+        callback: EventHandler that will receive the result of evaluating the javascript code.
 
     Returns:
         EventSpec: An event that will execute the client side javascript.
+
+    Raises:
+        ValueError: If the callback is not a valid event handler.
     """
+    callback_kwargs = {}
+    if callback is not None:
+        arg_name = parse_args_spec(_callback_arg_spec)[0]._var_name
+        if isinstance(callback, EventHandler):
+            event_spec = call_event_handler(callback, _callback_arg_spec)
+        elif isinstance(callback, FunctionType):
+            event_spec = call_event_fn(callback, _callback_arg_spec)[0]
+        else:
+            raise ValueError("Cannot use {callback!r} as a call_script callback.")
+        callback_kwargs = {
+            "callback": f"({arg_name}) => queueEvents([{format.format_event(event_spec)}], {constants.CompileVars.SOCKET})"
+        }
     return server_side(
         "_call_script",
         get_fn_signature(call_script),
         javascript_code=javascript_code,
+        **callback_kwargs,
     )
 
 

+ 1 - 1
reflex/utils/format.py

@@ -417,7 +417,7 @@ def format_event(event_spec: EventSpec) -> str:
             ":".join(
                 (
                     name._var_name,
-                    wrap(json.dumps(val._var_name).strip('"'), "`")
+                    wrap(json.dumps(val._var_name).strip('"').replace("`", "\\`"), "`")
                     if val._var_is_string
                     else val._var_full_name,
                 )