Răsfoiți Sursa

Add optional catch all handling (#260)

Thomas Brandého 2 ani în urmă
părinte
comite
2e41303b25
4 a modificat fișierele cu 120 adăugiri și 34 ștergeri
  1. 32 1
      pynecone/app.py
  2. 23 0
      pynecone/constants.py
  3. 6 6
      pynecone/state.py
  4. 59 27
      pynecone/utils.py

+ 32 - 1
pynecone/app.py

@@ -199,10 +199,41 @@ class App(Base):
 
         # Format the route.
         route = utils.format_route(path)
-
         # Add the page.
+        self._check_routes_conflict(route)
         self.pages[route] = component
 
+    def _check_routes_conflict(self, new_route: str):
+        """Verify if there is any conflict between the new route and any existing route.
+
+        Based on conflicts that NextJS would throw if not intercepted.
+
+        Raises:
+            ValueError: exception showing which conflict exist with the path to be added
+
+        Args:
+            new_route: the route being newly added.
+        """
+        newroute_catchall = utils.catchall_in_route(new_route)
+        if not newroute_catchall:
+            return
+
+        for route in self.pages:
+            route = "" if route == "index" else route
+
+            if new_route.startswith(route + "/[[..."):
+                raise ValueError(
+                    f"You cannot define a route with the same specificity as a optional catch-all route ('{route}' and '{new_route}')"
+                )
+
+            route_catchall = utils.catchall_in_route(route)
+            if route_catchall and newroute_catchall:
+                # both route have a catchall, check if preceding path is the same
+                if utils.catchall_prefix(route) == utils.catchall_prefix(new_route):
+                    raise ValueError(
+                        f"You cannot use multiple catchall for the same dynamic path ({route} !== {new_route})"
+                    )
+
     def compile(self, force_compile: bool = False):
         """Compile the app and output it to the pages folder.
 

+ 23 - 0
pynecone/constants.py

@@ -1,7 +1,10 @@
 """Constants used throughout the package."""
 
 import os
+import re
+
 from enum import Enum
+from types import SimpleNamespace
 
 import pkg_resources
 
@@ -188,3 +191,23 @@ class Endpoint(Enum):
 
         # Return the url.
         return url
+
+
+class PathArgType(SimpleNamespace):
+    """Type of pathArg extracted from URI path."""
+
+    SINGLE = str("arg_single")
+    LIST = str("arg_list")
+
+
+# ROUTE REGEXs
+class RouteRegex(SimpleNamespace):
+    """Regex used for extracting path args in path."""
+
+    ARG = re.compile(r"\[(?!\.)([^\[\]]+)\]")
+    # group return the catchall pattern (i.e "[[..slug]]")
+    CATCHALL = re.compile(r"(\[?\[\.{3}(?![0-9]).*\]?\])")
+    # group return the argname (i.e "slug")
+    STRICT_CATCHALL = re.compile(r"\[\.{3}([a-zA-Z_][\w]*)\]")
+    # group return the argname (i.e "slug")
+    OPT_CATCHALL = re.compile(r"\[\[\.{3}([a-zA-Z_][\w]*)\]\]")

+ 6 - 6
pynecone/state.py

@@ -287,14 +287,14 @@ class State(Base, ABC):
             args: a dict of args
         """
 
-        def param_factory(param):
+        def argsingle_factory(param):
             @ComputedVar
             def inner_func(self) -> str:
                 return self.get_query_params().get(param, "")
 
             return inner_func
 
-        def catchall_factory(param):
+        def arglist_factory(param):
             @ComputedVar
             def inner_func(self) -> List:
                 return self.get_query_params().get(param, [])
@@ -303,10 +303,10 @@ class State(Base, ABC):
 
         for param, value in args.items():
 
-            if value == "catchall":
-                func = catchall_factory(param)
-            elif value == "patharg":
-                func = param_factory(param)
+            if value == constants.PathArgType.SINGLE:
+                func = argsingle_factory(param)
+            elif value == constants.PathArgType.LIST:
+                func = arglist_factory(param)
             else:
                 continue
             cls.computed_vars[param] = func.set_state(cls)  # type: ignore

+ 59 - 27
pynecone/utils.py

@@ -783,14 +783,9 @@ def verify_path_validity(path: str) -> None:
     Raises:
         ValueError: explains what is wrong with the path.
     """
-    check_catchall = re.compile(r"^\[\.\.\.(.+)\]$")
-    catchall_found = False
-    for part in path.split("/"):
-        if catchall_found:
-            raise ValueError(f"Catch-all must be the last part of the URL: {path}")
-        match = check_catchall.match(part)
-        if match:
-            catchall_found = True
+    pattern = catchall_in_route(path)
+    if pattern and not path.endswith(pattern):
+        raise ValueError(f"Catch-all must be the last part of the URL: {path}")
 
 
 def get_path_args(path: str) -> Dict[str, str]:
@@ -799,42 +794,79 @@ def get_path_args(path: str) -> Dict[str, str]:
     Args:
         path: The path to get the arguments for.
 
-    Raises:
-        ValueError: explains what is wrong with the path.
-
     Returns:
         The path arguments.
     """
-    # Import here to avoid circular imports.
-    from pynecone.var import BaseVar
+
+    def add_path_arg(match: re.Match[str], type_: str):
+        """Add arg from regex search result.
+
+        Args:
+            match: result of a regex search
+            type_: the assigned type for this arg
+
+        Raises:
+            ValueError: explains what is wrong with the path.
+        """
+        arg_name = match.groups()[0]
+        if arg_name in args:
+            raise ValueError(
+                f"arg name [{arg_name}] is used more than once in this URL"
+            )
+        args[arg_name] = type_
 
     # Regex to check for path args.
-    check = re.compile(r"^\[(.+)\]$")
-    check_catchall = re.compile(r"^\[\.\.\.(.+)\]$")
+    check = constants.RouteRegex.ARG
+    check_strict_catchall = constants.RouteRegex.STRICT_CATCHALL
+    check_opt_catchall = constants.RouteRegex.OPT_CATCHALL
 
     # Iterate over the path parts and check for path args.
     args = {}
     for part in path.split("/"):
-        match = check_catchall.match(part)
-        if match:
-            arg_name = match.groups()[0]
-            if arg_name in args:
-                raise ValueError(f"arg [{arg_name}] is used more than once in this URL")
+        match_opt = check_opt_catchall.match(part)
+        if match_opt:
+            add_path_arg(match_opt, constants.PathArgType.LIST)
+            break
 
-            args[arg_name] = "catchall"
-            continue
+        match_strict = check_strict_catchall.match(part)
+        if match_strict:
+            add_path_arg(match_strict, constants.PathArgType.LIST)
+            break
 
         match = check.match(part)
         if match:
             # Add the path arg to the list.
-            arg_name = match.groups()[0]
-            if arg_name in args:
-                raise ValueError(f"arg [{arg_name}] is used more than once in this URL")
-            args[arg_name] = "patharg"
+            add_path_arg(match, constants.PathArgType.SINGLE)
     return args
 
 
-def format_route(route: str):
+def catchall_in_route(route: str) -> str:
+    """Extract the catchall part from a route.
+
+    Args:
+        route: the route from which to extract
+
+    Returns:
+        str: the catchall part of the URI
+    """
+    match_ = constants.RouteRegex.CATCHALL.search(route)
+    return match_.group() if match_ else ""
+
+
+def catchall_prefix(route: str) -> str:
+    """Extract the prefix part from a route that contains a catchall.
+
+    Args:
+        route: the route from which to extract
+
+    Returns:
+        str: the prefix part of the URI
+    """
+    pattern = catchall_in_route(route)
+    return route.replace(pattern, "") if pattern else ""
+
+
+def format_route(route: str) -> str:
     """Format the given route.
 
     Args: