Преглед на файлове

[REF-723+] Upload with progress and cancellation (#1899)

Masen Furer преди 1 година
родител
ревизия
7eccc6d988

+ 135 - 11
integration/test_upload.py

@@ -1,13 +1,14 @@
 """Integration tests for file upload."""
 """Integration tests for file upload."""
 from __future__ import annotations
 from __future__ import annotations
 
 
+import asyncio
 import time
 import time
 from typing import Generator
 from typing import Generator
 
 
 import pytest
 import pytest
 from selenium.webdriver.common.by import By
 from selenium.webdriver.common.by import By
 
 
-from reflex.testing import AppHarness
+from reflex.testing import AppHarness, WebDriver
 
 
 
 
 def UploadFile():
 def UploadFile():
@@ -16,12 +17,28 @@ def UploadFile():
 
 
     class UploadState(rx.State):
     class UploadState(rx.State):
         _file_data: dict[str, str] = {}
         _file_data: dict[str, str] = {}
+        event_order: list[str] = []
+        progress_dicts: list[dict] = []
 
 
         async def handle_upload(self, files: list[rx.UploadFile]):
         async def handle_upload(self, files: list[rx.UploadFile]):
             for file in files:
             for file in files:
                 upload_data = await file.read()
                 upload_data = await file.read()
                 self._file_data[file.filename or ""] = upload_data.decode("utf-8")
                 self._file_data[file.filename or ""] = upload_data.decode("utf-8")
 
 
+        async def handle_upload_secondary(self, files: list[rx.UploadFile]):
+            for file in files:
+                upload_data = await file.read()
+                self._file_data[file.filename or ""] = upload_data.decode("utf-8")
+                yield UploadState.chain_event
+
+        def upload_progress(self, progress):
+            assert progress
+            self.event_order.append("upload_progress")
+            self.progress_dicts.append(progress)
+
+        def chain_event(self):
+            self.event_order.append("chain_event")
+
     def index():
     def index():
         return rx.vstack(
         return rx.vstack(
             rx.input(
             rx.input(
@@ -29,6 +46,7 @@ def UploadFile():
                 is_read_only=True,
                 is_read_only=True,
                 id="token",
                 id="token",
             ),
             ),
+            rx.heading("Default Upload"),
             rx.upload(
             rx.upload(
                 rx.vstack(
                 rx.vstack(
                     rx.button("Select File"),
                     rx.button("Select File"),
@@ -52,6 +70,47 @@ def UploadFile():
                 on_click=rx.clear_selected_files,
                 on_click=rx.clear_selected_files,
                 id="clear_button",
                 id="clear_button",
             ),
             ),
+            rx.heading("Secondary Upload"),
+            rx.upload(
+                rx.vstack(
+                    rx.button("Select File"),
+                    rx.text("Drag and drop files here or click to select files"),
+                ),
+                id="secondary",
+            ),
+            rx.button(
+                "Upload",
+                on_click=UploadState.handle_upload_secondary(  # type: ignore
+                    rx.upload_files(
+                        upload_id="secondary",
+                        on_upload_progress=UploadState.upload_progress,
+                    ),
+                ),
+                id="upload_button_secondary",
+            ),
+            rx.box(
+                rx.foreach(
+                    rx.selected_files("secondary"),
+                    lambda f: rx.text(f),
+                ),
+                id="selected_files_secondary",
+            ),
+            rx.button(
+                "Clear",
+                on_click=rx.clear_selected_files("secondary"),
+                id="clear_button_secondary",
+            ),
+            rx.vstack(
+                rx.foreach(
+                    UploadState.progress_dicts,  # type: ignore
+                    lambda d: rx.text(d.to_string()),
+                )
+            ),
+            rx.button(
+                "Cancel",
+                on_click=rx.cancel_upload("secondary"),
+                id="cancel_button_secondary",
+            ),
         )
         )
 
 
     app = rx.App(state=UploadState)
     app = rx.App(state=UploadState)
@@ -94,14 +153,18 @@ def driver(upload_file: AppHarness):
         driver.quit()
         driver.quit()
 
 
 
 
+@pytest.mark.parametrize("secondary", [False, True])
 @pytest.mark.asyncio
 @pytest.mark.asyncio
-async def test_upload_file(tmp_path, upload_file: AppHarness, driver):
+async def test_upload_file(
+    tmp_path, upload_file: AppHarness, driver: WebDriver, secondary: bool
+):
     """Submit a file upload and check that it arrived on the backend.
     """Submit a file upload and check that it arrived on the backend.
 
 
     Args:
     Args:
         tmp_path: pytest tmp_path fixture
         tmp_path: pytest tmp_path fixture
         upload_file: harness for UploadFile app.
         upload_file: harness for UploadFile app.
         driver: WebDriver instance.
         driver: WebDriver instance.
+        secondary: whether to use the secondary upload form
     """
     """
     assert upload_file.app_instance is not None
     assert upload_file.app_instance is not None
     token_input = driver.find_element(By.ID, "token")
     token_input = driver.find_element(By.ID, "token")
@@ -110,9 +173,13 @@ async def test_upload_file(tmp_path, upload_file: AppHarness, driver):
     token = upload_file.poll_for_value(token_input)
     token = upload_file.poll_for_value(token_input)
     assert token is not None
     assert token is not None
 
 
-    upload_box = driver.find_element(By.XPATH, "//input[@type='file']")
+    suffix = "_secondary" if secondary else ""
+
+    upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[
+        1 if secondary else 0
+    ]
     assert upload_box
     assert upload_box
-    upload_button = driver.find_element(By.ID, "upload_button")
+    upload_button = driver.find_element(By.ID, f"upload_button{suffix}")
     assert upload_button
     assert upload_button
 
 
     exp_name = "test.txt"
     exp_name = "test.txt"
@@ -132,9 +199,15 @@ async def test_upload_file(tmp_path, upload_file: AppHarness, driver):
     assert file_data[exp_name] == exp_contents
     assert file_data[exp_name] == exp_contents
 
 
     # check that the selected files are displayed
     # check that the selected files are displayed
-    selected_files = driver.find_element(By.ID, "selected_files")
+    selected_files = driver.find_element(By.ID, f"selected_files{suffix}")
     assert selected_files.text == exp_name
     assert selected_files.text == exp_name
 
 
+    state = await upload_file.get_state(token)
+    if secondary:
+        # only the secondary form tracks progress and chain events
+        assert state.event_order.count("upload_progress") == 1
+        assert state.event_order.count("chain_event") == 1
+
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
 async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
@@ -186,13 +259,17 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
         assert file_data[exp_name] == exp_contents
         assert file_data[exp_name] == exp_contents
 
 
 
 
-def test_clear_files(tmp_path, upload_file: AppHarness, driver):
+@pytest.mark.parametrize("secondary", [False, True])
+def test_clear_files(
+    tmp_path, upload_file: AppHarness, driver: WebDriver, secondary: bool
+):
     """Select then clear several file uploads and check that they are cleared.
     """Select then clear several file uploads and check that they are cleared.
 
 
     Args:
     Args:
         tmp_path: pytest tmp_path fixture
         tmp_path: pytest tmp_path fixture
         upload_file: harness for UploadFile app.
         upload_file: harness for UploadFile app.
         driver: WebDriver instance.
         driver: WebDriver instance.
+        secondary: whether to use the secondary upload form.
     """
     """
     assert upload_file.app_instance is not None
     assert upload_file.app_instance is not None
     token_input = driver.find_element(By.ID, "token")
     token_input = driver.find_element(By.ID, "token")
@@ -201,9 +278,13 @@ def test_clear_files(tmp_path, upload_file: AppHarness, driver):
     token = upload_file.poll_for_value(token_input)
     token = upload_file.poll_for_value(token_input)
     assert token is not None
     assert token is not None
 
 
-    upload_box = driver.find_element(By.XPATH, "//input[@type='file']")
+    suffix = "_secondary" if secondary else ""
+
+    upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[
+        1 if secondary else 0
+    ]
     assert upload_box
     assert upload_box
-    upload_button = driver.find_element(By.ID, "upload_button")
+    upload_button = driver.find_element(By.ID, f"upload_button{suffix}")
     assert upload_button
     assert upload_button
 
 
     exp_files = {
     exp_files = {
@@ -219,13 +300,56 @@ def test_clear_files(tmp_path, upload_file: AppHarness, driver):
     time.sleep(0.2)
     time.sleep(0.2)
 
 
     # check that the selected files are displayed
     # check that the selected files are displayed
-    selected_files = driver.find_element(By.ID, "selected_files")
+    selected_files = driver.find_element(By.ID, f"selected_files{suffix}")
     assert selected_files.text == "\n".join(exp_files)
     assert selected_files.text == "\n".join(exp_files)
 
 
-    clear_button = driver.find_element(By.ID, "clear_button")
+    clear_button = driver.find_element(By.ID, f"clear_button{suffix}")
     assert clear_button
     assert clear_button
     clear_button.click()
     clear_button.click()
 
 
     # check that the selected files are cleared
     # check that the selected files are cleared
-    selected_files = driver.find_element(By.ID, "selected_files")
+    selected_files = driver.find_element(By.ID, f"selected_files{suffix}")
     assert selected_files.text == ""
     assert selected_files.text == ""
+
+
+# TODO: drag and drop directory
+# https://gist.github.com/florentbr/349b1ab024ca9f3de56e6bf8af2ac69e
+
+
+@pytest.mark.asyncio
+async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDriver):
+    """Submit a large file upload and cancel it.
+
+    Args:
+        tmp_path: pytest tmp_path fixture
+        upload_file: harness for UploadFile app.
+        driver: WebDriver instance.
+    """
+    assert upload_file.app_instance is not None
+    token_input = driver.find_element(By.ID, "token")
+    assert token_input
+    # wait for the backend connection to send the token
+    token = upload_file.poll_for_value(token_input)
+    assert token is not None
+
+    upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[1]
+    upload_button = driver.find_element(By.ID, f"upload_button_secondary")
+    cancel_button = driver.find_element(By.ID, f"cancel_button_secondary")
+
+    exp_name = "large.txt"
+    target_file = tmp_path / exp_name
+    with target_file.open("wb") as f:
+        f.seek(1024 * 1024 * 256)
+        f.write(b"0")
+
+    upload_box.send_keys(str(target_file))
+    upload_button.click()
+    await asyncio.sleep(0.3)
+    cancel_button.click()
+
+    # look up the backend state and assert on progress
+    state = await upload_file.get_state(token)
+    assert state.progress_dicts
+    assert exp_name not in state._file_data
+
+    target_file.unlink()

+ 82 - 33
reflex/.templates/web/utils/state.js

@@ -32,6 +32,11 @@ let event_processing = false
 // Array holding pending events to be processed.
 // Array holding pending events to be processed.
 const event_queue = [];
 const event_queue = [];
 
 
+// Pending upload promises, by id
+const upload_controllers = {};
+// Upload files state by id
+export const upload_files = {};
+
 /**
 /**
  * Generate a UUID (Used for session tokens).
  * Generate a UUID (Used for session tokens).
  * Taken from: https://stackoverflow.com/questions/105034/how-do-i-create-a-guid-uuid
  * Taken from: https://stackoverflow.com/questions/105034/how-do-i-create-a-guid-uuid
@@ -235,14 +240,22 @@ export const applyEvent = async (event, socket) => {
 /**
 /**
  * Send an event to the server via REST.
  * Send an event to the server via REST.
  * @param event The current event.
  * @param event The current event.
- * @param state The state with the event queue.
+ * @param socket The socket object to send the response event(s) on.
  *
  *
  * @returns Whether the event was sent.
  * @returns Whether the event was sent.
  */
  */
-export const applyRestEvent = async (event) => {
+export const applyRestEvent = async (event, socket) => {
   let eventSent = false;
   let eventSent = false;
   if (event.handler == "uploadFiles") {
   if (event.handler == "uploadFiles") {
-    eventSent = await uploadFiles(event.name, event.payload.files);
+    // Start upload, but do not wait for it, which would block other events.
+    uploadFiles(
+      event.name,
+      event.payload.files,
+      event.payload.upload_id,
+      event.payload.on_upload_progress,
+      socket
+    );
+    return false;
   }
   }
   return eventSent;
   return eventSent;
 };
 };
@@ -283,7 +296,7 @@ export const processEvent = async (
   let eventSent = false
   let eventSent = false
   // Process events with handlers via REST and all others via websockets.
   // Process events with handlers via REST and all others via websockets.
   if (event.handler) {
   if (event.handler) {
-    eventSent = await applyRestEvent(event);
+    eventSent = await applyRestEvent(event, socket);
   } else {
   } else {
     eventSent = await applyEvent(event, socket);
     eventSent = await applyEvent(event, socket);
   }
   }
@@ -347,50 +360,86 @@ export const connect = async (
  *
  *
  * @param state The state to apply the delta to.
  * @param state The state to apply the delta to.
  * @param handler The handler to use.
  * @param handler The handler to use.
+ * @param upload_id The upload id to use.
+ * @param on_upload_progress The function to call on upload progress.
+ * @param socket the websocket connection
  *
  *
- * @returns Whether the files were uploaded.
+ * @returns The response from posting to the UPLOADURL endpoint.
  */
  */
-export const uploadFiles = async (handler, files) => {
+export const uploadFiles = async (handler, files, upload_id, on_upload_progress, socket) => {
   // return if there's no file to upload
   // return if there's no file to upload
   if (files.length == 0) {
   if (files.length == 0) {
     return false;
     return false;
   }
   }
 
 
-  const headers = {
-    "Content-Type": files[0].type,
-  };
+  if (upload_controllers[upload_id]) {
+    console.log("Upload already in progress for ", upload_id)
+    return false;
+  }
+
+  let resp_idx = 0;
+  const eventHandler = (progressEvent) => {
+    // handle any delta / event streamed from the upload event handler
+    const chunks = progressEvent.event.target.responseText.trim().split("\n")
+    chunks.slice(resp_idx).map((chunk) => {
+      try {
+        socket._callbacks.$event.map((f) => {
+          f(chunk)
+        })
+        resp_idx += 1
+      } catch (e) {
+        console.log("Error parsing chunk", chunk, e)
+        return
+      }
+    })
+  }
+
+  const controller = new AbortController()
+  const config = {
+    headers: {
+      "Reflex-Client-Token": getToken(),
+      "Reflex-Event-Handler": handler,
+    },
+    signal: controller.signal,
+    onDownloadProgress: eventHandler,
+  }
+  if (on_upload_progress) {
+    config["onUploadProgress"] = on_upload_progress
+  }
   const formdata = new FormData();
   const formdata = new FormData();
 
 
   // Add the token and handler to the file name.
   // Add the token and handler to the file name.
-  for (let i = 0; i < files.length; i++) {
+  files.forEach((file) => {
     formdata.append(
     formdata.append(
       "files",
       "files",
-      files[i],
-      getToken() + ":" + handler + ":" + files[i].name
+      file,
+      file.path || file.name
     );
     );
-  }
+  })
 
 
   // Send the file to the server.
   // Send the file to the server.
-  await axios.post(UPLOADURL, formdata, headers)
-    .then(() => { return true; })
-    .catch(
-      error => {
-        if (error.response) {
-          // The request was made and the server responded with a status code
-          // that falls out of the range of 2xx
-          console.log(error.response.data);
-        } else if (error.request) {
-          // The request was made but no response was received
-          // `error.request` is an instance of XMLHttpRequest in the browser and an instance of
-          // http.ClientRequest in node.js
-          console.log(error.request);
-        } else {
-          // Something happened in setting up the request that triggered an Error
-          console.log(error.message);
-        }
-        return false;
-      }
-    )
+  upload_controllers[upload_id] = controller
+
+  try {
+    return await axios.post(UPLOADURL, formdata, config)
+  } catch (error) {
+    if (error.response) {
+      // The request was made and the server responded with a status code
+      // that falls out of the range of 2xx
+      console.log(error.response.data);
+    } else if (error.request) {
+      // The request was made but no response was received
+      // `error.request` is an instance of XMLHttpRequest in the browser and an instance of
+      // http.ClientRequest in node.js
+      console.log(error.request);
+    } else {
+      // Something happened in setting up the request that triggered an Error
+      console.log(error.message);
+    }
+    return false;
+  } finally {
+    delete upload_controllers[upload_id]
+  }
 };
 };
 
 
 /**
 /**

+ 1 - 0
reflex/__init__.py

@@ -229,6 +229,7 @@ _ALL_COMPONENTS = [
 
 
 _ALL_COMPONENTS += [to_snake_case(component) for component in _ALL_COMPONENTS]
 _ALL_COMPONENTS += [to_snake_case(component) for component in _ALL_COMPONENTS]
 _ALL_COMPONENTS += [
 _ALL_COMPONENTS += [
+    "cancel_upload",
     "components",
     "components",
     "color_mode_cond",
     "color_mode_cond",
     "desktop_only",
     "desktop_only",

+ 10 - 0
reflex/__init__.pyi

@@ -58,6 +58,9 @@ from reflex.components import ConnectionModal as ConnectionModal
 from reflex.components import Container as Container
 from reflex.components import Container as Container
 from reflex.components import DataTable as DataTable
 from reflex.components import DataTable as DataTable
 from reflex.components import DataEditor as DataEditor
 from reflex.components import DataEditor as DataEditor
+from reflex.components import DataEditorTheme as DataEditorTheme
+from reflex.components import DatePicker as DatePicker
+from reflex.components import DateTimePicker as DateTimePicker
 from reflex.components import DebounceInput as DebounceInput
 from reflex.components import DebounceInput as DebounceInput
 from reflex.components import Divider as Divider
 from reflex.components import Divider as Divider
 from reflex.components import Drawer as Drawer
 from reflex.components import Drawer as Drawer
@@ -265,6 +268,9 @@ from reflex.components import connection_modal as connection_modal
 from reflex.components import container as container
 from reflex.components import container as container
 from reflex.components import data_table as data_table
 from reflex.components import data_table as data_table
 from reflex.components import data_editor as data_editor
 from reflex.components import data_editor as data_editor
+from reflex.components import data_editor_theme as data_editor_theme
+from reflex.components import date_picker as date_picker
+from reflex.components import date_time_picker as date_time_picker
 from reflex.components import debounce_input as debounce_input
 from reflex.components import debounce_input as debounce_input
 from reflex.components import divider as divider
 from reflex.components import divider as divider
 from reflex.components import drawer as drawer
 from reflex.components import drawer as drawer
@@ -421,7 +427,9 @@ from reflex.components import visually_hidden as visually_hidden
 from reflex.components import vstack as vstack
 from reflex.components import vstack as vstack
 from reflex.components import wrap as wrap
 from reflex.components import wrap as wrap
 from reflex.components import wrap_item as wrap_item
 from reflex.components import wrap_item as wrap_item
+from reflex.components import cancel_upload as cancel_upload
 from reflex import components as components
 from reflex import components as components
+from reflex.components import color_mode_cond as color_mode_cond
 from reflex.components import desktop_only as desktop_only
 from reflex.components import desktop_only as desktop_only
 from reflex.components import mobile_only as mobile_only
 from reflex.components import mobile_only as mobile_only
 from reflex.components import tablet_only as tablet_only
 from reflex.components import tablet_only as tablet_only
@@ -429,7 +437,9 @@ from reflex.components import mobile_and_tablet as mobile_and_tablet
 from reflex.components import tablet_and_desktop as tablet_and_desktop
 from reflex.components import tablet_and_desktop as tablet_and_desktop
 from reflex.components import selected_files as selected_files
 from reflex.components import selected_files as selected_files
 from reflex.components import clear_selected_files as clear_selected_files
 from reflex.components import clear_selected_files as clear_selected_files
+from reflex.components import EditorButtonList as EditorButtonList
 from reflex.components import EditorOptions as EditorOptions
 from reflex.components import EditorOptions as EditorOptions
+from reflex.components import NoSSRComponent as NoSSRComponent
 from reflex.components.component import memo as memo
 from reflex.components.component import memo as memo
 from reflex.components.graphing import recharts as recharts
 from reflex.components.graphing import recharts as recharts
 from reflex import config as config
 from reflex import config as config

+ 76 - 45
reflex/app.py

@@ -3,7 +3,7 @@ from __future__ import annotations
 
 
 import asyncio
 import asyncio
 import contextlib
 import contextlib
-import inspect
+import functools
 import os
 import os
 from multiprocessing.pool import ThreadPool
 from multiprocessing.pool import ThreadPool
 from typing import (
 from typing import (
@@ -17,10 +17,13 @@ from typing import (
     Set,
     Set,
     Type,
     Type,
     Union,
     Union,
+    get_args,
+    get_type_hints,
 )
 )
 
 
-from fastapi import FastAPI, UploadFile
+from fastapi import FastAPI, HTTPException, Request, UploadFile
 from fastapi.middleware import cors
 from fastapi.middleware import cors
+from fastapi.responses import StreamingResponse
 from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
 from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
 from socketio import ASGIApp, AsyncNamespace, AsyncServer
 from socketio import ASGIApp, AsyncNamespace, AsyncServer
 from starlette_admin.contrib.sqla.admin import Admin
 from starlette_admin.contrib.sqla.admin import Admin
@@ -880,62 +883,90 @@ def upload(app: App):
         The upload function.
         The upload function.
     """
     """
 
 
-    async def upload_file(files: List[UploadFile]):
+    async def upload_file(request: Request, files: List[UploadFile]):
         """Upload a file.
         """Upload a file.
 
 
         Args:
         Args:
+            request: The FastAPI request object.
             files: The file(s) to upload.
             files: The file(s) to upload.
 
 
+        Returns:
+            StreamingResponse yielding newline-delimited JSON of StateUpdate
+            emitted by the upload handler.
+
         Raises:
         Raises:
             ValueError: if there are no args with supported annotation.
             ValueError: if there are no args with supported annotation.
+            TypeError: if a background task is used as the handler.
+            HTTPException: when the request does not include token / handler headers.
         """
         """
-        assert files[0].filename is not None
-        token, handler = files[0].filename.split(":")[:2]
-        for file in files:
-            assert file.filename is not None
-            file.filename = file.filename.split(":")[-1]
+        token = request.headers.get("reflex-client-token")
+        handler = request.headers.get("reflex-event-handler")
+
+        if not token or not handler:
+            raise HTTPException(
+                status_code=400,
+                detail="Missing reflex-client-token or reflex-event-handler header.",
+            )
 
 
         # Get the state for the session.
         # Get the state for the session.
-        async with app.state_manager.modify_state(token) as state:
-            # get the current session ID
-            sid = state.router.session.session_id
-            # get the current state(parent state/substate)
-            path = handler.split(".")[:-1]
-            current_state = state.get_substate(path)
-            handler_upload_param = ()
-
-            # get handler function
-            func = getattr(current_state, handler.split(".")[-1])
-
-            # check if there exists any handler args with annotation, List[UploadFile]
-            for k, v in inspect.getfullargspec(
-                func.fn if isinstance(func, EventHandler) else func
-            ).annotations.items():
-                if types.is_generic_alias(v) and types._issubclass(
-                    v.__args__[0], UploadFile
-                ):
-                    handler_upload_param = (k, v)
-                    break
-
-            if not handler_upload_param:
-                raise ValueError(
-                    f"`{handler}` handler should have a parameter annotated as List["
-                    f"rx.UploadFile]"
+        state = await app.state_manager.get_state(token)
+
+        # get the current session ID
+        # get the current state(parent state/substate)
+        path = handler.split(".")[:-1]
+        current_state = state.get_substate(path)
+        handler_upload_param = ()
+
+        # get handler function
+        func = getattr(type(current_state), handler.split(".")[-1])
+
+        # check if there exists any handler args with annotation, List[UploadFile]
+        if isinstance(func, EventHandler):
+            if func.is_background:
+                raise TypeError(
+                    f"@rx.background is not supported for upload handler `{handler}`.",
                 )
                 )
+            func = func.fn
+        if isinstance(func, functools.partial):
+            func = func.func
+        for k, v in get_type_hints(func).items():
+            if types.is_generic_alias(v) and types._issubclass(
+                get_args(v)[0],
+                UploadFile,
+            ):
+                handler_upload_param = (k, v)
+                break
 
 
-            event = Event(
-                token=token,
-                name=handler,
-                payload={handler_upload_param[0]: files},
+        if not handler_upload_param:
+            raise ValueError(
+                f"`{handler}` handler should have a parameter annotated as "
+                "List[rx.UploadFile]"
             )
             )
-            async for update in state._process(event):
-                # Postprocess the event.
-                update = await app.postprocess(state, event, update)
-                # Send update to client
-                await app.event_namespace.emit_update(  # type: ignore
-                    update=update,
-                    sid=sid,
-                )
+
+        event = Event(
+            token=token,
+            name=handler,
+            payload={handler_upload_param[0]: files},
+        )
+
+        async def _ndjson_updates():
+            """Process the upload event, generating ndjson updates.
+
+            Yields:
+                Each state update as JSON followed by a new line.
+            """
+            # Process the event.
+            async with app.state_manager.modify_state(token) as state:
+                async for update in state._process(event):
+                    # Postprocess the event.
+                    update = await app.postprocess(state, event, update)
+                    yield update.json() + "\n"
+
+        # Stream updates to client
+        return StreamingResponse(
+            _ndjson_updates(),
+            media_type="application/x-ndjson",
+        )
 
 
     return upload_file
     return upload_file
 
 

+ 8 - 2
reflex/components/forms/__init__.py

@@ -46,12 +46,18 @@ from .select import Option, Select
 from .slider import Slider, SliderFilledTrack, SliderMark, SliderThumb, SliderTrack
 from .slider import Slider, SliderFilledTrack, SliderMark, SliderThumb, SliderTrack
 from .switch import Switch
 from .switch import Switch
 from .textarea import TextArea
 from .textarea import TextArea
-from .upload import Upload, clear_selected_files, selected_files
+from .upload import (
+    Upload,
+    cancel_upload,
+    clear_selected_files,
+    selected_files,
+)
 
 
 helpers = [
 helpers = [
     "color_mode_cond",
     "color_mode_cond",
-    "selected_files",
+    "cancel_upload",
     "clear_selected_files",
     "clear_selected_files",
+    "selected_files",
 ]
 ]
 
 
 __all__ = [f for f in dir() if f[0].isupper()] + helpers  # type: ignore
 __all__ = [f for f in dir() if f[0].isupper()] + helpers  # type: ignore

+ 80 - 17
reflex/components/forms/upload.py

@@ -3,26 +3,75 @@ from __future__ import annotations
 
 
 from typing import Any, Dict, List, Optional, Union
 from typing import Any, Dict, List, Optional, Union
 
 
+from reflex import constants
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.components.forms.input import Input
 from reflex.components.forms.input import Input
 from reflex.components.layout.box import Box
 from reflex.components.layout.box import Box
-from reflex.constants import EventTriggers
-from reflex.event import EventChain
-from reflex.vars import BaseVar, Var
+from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script
+from reflex.utils import imports
+from reflex.vars import BaseVar, CallableVar, ImportVar, Var
 
 
-files_state: str = "const [files, setFiles] = useState([]);"
-upload_file: BaseVar = BaseVar(
-    _var_name="e => setFiles((files) => e)", _var_type=EventChain
-)
+DEFAULT_UPLOAD_ID: str = "default"
 
 
-# Use this var along with the Upload component to render the list of selected files.
-selected_files: BaseVar = BaseVar(
-    _var_name="files.map((f) => f.name)", _var_type=List[str]
-)
 
 
-clear_selected_files: BaseVar = BaseVar(
-    _var_name="_e => setFiles((files) => [])", _var_type=EventChain
-)
+@CallableVar
+def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
+    """Get the file upload drop trigger.
+
+    This var is passed to the dropzone component to update the file list when a
+    drop occurs.
+
+    Args:
+        id_: The id of the upload to get the drop trigger for.
+
+    Returns:
+        A var referencing the file upload drop trigger.
+    """
+    return BaseVar(
+        _var_name=f"e => upload_files.{id_}[1]((files) => e)",
+        _var_type=EventChain,
+    )
+
+
+@CallableVar
+def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
+    """Get the list of selected files.
+
+    Args:
+        id_: The id of the upload to get the selected files for.
+
+    Returns:
+        A var referencing the list of selected file paths.
+    """
+    return BaseVar(
+        _var_name=f"(upload_files.{id_} ? upload_files.{id_}[0]?.map((f) => (f.path || f.name)) : [])",
+        _var_type=List[str],
+    )
+
+
+@CallableEventSpec
+def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec:
+    """Clear the list of selected files.
+
+    Args:
+        id_: The id of the upload to clear.
+
+    Returns:
+        An event spec that clears the list of selected files when triggered.
+    """
+    return call_script(f"upload_files.{id_}[1]((files) => [])")
+
+
+def cancel_upload(upload_id: str) -> EventSpec:
+    """Cancel an upload.
+
+    Args:
+        upload_id: The id of the upload to cancel.
+
+    Returns:
+        An event spec that cancels the upload when triggered.
+    """
+    return call_script(f"upload_controllers[{upload_id!r}]?.abort()")
 
 
 
 
 class Upload(Component):
 class Upload(Component):
@@ -94,7 +143,10 @@ class Upload(Component):
         zone.special_props = {BaseVar(_var_name="{...getRootProps()}", _var_type=None)}
         zone.special_props = {BaseVar(_var_name="{...getRootProps()}", _var_type=None)}
 
 
         # Create the component.
         # Create the component.
-        return super().create(zone, on_drop=upload_file, **upload_props)
+        upload_props["id"] = props.get("id", DEFAULT_UPLOAD_ID)
+        return super().create(
+            zone, on_drop=upload_file(upload_props["id"]), **upload_props
+        )
 
 
     def get_event_triggers(self) -> dict[str, Union[Var, Any]]:
     def get_event_triggers(self) -> dict[str, Union[Var, Any]]:
         """Get the event triggers that pass the component's value to the handler.
         """Get the event triggers that pass the component's value to the handler.
@@ -104,7 +156,7 @@ class Upload(Component):
         """
         """
         return {
         return {
             **super().get_event_triggers(),
             **super().get_event_triggers(),
-            EventTriggers.ON_DROP: lambda e0: [e0],
+            constants.EventTriggers.ON_DROP: lambda e0: [e0],
         }
         }
 
 
     def _render(self):
     def _render(self):
@@ -113,4 +165,15 @@ class Upload(Component):
         return out
         return out
 
 
     def _get_hooks(self) -> str | None:
     def _get_hooks(self) -> str | None:
-        return (super()._get_hooks() or "") + files_state
+        return (
+            (super()._get_hooks() or "")
+            + f"""
+        upload_files.{self.id or DEFAULT_UPLOAD_ID} = useState([]);
+        """
+        )
+
+    def _get_imports(self) -> imports.ImportDict:
+        return {
+            **super()._get_imports(),
+            f"/{constants.Dirs.STATE_PATH}": {ImportVar(tag="upload_files")},
+        }

+ 13 - 7
reflex/components/forms/upload.pyi

@@ -8,17 +8,23 @@ from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.style import Style
 from reflex.style import Style
 from typing import Any, Dict, List, Optional, Union
 from typing import Any, Dict, List, Optional, Union
+from reflex import constants
 from reflex.components.component import Component
 from reflex.components.component import Component
 from reflex.components.forms.input import Input
 from reflex.components.forms.input import Input
 from reflex.components.layout.box import Box
 from reflex.components.layout.box import Box
-from reflex.constants import EventTriggers
-from reflex.event import EventChain
-from reflex.vars import BaseVar, Var
+from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script
+from reflex.utils import imports
+from reflex.vars import BaseVar, CallableVar, ImportVar, Var
 
 
-files_state: str
-upload_file: BaseVar
-selected_files: BaseVar
-clear_selected_files: BaseVar
+DEFAULT_UPLOAD_ID: str
+
+@CallableVar
+def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: ...
+@CallableVar
+def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: ...
+@CallableEventSpec
+def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec: ...
+def cancel_upload(upload_id: str) -> EventSpec: ...
 
 
 class Upload(Component):
 class Upload(Component):
     @overload
     @overload

+ 115 - 8
reflex/event.py

@@ -174,13 +174,7 @@ class EventHandler(EventActionsMixin):
         for arg in args:
         for arg in args:
             # Special case for file uploads.
             # Special case for file uploads.
             if isinstance(arg, FileUpload):
             if isinstance(arg, FileUpload):
-                return EventSpec(
-                    handler=self,
-                    client_handler_name="uploadFiles",
-                    # `files` is defined in the Upload component's _use_hooks
-                    args=((Var.create_safe("files"), Var.create_safe("files")),),
-                    event_actions=self.event_actions.copy(),
-                )
+                return arg.as_event_spec(handler=self)
 
 
             # Otherwise, convert to JSON.
             # Otherwise, convert to JSON.
             try:
             try:
@@ -236,6 +230,50 @@ class EventSpec(EventActionsMixin):
         )
         )
 
 
 
 
+class CallableEventSpec(EventSpec):
+    """Decorate an EventSpec-returning function to act as both a EventSpec and a function.
+
+    This is used as a compatibility shim for replacing EventSpec objects in the
+    API with functions that return a family of EventSpec.
+    """
+
+    fn: Optional[Callable[..., EventSpec]] = None
+
+    def __init__(self, fn: Callable[..., EventSpec] | None = None, **kwargs):
+        """Initialize a CallableEventSpec.
+
+        Args:
+            fn: The function to decorate.
+            **kwargs: The kwargs to pass to pydantic initializer
+        """
+        if fn is not None:
+            default_event_spec = fn()
+            super().__init__(
+                fn=fn,  # type: ignore
+                **default_event_spec.dict(),
+                **kwargs,
+            )
+        else:
+            super().__init__(**kwargs)
+
+    def __call__(self, *args, **kwargs) -> EventSpec:
+        """Call the decorated function.
+
+        Args:
+            *args: The args to pass to the function.
+            **kwargs: The kwargs to pass to the function.
+
+        Returns:
+            The EventSpec returned from calling the function.
+
+        Raises:
+            TypeError: If the CallableEventSpec has no associated function.
+        """
+        if self.fn is None:
+            raise TypeError("CallableEventSpec has no associated function.")
+        return self.fn(*args, **kwargs)
+
+
 class EventChain(EventActionsMixin):
 class EventChain(EventActionsMixin):
     """Container for a chain of events that will be executed in order."""
     """Container for a chain of events that will be executed in order."""
 
 
@@ -267,7 +305,76 @@ class FrontendEvent(Base):
 class FileUpload(Base):
 class FileUpload(Base):
     """Class to represent a file upload."""
     """Class to represent a file upload."""
 
 
-    pass
+    upload_id: Optional[str] = None
+    on_upload_progress: Optional[Union[EventHandler, Callable]] = None
+
+    @staticmethod
+    def on_upload_progress_args_spec(_prog: dict[str, int | float | bool]):
+        """Args spec for on_upload_progress event handler.
+
+        Returns:
+            The arg mapping passed to backend event handler
+        """
+        return [_prog]
+
+    def as_event_spec(self, handler: EventHandler) -> EventSpec:
+        """Get the EventSpec for the file upload.
+
+        Args:
+            handler: The event handler.
+
+        Returns:
+            The event spec for the handler.
+
+        Raises:
+            ValueError: If the on_upload_progress is not a valid event handler.
+        """
+        from reflex.components.forms.upload import DEFAULT_UPLOAD_ID
+
+        upload_id = self.upload_id or DEFAULT_UPLOAD_ID
+
+        spec_args = [
+            # `upload_files` is defined in state.js and assigned in the Upload component's _use_hooks
+            (Var.create_safe("files"), Var.create_safe(f"upload_files.{upload_id}[0]")),
+            (
+                Var.create_safe("upload_id"),
+                Var.create_safe(upload_id, _var_is_string=True),
+            ),
+        ]
+        if self.on_upload_progress is not None:
+            on_upload_progress = self.on_upload_progress
+            if isinstance(on_upload_progress, EventHandler):
+                events = [
+                    call_event_handler(
+                        on_upload_progress,
+                        self.on_upload_progress_args_spec,
+                    ),
+                ]
+            elif isinstance(on_upload_progress, Callable):
+                # Call the lambda to get the event chain.
+                events = call_event_fn(on_upload_progress, self.on_upload_progress_args_spec)  # type: ignore
+            else:
+                raise ValueError(f"{on_upload_progress} is not a valid event handler.")
+            on_upload_progress_chain = EventChain(
+                events=events,
+                args_spec=self.on_upload_progress_args_spec,
+            )
+            formatted_chain = str(format.format_prop(on_upload_progress_chain))
+            spec_args.append(
+                (
+                    Var.create_safe("on_upload_progress"),
+                    BaseVar(
+                        _var_name=formatted_chain.strip("{}"),
+                        _var_type=EventChain,
+                    ),
+                ),
+            )
+        return EventSpec(
+            handler=handler,
+            client_handler_name="uploadFiles",
+            args=tuple(spec_args),
+            event_actions=handler.event_actions.copy(),
+        )
 
 
 
 
 # Alias for rx.upload_files
 # Alias for rx.upload_files

+ 30 - 0
reflex/vars.py

@@ -1590,3 +1590,33 @@ class NoRenderImportVar(ImportVar):
     """A import that doesn't need to be rendered."""
     """A import that doesn't need to be rendered."""
 
 
     render: Optional[bool] = False
     render: Optional[bool] = False
+
+
+class CallableVar(BaseVar):
+    """Decorate a Var-returning function to act as both a Var and a function.
+
+    This is used as a compatibility shim for replacing Var objects in the
+    API with functions that return a family of Var.
+    """
+
+    def __init__(self, fn: Callable[..., BaseVar]):
+        """Initialize a CallableVar.
+
+        Args:
+            fn: The function to decorate (must return Var)
+        """
+        self.fn = fn
+        default_var = fn()
+        super().__init__(**dataclasses.asdict(default_var))
+
+    def __call__(self, *args, **kwargs) -> BaseVar:
+        """Call the decorated function.
+
+        Args:
+            *args: The args to pass to the function.
+            **kwargs: The kwargs to pass to the function.
+
+        Returns:
+            The Var returned from calling the function.
+        """
+        return self.fn(*args, **kwargs)

+ 4 - 0
reflex/vars.pyi

@@ -137,3 +137,7 @@ class NoRenderImportVar(ImportVar):
     """A import that doesn't need to be rendered."""
     """A import that doesn't need to be rendered."""
 
 
 def get_local_storage(key: Optional[Union[Var, str]] = ...) -> BaseVar: ...
 def get_local_storage(key: Optional[Union[Var, str]] = ...) -> BaseVar: ...
+
+class CallableVar(BaseVar):
+    def __init__(self, fn: Callable[..., BaseVar]): ...
+    def __call__(self, *args, **kwargs) -> BaseVar: ...

+ 6 - 2
tests/components/forms/test_uploads.py

@@ -52,8 +52,10 @@ def test_upload_component_render(upload_component):
     # upload
     # upload
     assert upload["name"] == "ReactDropzone"
     assert upload["name"] == "ReactDropzone"
     assert upload["props"] == [
     assert upload["props"] == [
+        "id={`default`}",
         "multiple={true}",
         "multiple={true}",
-        "onDrop={e => setFiles((files) => e)}",
+        "onDrop={e => upload_files.default[1]((files) => e)}",
+        "ref={ref_default}",
     ]
     ]
     assert upload["args"] == ("getRootProps", "getInputProps")
     assert upload["args"] == ("getRootProps", "getInputProps")
 
 
@@ -89,8 +91,10 @@ def test_upload_component_with_props_render(upload_component_with_props):
     upload = upload_component_with_props.render()
     upload = upload_component_with_props.render()
 
 
     assert upload["props"] == [
     assert upload["props"] == [
+        "id={`default`}",
         "maxFiles={2}",
         "maxFiles={2}",
         "multiple={true}",
         "multiple={true}",
         "noDrag={true}",
         "noDrag={true}",
-        "onDrop={e => setFiles((files) => e)}",
+        "onDrop={e => upload_files.default[1]((files) => e)}",
+        "ref={ref_default}",
     ]
     ]

+ 30 - 30
tests/states/upload.py

@@ -49,16 +49,7 @@ class FileUploadState(rx.State):
         Args:
         Args:
             files: The uploaded files.
             files: The uploaded files.
         """
         """
-        for file in files:
-            upload_data = await file.read()
-            outfile = f"{self._tmp_path}/{file.filename}"
-
-            # Save the file.
-            with open(outfile, "wb") as file_object:
-                file_object.write(upload_data)
-
-            # Update the img var.
-            self.img_list.append(file.filename)
+        pass
 
 
     async def multi_handle_upload(self, files: List[rx.UploadFile]):
     async def multi_handle_upload(self, files: List[rx.UploadFile]):
         """Handle the upload of a file.
         """Handle the upload of a file.
@@ -78,6 +69,15 @@ class FileUploadState(rx.State):
             assert file.filename is not None
             assert file.filename is not None
             self.img_list.append(file.filename)
             self.img_list.append(file.filename)
 
 
+    @rx.background
+    async def bg_upload(self, files: List[rx.UploadFile]):
+        """Background task cannot be upload handler.
+
+        Args:
+            files: The uploaded files.
+        """
+        pass
+
 
 
 class FileStateBase1(rx.State):
 class FileStateBase1(rx.State):
     """The base state for a child FileUploadState."""
     """The base state for a child FileUploadState."""
@@ -97,16 +97,7 @@ class ChildFileUploadState(FileStateBase1):
         Args:
         Args:
             files: The uploaded files.
             files: The uploaded files.
         """
         """
-        for file in files:
-            upload_data = await file.read()
-            outfile = f"{self._tmp_path}/{file.filename}"
-
-            # Save the file.
-            with open(outfile, "wb") as file_object:
-                file_object.write(upload_data)
-
-            # Update the img var.
-            self.img_list.append(file.filename)
+        pass
 
 
     async def multi_handle_upload(self, files: List[rx.UploadFile]):
     async def multi_handle_upload(self, files: List[rx.UploadFile]):
         """Handle the upload of a file.
         """Handle the upload of a file.
@@ -126,6 +117,15 @@ class ChildFileUploadState(FileStateBase1):
             assert file.filename is not None
             assert file.filename is not None
             self.img_list.append(file.filename)
             self.img_list.append(file.filename)
 
 
+    @rx.background
+    async def bg_upload(self, files: List[rx.UploadFile]):
+        """Background task cannot be upload handler.
+
+        Args:
+            files: The uploaded files.
+        """
+        pass
+
 
 
 class FileStateBase2(FileStateBase1):
 class FileStateBase2(FileStateBase1):
     """The parent state for a grandchild FileUploadState."""
     """The parent state for a grandchild FileUploadState."""
@@ -145,16 +145,7 @@ class GrandChildFileUploadState(FileStateBase2):
         Args:
         Args:
             files: The uploaded files.
             files: The uploaded files.
         """
         """
-        for file in files:
-            upload_data = await file.read()
-            outfile = f"{self._tmp_path}/{file.filename}"
-
-            # Save the file.
-            with open(outfile, "wb") as file_object:
-                file_object.write(upload_data)
-
-            # Update the img var.
-            self.img_list.append(file.filename)
+        pass
 
 
     async def multi_handle_upload(self, files: List[rx.UploadFile]):
     async def multi_handle_upload(self, files: List[rx.UploadFile]):
         """Handle the upload of a file.
         """Handle the upload of a file.
@@ -173,3 +164,12 @@ class GrandChildFileUploadState(FileStateBase2):
             # Update the img var.
             # Update the img var.
             assert file.filename is not None
             assert file.filename is not None
             self.img_list.append(file.filename)
             self.img_list.append(file.filename)
+
+    @rx.background
+    async def bg_upload(self, files: List[rx.UploadFile]):
+        """Background task cannot be upload handler.
+
+        Args:
+            files: The uploaded files.
+        """
+        pass

+ 56 - 25
tests/test_app.py

@@ -746,23 +746,28 @@ async def test_upload_file(tmp_path, state, delta, token: str):
     bio.write(data)
     bio.write(data)
 
 
     state_name = state.get_full_name().partition(".")[2] or state.get_name()
     state_name = state.get_full_name().partition(".")[2] or state.get_name()
-    handler_prefix = f"{token}:{state_name}"
+    request_mock = unittest.mock.Mock()
+    request_mock.headers = {
+        "reflex-client-token": token,
+        "reflex-event-handler": f"{state_name}.multi_handle_upload",
+    }
 
 
     file1 = UploadFile(
     file1 = UploadFile(
-        filename=f"{handler_prefix}.multi_handle_upload:True:image1.jpg",
+        filename=f"image1.jpg",
         file=bio,
         file=bio,
     )
     )
     file2 = UploadFile(
     file2 = UploadFile(
-        filename=f"{handler_prefix}.multi_handle_upload:True:image2.jpg",
+        filename=f"image2.jpg",
         file=bio,
         file=bio,
     )
     )
     upload_fn = upload(app)
     upload_fn = upload(app)
-    await upload_fn([file1, file2])
-    state_update = StateUpdate(delta=delta, events=[], final=True)
+    streaming_response = await upload_fn(request_mock, [file1, file2])
+    async for state_update in streaming_response.body_iterator:
+        assert (
+            state_update
+            == StateUpdate(delta=delta, events=[], final=True).json() + "\n"
+        )
 
 
-    app.event_namespace.emit.assert_called_with(  # type: ignore
-        "event", state_update.json(), to=current_state.router.session.session_id
-    )
     current_state = await app.state_manager.get_state(token)
     current_state = await app.state_manager.get_state(token)
     state_dict = current_state.dict()
     state_dict = current_state.dict()
     for substate in state.get_full_name().split(".")[1:]:
     for substate in state.get_full_name().split(".")[1:]:
@@ -789,30 +794,20 @@ async def test_upload_file_without_annotation(state, tmp_path, token):
         tmp_path: Temporary path.
         tmp_path: Temporary path.
         token: a Token.
         token: a Token.
     """
     """
-    data = b"This is binary data"
-
-    # Create a binary IO object and write data to it
-    bio = io.BytesIO()
-    bio.write(data)
-
     state._tmp_path = tmp_path
     state._tmp_path = tmp_path
     # The App state must be the "root" of the state tree
     # The App state must be the "root" of the state tree
     app = App(state=state if state is FileUploadState else FileStateBase1)
     app = App(state=state if state is FileUploadState else FileStateBase1)
 
 
     state_name = state.get_full_name().partition(".")[2] or state.get_name()
     state_name = state.get_full_name().partition(".")[2] or state.get_name()
-    handler_prefix = f"{token}:{state_name}"
-
-    file1 = UploadFile(
-        filename=f"{handler_prefix}.handle_upload2:True:image1.jpg",
-        file=bio,
-    )
-    file2 = UploadFile(
-        filename=f"{handler_prefix}.handle_upload2:True:image2.jpg",
-        file=bio,
-    )
+    request_mock = unittest.mock.Mock()
+    request_mock.headers = {
+        "reflex-client-token": token,
+        "reflex-event-handler": f"{state_name}.handle_upload2",
+    }
+    file_mock = unittest.mock.Mock(filename="image1.jpg")
     fn = upload(app)
     fn = upload(app)
     with pytest.raises(ValueError) as err:
     with pytest.raises(ValueError) as err:
-        await fn([file1, file2])
+        await fn(request_mock, [file_mock])
     assert (
     assert (
         err.value.args[0]
         err.value.args[0]
         == f"`{state_name}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
         == f"`{state_name}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
@@ -822,6 +817,42 @@ async def test_upload_file_without_annotation(state, tmp_path, token):
         await app.state_manager.redis.close()
         await app.state_manager.redis.close()
 
 
 
 
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "state",
+    [FileUploadState, ChildFileUploadState, GrandChildFileUploadState],
+)
+async def test_upload_file_background(state, tmp_path, token):
+    """Test that an error is thrown handler is a background task.
+
+    Args:
+        state: The state class.
+        tmp_path: Temporary path.
+        token: a Token.
+    """
+    state._tmp_path = tmp_path
+    # The App state must be the "root" of the state tree
+    app = App(state=state if state is FileUploadState else FileStateBase1)
+
+    state_name = state.get_full_name().partition(".")[2] or state.get_name()
+    request_mock = unittest.mock.Mock()
+    request_mock.headers = {
+        "reflex-client-token": token,
+        "reflex-event-handler": f"{state_name}.bg_upload",
+    }
+    file_mock = unittest.mock.Mock(filename="image1.jpg")
+    fn = upload(app)
+    with pytest.raises(TypeError) as err:
+        await fn(request_mock, [file_mock])
+    assert (
+        err.value.args[0]
+        == f"@rx.background is not supported for upload handler `{state_name}.bg_upload`."
+    )
+
+    if isinstance(app.state_manager, StateManagerRedis):
+        await app.state_manager.redis.close()
+
+
 class DynamicState(State):
 class DynamicState(State):
     """State class for testing dynamic route var.
     """State class for testing dynamic route var.