Kaynağa Gözat

[REF-723+] Upload with progress and cancellation (#1899)

Masen Furer 1 yıl önce
ebeveyn
işleme
7eccc6d988

+ 135 - 11
integration/test_upload.py

@@ -1,13 +1,14 @@
 """Integration tests for file upload."""
 from __future__ import annotations
 
+import asyncio
 import time
 from typing import Generator
 
 import pytest
 from selenium.webdriver.common.by import By
 
-from reflex.testing import AppHarness
+from reflex.testing import AppHarness, WebDriver
 
 
 def UploadFile():
@@ -16,12 +17,28 @@ def UploadFile():
 
     class UploadState(rx.State):
         _file_data: dict[str, str] = {}
+        event_order: list[str] = []
+        progress_dicts: list[dict] = []
 
         async def handle_upload(self, files: list[rx.UploadFile]):
             for file in files:
                 upload_data = await file.read()
                 self._file_data[file.filename or ""] = upload_data.decode("utf-8")
 
+        async def handle_upload_secondary(self, files: list[rx.UploadFile]):
+            for file in files:
+                upload_data = await file.read()
+                self._file_data[file.filename or ""] = upload_data.decode("utf-8")
+                yield UploadState.chain_event
+
+        def upload_progress(self, progress):
+            assert progress
+            self.event_order.append("upload_progress")
+            self.progress_dicts.append(progress)
+
+        def chain_event(self):
+            self.event_order.append("chain_event")
+
     def index():
         return rx.vstack(
             rx.input(
@@ -29,6 +46,7 @@ def UploadFile():
                 is_read_only=True,
                 id="token",
             ),
+            rx.heading("Default Upload"),
             rx.upload(
                 rx.vstack(
                     rx.button("Select File"),
@@ -52,6 +70,47 @@ def UploadFile():
                 on_click=rx.clear_selected_files,
                 id="clear_button",
             ),
+            rx.heading("Secondary Upload"),
+            rx.upload(
+                rx.vstack(
+                    rx.button("Select File"),
+                    rx.text("Drag and drop files here or click to select files"),
+                ),
+                id="secondary",
+            ),
+            rx.button(
+                "Upload",
+                on_click=UploadState.handle_upload_secondary(  # type: ignore
+                    rx.upload_files(
+                        upload_id="secondary",
+                        on_upload_progress=UploadState.upload_progress,
+                    ),
+                ),
+                id="upload_button_secondary",
+            ),
+            rx.box(
+                rx.foreach(
+                    rx.selected_files("secondary"),
+                    lambda f: rx.text(f),
+                ),
+                id="selected_files_secondary",
+            ),
+            rx.button(
+                "Clear",
+                on_click=rx.clear_selected_files("secondary"),
+                id="clear_button_secondary",
+            ),
+            rx.vstack(
+                rx.foreach(
+                    UploadState.progress_dicts,  # type: ignore
+                    lambda d: rx.text(d.to_string()),
+                )
+            ),
+            rx.button(
+                "Cancel",
+                on_click=rx.cancel_upload("secondary"),
+                id="cancel_button_secondary",
+            ),
         )
 
     app = rx.App(state=UploadState)
@@ -94,14 +153,18 @@ def driver(upload_file: AppHarness):
         driver.quit()
 
 
+@pytest.mark.parametrize("secondary", [False, True])
 @pytest.mark.asyncio
-async def test_upload_file(tmp_path, upload_file: AppHarness, driver):
+async def test_upload_file(
+    tmp_path, upload_file: AppHarness, driver: WebDriver, secondary: bool
+):
     """Submit a file upload and check that it arrived on the backend.
 
     Args:
         tmp_path: pytest tmp_path fixture
         upload_file: harness for UploadFile app.
         driver: WebDriver instance.
+        secondary: whether to use the secondary upload form
     """
     assert upload_file.app_instance is not None
     token_input = driver.find_element(By.ID, "token")
@@ -110,9 +173,13 @@ async def test_upload_file(tmp_path, upload_file: AppHarness, driver):
     token = upload_file.poll_for_value(token_input)
     assert token is not None
 
-    upload_box = driver.find_element(By.XPATH, "//input[@type='file']")
+    suffix = "_secondary" if secondary else ""
+
+    upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[
+        1 if secondary else 0
+    ]
     assert upload_box
-    upload_button = driver.find_element(By.ID, "upload_button")
+    upload_button = driver.find_element(By.ID, f"upload_button{suffix}")
     assert upload_button
 
     exp_name = "test.txt"
@@ -132,9 +199,15 @@ async def test_upload_file(tmp_path, upload_file: AppHarness, driver):
     assert file_data[exp_name] == exp_contents
 
     # check that the selected files are displayed
-    selected_files = driver.find_element(By.ID, "selected_files")
+    selected_files = driver.find_element(By.ID, f"selected_files{suffix}")
     assert selected_files.text == exp_name
 
+    state = await upload_file.get_state(token)
+    if secondary:
+        # only the secondary form tracks progress and chain events
+        assert state.event_order.count("upload_progress") == 1
+        assert state.event_order.count("chain_event") == 1
+
 
 @pytest.mark.asyncio
 async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
@@ -186,13 +259,17 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
         assert file_data[exp_name] == exp_contents
 
 
-def test_clear_files(tmp_path, upload_file: AppHarness, driver):
+@pytest.mark.parametrize("secondary", [False, True])
+def test_clear_files(
+    tmp_path, upload_file: AppHarness, driver: WebDriver, secondary: bool
+):
     """Select then clear several file uploads and check that they are cleared.
 
     Args:
         tmp_path: pytest tmp_path fixture
         upload_file: harness for UploadFile app.
         driver: WebDriver instance.
+        secondary: whether to use the secondary upload form.
     """
     assert upload_file.app_instance is not None
     token_input = driver.find_element(By.ID, "token")
@@ -201,9 +278,13 @@ def test_clear_files(tmp_path, upload_file: AppHarness, driver):
     token = upload_file.poll_for_value(token_input)
     assert token is not None
 
-    upload_box = driver.find_element(By.XPATH, "//input[@type='file']")
+    suffix = "_secondary" if secondary else ""
+
+    upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[
+        1 if secondary else 0
+    ]
     assert upload_box
-    upload_button = driver.find_element(By.ID, "upload_button")
+    upload_button = driver.find_element(By.ID, f"upload_button{suffix}")
     assert upload_button
 
     exp_files = {
@@ -219,13 +300,56 @@ def test_clear_files(tmp_path, upload_file: AppHarness, driver):
     time.sleep(0.2)
 
     # check that the selected files are displayed
-    selected_files = driver.find_element(By.ID, "selected_files")
+    selected_files = driver.find_element(By.ID, f"selected_files{suffix}")
     assert selected_files.text == "\n".join(exp_files)
 
-    clear_button = driver.find_element(By.ID, "clear_button")
+    clear_button = driver.find_element(By.ID, f"clear_button{suffix}")
     assert clear_button
     clear_button.click()
 
     # check that the selected files are cleared
-    selected_files = driver.find_element(By.ID, "selected_files")
+    selected_files = driver.find_element(By.ID, f"selected_files{suffix}")
     assert selected_files.text == ""
+
+
+# TODO: drag and drop directory
+# https://gist.github.com/florentbr/349b1ab024ca9f3de56e6bf8af2ac69e
+
+
+@pytest.mark.asyncio
+async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDriver):
+    """Submit a large file upload and cancel it.
+
+    Args:
+        tmp_path: pytest tmp_path fixture
+        upload_file: harness for UploadFile app.
+        driver: WebDriver instance.
+    """
+    assert upload_file.app_instance is not None
+    token_input = driver.find_element(By.ID, "token")
+    assert token_input
+    # wait for the backend connection to send the token
+    token = upload_file.poll_for_value(token_input)
+    assert token is not None
+
+    upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[1]
+    upload_button = driver.find_element(By.ID, f"upload_button_secondary")
+    cancel_button = driver.find_element(By.ID, f"cancel_button_secondary")
+
+    exp_name = "large.txt"
+    target_file = tmp_path / exp_name
+    with target_file.open("wb") as f:
+        f.seek(1024 * 1024 * 256)
+        f.write(b"0")
+
+    upload_box.send_keys(str(target_file))
+    upload_button.click()
+    await asyncio.sleep(0.3)
+    cancel_button.click()
+
+    # look up the backend state and assert on progress
+    state = await upload_file.get_state(token)
+    assert state.progress_dicts
+    assert exp_name not in state._file_data
+
+    target_file.unlink()

+ 82 - 33
reflex/.templates/web/utils/state.js

@@ -32,6 +32,11 @@ let event_processing = false
 // Array holding pending events to be processed.
 const event_queue = [];
 
+// Pending upload promises, by id
+const upload_controllers = {};
+// Upload files state by id
+export const upload_files = {};
+
 /**
  * Generate a UUID (Used for session tokens).
  * Taken from: https://stackoverflow.com/questions/105034/how-do-i-create-a-guid-uuid
@@ -235,14 +240,22 @@ export const applyEvent = async (event, socket) => {
 /**
  * Send an event to the server via REST.
  * @param event The current event.
- * @param state The state with the event queue.
+ * @param socket The socket object to send the response event(s) on.
  *
  * @returns Whether the event was sent.
  */
-export const applyRestEvent = async (event) => {
+export const applyRestEvent = async (event, socket) => {
   let eventSent = false;
   if (event.handler == "uploadFiles") {
-    eventSent = await uploadFiles(event.name, event.payload.files);
+    // Start upload, but do not wait for it, which would block other events.
+    uploadFiles(
+      event.name,
+      event.payload.files,
+      event.payload.upload_id,
+      event.payload.on_upload_progress,
+      socket
+    );
+    return false;
   }
   return eventSent;
 };
@@ -283,7 +296,7 @@ export const processEvent = async (
   let eventSent = false
   // Process events with handlers via REST and all others via websockets.
   if (event.handler) {
-    eventSent = await applyRestEvent(event);
+    eventSent = await applyRestEvent(event, socket);
   } else {
     eventSent = await applyEvent(event, socket);
   }
@@ -347,50 +360,86 @@ export const connect = async (
  *
  * @param state The state to apply the delta to.
  * @param handler The handler to use.
+ * @param upload_id The upload id to use.
+ * @param on_upload_progress The function to call on upload progress.
+ * @param socket the websocket connection
  *
- * @returns Whether the files were uploaded.
+ * @returns The response from posting to the UPLOADURL endpoint.
  */
-export const uploadFiles = async (handler, files) => {
+export const uploadFiles = async (handler, files, upload_id, on_upload_progress, socket) => {
   // return if there's no file to upload
   if (files.length == 0) {
     return false;
   }
 
-  const headers = {
-    "Content-Type": files[0].type,
-  };
+  if (upload_controllers[upload_id]) {
+    console.log("Upload already in progress for ", upload_id)
+    return false;
+  }
+
+  let resp_idx = 0;
+  const eventHandler = (progressEvent) => {
+    // handle any delta / event streamed from the upload event handler
+    const chunks = progressEvent.event.target.responseText.trim().split("\n")
+    chunks.slice(resp_idx).map((chunk) => {
+      try {
+        socket._callbacks.$event.map((f) => {
+          f(chunk)
+        })
+        resp_idx += 1
+      } catch (e) {
+        console.log("Error parsing chunk", chunk, e)
+        return
+      }
+    })
+  }
+
+  const controller = new AbortController()
+  const config = {
+    headers: {
+      "Reflex-Client-Token": getToken(),
+      "Reflex-Event-Handler": handler,
+    },
+    signal: controller.signal,
+    onDownloadProgress: eventHandler,
+  }
+  if (on_upload_progress) {
+    config["onUploadProgress"] = on_upload_progress
+  }
   const formdata = new FormData();
 
   // Add the token and handler to the file name.
-  for (let i = 0; i < files.length; i++) {
+  files.forEach((file) => {
     formdata.append(
       "files",
-      files[i],
-      getToken() + ":" + handler + ":" + files[i].name
+      file,
+      file.path || file.name
     );
-  }
+  })
 
   // Send the file to the server.
-  await axios.post(UPLOADURL, formdata, headers)
-    .then(() => { return true; })
-    .catch(
-      error => {
-        if (error.response) {
-          // The request was made and the server responded with a status code
-          // that falls out of the range of 2xx
-          console.log(error.response.data);
-        } else if (error.request) {
-          // The request was made but no response was received
-          // `error.request` is an instance of XMLHttpRequest in the browser and an instance of
-          // http.ClientRequest in node.js
-          console.log(error.request);
-        } else {
-          // Something happened in setting up the request that triggered an Error
-          console.log(error.message);
-        }
-        return false;
-      }
-    )
+  upload_controllers[upload_id] = controller
+
+  try {
+    return await axios.post(UPLOADURL, formdata, config)
+  } catch (error) {
+    if (error.response) {
+      // The request was made and the server responded with a status code
+      // that falls out of the range of 2xx
+      console.log(error.response.data);
+    } else if (error.request) {
+      // The request was made but no response was received
+      // `error.request` is an instance of XMLHttpRequest in the browser and an instance of
+      // http.ClientRequest in node.js
+      console.log(error.request);
+    } else {
+      // Something happened in setting up the request that triggered an Error
+      console.log(error.message);
+    }
+    return false;
+  } finally {
+    delete upload_controllers[upload_id]
+  }
 };
 
 /**

+ 1 - 0
reflex/__init__.py

@@ -229,6 +229,7 @@ _ALL_COMPONENTS = [
 
 _ALL_COMPONENTS += [to_snake_case(component) for component in _ALL_COMPONENTS]
 _ALL_COMPONENTS += [
+    "cancel_upload",
     "components",
     "color_mode_cond",
     "desktop_only",

+ 10 - 0
reflex/__init__.pyi

@@ -58,6 +58,9 @@ from reflex.components import ConnectionModal as ConnectionModal
 from reflex.components import Container as Container
 from reflex.components import DataTable as DataTable
 from reflex.components import DataEditor as DataEditor
+from reflex.components import DataEditorTheme as DataEditorTheme
+from reflex.components import DatePicker as DatePicker
+from reflex.components import DateTimePicker as DateTimePicker
 from reflex.components import DebounceInput as DebounceInput
 from reflex.components import Divider as Divider
 from reflex.components import Drawer as Drawer
@@ -265,6 +268,9 @@ from reflex.components import connection_modal as connection_modal
 from reflex.components import container as container
 from reflex.components import data_table as data_table
 from reflex.components import data_editor as data_editor
+from reflex.components import data_editor_theme as data_editor_theme
+from reflex.components import date_picker as date_picker
+from reflex.components import date_time_picker as date_time_picker
 from reflex.components import debounce_input as debounce_input
 from reflex.components import divider as divider
 from reflex.components import drawer as drawer
@@ -421,7 +427,9 @@ from reflex.components import visually_hidden as visually_hidden
 from reflex.components import vstack as vstack
 from reflex.components import wrap as wrap
 from reflex.components import wrap_item as wrap_item
+from reflex.components import cancel_upload as cancel_upload
 from reflex import components as components
+from reflex.components import color_mode_cond as color_mode_cond
 from reflex.components import desktop_only as desktop_only
 from reflex.components import mobile_only as mobile_only
 from reflex.components import tablet_only as tablet_only
@@ -429,7 +437,9 @@ from reflex.components import mobile_and_tablet as mobile_and_tablet
 from reflex.components import tablet_and_desktop as tablet_and_desktop
 from reflex.components import selected_files as selected_files
 from reflex.components import clear_selected_files as clear_selected_files
+from reflex.components import EditorButtonList as EditorButtonList
 from reflex.components import EditorOptions as EditorOptions
+from reflex.components import NoSSRComponent as NoSSRComponent
 from reflex.components.component import memo as memo
 from reflex.components.graphing import recharts as recharts
 from reflex import config as config

+ 76 - 45
reflex/app.py

@@ -3,7 +3,7 @@ from __future__ import annotations
 
 import asyncio
 import contextlib
-import inspect
+import functools
 import os
 from multiprocessing.pool import ThreadPool
 from typing import (
@@ -17,10 +17,13 @@ from typing import (
     Set,
     Type,
     Union,
+    get_args,
+    get_type_hints,
 )
 
-from fastapi import FastAPI, UploadFile
+from fastapi import FastAPI, HTTPException, Request, UploadFile
 from fastapi.middleware import cors
+from fastapi.responses import StreamingResponse
 from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
 from socketio import ASGIApp, AsyncNamespace, AsyncServer
 from starlette_admin.contrib.sqla.admin import Admin
@@ -880,62 +883,90 @@ def upload(app: App):
         The upload function.
     """
 
-    async def upload_file(files: List[UploadFile]):
+    async def upload_file(request: Request, files: List[UploadFile]):
         """Upload a file.
 
         Args:
+            request: The FastAPI request object.
             files: The file(s) to upload.
 
+        Returns:
+            StreamingResponse yielding newline-delimited JSON of StateUpdate
+            emitted by the upload handler.
+
         Raises:
             ValueError: if there are no args with supported annotation.
+            TypeError: if a background task is used as the handler.
+            HTTPException: when the request does not include token / handler headers.
         """
-        assert files[0].filename is not None
-        token, handler = files[0].filename.split(":")[:2]
-        for file in files:
-            assert file.filename is not None
-            file.filename = file.filename.split(":")[-1]
+        token = request.headers.get("reflex-client-token")
+        handler = request.headers.get("reflex-event-handler")
+
+        if not token or not handler:
+            raise HTTPException(
+                status_code=400,
+                detail="Missing reflex-client-token or reflex-event-handler header.",
+            )
 
         # Get the state for the session.
-        async with app.state_manager.modify_state(token) as state:
-            # get the current session ID
-            sid = state.router.session.session_id
-            # get the current state(parent state/substate)
-            path = handler.split(".")[:-1]
-            current_state = state.get_substate(path)
-            handler_upload_param = ()
-
-            # get handler function
-            func = getattr(current_state, handler.split(".")[-1])
-
-            # check if there exists any handler args with annotation, List[UploadFile]
-            for k, v in inspect.getfullargspec(
-                func.fn if isinstance(func, EventHandler) else func
-            ).annotations.items():
-                if types.is_generic_alias(v) and types._issubclass(
-                    v.__args__[0], UploadFile
-                ):
-                    handler_upload_param = (k, v)
-                    break
-
-            if not handler_upload_param:
-                raise ValueError(
-                    f"`{handler}` handler should have a parameter annotated as List["
-                    f"rx.UploadFile]"
+        state = await app.state_manager.get_state(token)
+
+        # get the current session ID
+        # get the current state(parent state/substate)
+        path = handler.split(".")[:-1]
+        current_state = state.get_substate(path)
+        handler_upload_param = ()
+
+        # get handler function
+        func = getattr(type(current_state), handler.split(".")[-1])
+
+        # check if there exists any handler args with annotation, List[UploadFile]
+        if isinstance(func, EventHandler):
+            if func.is_background:
+                raise TypeError(
+                    f"@rx.background is not supported for upload handler `{handler}`.",
                 )
+            func = func.fn
+        if isinstance(func, functools.partial):
+            func = func.func
+        for k, v in get_type_hints(func).items():
+            if types.is_generic_alias(v) and types._issubclass(
+                get_args(v)[0],
+                UploadFile,
+            ):
+                handler_upload_param = (k, v)
+                break
 
-            event = Event(
-                token=token,
-                name=handler,
-                payload={handler_upload_param[0]: files},
+        if not handler_upload_param:
+            raise ValueError(
+                f"`{handler}` handler should have a parameter annotated as "
+                "List[rx.UploadFile]"
             )
-            async for update in state._process(event):
-                # Postprocess the event.
-                update = await app.postprocess(state, event, update)
-                # Send update to client
-                await app.event_namespace.emit_update(  # type: ignore
-                    update=update,
-                    sid=sid,
-                )
+
+        event = Event(
+            token=token,
+            name=handler,
+            payload={handler_upload_param[0]: files},
+        )
+
+        async def _ndjson_updates():
+            """Process the upload event, generating ndjson updates.
+
+            Yields:
+                Each state update as JSON followed by a new line.
+            """
+            # Process the event.
+            async with app.state_manager.modify_state(token) as state:
+                async for update in state._process(event):
+                    # Postprocess the event.
+                    update = await app.postprocess(state, event, update)
+                    yield update.json() + "\n"
+
+        # Stream updates to client
+        return StreamingResponse(
+            _ndjson_updates(),
+            media_type="application/x-ndjson",
+        )
 
     return upload_file
 

+ 8 - 2
reflex/components/forms/__init__.py

@@ -46,12 +46,18 @@ from .select import Option, Select
 from .slider import Slider, SliderFilledTrack, SliderMark, SliderThumb, SliderTrack
 from .switch import Switch
 from .textarea import TextArea
-from .upload import Upload, clear_selected_files, selected_files
+from .upload import (
+    Upload,
+    cancel_upload,
+    clear_selected_files,
+    selected_files,
+)
 
 helpers = [
     "color_mode_cond",
-    "selected_files",
+    "cancel_upload",
     "clear_selected_files",
+    "selected_files",
 ]
 
 __all__ = [f for f in dir() if f[0].isupper()] + helpers  # type: ignore

+ 80 - 17
reflex/components/forms/upload.py

@@ -3,26 +3,75 @@ from __future__ import annotations
 
 from typing import Any, Dict, List, Optional, Union
 
+from reflex import constants
 from reflex.components.component import Component
 from reflex.components.forms.input import Input
 from reflex.components.layout.box import Box
-from reflex.constants import EventTriggers
-from reflex.event import EventChain
-from reflex.vars import BaseVar, Var
+from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script
+from reflex.utils import imports
+from reflex.vars import BaseVar, CallableVar, ImportVar, Var
 
-files_state: str = "const [files, setFiles] = useState([]);"
-upload_file: BaseVar = BaseVar(
-    _var_name="e => setFiles((files) => e)", _var_type=EventChain
-)
+DEFAULT_UPLOAD_ID: str = "default"
 
-# Use this var along with the Upload component to render the list of selected files.
-selected_files: BaseVar = BaseVar(
-    _var_name="files.map((f) => f.name)", _var_type=List[str]
-)
 
-clear_selected_files: BaseVar = BaseVar(
-    _var_name="_e => setFiles((files) => [])", _var_type=EventChain
-)
+@CallableVar
+def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
+    """Get the file upload drop trigger.
+
+    This var is passed to the dropzone component to update the file list when a
+    drop occurs.
+
+    Args:
+        id_: The id of the upload to get the drop trigger for.
+
+    Returns:
+        A var referencing the file upload drop trigger.
+    """
+    return BaseVar(
+        _var_name=f"e => upload_files.{id_}[1]((files) => e)",
+        _var_type=EventChain,
+    )
+
+
+@CallableVar
+def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
+    """Get the list of selected files.
+
+    Args:
+        id_: The id of the upload to get the selected files for.
+
+    Returns:
+        A var referencing the list of selected file paths.
+    """
+    return BaseVar(
+        _var_name=f"(upload_files.{id_} ? upload_files.{id_}[0]?.map((f) => (f.path || f.name)) : [])",
+        _var_type=List[str],
+    )
+
+
+@CallableEventSpec
+def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec:
+    """Clear the list of selected files.
+
+    Args:
+        id_: The id of the upload to clear.
+
+    Returns:
+        An event spec that clears the list of selected files when triggered.
+    """
+    return call_script(f"upload_files.{id_}[1]((files) => [])")
+
+
+def cancel_upload(upload_id: str) -> EventSpec:
+    """Cancel an upload.
+
+    Args:
+        upload_id: The id of the upload to cancel.
+
+    Returns:
+        An event spec that cancels the upload when triggered.
+    """
+    return call_script(f"upload_controllers[{upload_id!r}]?.abort()")
 
 
 class Upload(Component):
@@ -94,7 +143,10 @@ class Upload(Component):
         zone.special_props = {BaseVar(_var_name="{...getRootProps()}", _var_type=None)}
 
         # Create the component.
-        return super().create(zone, on_drop=upload_file, **upload_props)
+        upload_props["id"] = props.get("id", DEFAULT_UPLOAD_ID)
+        return super().create(
+            zone, on_drop=upload_file(upload_props["id"]), **upload_props
+        )
 
     def get_event_triggers(self) -> dict[str, Union[Var, Any]]:
         """Get the event triggers that pass the component's value to the handler.
@@ -104,7 +156,7 @@ class Upload(Component):
         """
         return {
             **super().get_event_triggers(),
-            EventTriggers.ON_DROP: lambda e0: [e0],
+            constants.EventTriggers.ON_DROP: lambda e0: [e0],
         }
 
     def _render(self):
@@ -113,4 +165,15 @@ class Upload(Component):
         return out
 
     def _get_hooks(self) -> str | None:
-        return (super()._get_hooks() or "") + files_state
+        return (
+            (super()._get_hooks() or "")
+            + f"""
+        upload_files.{self.id or DEFAULT_UPLOAD_ID} = useState([]);
+        """
+        )
+
+    def _get_imports(self) -> imports.ImportDict:
+        return {
+            **super()._get_imports(),
+            f"/{constants.Dirs.STATE_PATH}": {ImportVar(tag="upload_files")},
+        }

+ 13 - 7
reflex/components/forms/upload.pyi

@@ -8,17 +8,23 @@ from reflex.vars import Var, BaseVar, ComputedVar
 from reflex.event import EventChain, EventHandler, EventSpec
 from reflex.style import Style
 from typing import Any, Dict, List, Optional, Union
+from reflex import constants
 from reflex.components.component import Component
 from reflex.components.forms.input import Input
 from reflex.components.layout.box import Box
-from reflex.constants import EventTriggers
-from reflex.event import EventChain
-from reflex.vars import BaseVar, Var
+from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script
+from reflex.utils import imports
+from reflex.vars import BaseVar, CallableVar, ImportVar, Var
 
-files_state: str
-upload_file: BaseVar
-selected_files: BaseVar
-clear_selected_files: BaseVar
+DEFAULT_UPLOAD_ID: str
+
+@CallableVar
+def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: ...
+@CallableVar
+def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: ...
+@CallableEventSpec
+def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec: ...
+def cancel_upload(upload_id: str) -> EventSpec: ...
 
 class Upload(Component):
     @overload

+ 115 - 8
reflex/event.py

@@ -174,13 +174,7 @@ class EventHandler(EventActionsMixin):
         for arg in args:
             # Special case for file uploads.
             if isinstance(arg, FileUpload):
-                return EventSpec(
-                    handler=self,
-                    client_handler_name="uploadFiles",
-                    # `files` is defined in the Upload component's _use_hooks
-                    args=((Var.create_safe("files"), Var.create_safe("files")),),
-                    event_actions=self.event_actions.copy(),
-                )
+                return arg.as_event_spec(handler=self)
 
             # Otherwise, convert to JSON.
             try:
@@ -236,6 +230,50 @@ class EventSpec(EventActionsMixin):
         )
 
 
+class CallableEventSpec(EventSpec):
+    """Decorate an EventSpec-returning function to act as both a EventSpec and a function.
+
+    This is used as a compatibility shim for replacing EventSpec objects in the
+    API with functions that return a family of EventSpec.
+    """
+
+    fn: Optional[Callable[..., EventSpec]] = None
+
+    def __init__(self, fn: Callable[..., EventSpec] | None = None, **kwargs):
+        """Initialize a CallableEventSpec.
+
+        Args:
+            fn: The function to decorate.
+            **kwargs: The kwargs to pass to pydantic initializer
+        """
+        if fn is not None:
+            default_event_spec = fn()
+            super().__init__(
+                fn=fn,  # type: ignore
+                **default_event_spec.dict(),
+                **kwargs,
+            )
+        else:
+            super().__init__(**kwargs)
+
+    def __call__(self, *args, **kwargs) -> EventSpec:
+        """Call the decorated function.
+
+        Args:
+            *args: The args to pass to the function.
+            **kwargs: The kwargs to pass to the function.
+
+        Returns:
+            The EventSpec returned from calling the function.
+
+        Raises:
+            TypeError: If the CallableEventSpec has no associated function.
+        """
+        if self.fn is None:
+            raise TypeError("CallableEventSpec has no associated function.")
+        return self.fn(*args, **kwargs)
+
+
 class EventChain(EventActionsMixin):
     """Container for a chain of events that will be executed in order."""
 
@@ -267,7 +305,76 @@ class FrontendEvent(Base):
 class FileUpload(Base):
     """Class to represent a file upload."""
 
-    pass
+    upload_id: Optional[str] = None
+    on_upload_progress: Optional[Union[EventHandler, Callable]] = None
+
+    @staticmethod
+    def on_upload_progress_args_spec(_prog: dict[str, int | float | bool]):
+        """Args spec for on_upload_progress event handler.
+
+        Returns:
+            The arg mapping passed to backend event handler
+        """
+        return [_prog]
+
+    def as_event_spec(self, handler: EventHandler) -> EventSpec:
+        """Get the EventSpec for the file upload.
+
+        Args:
+            handler: The event handler.
+
+        Returns:
+            The event spec for the handler.
+
+        Raises:
+            ValueError: If the on_upload_progress is not a valid event handler.
+        """
+        from reflex.components.forms.upload import DEFAULT_UPLOAD_ID
+
+        upload_id = self.upload_id or DEFAULT_UPLOAD_ID
+
+        spec_args = [
+            # `upload_files` is defined in state.js and assigned in the Upload component's _use_hooks
+            (Var.create_safe("files"), Var.create_safe(f"upload_files.{upload_id}[0]")),
+            (
+                Var.create_safe("upload_id"),
+                Var.create_safe(upload_id, _var_is_string=True),
+            ),
+        ]
+        if self.on_upload_progress is not None:
+            on_upload_progress = self.on_upload_progress
+            if isinstance(on_upload_progress, EventHandler):
+                events = [
+                    call_event_handler(
+                        on_upload_progress,
+                        self.on_upload_progress_args_spec,
+                    ),
+                ]
+            elif isinstance(on_upload_progress, Callable):
+                # Call the lambda to get the event chain.
+                events = call_event_fn(on_upload_progress, self.on_upload_progress_args_spec)  # type: ignore
+            else:
+                raise ValueError(f"{on_upload_progress} is not a valid event handler.")
+            on_upload_progress_chain = EventChain(
+                events=events,
+                args_spec=self.on_upload_progress_args_spec,
+            )
+            formatted_chain = str(format.format_prop(on_upload_progress_chain))
+            spec_args.append(
+                (
+                    Var.create_safe("on_upload_progress"),
+                    BaseVar(
+                        _var_name=formatted_chain.strip("{}"),
+                        _var_type=EventChain,
+                    ),
+                ),
+            )
+        return EventSpec(
+            handler=handler,
+            client_handler_name="uploadFiles",
+            args=tuple(spec_args),
+            event_actions=handler.event_actions.copy(),
+        )
 
 
 # Alias for rx.upload_files

+ 30 - 0
reflex/vars.py

@@ -1590,3 +1590,33 @@ class NoRenderImportVar(ImportVar):
     """A import that doesn't need to be rendered."""
 
     render: Optional[bool] = False
+
+
+class CallableVar(BaseVar):
+    """Decorate a Var-returning function to act as both a Var and a function.
+
+    This is used as a compatibility shim for replacing Var objects in the
+    API with functions that return a family of Var.
+    """
+
+    def __init__(self, fn: Callable[..., BaseVar]):
+        """Initialize a CallableVar.
+
+        Args:
+            fn: The function to decorate (must return Var)
+        """
+        self.fn = fn
+        default_var = fn()
+        super().__init__(**dataclasses.asdict(default_var))
+
+    def __call__(self, *args, **kwargs) -> BaseVar:
+        """Call the decorated function.
+
+        Args:
+            *args: The args to pass to the function.
+            **kwargs: The kwargs to pass to the function.
+
+        Returns:
+            The Var returned from calling the function.
+        """
+        return self.fn(*args, **kwargs)

+ 4 - 0
reflex/vars.pyi

@@ -137,3 +137,7 @@ class NoRenderImportVar(ImportVar):
     """A import that doesn't need to be rendered."""
 
 def get_local_storage(key: Optional[Union[Var, str]] = ...) -> BaseVar: ...
+
+class CallableVar(BaseVar):
+    def __init__(self, fn: Callable[..., BaseVar]): ...
+    def __call__(self, *args, **kwargs) -> BaseVar: ...

+ 6 - 2
tests/components/forms/test_uploads.py

@@ -52,8 +52,10 @@ def test_upload_component_render(upload_component):
     # upload
     assert upload["name"] == "ReactDropzone"
     assert upload["props"] == [
+        "id={`default`}",
         "multiple={true}",
-        "onDrop={e => setFiles((files) => e)}",
+        "onDrop={e => upload_files.default[1]((files) => e)}",
+        "ref={ref_default}",
     ]
     assert upload["args"] == ("getRootProps", "getInputProps")
 
@@ -89,8 +91,10 @@ def test_upload_component_with_props_render(upload_component_with_props):
     upload = upload_component_with_props.render()
 
     assert upload["props"] == [
+        "id={`default`}",
         "maxFiles={2}",
         "multiple={true}",
         "noDrag={true}",
-        "onDrop={e => setFiles((files) => e)}",
+        "onDrop={e => upload_files.default[1]((files) => e)}",
+        "ref={ref_default}",
     ]

+ 30 - 30
tests/states/upload.py

@@ -49,16 +49,7 @@ class FileUploadState(rx.State):
         Args:
             files: The uploaded files.
         """
-        for file in files:
-            upload_data = await file.read()
-            outfile = f"{self._tmp_path}/{file.filename}"
-
-            # Save the file.
-            with open(outfile, "wb") as file_object:
-                file_object.write(upload_data)
-
-            # Update the img var.
-            self.img_list.append(file.filename)
+        pass
 
     async def multi_handle_upload(self, files: List[rx.UploadFile]):
         """Handle the upload of a file.
@@ -78,6 +69,15 @@ class FileUploadState(rx.State):
             assert file.filename is not None
             self.img_list.append(file.filename)
 
+    @rx.background
+    async def bg_upload(self, files: List[rx.UploadFile]):
+        """Background task cannot be upload handler.
+
+        Args:
+            files: The uploaded files.
+        """
+        pass
+
 
 class FileStateBase1(rx.State):
     """The base state for a child FileUploadState."""
@@ -97,16 +97,7 @@ class ChildFileUploadState(FileStateBase1):
         Args:
             files: The uploaded files.
         """
-        for file in files:
-            upload_data = await file.read()
-            outfile = f"{self._tmp_path}/{file.filename}"
-
-            # Save the file.
-            with open(outfile, "wb") as file_object:
-                file_object.write(upload_data)
-
-            # Update the img var.
-            self.img_list.append(file.filename)
+        pass
 
     async def multi_handle_upload(self, files: List[rx.UploadFile]):
         """Handle the upload of a file.
@@ -126,6 +117,15 @@ class ChildFileUploadState(FileStateBase1):
             assert file.filename is not None
             self.img_list.append(file.filename)
 
+    @rx.background
+    async def bg_upload(self, files: List[rx.UploadFile]):
+        """Background task cannot be upload handler.
+
+        Args:
+            files: The uploaded files.
+        """
+        pass
+
 
 class FileStateBase2(FileStateBase1):
     """The parent state for a grandchild FileUploadState."""
@@ -145,16 +145,7 @@ class GrandChildFileUploadState(FileStateBase2):
         Args:
             files: The uploaded files.
         """
-        for file in files:
-            upload_data = await file.read()
-            outfile = f"{self._tmp_path}/{file.filename}"
-
-            # Save the file.
-            with open(outfile, "wb") as file_object:
-                file_object.write(upload_data)
-
-            # Update the img var.
-            self.img_list.append(file.filename)
+        pass
 
     async def multi_handle_upload(self, files: List[rx.UploadFile]):
         """Handle the upload of a file.
@@ -173,3 +164,12 @@ class GrandChildFileUploadState(FileStateBase2):
             # Update the img var.
             assert file.filename is not None
             self.img_list.append(file.filename)
+
+    @rx.background
+    async def bg_upload(self, files: List[rx.UploadFile]):
+        """Background task cannot be upload handler.
+
+        Args:
+            files: The uploaded files.
+        """
+        pass

+ 56 - 25
tests/test_app.py

@@ -746,23 +746,28 @@ async def test_upload_file(tmp_path, state, delta, token: str):
     bio.write(data)
 
     state_name = state.get_full_name().partition(".")[2] or state.get_name()
-    handler_prefix = f"{token}:{state_name}"
+    request_mock = unittest.mock.Mock()
+    request_mock.headers = {
+        "reflex-client-token": token,
+        "reflex-event-handler": f"{state_name}.multi_handle_upload",
+    }
 
     file1 = UploadFile(
-        filename=f"{handler_prefix}.multi_handle_upload:True:image1.jpg",
+        filename=f"image1.jpg",
         file=bio,
     )
     file2 = UploadFile(
-        filename=f"{handler_prefix}.multi_handle_upload:True:image2.jpg",
+        filename=f"image2.jpg",
         file=bio,
     )
     upload_fn = upload(app)
-    await upload_fn([file1, file2])
-    state_update = StateUpdate(delta=delta, events=[], final=True)
+    streaming_response = await upload_fn(request_mock, [file1, file2])
+    async for state_update in streaming_response.body_iterator:
+        assert (
+            state_update
+            == StateUpdate(delta=delta, events=[], final=True).json() + "\n"
+        )
 
-    app.event_namespace.emit.assert_called_with(  # type: ignore
-        "event", state_update.json(), to=current_state.router.session.session_id
-    )
     current_state = await app.state_manager.get_state(token)
     state_dict = current_state.dict()
     for substate in state.get_full_name().split(".")[1:]:
@@ -789,30 +794,20 @@ async def test_upload_file_without_annotation(state, tmp_path, token):
         tmp_path: Temporary path.
         token: a Token.
     """
-    data = b"This is binary data"
-
-    # Create a binary IO object and write data to it
-    bio = io.BytesIO()
-    bio.write(data)
-
     state._tmp_path = tmp_path
     # The App state must be the "root" of the state tree
     app = App(state=state if state is FileUploadState else FileStateBase1)
 
     state_name = state.get_full_name().partition(".")[2] or state.get_name()
-    handler_prefix = f"{token}:{state_name}"
-
-    file1 = UploadFile(
-        filename=f"{handler_prefix}.handle_upload2:True:image1.jpg",
-        file=bio,
-    )
-    file2 = UploadFile(
-        filename=f"{handler_prefix}.handle_upload2:True:image2.jpg",
-        file=bio,
-    )
+    request_mock = unittest.mock.Mock()
+    request_mock.headers = {
+        "reflex-client-token": token,
+        "reflex-event-handler": f"{state_name}.handle_upload2",
+    }
+    file_mock = unittest.mock.Mock(filename="image1.jpg")
     fn = upload(app)
     with pytest.raises(ValueError) as err:
-        await fn([file1, file2])
+        await fn(request_mock, [file_mock])
     assert (
         err.value.args[0]
         == f"`{state_name}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
@@ -822,6 +817,42 @@ async def test_upload_file_without_annotation(state, tmp_path, token):
         await app.state_manager.redis.close()
 
 
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "state",
+    [FileUploadState, ChildFileUploadState, GrandChildFileUploadState],
+)
+async def test_upload_file_background(state, tmp_path, token):
+    """Test that an error is thrown handler is a background task.
+
+    Args:
+        state: The state class.
+        tmp_path: Temporary path.
+        token: a Token.
+    """
+    state._tmp_path = tmp_path
+    # The App state must be the "root" of the state tree
+    app = App(state=state if state is FileUploadState else FileStateBase1)
+
+    state_name = state.get_full_name().partition(".")[2] or state.get_name()
+    request_mock = unittest.mock.Mock()
+    request_mock.headers = {
+        "reflex-client-token": token,
+        "reflex-event-handler": f"{state_name}.bg_upload",
+    }
+    file_mock = unittest.mock.Mock(filename="image1.jpg")
+    fn = upload(app)
+    with pytest.raises(TypeError) as err:
+        await fn(request_mock, [file_mock])
+    assert (
+        err.value.args[0]
+        == f"@rx.background is not supported for upload handler `{state_name}.bg_upload`."
+    )
+
+    if isinstance(app.state_manager, StateManagerRedis):
+        await app.state_manager.redis.close()
+
+
 class DynamicState(State):
     """State class for testing dynamic route var.