浏览代码

Define default classes while subclassing (#3843)

This PR introduces `default_classes`, `default_style` and
`default_props` arguments for `Element.__init_subclass__` so that
defaults like "nicegui-*" classes can be defined while subclassing. This
way they can be removed by the user calling something like
`ui.link.default_classes(replace='text-red-500')`, solving issue #3826.
Falko Schindler 7 月之前
父节点
当前提交
70081770b4

+ 6 - 0
nicegui/element.py

@@ -84,6 +84,9 @@ class Element(Visibility):
                           libraries: List[Union[str, Path]] = [],  # noqa: B006  # DEPRECATED
                           exposed_libraries: List[Union[str, Path]] = [],  # noqa: B006  # DEPRECATED
                           extra_libraries: List[Union[str, Path]] = [],  # noqa: B006  # DEPRECATED
+                          default_classes: Optional[str] = None,
+                          default_style: Optional[str] = None,
+                          default_props: Optional[str] = None,
                           ) -> None:
         super().__init_subclass__()
         base = Path(inspect.getfile(cls)).parent
@@ -127,6 +130,9 @@ class Element(Visibility):
         cls._default_props = copy(cls._default_props)
         cls._default_classes = copy(cls._default_classes)
         cls._default_style = copy(cls._default_style)
+        cls.default_classes(default_classes)
+        cls.default_style(default_style)
+        cls.default_props(default_props)
 
     def add_resource(self, path: Union[str, Path]) -> None:
         """Add a resource to the element.

+ 4 - 2
nicegui/elements/aggrid.py

@@ -13,7 +13,10 @@ if importlib.util.find_spec('pandas'):
         import pandas as pd
 
 
-class AgGrid(Element, component='aggrid.js', dependencies=['lib/aggrid/ag-grid-community.min.js']):
+class AgGrid(Element,
+             component='aggrid.js',
+             dependencies=['lib/aggrid/ag-grid-community.min.js'],
+             default_classes='nicegui-aggrid'):
 
     def __init__(self,
                  options: Dict, *,
@@ -36,7 +39,6 @@ class AgGrid(Element, component='aggrid.js', dependencies=['lib/aggrid/ag-grid-c
         self._props['options'] = options
         self._props['html_columns'] = html_columns[:]
         self._props['auto_size_columns'] = auto_size_columns
-        self._classes.append('nicegui-aggrid')
         self._classes.append(f'ag-theme-{theme}')
 
     @classmethod

+ 1 - 2
nicegui/elements/card.py

@@ -5,7 +5,7 @@ from typing_extensions import Self
 from ..element import Element
 
 
-class Card(Element):
+class Card(Element, default_classes='nicegui-card'):
 
     def __init__(self, *,
                  align_items: Optional[Literal['start', 'end', 'center', 'baseline', 'stretch']] = None,
@@ -23,7 +23,6 @@ class Card(Element):
         :param align_items: alignment of the items in the card ("start", "end", "center", "baseline", or "stretch"; default: `None`)
         """
         super().__init__('q-card')
-        self._classes.append('nicegui-card')
         if align_items:
             self._classes.append(f'items-{align_items}')
 

+ 1 - 2
nicegui/elements/carousel.py

@@ -52,7 +52,7 @@ class Carousel(ValueElement):
         self.run_method('previous')
 
 
-class CarouselSlide(DisableableElement):
+class CarouselSlide(DisableableElement, default_classes='nicegui-carousel-slide'):
 
     def __init__(self, name: Optional[str] = None) -> None:
         """Carousel Slide
@@ -66,6 +66,5 @@ class CarouselSlide(DisableableElement):
         self.carousel = cast(ValueElement, context.slot.parent)
         name = name or f'slide_{len(self.carousel.default_slot.children)}'
         self._props['name'] = name
-        self._classes.append('nicegui-carousel-slide')
         if self.carousel.value is None:
             self.carousel.value = name

+ 1 - 2
nicegui/elements/code.py

@@ -10,7 +10,7 @@ from .mixins.content_element import ContentElement
 from .timer import Timer as timer
 
 
-class Code(ContentElement):
+class Code(ContentElement, default_classes='nicegui-code'):
 
     def __init__(self, content: str = '', *, language: Optional[str] = 'python') -> None:
         """Code
@@ -23,7 +23,6 @@ class Code(ContentElement):
         :param language: language of the code (default: "python")
         """
         super().__init__(content=remove_indentation(content))
-        self._classes.append('nicegui-code')
 
         with self:
             self.markdown = markdown().classes('overflow-auto') \

+ 1 - 3
nicegui/elements/codemirror.py

@@ -245,7 +245,7 @@ SUPPORTED_THEMES = Literal[
 ]
 
 
-class CodeMirror(ValueElement, DisableableElement, component='codemirror.js'):
+class CodeMirror(ValueElement, DisableableElement, component='codemirror.js', default_classes='nicegui-codemirror'):
     VALUE_PROP = 'value'
     LOOPBACK = None
 
@@ -283,8 +283,6 @@ class CodeMirror(ValueElement, DisableableElement, component='codemirror.js'):
         super().__init__(value=value, on_value_change=on_change)
         self.add_resource(Path(__file__).parent / 'lib' / 'codemirror')
 
-        self._classes.append('nicegui-codemirror')
-
         self._props['language'] = language
         self._props['theme'] = theme
         self._props['indent'] = indent

+ 1 - 2
nicegui/elements/column.py

@@ -3,7 +3,7 @@ from typing import Literal, Optional
 from ..element import Element
 
 
-class Column(Element):
+class Column(Element, default_classes='nicegui-column'):
 
     def __init__(self, *,
                  wrap: bool = False,
@@ -17,7 +17,6 @@ class Column(Element):
         :param align_items: alignment of the items in the column ("start", "end", "center", "baseline", or "stretch"; default: `None`)
         """
         super().__init__('div')
-        self._classes.append('nicegui-column')
         if align_items:
             self._classes.append(f'items-{align_items}')
 

+ 4 - 2
nicegui/elements/echart.py

@@ -17,7 +17,10 @@ except ImportError:
     pass
 
 
-class EChart(Element, component='echart.js', dependencies=['lib/echarts/echarts.min.js', 'lib/echarts-gl/echarts-gl.min.js']):
+class EChart(Element,
+             component='echart.js',
+             dependencies=['lib/echarts/echarts.min.js', 'lib/echarts-gl/echarts-gl.min.js'],
+             default_classes='nicegui-echart'):
 
     def __init__(self, options: Dict, on_point_click: Optional[Handler[EChartPointClickEventArguments]] = None, *, enable_3d: bool = False) -> None:
         """Apache EChart
@@ -33,7 +36,6 @@ class EChart(Element, component='echart.js', dependencies=['lib/echarts/echarts.
         super().__init__()
         self._props['options'] = options
         self._props['enable_3d'] = enable_3d or any('3D' in key for key in options)
-        self._classes.append('nicegui-echart')
 
         if on_point_click:
             self.on_point_click(on_point_click)

+ 1 - 2
nicegui/elements/editor.py

@@ -5,7 +5,7 @@ from .mixins.disableable_element import DisableableElement
 from .mixins.value_element import ValueElement
 
 
-class Editor(ValueElement, DisableableElement, component='editor.js'):
+class Editor(ValueElement, DisableableElement, component='editor.js', default_classes='nicegui-editor'):
     VALUE_PROP: str = 'value'
     LOOPBACK = False
 
@@ -24,7 +24,6 @@ class Editor(ValueElement, DisableableElement, component='editor.js'):
         :param on_change: callback to be invoked when the value changes
         """
         super().__init__(value=value, on_value_change=on_change)
-        self._classes.append('nicegui-editor')
         if placeholder is not None:
             self._props['placeholder'] = placeholder
 

+ 1 - 2
nicegui/elements/expansion.py

@@ -7,7 +7,7 @@ from .mixins.text_element import TextElement
 from .mixins.value_element import ValueElement
 
 
-class Expansion(IconElement, TextElement, ValueElement, DisableableElement):
+class Expansion(IconElement, TextElement, ValueElement, DisableableElement, default_classes='nicegui-expansion'):
 
     def __init__(self,
                  text: str = '', *,
@@ -33,7 +33,6 @@ class Expansion(IconElement, TextElement, ValueElement, DisableableElement):
             self._props['caption'] = caption
         if group is not None:
             self._props['group'] = group
-        self._classes.append('nicegui-expansion')
 
     def open(self) -> None:
         """Open the expansion."""

+ 1 - 2
nicegui/elements/grid.py

@@ -3,7 +3,7 @@ from typing import Optional, Union
 from ..element import Element
 
 
-class Grid(Element):
+class Grid(Element, default_classes='nicegui-grid'):
 
     def __init__(self,
                  *,
@@ -18,7 +18,6 @@ class Grid(Element):
         :param columns: number of columns in the grid or a string with the grid-template-columns CSS property (e.g. 'auto 1fr')
         """
         super().__init__('div')
-        self._classes.append('nicegui-grid')
 
         if isinstance(rows, int):
             self._style['grid-template-rows'] = f'repeat({rows}, minmax(0, 1fr))'

+ 5 - 3
nicegui/elements/joystick.py

@@ -6,14 +6,17 @@ from ..element import Element
 from ..events import GenericEventArguments, Handler, JoystickEventArguments, handle_event
 
 
-class Joystick(Element, component='joystick.vue', dependencies=['lib/nipplejs/nipplejs.js']):
+class Joystick(Element,
+               component='joystick.vue',
+               dependencies=['lib/nipplejs/nipplejs.js'],
+               default_classes='nicegui-joystick'):
 
     def __init__(self, *,
                  on_start: Optional[Handler[JoystickEventArguments]] = None,
                  on_move: Optional[Handler[JoystickEventArguments]] = None,
                  on_end: Optional[Handler[JoystickEventArguments]] = None,
                  throttle: float = 0.05,
-                 ** options: Any) -> None:
+                 **options: Any) -> None:
         """Joystick
 
         Create a joystick based on `nipple.js <https://yoannmoi.net/nipplejs/>`_.
@@ -26,7 +29,6 @@ class Joystick(Element, component='joystick.vue', dependencies=['lib/nipplejs/ni
         """
         super().__init__()
         self._props['options'] = options
-        self._classes.append('nicegui-joystick')
         self.active = False
 
         self._start_handlers = [on_start] if on_start else []

+ 1 - 2
nicegui/elements/leaflet.py

@@ -11,7 +11,7 @@ from ..events import GenericEventArguments
 from .leaflet_layer import Layer
 
 
-class Leaflet(Element, component='leaflet.js'):
+class Leaflet(Element, component='leaflet.js', default_classes='nicegui-leaflet'):
     # pylint: disable=import-outside-toplevel
     from .leaflet_layers import GenericLayer as generic_layer
     from .leaflet_layers import Marker as marker
@@ -40,7 +40,6 @@ class Leaflet(Element, component='leaflet.js'):
         """
         super().__init__()
         self.add_resource(Path(__file__).parent / 'lib' / 'leaflet')
-        self._classes.append('nicegui-leaflet')
 
         self.layers: List[Layer] = []
         self.is_initialized = False

+ 1 - 2
nicegui/elements/link.py

@@ -5,7 +5,7 @@ from ..element import Element
 from .mixins.text_element import TextElement
 
 
-class Link(TextElement, component='link.js'):
+class Link(TextElement, component='link.js', default_classes='nicegui-link'):
 
     def __init__(self,
                  text: str = '',
@@ -31,7 +31,6 @@ class Link(TextElement, component='link.js'):
         elif callable(target):
             self._props['href'] = Client.page_routes[target]
         self._props['target'] = '_blank' if new_tab else '_self'
-        self._classes.append('nicegui-link')
 
 
 class LinkTarget(Element):

+ 1 - 2
nicegui/elements/log.py

@@ -4,7 +4,7 @@ from ..element import Element
 from .label import Label
 
 
-class Log(Element):
+class Log(Element, default_classes='nicegui-log'):
 
     def __init__(self, max_lines: Optional[int] = None) -> None:
         """Log View
@@ -15,7 +15,6 @@ class Log(Element):
         """
         super().__init__()
         self.max_lines = max_lines
-        self._classes.append('nicegui-log')
 
     def push(self, line: Any) -> None:
         """Add a new line to the log.

+ 1 - 2
nicegui/elements/markdown.py

@@ -14,7 +14,7 @@ from .mixins.content_element import ContentElement
 CODEHILITE_CSS_URL = f'/_nicegui/{__version__}/codehilite.css'
 
 
-class Markdown(ContentElement, component='markdown.js'):
+class Markdown(ContentElement, component='markdown.js', default_classes='nicegui-markdown'):
 
     def __init__(self,
                  content: str = '', *,
@@ -29,7 +29,6 @@ class Markdown(ContentElement, component='markdown.js'):
         """
         self.extras = extras[:]
         super().__init__(content=content)
-        self._classes.append('nicegui-markdown')
         if 'mermaid' in extras:
             self._props['use_mermaid'] = True
             self.libraries.append(Mermaid.exposed_libraries[0])

+ 1 - 2
nicegui/elements/pyplot.py

@@ -33,7 +33,7 @@ except ImportError:
     pass
 
 
-class Pyplot(Element):
+class Pyplot(Element, default_classes='nicegui-pyplot'):
 
     def __init__(self, *, close: bool = True, **kwargs: Any) -> None:
         """Pyplot Context
@@ -47,7 +47,6 @@ class Pyplot(Element):
             raise ImportError('Matplotlib is not installed. Please run "pip install matplotlib".')
 
         super().__init__('div')
-        self._classes.append('nicegui-pyplot')
         self.close = close
         self.fig = plt.figure(**kwargs)
         self._convert_to_html()

+ 1 - 2
nicegui/elements/row.py

@@ -3,7 +3,7 @@ from typing import Literal, Optional
 from ..element import Element
 
 
-class Row(Element):
+class Row(Element, default_classes='nicegui-row'):
 
     def __init__(self, *,
                  wrap: bool = True,
@@ -17,7 +17,6 @@ class Row(Element):
         :param align_items: alignment of the items in the row ("start", "end", "center", "baseline", or "stretch"; default: `None`)
         """
         super().__init__('div')
-        self._classes.append('nicegui-row')
         self._classes.append('row')  # NOTE: for compatibility with Quasar's col-* classes
         if align_items:
             self._classes.append(f'items-{align_items}')

+ 2 - 2
nicegui/elements/scene.py

@@ -50,7 +50,8 @@ class Scene(Element,
                 'lib/three/modules/OrbitControls.js',
                 'lib/three/modules/STLLoader.js',
                 'lib/tween/tween.umd.js',
-            ]):
+            ],
+            default_classes='nicegui-scene'):
     # pylint: disable=import-outside-toplevel
     from .scene_objects import AxesHelper as axes_helper
     from .scene_objects import Box as box
@@ -119,7 +120,6 @@ class Scene(Element,
         self.on('dragstart', self._handle_drag)
         self.on('dragend', self._handle_drag)
         self._props['drag_constraints'] = drag_constraints
-        self._classes.append('nicegui-scene')
 
     def on_click(self, callback: Handler[SceneClickEventArguments]) -> Self:
         """Add a callback to be invoked when a 3D object is clicked."""

+ 2 - 2
nicegui/elements/scene_view.py

@@ -20,7 +20,8 @@ class SceneView(Element,
                 dependencies=[
                     'lib/tween/tween.umd.js',
                     'lib/three/three.module.js',
-                ]):
+                ],
+                default_classes='nicegui-scene-view'):
 
     def __init__(self,
                  scene: Scene,
@@ -53,7 +54,6 @@ class SceneView(Element,
         self._click_handlers = [on_click] if on_click else []
         self.on('init', self._handle_init)
         self.on('click3d', self._handle_click)
-        self._classes.append('nicegui-scene-view')
 
     def on_click(self, callback: Handler[ClickEventArguments]) -> Self:
         """Add a callback to be invoked when a 3D object is clicked."""

+ 1 - 2
nicegui/elements/scroll_area.py

@@ -6,7 +6,7 @@ from ..element import Element
 from ..events import GenericEventArguments, Handler, ScrollEventArguments, handle_event
 
 
-class ScrollArea(Element):
+class ScrollArea(Element, default_classes='nicegui-scroll-area'):
 
     def __init__(self, *, on_scroll: Optional[Handler[ScrollEventArguments]] = None) -> None:
         """Scroll Area
@@ -17,7 +17,6 @@ class ScrollArea(Element):
         :param on_scroll: function to be called when the scroll position changes
         """
         super().__init__('q-scroll-area')
-        self._classes.append('nicegui-scroll-area')
 
         if on_scroll:
             self.on_scroll(on_scroll)

+ 1 - 2
nicegui/elements/separator.py

@@ -1,7 +1,7 @@
 from ..element import Element
 
 
-class Separator(Element):
+class Separator(Element, default_classes='nicegui-separator'):
 
     def __init__(self) -> None:
         """Separator
@@ -11,4 +11,3 @@ class Separator(Element):
         It serves as a separator for cards, menus and other component containers and is similar to HTML's <hr> tag.
         """
         super().__init__('q-separator')
-        self._classes.append('nicegui-separator')

+ 1 - 2
nicegui/elements/splitter.py

@@ -5,7 +5,7 @@ from .mixins.disableable_element import DisableableElement
 from .mixins.value_element import ValueElement
 
 
-class Splitter(ValueElement, DisableableElement):
+class Splitter(ValueElement, DisableableElement, default_classes='nicegui-splitter'):
 
     def __init__(self, *,
                  horizontal: Optional[bool] = False,
@@ -35,7 +35,6 @@ class Splitter(ValueElement, DisableableElement):
         self._props['horizontal'] = horizontal
         self._props['limits'] = limits
         self._props['reverse'] = reverse
-        self._classes.append('nicegui-splitter')
 
         self.before = self.add_slot('before')
         self.after = self.add_slot('after')

+ 2 - 4
nicegui/elements/stepper.py

@@ -10,7 +10,7 @@ from .mixins.icon_element import IconElement
 from .mixins.value_element import ValueElement
 
 
-class Stepper(ValueElement):
+class Stepper(ValueElement, default_classes='nicegui-stepper'):
 
     def __init__(self, *,
                  value: Union[str, Step, None] = None,
@@ -32,7 +32,6 @@ class Stepper(ValueElement):
         """
         super().__init__(tag='q-stepper', value=value, on_value_change=on_value_change)
         self._props['keep-alive'] = keep_alive
-        self._classes.append('nicegui-stepper')
 
     def _value_to_model_value(self, value: Any) -> Any:
         return value.props['name'] if isinstance(value, Step) else value
@@ -53,7 +52,7 @@ class Stepper(ValueElement):
         self.run_method('previous')
 
 
-class Step(IconElement, DisableableElement):
+class Step(IconElement, DisableableElement, default_classes='nicegui-step'):
 
     def __init__(self, name: str, title: Optional[str] = None, icon: Optional[str] = None) -> None:
         """Step
@@ -68,7 +67,6 @@ class Step(IconElement, DisableableElement):
         super().__init__(tag='q-step', icon=icon)
         self._props['name'] = name
         self._props['title'] = title if title is not None else name
-        self._classes.append('nicegui-step')
         self.stepper = cast(ValueElement, context.slot.parent)
         if self.stepper.value is None:
             self.stepper.value = name

+ 1 - 2
nicegui/elements/tabs.py

@@ -81,7 +81,7 @@ class TabPanels(ValueElement):
         return value.props['name'] if isinstance(value, (Tab, TabPanel)) else value
 
 
-class TabPanel(DisableableElement):
+class TabPanel(DisableableElement, default_classes='nicegui-tab-panel'):
 
     def __init__(self, name: Union[Tab, str]) -> None:
         """Tab Panel
@@ -93,4 +93,3 @@ class TabPanel(DisableableElement):
         """
         super().__init__(tag='q-tab-panel')
         self._props['name'] = name.props['name'] if isinstance(name, Tab) else name
-        self._classes.append('nicegui-tab-panel')

+ 1 - 2
nicegui/elements/timeline.py

@@ -27,7 +27,7 @@ class Timeline(Element):
             self._props['color'] = color
 
 
-class TimelineEntry(IconElement):
+class TimelineEntry(IconElement, default_classes='nicegui-timeline-entry'):
 
     def __init__(self,
                  body: Optional[str] = None,
@@ -70,4 +70,3 @@ class TimelineEntry(IconElement):
             self._props['title'] = title
         if subtitle is not None:
             self._props['subtitle'] = subtitle
-        self._classes.append('nicegui-timeline-entry')

+ 3 - 6
nicegui/page_layout.py

@@ -19,7 +19,7 @@ PageStickyPositions = Literal[
 ]
 
 
-class Header(ValueElement):
+class Header(ValueElement, default_classes='nicegui-header'):
 
     def __init__(self, *,
                  value: bool = True,
@@ -48,7 +48,6 @@ class Header(ValueElement):
         _check_current_slot(self)
         with context.client.layout:
             super().__init__(tag='q-header', value=value, on_value_change=None)
-        self._classes.append('nicegui-header')
         self._props['bordered'] = bordered
         self._props['elevated'] = elevated
         if wrap:
@@ -84,7 +83,7 @@ class Header(ValueElement):
         self.value = False
 
 
-class Drawer(Element):
+class Drawer(Element, default_classes='nicegui-drawer'):
 
     def __init__(self,
                  side: DrawerSides, *,
@@ -121,7 +120,6 @@ class Drawer(Element):
         self._props['side'] = side
         self._props['bordered'] = bordered
         self._props['elevated'] = elevated
-        self._classes.append('nicegui-drawer')
         code = list(self.client.layout.props['view'])
         code[0 if side == 'left' else 2] = side[0].lower() if top_corner else 'h'
         code[4 if side == 'left' else 6] = side[0].upper() if fixed else side[0].lower()
@@ -212,7 +210,7 @@ class RightDrawer(Drawer):
                          bottom_corner=bottom_corner)
 
 
-class Footer(ValueElement):
+class Footer(ValueElement, default_classes='nicegui-footer'):
 
     def __init__(self, *,
                  value: bool = True,
@@ -239,7 +237,6 @@ class Footer(ValueElement):
         _check_current_slot(self)
         with context.client.layout:
             super().__init__(tag='q-footer', value=value, on_value_change=None)
-        self.classes('nicegui-footer')
         self._props['bordered'] = bordered
         self._props['elevated'] = elevated
         if wrap:

+ 25 - 8
nicegui/testing/general_fixtures.py

@@ -1,4 +1,5 @@
 import importlib
+from copy import copy
 from typing import Generator, List, Type
 
 import pytest
@@ -36,14 +37,13 @@ def nicegui_reset_globals() -> Generator[None, None, None]:
             app.routes.remove(route)
     importlib.reload(core)
     importlib.reload(run)
-    element_classes: List[Type[ui.element]] = [ui.element]
-    while element_classes:
-        parent = element_classes.pop()
-        for cls in parent.__subclasses__():
-            cls._default_props = {}  # pylint: disable=protected-access
-            cls._default_style = {}  # pylint: disable=protected-access
-            cls._default_classes = []  # pylint: disable=protected-access
-            element_classes.append(cls)
+
+    # capture initial defaults
+    element_types: List[Type[ui.element]] = [ui.element, *find_all_subclasses(ui.element)]
+    default_classes = {t: copy(t._default_classes) for t in element_types}  # pylint: disable=protected-access
+    default_styles = {t: copy(t._default_style) for t in element_types}  # pylint: disable=protected-access
+    default_props = {t: copy(t._default_props) for t in element_types}  # pylint: disable=protected-access
+
     Client.instances.clear()
     Client.page_routes.clear()
     app.reset()
@@ -51,9 +51,26 @@ def nicegui_reset_globals() -> Generator[None, None, None]:
     # NOTE we need to re-add the auto index route because we removed all routes above
     app.get('/')(Client.auto_index_client.build_response)
     binding.reset()
+
     yield
+
     app.reset()
 
+    # restore initial defaults
+    for t in element_types:
+        t._default_classes = default_classes[t]  # pylint: disable=protected-access
+        t._default_style = default_styles[t]  # pylint: disable=protected-access
+        t._default_props = default_props[t]  # pylint: disable=protected-access
+
+
+def find_all_subclasses(cls: Type) -> List[Type]:
+    """Find all subclasses of a class."""
+    subclasses = []
+    for subclass in cls.__subclasses__():
+        subclasses.append(subclass)
+        subclasses.extend(find_all_subclasses(subclass))
+    return subclasses
+
 
 def prepare_simulation(request: pytest.FixtureRequest) -> None:
     """Prepare a simulation to be started.