Răsfoiți Sursa

[REF-1365] Radix + Tailwind Compatibility (and allow tailwind to be completely disabled) (#2246)

Masen Furer 1 an în urmă
părinte
comite
accaf6dc52

+ 108 - 0
integration/test_tailwind.py

@@ -0,0 +1,108 @@
+"""Test case for disabling tailwind in the config."""
+
+import functools
+from typing import Generator
+
+import pytest
+from selenium.webdriver.common.by import By
+
+from reflex.testing import AppHarness
+
+PARAGRAPH_TEXT = "Tailwind Is Cool"
+PARAGRAPH_CLASS_NAME = "text-red-500"
+GLOBAL_PARAGRAPH_COLOR = "rgba(0, 0, 242, 1)"
+
+
+def TailwindApp(
+    tailwind_disabled: bool = False,
+    paragraph_text: str = PARAGRAPH_TEXT,
+    paragraph_class_name: str = PARAGRAPH_CLASS_NAME,
+    global_paragraph_color: str = GLOBAL_PARAGRAPH_COLOR,
+):
+    """App with tailwind optionally disabled.
+
+    Args:
+        tailwind_disabled: Whether tailwind is disabled for the app.
+        paragraph_text: Text for the paragraph.
+        paragraph_class_name: Tailwind class_name for the paragraph.
+        global_paragraph_color: Color for the paragraph set in global app styles.
+    """
+    import reflex as rx
+    import reflex.components.radix.themes as rdxt
+
+    def index():
+        return rx.el.div(
+            rx.text(paragraph_text, class_name=paragraph_class_name),
+            rx.el.p(paragraph_text, class_name=paragraph_class_name),
+            rdxt.text(paragraph_text, as_="p", class_name=paragraph_class_name),
+            id="p-content",
+        )
+
+    app = rx.App(style={"p": {"color": global_paragraph_color}})
+    app.add_page(index)
+    if tailwind_disabled:
+        config = rx.config.get_config()
+        config.tailwind = None
+
+
+@pytest.fixture(params=[False, True], ids=["tailwind_enabled", "tailwind_disabled"])
+def tailwind_disabled(request) -> bool:
+    """Tailwind disabled fixture.
+
+    Args:
+        request: pytest request fixture.
+
+    Returns:
+        True if tailwind is disabled, False otherwise.
+    """
+    return request.param
+
+
+@pytest.fixture()
+def tailwind_app(tmp_path, tailwind_disabled) -> Generator[AppHarness, None, None]:
+    """Start TailwindApp app at tmp_path via AppHarness with tailwind disabled via config.
+
+    Args:
+        tmp_path: pytest tmp_path fixture
+        tailwind_disabled: Whether tailwind is disabled for the app.
+
+    Yields:
+        running AppHarness instance
+    """
+    with AppHarness.create(
+        root=tmp_path,
+        app_source=functools.partial(TailwindApp, tailwind_disabled=tailwind_disabled),  # type: ignore
+        app_name="tailwind_disabled_app" if tailwind_disabled else "tailwind_app",
+    ) as harness:
+        yield harness
+
+
+def test_tailwind_app(tailwind_app: AppHarness, tailwind_disabled: bool):
+    """Test that the app can compile without tailwind.
+
+    Args:
+        tailwind_app: AppHarness instance.
+        tailwind_disabled: Whether tailwind is disabled for the app.
+    """
+    assert tailwind_app.app_instance is not None
+    assert tailwind_app.backend is not None
+
+    driver = tailwind_app.frontend()
+
+    # Assert the app is stateless.
+    with pytest.raises(ValueError) as errctx:
+        _ = tailwind_app.app_instance.state_manager
+    errctx.match("The state manager has not been initialized.")
+
+    # Assert content is visible (and not some error)
+    content = driver.find_element(By.ID, "p-content")
+    paragraphs = content.find_elements(By.TAG_NAME, "p")
+    assert len(paragraphs) == 3
+    for p in paragraphs:
+        assert tailwind_app.poll_for_content(p, exp_not_equal="") == PARAGRAPH_TEXT
+        if tailwind_disabled:
+            # expect "blue" color from global stylesheet, not "text-red-500" from tailwind utility class
+            assert p.value_of_css_property("color") == GLOBAL_PARAGRAPH_COLOR
+        else:
+            # expect "text-red-500" from tailwind utility class
+            assert p.value_of_css_property("color") == "rgba(239, 68, 68, 1)"

+ 4 - 1
reflex/.templates/jinja/web/pages/_app.js.jinja2

@@ -1,9 +1,12 @@
 {% extends "web/pages/base_page.js.jinja2" %}
 
+{% block early_imports %}
+import '/styles/styles.css'
+{% endblock %}
+
 {% block declaration %}
 import { EventLoopProvider, StateProvider } from "/utils/context.js";
 import { ThemeProvider } from 'next-themes'
-import '/styles/styles.css'
 
 
 {% for custom_code in custom_codes %}

+ 3 - 1
reflex/.templates/jinja/web/pages/base_page.js.jinja2

@@ -1,7 +1,9 @@
 {% import 'web/pages/utils.js.jinja2' as utils %}
-
 /** @jsxImportSource @emotion/react */
 
+{% block early_imports %}
+{% endblock %}
+
 {%- block imports_libs %}
 
 {% for module in imports%}

+ 3 - 1
reflex/app.py

@@ -749,7 +749,9 @@ class App(Base):
                 config.tailwind["content"] = config.tailwind.get(
                     "content", constants.Tailwind.CONTENT
                 )
-            submit_work(compiler.compile_tailwind, config.tailwind)
+                submit_work(compiler.compile_tailwind, config.tailwind)
+            else:
+                submit_work(compiler.remove_tailwind_from_postcss)
 
             # Get imports from AppWrap components.
             all_imports.update(app_root.get_imports())

+ 19 - 0
reflex/compiler/compiler.py

@@ -432,6 +432,25 @@ def compile_tailwind(
     return output_path, code
 
 
+def remove_tailwind_from_postcss() -> tuple[str, str]:
+    """If tailwind is not to be used, remove it from postcss.config.js.
+
+    Returns:
+        The path and code of the compiled postcss.config.js.
+    """
+    # Get the path for the output file.
+    output_path = constants.Dirs.POSTCSS_JS
+
+    code = [
+        line
+        for line in Path(output_path).read_text().splitlines(keepends=True)
+        if "tailwindcss: " not in line
+    ]
+
+    # Compile the config.
+    return output_path, "".join(code)
+
+
 def purge_web_pages_dir():
     """Empty out .web directory."""
     utils.empty_dir(constants.Dirs.WEB_PAGES, keep_files=["_app.js"])

+ 2 - 0
reflex/constants/base.py

@@ -43,6 +43,8 @@ class Dirs(SimpleNamespace):
     ENV_JSON = os.path.join(WEB, "env.json")
     # The reflex json file.
     REFLEX_JSON = os.path.join(WEB, "reflex.json")
+    # The path to postcss.config.js
+    POSTCSS_JS = os.path.join(WEB, "postcss.config.js")
 
 
 class Reflex(SimpleNamespace):

+ 1 - 0
reflex/constants/installer.py

@@ -102,6 +102,7 @@ class PackageJson(SimpleNamespace):
     PATH = os.path.join(Dirs.WEB, "package.json")
 
     DEPENDENCIES = {
+        "@emotion/react": "11.11.1",
         "axios": "1.4.0",
         "json5": "2.2.3",
         "next": "14.0.1",

+ 56 - 8
reflex/testing.py

@@ -4,6 +4,7 @@ from __future__ import annotations
 import asyncio
 import contextlib
 import dataclasses
+import functools
 import inspect
 import os
 import pathlib
@@ -20,6 +21,7 @@ import types
 from http.server import SimpleHTTPRequestHandler
 from typing import (
     TYPE_CHECKING,
+    Any,
     AsyncIterator,
     Callable,
     Coroutine,
@@ -135,6 +137,8 @@ class AppHarness:
         if app_name is None:
             if app_source is None:
                 app_name = root.name.lower()
+            elif isinstance(app_source, functools.partial):
+                app_name = app_source.func.__name__.lower()
             else:
                 app_name = app_source.__name__.lower()
         return cls(
@@ -144,13 +148,54 @@ class AppHarness:
             app_module_path=root / app_name / f"{app_name}.py",
         )
 
+    def _get_globals_from_signature(self, func: Any) -> dict[str, Any]:
+        """Get the globals from a function or module object.
+
+        Args:
+            func: function or module object
+
+        Returns:
+            dict of globals
+        """
+        overrides = {}
+        glbs = {}
+        if not callable(func):
+            return glbs
+        if isinstance(func, functools.partial):
+            overrides = func.keywords
+            func = func.func
+        for param in inspect.signature(func).parameters.values():
+            if param.default is not inspect.Parameter.empty:
+                glbs[param.name] = param.default
+        glbs.update(overrides)
+        return glbs
+
+    def _get_source_from_func(self, func: Any) -> str:
+        """Get the source from a function or module object.
+
+        Args:
+            func: function or module object
+
+        Returns:
+            source code
+        """
+        source = inspect.getsource(func)
+        source = re.sub(r"^\s*def\s+\w+\s*\(.*?\):", "", source, flags=re.DOTALL)
+        return textwrap.dedent(source)
+
     def _initialize_app(self):
         os.environ["TELEMETRY_ENABLED"] = ""  # disable telemetry reporting for tests
         self.app_path.mkdir(parents=True, exist_ok=True)
         if self.app_source is not None:
+            app_globals = self._get_globals_from_signature(self.app_source)
+            if isinstance(self.app_source, functools.partial):
+                self.app_source = self.app_source.func  # type: ignore
             # get the source from a function or module object
-            source_code = textwrap.dedent(
-                "".join(inspect.getsource(self.app_source).splitlines(True)[1:]),
+            source_code = "\n".join(
+                [
+                    "\n".join(f"{k} = {v!r}" for k, v in app_globals.items()),
+                    self._get_source_from_func(self.app_source),
+                ]
             )
             with chdir(self.app_path):
                 reflex.reflex._init(
@@ -167,11 +212,11 @@ class AppHarness:
             # self.app_module.app.
             self.app_module = reflex.utils.prerequisites.get_compiled_app(reload=True)
         self.app_instance = self.app_module.app
-        if isinstance(self.app_instance.state_manager, StateManagerRedis):
+        if isinstance(self.app_instance._state_manager, StateManagerRedis):
             # Create our own redis connection for testing.
             self.state_manager = StateManagerRedis.create(self.app_instance.state)
         else:
-            self.state_manager = self.app_instance.state_manager
+            self.state_manager = self.app_instance._state_manager
 
     def _get_backend_shutdown_handler(self):
         if self.backend is None:
@@ -181,10 +226,13 @@ class AppHarness:
 
         async def _shutdown_redis(*args, **kwargs) -> None:
             # ensure redis is closed before event loop
-            if self.app_instance is not None and isinstance(
-                self.app_instance.state_manager, StateManagerRedis
-            ):
-                await self.app_instance.state_manager.close()
+            try:
+                if self.app_instance is not None and isinstance(
+                    self.app_instance.state_manager, StateManagerRedis
+                ):
+                    await self.app_instance.state_manager.close()
+            except ValueError:
+                pass
             await original_shutdown(*args, **kwargs)
 
         return _shutdown_redis