123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131 |
- """General utility functions."""
- from __future__ import annotations
- import inspect
- import json
- import os
- import platform
- import random
- import re
- import shutil
- import signal
- import string
- import subprocess
- import sys
- import uvicorn
- from urllib.parse import urlparse
- from collections import defaultdict
- from subprocess import PIPE
- from types import ModuleType
- from typing import _GenericAlias # type: ignore
- from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- Dict,
- List,
- Optional,
- Tuple,
- Type,
- Union,
- )
- import plotly.graph_objects as go
- import typer
- from plotly.io import to_json
- from redis import Redis
- from rich.console import Console
- from pynecone import constants
- from pynecone.base import Base
- if TYPE_CHECKING:
- from pynecone.app import App
- from pynecone.components.component import ImportDict
- from pynecone.config import Config
- from pynecone.event import Event, EventHandler, EventSpec
- from pynecone.var import Var
- # Shorthand for join.
- join = os.linesep.join
- # Console for pretty printing.
- console = Console()
- # Union of generic types.
- GenericType = Union[Type, _GenericAlias]
- # Valid state var types.
- PrimitiveType = Union[int, float, bool, str, list, dict, tuple]
- StateVar = Union[PrimitiveType, Base, None]
- def get_args(alias: _GenericAlias) -> Tuple[Type, ...]:
- """Get the arguments of a type alias.
- Args:
- alias: The type alias.
- Returns:
- The arguments of the type alias.
- """
- return alias.__args__
- def is_generic_alias(cls: GenericType) -> bool:
- """Check whether the class is a generic alias.
- Args:
- cls: The class to check.
- Returns:
- Whether the class is a generic alias.
- """
- # For older versions of Python.
- if isinstance(cls, _GenericAlias):
- return True
- try:
- from typing import _SpecialGenericAlias # type: ignore
- if isinstance(cls, _SpecialGenericAlias):
- return True
- except ImportError:
- pass
- # For newer versions of Python.
- try:
- from types import GenericAlias # type: ignore
- return isinstance(cls, GenericAlias)
- except ImportError:
- return False
- def is_union(cls: GenericType) -> bool:
- """Check if a class is a Union.
- Args:
- cls: The class to check.
- Returns:
- Whether the class is a Union.
- """
- try:
- from typing import _UnionGenericAlias # type: ignore
- return isinstance(cls, _UnionGenericAlias)
- except ImportError:
- pass
- if is_generic_alias(cls):
- return cls.__origin__ == Union
- return False
- def get_base_class(cls: GenericType) -> Type:
- """Get the base class of a class.
- Args:
- cls: The class.
- Returns:
- The base class of the class.
- """
- if is_union(cls):
- return tuple(get_base_class(arg) for arg in get_args(cls))
- if is_generic_alias(cls):
- return get_base_class(cls.__origin__)
- return cls
- def _issubclass(cls: GenericType, cls_check: GenericType) -> bool:
- """Check if a class is a subclass of another class.
- Args:
- cls: The class to check.
- cls_check: The class to check against.
- Returns:
- Whether the class is a subclass of the other class.
- """
- # Special check for Any.
- if cls_check == Any:
- return True
- if cls == Any or cls == Callable:
- return False
- cls_base = get_base_class(cls)
- cls_check_base = get_base_class(cls_check)
- return cls_check_base == Any or issubclass(cls_base, cls_check_base)
- def _isinstance(obj: Any, cls: GenericType) -> bool:
- """Check if an object is an instance of a class.
- Args:
- obj: The object to check.
- cls: The class to check against.
- Returns:
- Whether the object is an instance of the class.
- """
- return isinstance(obj, get_base_class(cls))
- def rm(path: str):
- """Remove a file or directory.
- Args:
- path: The path to the file or directory.
- """
- if os.path.isdir(path):
- shutil.rmtree(path)
- elif os.path.isfile(path):
- os.remove(path)
- def cp(src: str, dest: str, overwrite: bool = True) -> bool:
- """Copy a file or directory.
- Args:
- src: The path to the file or directory.
- dest: The path to the destination.
- overwrite: Whether to overwrite the destination.
- Returns:
- Whether the copy was successful.
- """
- if src == dest:
- return False
- if not overwrite and os.path.exists(dest):
- return False
- if os.path.isdir(src):
- rm(dest)
- shutil.copytree(src, dest)
- else:
- shutil.copyfile(src, dest)
- return True
- def mv(src: str, dest: str, overwrite: bool = True) -> bool:
- """Move a file or directory.
- Args:
- src: The path to the file or directory.
- dest: The path to the destination.
- overwrite: Whether to overwrite the destination.
- Returns:
- Whether the move was successful.
- """
- if src == dest:
- return False
- if not overwrite and os.path.exists(dest):
- return False
- rm(dest)
- shutil.move(src, dest)
- return True
- def mkdir(path: str):
- """Create a directory.
- Args:
- path: The path to the directory.
- """
- if not os.path.exists(path):
- os.makedirs(path)
- def ln(src: str, dest: str, overwrite: bool = False) -> bool:
- """Create a symbolic link.
- Args:
- src: The path to the file or directory.
- dest: The path to the destination.
- overwrite: Whether to overwrite the destination.
- Returns:
- Whether the link was successful.
- """
- if src == dest:
- return False
- if not overwrite and (os.path.exists(dest) or os.path.islink(dest)):
- return False
- if os.path.isdir(src):
- rm(dest)
- os.symlink(src, dest, target_is_directory=True)
- else:
- os.symlink(src, dest)
- return True
- def kill(pid):
- """Kill a process.
- Args:
- pid: The process ID.
- """
- os.kill(pid, signal.SIGTERM)
- def which(program: str) -> Optional[str]:
- """Find the path to an executable.
- Args:
- program: The name of the executable.
- Returns:
- The path to the executable.
- """
- return shutil.which(program)
- def get_config() -> Config:
- """Get the app config.
- Returns:
- The app config.
- """
- from pynecone.config import Config
- sys.path.append(os.getcwd())
- try:
- return __import__(constants.CONFIG_MODULE).config
- except ImportError:
- return Config(app_name="")
- def check_node_version(min_version):
- """Check the version of Node.js.
- Args:
- min_version: The minimum version of Node.js required.
- Returns:
- Whether the version of Node.js is high enough.
- """
- try:
- # Run the node -v command and capture the output
- result = subprocess.run(
- ["node", "-v"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
- )
- # The output will be in the form "vX.Y.Z", so we can split it on the "v" character and take the second part
- version = result.stdout.decode().strip().split("v")[1]
- # Compare the version numbers
- return version.split(".") >= min_version.split(".")
- except Exception as e:
- return False
- def get_package_manager() -> str:
- """Get the package manager executable.
- Returns:
- The path to the package manager.
- Raises:
- FileNotFoundError: If bun or npm is not installed.
- Exit: If the app directory is invalid.
- """
- # Check that the node version is valid.
- if not check_node_version(constants.MIN_NODE_VERSION):
- console.print(
- f"[red]Node.js version {constants.MIN_NODE_VERSION} or higher is required to run Pynecone."
- )
- raise typer.Exit()
- # On Windows, we use npm instead of bun.
- if platform.system() == "Windows":
- npm_path = which("npm")
- if npm_path is None:
- raise FileNotFoundError("Pynecone requires npm to be installed on Windows.")
- return npm_path
- # On other platforms, we use bun.
- return os.path.expandvars(get_config().bun_path)
- def get_app() -> ModuleType:
- """Get the app module based on the default config.
- Returns:
- The app based on the default config.
- """
- config = get_config()
- module = ".".join([config.app_name, config.app_name])
- app = __import__(module, fromlist=(constants.APP_VAR,))
- return app
- def create_config(app_name: str):
- """Create a new pcconfig file.
- Args:
- app_name: The name of the app.
- """
- # Import here to avoid circular imports.
- from pynecone.compiler import templates
- with open(constants.CONFIG_FILE, "w") as f:
- f.write(templates.PCCONFIG.format(app_name=app_name))
- def initialize_app_directory(app_name: str):
- """Initialize the app directory on pc init.
- Args:
- app_name: The name of the app.
- """
- console.log("Initializing the app directory.")
- cp(constants.APP_TEMPLATE_DIR, app_name)
- mv(
- os.path.join(app_name, constants.APP_TEMPLATE_FILE),
- os.path.join(app_name, app_name + constants.PY_EXT),
- )
- cp(constants.ASSETS_TEMPLATE_DIR, constants.APP_ASSETS_DIR)
- def initialize_web_directory():
- """Initialize the web directory on pc init."""
- console.log("Initializing the web directory.")
- rm(os.path.join(constants.WEB_TEMPLATE_DIR, constants.NODE_MODULES))
- rm(os.path.join(constants.WEB_TEMPLATE_DIR, constants.PACKAGE_LOCK))
- cp(constants.WEB_TEMPLATE_DIR, constants.WEB_DIR)
- def install_bun():
- """Install bun onto the user's system."""
- # Bun is not supported on Windows.
- if platform.system() == "Windows":
- console.log("Skipping bun installation on Windows.")
- return
- # Only install if bun is not already installed.
- if not os.path.exists(get_package_manager()):
- console.log("Installing bun...")
- os.system(constants.INSTALL_BUN)
- def install_frontend_packages():
- """Install the frontend packages."""
- # Install the base packages.
- subprocess.run(
- [get_package_manager(), "install"], cwd=constants.WEB_DIR, stdout=PIPE
- )
- # Install the app packages.
- packages = get_config().frontend_packages
- if len(packages) > 0:
- subprocess.run(
- [get_package_manager(), "add", *packages],
- cwd=constants.WEB_DIR,
- stdout=PIPE,
- )
- def is_initialized() -> bool:
- """Check whether the app is initialized.
- Returns:
- Whether the app is initialized in the current directory.
- """
- return os.path.exists(constants.CONFIG_FILE) and os.path.exists(constants.WEB_DIR)
- def is_latest_template() -> bool:
- """Whether the app is using the latest template.
- Returns:
- Whether the app is using the latest template.
- """
- template_version = open(constants.PCVERSION_TEMPLATE_FILE).read()
- if not os.path.exists(constants.PCVERSION_APP_FILE):
- return False
- app_version = open(constants.PCVERSION_APP_FILE).read()
- return app_version >= template_version
- def export_app(app: App, zip: bool = False):
- """Zip up the app for deployment.
- Args:
- app: The app.
- zip: Whether to zip the app.
- """
- # Force compile the app.
- app.compile(force_compile=True)
- # Remove the static folder.
- rm(constants.WEB_STATIC_DIR)
- # Export the Next app.
- subprocess.run([get_package_manager(), "run", "export"], cwd=constants.WEB_DIR)
- # Zip up the app.
- if zip:
- cmd = r"cd .web/_static && zip -r ../../frontend.zip ./* && cd ../.. && zip -r backend.zip ./* -x .web/\* ./assets\* ./frontend.zip\* ./backend.zip\*"
- os.system(cmd)
- def setup_frontend():
- """Set up the frontend."""
- # Initialize the web directory if it doesn't exist.
- cp(constants.WEB_TEMPLATE_DIR, constants.WEB_DIR, overwrite=False)
- # Install the frontend packages.
- console.rule("[bold]Installing frontend packages")
- install_frontend_packages()
- # Link the assets folder.
- ln(src=os.path.join("..", constants.APP_ASSETS_DIR), dest=constants.WEB_ASSETS_DIR)
- def run_frontend(app: App):
- """Run the frontend.
- Args:
- app: The app.
- """
- # Set up the frontend.
- setup_frontend()
- # Compile the frontend.
- app.compile(force_compile=True)
- # Run the frontend in development mode.
- console.rule("[bold green]App Running")
- os.environ["PORT"] = get_config().port
- subprocess.Popen(
- [get_package_manager(), "run", "dev"], cwd=constants.WEB_DIR, env=os.environ
- )
- def run_frontend_prod(app: App):
- """Run the frontend.
- Args:
- app: The app.
- """
- # Set up the frontend.
- setup_frontend()
- # Export the app.
- export_app(app)
- os.environ["PORT"] = get_config().port
- # Run the frontend in production mode.
- subprocess.Popen(
- [get_package_manager(), "run", "prod"], cwd=constants.WEB_DIR, env=os.environ
- )
- def get_num_workers() -> int:
- """Get the number of backend worker processes.
- Returns:
- The number of backend worker processes.
- """
- if get_redis() is None:
- # If there is no redis, then just use 1 worker.
- return 1
- # Use the number of cores * 2 + 1.
- return (os.cpu_count() or 1) * 2 + 1
- def get_api_port() -> int:
- """Get the API port.
- Returns:
- The API port.
- """
- port = urlparse(get_config().api_url).port
- if port is None:
- port = urlparse(constants.API_URL).port
- assert port is not None
- return port
- def run_backend(app_name: str, loglevel: constants.LogLevel = constants.LogLevel.ERROR):
- """Run the backend.
- Args:
- app_name: The app name.
- loglevel: The log level.
- """
- uvicorn.run(
- f"{app_name}:{constants.APP_VAR}.{constants.API_VAR}",
- host=constants.BACKEND_HOST,
- port=get_api_port(),
- log_level=loglevel,
- reload=True,
- )
- def run_backend_prod(
- app_name: str, loglevel: constants.LogLevel = constants.LogLevel.ERROR
- ):
- """Run the backend.
- Args:
- app_name: The app name.
- loglevel: The log level.
- """
- num_workers = get_num_workers()
- command = constants.RUN_BACKEND_PROD + [
- "--bind",
- f"0.0.0.0:{get_api_port()}",
- "--workers",
- str(num_workers),
- "--threads",
- str(num_workers),
- "--log-level",
- str(loglevel),
- f"{app_name}:{constants.APP_VAR}()",
- ]
- subprocess.run(command)
- def get_production_backend_url() -> str:
- """Get the production backend URL.
- Returns:
- The production backend URL.
- """
- config = get_config()
- return constants.PRODUCTION_BACKEND_URL.format(
- username=config.username,
- app_name=config.app_name,
- )
- def to_snake_case(text: str) -> str:
- """Convert a string to snake case.
- The words in the text are converted to lowercase and
- separated by underscores.
- Args:
- text: The string to convert.
- Returns:
- The snake case string.
- """
- s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", text)
- return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
- def to_camel_case(text: str) -> str:
- """Convert a string to camel case.
- The first word in the text is converted to lowercase and
- the rest of the words are converted to title case, removing underscores.
- Args:
- text: The string to convert.
- Returns:
- The camel case string.
- """
- if "_" not in text:
- return text
- camel = "".join(
- word.capitalize() if i > 0 else word.lower()
- for i, word in enumerate(text.lstrip("_").split("_"))
- )
- prefix = "_" if text.startswith("_") else ""
- return prefix + camel
- def to_title_case(text: str) -> str:
- """Convert a string from snake case to title case.
- Args:
- text: The string to convert.
- Returns:
- The title case string.
- """
- return "".join(word.capitalize() for word in text.split("_"))
- WRAP_MAP = {
- "{": "}",
- "(": ")",
- "[": "]",
- "<": ">",
- '"': '"',
- "'": "'",
- "`": "`",
- }
- def get_close_char(open: str, close: Optional[str] = None) -> str:
- """Check if the given character is a valid brace.
- Args:
- open: The open character.
- close: The close character if provided.
- Returns:
- The close character.
- Raises:
- ValueError: If the open character is not a valid brace.
- """
- if close is not None:
- return close
- if open not in WRAP_MAP:
- raise ValueError(f"Invalid wrap open: {open}, must be one of {WRAP_MAP.keys()}")
- return WRAP_MAP[open]
- def is_wrapped(text: str, open: str, close: Optional[str] = None) -> bool:
- """Check if the given text is wrapped in the given open and close characters.
- Args:
- text: The text to check.
- open: The open character.
- close: The close character.
- Returns:
- Whether the text is wrapped.
- """
- close = get_close_char(open, close)
- return text.startswith(open) and text.endswith(close)
- def wrap(
- text: str,
- open: str,
- close: Optional[str] = None,
- check_first: bool = True,
- num: int = 1,
- ) -> str:
- """Wrap the given text in the given open and close characters.
- Args:
- text: The text to wrap.
- open: The open character.
- close: The close character.
- check_first: Whether to check if the text is already wrapped.
- num: The number of times to wrap the text.
- Returns:
- The wrapped text.
- """
- close = get_close_char(open, close)
- # If desired, check if the text is already wrapped in braces.
- if check_first and is_wrapped(text=text, open=open, close=close):
- return text
- # Wrap the text in braces.
- return f"{open * num}{text}{close * num}"
- def indent(text: str, indent_level: int = 2) -> str:
- """Indent the given text by the given indent level.
- Args:
- text: The text to indent.
- indent_level: The indent level.
- Returns:
- The indented text.
- """
- lines = text.splitlines()
- if len(lines) < 2:
- return text
- return os.linesep.join(f"{' ' * indent_level}{line}" for line in lines) + os.linesep
- def get_path_args(path: str) -> List[str]:
- """Get the path arguments for the given path.
- Args:
- path: The path to get the arguments for.
- Returns:
- The path arguments.
- """
- # Import here to avoid circular imports.
- from pynecone.var import BaseVar
- # Regex to check for path args.
- check = re.compile(r"^\[(.+)\]$")
- # Iterate over the path parts and check for path args.
- args = []
- for part in os.path.split(path):
- match = check.match(part)
- if match:
- # Add the path arg to the list.
- v = BaseVar(
- name=match.groups()[0],
- type_=str,
- state=f"{constants.ROUTER}.query",
- )
- args.append(v)
- return args
- def format_route(route: str):
- """Format the given route.
- Args:
- route: The route to format.
- Returns:
- The formatted route.
- """
- route = route.strip(os.path.sep)
- route = to_snake_case(route).replace("_", "-")
- if route == "":
- return constants.INDEX_ROUTE
- return route
- def format_cond(
- cond: str, true_value: str, false_value: str = '""', is_nested: bool = False
- ) -> str:
- """Format a conditional expression.
- Args:
- cond: The cond.
- true_value: The value to return if the cond is true.
- false_value: The value to return if the cond is false.
- is_nested: Whether the cond is nested.
- Returns:
- The formatted conditional expression.
- """
- expr = f"{cond} ? {true_value} : {false_value}"
- if not is_nested:
- expr = wrap(expr, "{")
- return expr
- def format_event_handler(handler: EventHandler) -> str:
- """Format an event handler.
- Args:
- handler: The event handler to format.
- Returns:
- The formatted function.
- """
- # Get the class that defines the event handler.
- parts = handler.fn.__qualname__.split(".")
- # If there's no enclosing class, just return the function name.
- if len(parts) == 1:
- return parts[-1]
- # Get the state and the function name.
- state_name, name = parts[-2:]
- # Construct the full event handler name.
- try:
- # Try to get the state from the module.
- state = vars(sys.modules[handler.fn.__module__])[state_name]
- except:
- # If the state isn't in the module, just return the function name.
- return handler.fn.__qualname__
- return ".".join([state.get_full_name(), name])
- def format_event(event_spec: EventSpec) -> str:
- """Format an event.
- Args:
- event_spec: The event to format.
- Returns:
- The compiled event.
- """
- args = ",".join([":".join((name, val)) for name, val in event_spec.args])
- return f"E(\"{format_event_handler(event_spec.handler)}\", {wrap(args, '{')})"
- USED_VARIABLES = set()
- def get_unique_variable_name() -> str:
- """Get a unique variable name.
- Returns:
- The unique variable name.
- """
- name = "".join([random.choice(string.ascii_lowercase) for _ in range(8)])
- if name not in USED_VARIABLES:
- USED_VARIABLES.add(name)
- return name
- return get_unique_variable_name()
- def get_default_app_name() -> str:
- """Get the default app name.
- The default app name is the name of the current directory.
- Returns:
- The default app name.
- """
- return os.getcwd().split(os.path.sep)[-1].replace("-", "_")
- def is_dataframe(value: Type) -> bool:
- """Check if the given value is a dataframe.
- Args:
- value: The value to check.
- Returns:
- Whether the value is a dataframe.
- """
- return value.__name__ == "DataFrame"
- def format_state(value: Any) -> Dict:
- """Recursively format values in the given state.
- Args:
- value: The state to format.
- Returns:
- The formatted state.
- Raises:
- TypeError: If the given value is not a valid state.
- """
- # Handle dicts.
- if isinstance(value, dict):
- return {k: format_state(v) for k, v in value.items()}
- # Return state vars as is.
- if isinstance(value, StateBases):
- return value
- # Convert plotly figures to JSON.
- if isinstance(value, go.Figure):
- return json.loads(to_json(value))["data"]
- # Convert pandas dataframes to JSON.
- if is_dataframe(type(value)):
- return {
- "columns": value.columns.tolist(),
- "data": value.values.tolist(),
- }
- raise TypeError(
- "State vars must be primitive Python types, "
- "or subclasses of pc.Base. "
- f"Got var of type {type(value)}."
- )
- def get_event(state, event):
- """Get the event from the given state.
- Args:
- state: The state.
- event: The event.
- Returns:
- The event.
- """
- return f"{state.get_name()}.{event}"
- def format_string(string: str) -> str:
- """Format the given string as a JS string literal..
- Args:
- string: The string to format.
- Returns:
- The formatted string.
- """
- # Escape backticks.
- string = string.replace(r"\`", "`")
- string = string.replace("`", r"\`")
- # Wrap the string so it looks like {`string`}.
- string = wrap(string, "`")
- string = wrap(string, "{")
- return string
- def call_event_handler(event_handler: EventHandler, arg: Var) -> EventSpec:
- """Call an event handler to get the event spec.
- This function will inspect the function signature of the event handler.
- If it takes in an arg, the arg will be passed to the event handler.
- Otherwise, the event handler will be called with no args.
- Args:
- event_handler: The event handler.
- arg: The argument to pass to the event handler.
- Returns:
- The event spec from calling the event handler.
- """
- args = inspect.getfullargspec(event_handler.fn).args
- if len(args) == 1:
- return event_handler()
- assert (
- len(args) == 2
- ), f"Event handler {event_handler.fn} must have 1 or 2 arguments."
- return event_handler(arg)
- def call_event_fn(fn: Callable, arg: Var) -> List[EventSpec]:
- """Call a function to a list of event specs.
- The function should return either a single EventSpec or a list of EventSpecs.
- If the function takes in an arg, the arg will be passed to the function.
- Otherwise, the function will be called with no args.
- Args:
- fn: The function to call.
- arg: The argument to pass to the function.
- Returns:
- The event specs from calling the function.
- Raises:
- ValueError: If the lambda has an invalid signature.
- """
- args = inspect.getfullargspec(fn).args
- if len(args) == 0:
- out = fn()
- elif len(args) == 1:
- out = fn(arg)
- else:
- raise ValueError(f"Lambda {fn} must have 0 or 1 arguments.")
- if not isinstance(out, List):
- out = [out]
- return out
- def get_handler_args(event_spec: EventSpec, arg: Var) -> Tuple[Tuple[str, str], ...]:
- """Get the handler args for the given event spec.
- Args:
- event_spec: The event spec.
- arg: The controlled event argument.
- Returns:
- The handler args.
- """
- args = inspect.getfullargspec(event_spec.handler.fn).args
- if len(args) > 2:
- return event_spec.args
- else:
- return ((args[1], arg.name),)
- def fix_events(events: Optional[List[Event]], token: str) -> List[Event]:
- """Fix a list of events returned by an event handler.
- Args:
- events: The events to fix.
- token: The user token.
- Returns:
- The fixed events.
- """
- from pynecone.event import Event, EventHandler, EventSpec
- # If the event handler returns nothing, return an empty list.
- if events is None:
- return []
- # If the handler returns a single event, wrap it in a list.
- if not isinstance(events, List):
- events = [events]
- # Fix the events created by the handler.
- out = []
- for e in events:
- # If it is already an event, don't modify it.
- if isinstance(e, Event):
- name = e.name
- payload = e.payload
- # Otherwise, create an event from the event spec.
- else:
- if isinstance(e, EventHandler):
- e = e()
- assert isinstance(e, EventSpec), f"Unexpected event type, {type(e)}."
- name = format_event_handler(e.handler)
- payload = dict(e.args)
- # Create an event and append it to the list.
- out.append(
- Event(
- token=token,
- name=name,
- payload=payload,
- )
- )
- return out
- def merge_imports(*imports) -> ImportDict:
- """Merge two import dicts together.
- Args:
- *imports: The list of import dicts to merge.
- Returns:
- The merged import dicts.
- """
- all_imports = defaultdict(set)
- for import_dict in imports:
- for lib, fields in import_dict.items():
- for field in fields:
- all_imports[lib].add(field)
- return all_imports
- def get_hydrate_event(state) -> str:
- """Get the name of the hydrate event for the state.
- Args:
- state: The state.
- Returns:
- The name of the hydrate event.
- """
- return get_event(state, constants.HYDRATE)
- def get_redis() -> Optional[Redis]:
- """Get the redis client.
- Returns:
- The redis client.
- """
- config = get_config()
- if config.redis_url is None:
- return None
- redis_url, redis_port = config.redis_url.split(":")
- print("Using redis at", config.redis_url)
- return Redis(host=redis_url, port=int(redis_port), db=0)
- # Store this here for performance.
- StateBases = get_base_class(StateVar)
|