|
@@ -11,6 +11,7 @@ import inspect
|
|
import json
|
|
import json
|
|
import pickle
|
|
import pickle
|
|
import sys
|
|
import sys
|
|
|
|
+import time
|
|
import typing
|
|
import typing
|
|
import uuid
|
|
import uuid
|
|
from abc import ABC, abstractmethod
|
|
from abc import ABC, abstractmethod
|
|
@@ -39,6 +40,7 @@ from typing import (
|
|
get_type_hints,
|
|
get_type_hints,
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+from redis.asyncio.client import PubSub
|
|
from sqlalchemy.orm import DeclarativeBase
|
|
from sqlalchemy.orm import DeclarativeBase
|
|
from typing_extensions import Self
|
|
from typing_extensions import Self
|
|
|
|
|
|
@@ -69,6 +71,11 @@ try:
|
|
except ModuleNotFoundError:
|
|
except ModuleNotFoundError:
|
|
BaseModelV1 = BaseModelV2
|
|
BaseModelV1 = BaseModelV2
|
|
|
|
|
|
|
|
+try:
|
|
|
|
+ from pydantic.v1 import validator
|
|
|
|
+except ModuleNotFoundError:
|
|
|
|
+ from pydantic import validator
|
|
|
|
+
|
|
import wrapt
|
|
import wrapt
|
|
from redis.asyncio import Redis
|
|
from redis.asyncio import Redis
|
|
from redis.exceptions import ResponseError
|
|
from redis.exceptions import ResponseError
|
|
@@ -92,6 +99,7 @@ from reflex.utils.exceptions import (
|
|
DynamicRouteArgShadowsStateVar,
|
|
DynamicRouteArgShadowsStateVar,
|
|
EventHandlerShadowsBuiltInStateMethod,
|
|
EventHandlerShadowsBuiltInStateMethod,
|
|
ImmutableStateError,
|
|
ImmutableStateError,
|
|
|
|
+ InvalidLockWarningThresholdError,
|
|
InvalidStateManagerMode,
|
|
InvalidStateManagerMode,
|
|
LockExpiredError,
|
|
LockExpiredError,
|
|
ReflexRuntimeError,
|
|
ReflexRuntimeError,
|
|
@@ -429,9 +437,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
)
|
|
)
|
|
|
|
|
|
# Create a fresh copy of the backend variables for this instance
|
|
# Create a fresh copy of the backend variables for this instance
|
|
- self._backend_vars = copy.deepcopy(
|
|
|
|
- {name: item for name, item in self.backend_vars.items()}
|
|
|
|
- )
|
|
|
|
|
|
+ self._backend_vars = copy.deepcopy(self.backend_vars)
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
def __repr__(self) -> str:
|
|
"""Get the string representation of the state.
|
|
"""Get the string representation of the state.
|
|
@@ -515,9 +521,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
cls.inherited_backend_vars = parent_state.backend_vars
|
|
cls.inherited_backend_vars = parent_state.backend_vars
|
|
|
|
|
|
# Check if another substate class with the same name has already been defined.
|
|
# Check if another substate class with the same name has already been defined.
|
|
- if cls.get_name() in set(
|
|
|
|
- c.get_name() for c in parent_state.class_subclasses
|
|
|
|
- ):
|
|
|
|
|
|
+ if cls.get_name() in {c.get_name() for c in parent_state.class_subclasses}:
|
|
# This should not happen, since we have added module prefix to state names in #3214
|
|
# This should not happen, since we have added module prefix to state names in #3214
|
|
raise StateValueError(
|
|
raise StateValueError(
|
|
f"The substate class '{cls.get_name()}' has been defined multiple times. "
|
|
f"The substate class '{cls.get_name()}' has been defined multiple times. "
|
|
@@ -780,11 +784,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
)
|
|
)
|
|
|
|
|
|
# ComputedVar with cache=False always need to be recomputed
|
|
# ComputedVar with cache=False always need to be recomputed
|
|
- cls._always_dirty_computed_vars = set(
|
|
|
|
|
|
+ cls._always_dirty_computed_vars = {
|
|
cvar_name
|
|
cvar_name
|
|
for cvar_name, cvar in cls.computed_vars.items()
|
|
for cvar_name, cvar in cls.computed_vars.items()
|
|
if not cvar._cache
|
|
if not cvar._cache
|
|
- )
|
|
|
|
|
|
+ }
|
|
|
|
|
|
# Any substate containing a ComputedVar with cache=False always needs to be recomputed
|
|
# Any substate containing a ComputedVar with cache=False always needs to be recomputed
|
|
if cls._always_dirty_computed_vars:
|
|
if cls._always_dirty_computed_vars:
|
|
@@ -1095,6 +1099,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
if (
|
|
if (
|
|
not field.required
|
|
not field.required
|
|
and field.default is None
|
|
and field.default is None
|
|
|
|
+ and field.default_factory is None
|
|
and not types.is_optional(prop._var_type)
|
|
and not types.is_optional(prop._var_type)
|
|
):
|
|
):
|
|
# Ensure frontend uses null coalescing when accessing.
|
|
# Ensure frontend uses null coalescing when accessing.
|
|
@@ -1235,13 +1240,16 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
if not super().__getattribute__("__dict__"):
|
|
if not super().__getattribute__("__dict__"):
|
|
return super().__getattribute__(name)
|
|
return super().__getattribute__(name)
|
|
|
|
|
|
- inherited_vars = {
|
|
|
|
- **super().__getattribute__("inherited_vars"),
|
|
|
|
- **super().__getattribute__("inherited_backend_vars"),
|
|
|
|
- }
|
|
|
|
|
|
+ # Fast path for dunder
|
|
|
|
+ if name.startswith("__"):
|
|
|
|
+ return super().__getattribute__(name)
|
|
|
|
|
|
# For now, handle router_data updates as a special case.
|
|
# For now, handle router_data updates as a special case.
|
|
- if name in inherited_vars or name == constants.ROUTER_DATA:
|
|
|
|
|
|
+ if (
|
|
|
|
+ name == constants.ROUTER_DATA
|
|
|
|
+ or name in super().__getattribute__("inherited_vars")
|
|
|
|
+ or name in super().__getattribute__("inherited_backend_vars")
|
|
|
|
+ ):
|
|
parent_state = super().__getattribute__("parent_state")
|
|
parent_state = super().__getattribute__("parent_state")
|
|
if parent_state is not None:
|
|
if parent_state is not None:
|
|
return getattr(parent_state, name)
|
|
return getattr(parent_state, name)
|
|
@@ -1296,15 +1304,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
value = value.__wrapped__
|
|
value = value.__wrapped__
|
|
|
|
|
|
# Set the var on the parent state.
|
|
# Set the var on the parent state.
|
|
- inherited_vars = {**self.inherited_vars, **self.inherited_backend_vars}
|
|
|
|
- if name in inherited_vars:
|
|
|
|
|
|
+ if name in self.inherited_vars or name in self.inherited_backend_vars:
|
|
setattr(self.parent_state, name, value)
|
|
setattr(self.parent_state, name, value)
|
|
return
|
|
return
|
|
|
|
|
|
if name in self.backend_vars:
|
|
if name in self.backend_vars:
|
|
- # abort if unchanged
|
|
|
|
- if self._backend_vars.get(name) == value:
|
|
|
|
- return
|
|
|
|
self._backend_vars.__setitem__(name, value)
|
|
self._backend_vars.__setitem__(name, value)
|
|
self.dirty_vars.add(name)
|
|
self.dirty_vars.add(name)
|
|
self._mark_dirty()
|
|
self._mark_dirty()
|
|
@@ -1853,11 +1857,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
Returns:
|
|
Returns:
|
|
Set of computed vars to include in the delta.
|
|
Set of computed vars to include in the delta.
|
|
"""
|
|
"""
|
|
- return set(
|
|
|
|
|
|
+ return {
|
|
cvar
|
|
cvar
|
|
for cvar in self.computed_vars
|
|
for cvar in self.computed_vars
|
|
if self.computed_vars[cvar].needs_update(instance=self)
|
|
if self.computed_vars[cvar].needs_update(instance=self)
|
|
- )
|
|
|
|
|
|
+ }
|
|
|
|
|
|
def _dirty_computed_vars(
|
|
def _dirty_computed_vars(
|
|
self, from_vars: set[str] | None = None, include_backend: bool = True
|
|
self, from_vars: set[str] | None = None, include_backend: bool = True
|
|
@@ -1871,12 +1875,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
Returns:
|
|
Returns:
|
|
Set of computed vars to include in the delta.
|
|
Set of computed vars to include in the delta.
|
|
"""
|
|
"""
|
|
- return set(
|
|
|
|
|
|
+ return {
|
|
cvar
|
|
cvar
|
|
for dirty_var in from_vars or self.dirty_vars
|
|
for dirty_var in from_vars or self.dirty_vars
|
|
for cvar in self._computed_var_dependencies[dirty_var]
|
|
for cvar in self._computed_var_dependencies[dirty_var]
|
|
if include_backend or not self.computed_vars[cvar]._backend
|
|
if include_backend or not self.computed_vars[cvar]._backend
|
|
- )
|
|
|
|
|
|
+ }
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
|
|
def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
|
|
@@ -1886,16 +1890,16 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
Set of State classes that may need to be fetched to recalc computed vars.
|
|
Set of State classes that may need to be fetched to recalc computed vars.
|
|
"""
|
|
"""
|
|
# _always_dirty_substates need to be fetched to recalc computed vars.
|
|
# _always_dirty_substates need to be fetched to recalc computed vars.
|
|
- fetch_substates = set(
|
|
|
|
|
|
+ fetch_substates = {
|
|
cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
|
|
cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
|
|
for substate_name in cls._always_dirty_substates
|
|
for substate_name in cls._always_dirty_substates
|
|
- )
|
|
|
|
|
|
+ }
|
|
for dependent_substates in cls._substate_var_dependencies.values():
|
|
for dependent_substates in cls._substate_var_dependencies.values():
|
|
fetch_substates.update(
|
|
fetch_substates.update(
|
|
- set(
|
|
|
|
|
|
+ {
|
|
cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
|
|
cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
|
|
for substate_name in dependent_substates
|
|
for substate_name in dependent_substates
|
|
- )
|
|
|
|
|
|
+ }
|
|
)
|
|
)
|
|
return fetch_substates
|
|
return fetch_substates
|
|
|
|
|
|
@@ -2122,14 +2126,26 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
state["__dict__"].pop("router", None)
|
|
state["__dict__"].pop("router", None)
|
|
state["__dict__"].pop("router_data", None)
|
|
state["__dict__"].pop("router_data", None)
|
|
# Never serialize parent_state or substates.
|
|
# Never serialize parent_state or substates.
|
|
- state["__dict__"]["parent_state"] = None
|
|
|
|
- state["__dict__"]["substates"] = {}
|
|
|
|
|
|
+ state["__dict__"].pop("parent_state", None)
|
|
|
|
+ state["__dict__"].pop("substates", None)
|
|
state["__dict__"].pop("_was_touched", None)
|
|
state["__dict__"].pop("_was_touched", None)
|
|
# Remove all inherited vars.
|
|
# Remove all inherited vars.
|
|
for inherited_var_name in self.inherited_vars:
|
|
for inherited_var_name in self.inherited_vars:
|
|
state["__dict__"].pop(inherited_var_name, None)
|
|
state["__dict__"].pop(inherited_var_name, None)
|
|
return state
|
|
return state
|
|
|
|
|
|
|
|
+ def __setstate__(self, state: dict[str, Any]):
|
|
|
|
+ """Set the state from redis deserialization.
|
|
|
|
+
|
|
|
|
+ This method is called by pickle to deserialize the object.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ state: The state dict for deserialization.
|
|
|
|
+ """
|
|
|
|
+ state["__dict__"]["parent_state"] = None
|
|
|
|
+ state["__dict__"]["substates"] = {}
|
|
|
|
+ super().__setstate__(state)
|
|
|
|
+
|
|
def _check_state_size(
|
|
def _check_state_size(
|
|
self,
|
|
self,
|
|
pickle_state_size: int,
|
|
pickle_state_size: int,
|
|
@@ -2185,7 +2201,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|
|
|
|
|
return md5(
|
|
return md5(
|
|
pickle.dumps(
|
|
pickle.dumps(
|
|
- list(sorted(_field_tuple(field_name) for field_name in cls.base_vars))
|
|
|
|
|
|
+ sorted(_field_tuple(field_name) for field_name in cls.base_vars)
|
|
)
|
|
)
|
|
).hexdigest()
|
|
).hexdigest()
|
|
|
|
|
|
@@ -2819,6 +2835,7 @@ class StateManager(Base, ABC):
|
|
redis=redis,
|
|
redis=redis,
|
|
token_expiration=config.redis_token_expiration,
|
|
token_expiration=config.redis_token_expiration,
|
|
lock_expiration=config.redis_lock_expiration,
|
|
lock_expiration=config.redis_lock_expiration,
|
|
|
|
+ lock_warning_threshold=config.redis_lock_warning_threshold,
|
|
)
|
|
)
|
|
raise InvalidStateManagerMode(
|
|
raise InvalidStateManagerMode(
|
|
f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
|
|
f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
|
|
@@ -3188,6 +3205,15 @@ def _default_lock_expiration() -> int:
|
|
return get_config().redis_lock_expiration
|
|
return get_config().redis_lock_expiration
|
|
|
|
|
|
|
|
|
|
|
|
+def _default_lock_warning_threshold() -> int:
|
|
|
|
+ """Get the default lock warning threshold.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ The default lock warning threshold.
|
|
|
|
+ """
|
|
|
|
+ return get_config().redis_lock_warning_threshold
|
|
|
|
+
|
|
|
|
+
|
|
class StateManagerRedis(StateManager):
|
|
class StateManagerRedis(StateManager):
|
|
"""A state manager that stores states in redis."""
|
|
"""A state manager that stores states in redis."""
|
|
|
|
|
|
@@ -3200,6 +3226,11 @@ class StateManagerRedis(StateManager):
|
|
# The maximum time to hold a lock (ms).
|
|
# The maximum time to hold a lock (ms).
|
|
lock_expiration: int = pydantic.Field(default_factory=_default_lock_expiration)
|
|
lock_expiration: int = pydantic.Field(default_factory=_default_lock_expiration)
|
|
|
|
|
|
|
|
+ # The maximum time to hold a lock (ms) before warning.
|
|
|
|
+ lock_warning_threshold: int = pydantic.Field(
|
|
|
|
+ default_factory=_default_lock_warning_threshold
|
|
|
|
+ )
|
|
|
|
+
|
|
# The keyspace subscription string when redis is waiting for lock to be released
|
|
# The keyspace subscription string when redis is waiting for lock to be released
|
|
_redis_notify_keyspace_events: str = (
|
|
_redis_notify_keyspace_events: str = (
|
|
"K" # Enable keyspace notifications (target a particular key)
|
|
"K" # Enable keyspace notifications (target a particular key)
|
|
@@ -3318,7 +3349,7 @@ class StateManagerRedis(StateManager):
|
|
state_cls = self.state.get_class_substate(state_path)
|
|
state_cls = self.state.get_class_substate(state_path)
|
|
else:
|
|
else:
|
|
raise RuntimeError(
|
|
raise RuntimeError(
|
|
- "StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}"
|
|
|
|
|
|
+ f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
|
|
)
|
|
)
|
|
|
|
|
|
# The deserialized or newly created (sub)state instance.
|
|
# The deserialized or newly created (sub)state instance.
|
|
@@ -3387,6 +3418,17 @@ class StateManagerRedis(StateManager):
|
|
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
|
|
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
|
|
"or use `@rx.event(background=True)` decorator for long-running tasks."
|
|
"or use `@rx.event(background=True)` decorator for long-running tasks."
|
|
)
|
|
)
|
|
|
|
+ elif lock_id is not None:
|
|
|
|
+ time_taken = self.lock_expiration / 1000 - (
|
|
|
|
+ await self.redis.ttl(self._lock_key(token))
|
|
|
|
+ )
|
|
|
|
+ if time_taken > self.lock_warning_threshold / 1000:
|
|
|
|
+ console.warn(
|
|
|
|
+ f"Lock for token {token} was held too long {time_taken=}s, "
|
|
|
|
+ f"use `@rx.event(background=True)` decorator for long-running tasks.",
|
|
|
|
+ dedupe=True,
|
|
|
|
+ )
|
|
|
|
+
|
|
client_token, substate_name = _split_substate_key(token)
|
|
client_token, substate_name = _split_substate_key(token)
|
|
# If the substate name on the token doesn't match the instance name, it cannot have a parent.
|
|
# If the substate name on the token doesn't match the instance name, it cannot have a parent.
|
|
if state.parent_state is not None and state.get_full_name() != substate_name:
|
|
if state.parent_state is not None and state.get_full_name() != substate_name:
|
|
@@ -3395,17 +3437,16 @@ class StateManagerRedis(StateManager):
|
|
)
|
|
)
|
|
|
|
|
|
# Recursively set_state on all known substates.
|
|
# Recursively set_state on all known substates.
|
|
- tasks = []
|
|
|
|
- for substate in state.substates.values():
|
|
|
|
- tasks.append(
|
|
|
|
- asyncio.create_task(
|
|
|
|
- self.set_state(
|
|
|
|
- token=_substate_key(client_token, substate),
|
|
|
|
- state=substate,
|
|
|
|
- lock_id=lock_id,
|
|
|
|
- )
|
|
|
|
|
|
+ tasks = [
|
|
|
|
+ asyncio.create_task(
|
|
|
|
+ self.set_state(
|
|
|
|
+ _substate_key(client_token, substate),
|
|
|
|
+ substate,
|
|
|
|
+ lock_id,
|
|
)
|
|
)
|
|
)
|
|
)
|
|
|
|
+ for substate in state.substates.values()
|
|
|
|
+ ]
|
|
# Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
|
|
# Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
|
|
if state._get_was_touched():
|
|
if state._get_was_touched():
|
|
pickle_state = state._serialize()
|
|
pickle_state = state._serialize()
|
|
@@ -3436,6 +3477,27 @@ class StateManagerRedis(StateManager):
|
|
yield state
|
|
yield state
|
|
await self.set_state(token, state, lock_id)
|
|
await self.set_state(token, state, lock_id)
|
|
|
|
|
|
|
|
+ @validator("lock_warning_threshold")
|
|
|
|
+ @classmethod
|
|
|
|
+ def validate_lock_warning_threshold(cls, lock_warning_threshold: int, values):
|
|
|
|
+ """Validate the lock warning threshold.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ lock_warning_threshold: The lock warning threshold.
|
|
|
|
+ values: The validated attributes.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ The lock warning threshold.
|
|
|
|
+
|
|
|
|
+ Raises:
|
|
|
|
+ InvalidLockWarningThresholdError: If the lock warning threshold is invalid.
|
|
|
|
+ """
|
|
|
|
+ if lock_warning_threshold >= (lock_expiration := values["lock_expiration"]):
|
|
|
|
+ raise InvalidLockWarningThresholdError(
|
|
|
|
+ f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
|
|
|
|
+ )
|
|
|
|
+ return lock_warning_threshold
|
|
|
|
+
|
|
@staticmethod
|
|
@staticmethod
|
|
def _lock_key(token: str) -> bytes:
|
|
def _lock_key(token: str) -> bytes:
|
|
"""Get the redis key for a token's lock.
|
|
"""Get the redis key for a token's lock.
|
|
@@ -3467,6 +3529,35 @@ class StateManagerRedis(StateManager):
|
|
nx=True, # only set if it doesn't exist
|
|
nx=True, # only set if it doesn't exist
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ async def _get_pubsub_message(
|
|
|
|
+ self, pubsub: PubSub, timeout: float | None = None
|
|
|
|
+ ) -> None:
|
|
|
|
+ """Get lock release events from the pubsub.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ pubsub: The pubsub to get a message from.
|
|
|
|
+ timeout: Remaining time to wait for a message.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ The message.
|
|
|
|
+ """
|
|
|
|
+ if timeout is None:
|
|
|
|
+ timeout = self.lock_expiration / 1000.0
|
|
|
|
+
|
|
|
|
+ started = time.time()
|
|
|
|
+ message = await pubsub.get_message(
|
|
|
|
+ ignore_subscribe_messages=True,
|
|
|
|
+ timeout=timeout,
|
|
|
|
+ )
|
|
|
|
+ if (
|
|
|
|
+ message is None
|
|
|
|
+ or message["data"] not in self._redis_keyspace_lock_release_events
|
|
|
|
+ ):
|
|
|
|
+ remaining = timeout - (time.time() - started)
|
|
|
|
+ if remaining <= 0:
|
|
|
|
+ return
|
|
|
|
+ await self._get_pubsub_message(pubsub, timeout=remaining)
|
|
|
|
+
|
|
async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
|
|
async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
|
|
"""Wait for a redis lock to be released via pubsub.
|
|
"""Wait for a redis lock to be released via pubsub.
|
|
|
|
|
|
@@ -3479,7 +3570,6 @@ class StateManagerRedis(StateManager):
|
|
Raises:
|
|
Raises:
|
|
ResponseError: when the keyspace config cannot be set.
|
|
ResponseError: when the keyspace config cannot be set.
|
|
"""
|
|
"""
|
|
- state_is_locked = False
|
|
|
|
lock_key_channel = f"__keyspace@0__:{lock_key.decode()}"
|
|
lock_key_channel = f"__keyspace@0__:{lock_key.decode()}"
|
|
# Enable keyspace notifications for the lock key, so we know when it is available.
|
|
# Enable keyspace notifications for the lock key, so we know when it is available.
|
|
try:
|
|
try:
|
|
@@ -3493,20 +3583,13 @@ class StateManagerRedis(StateManager):
|
|
raise
|
|
raise
|
|
async with self.redis.pubsub() as pubsub:
|
|
async with self.redis.pubsub() as pubsub:
|
|
await pubsub.psubscribe(lock_key_channel)
|
|
await pubsub.psubscribe(lock_key_channel)
|
|
- while not state_is_locked:
|
|
|
|
- # wait for the lock to be released
|
|
|
|
- while True:
|
|
|
|
- if not await self.redis.exists(lock_key):
|
|
|
|
- break # key was removed, try to get the lock again
|
|
|
|
- message = await pubsub.get_message(
|
|
|
|
- ignore_subscribe_messages=True,
|
|
|
|
- timeout=self.lock_expiration / 1000.0,
|
|
|
|
- )
|
|
|
|
- if message is None:
|
|
|
|
- continue
|
|
|
|
- if message["data"] in self._redis_keyspace_lock_release_events:
|
|
|
|
- break
|
|
|
|
- state_is_locked = await self._try_get_lock(lock_key, lock_id)
|
|
|
|
|
|
+ # wait for the lock to be released
|
|
|
|
+ while True:
|
|
|
|
+ # fast path
|
|
|
|
+ if await self._try_get_lock(lock_key, lock_id):
|
|
|
|
+ return
|
|
|
|
+ # wait for lock events
|
|
|
|
+ await self._get_pubsub_message(pubsub)
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
@contextlib.asynccontextmanager
|
|
async def _lock(self, token: str):
|
|
async def _lock(self, token: str):
|
|
@@ -3565,33 +3648,30 @@ class MutableProxy(wrapt.ObjectProxy):
|
|
"""A proxy for a mutable object that tracks changes."""
|
|
"""A proxy for a mutable object that tracks changes."""
|
|
|
|
|
|
# Methods on wrapped objects which should mark the state as dirty.
|
|
# Methods on wrapped objects which should mark the state as dirty.
|
|
- __mark_dirty_attrs__ = set(
|
|
|
|
- [
|
|
|
|
- "add",
|
|
|
|
- "append",
|
|
|
|
- "clear",
|
|
|
|
- "difference_update",
|
|
|
|
- "discard",
|
|
|
|
- "extend",
|
|
|
|
- "insert",
|
|
|
|
- "intersection_update",
|
|
|
|
- "pop",
|
|
|
|
- "popitem",
|
|
|
|
- "remove",
|
|
|
|
- "reverse",
|
|
|
|
- "setdefault",
|
|
|
|
- "sort",
|
|
|
|
- "symmetric_difference_update",
|
|
|
|
- "update",
|
|
|
|
- ]
|
|
|
|
- )
|
|
|
|
|
|
+ __mark_dirty_attrs__ = {
|
|
|
|
+ "add",
|
|
|
|
+ "append",
|
|
|
|
+ "clear",
|
|
|
|
+ "difference_update",
|
|
|
|
+ "discard",
|
|
|
|
+ "extend",
|
|
|
|
+ "insert",
|
|
|
|
+ "intersection_update",
|
|
|
|
+ "pop",
|
|
|
|
+ "popitem",
|
|
|
|
+ "remove",
|
|
|
|
+ "reverse",
|
|
|
|
+ "setdefault",
|
|
|
|
+ "sort",
|
|
|
|
+ "symmetric_difference_update",
|
|
|
|
+ "update",
|
|
|
|
+ }
|
|
|
|
+
|
|
# Methods on wrapped objects might return mutable objects that should be tracked.
|
|
# Methods on wrapped objects might return mutable objects that should be tracked.
|
|
- __wrap_mutable_attrs__ = set(
|
|
|
|
- [
|
|
|
|
- "get",
|
|
|
|
- "setdefault",
|
|
|
|
- ]
|
|
|
|
- )
|
|
|
|
|
|
+ __wrap_mutable_attrs__ = {
|
|
|
|
+ "get",
|
|
|
|
+ "setdefault",
|
|
|
|
+ }
|
|
|
|
|
|
# These internal attributes on rx.Base should NOT be wrapped in a MutableProxy.
|
|
# These internal attributes on rx.Base should NOT be wrapped in a MutableProxy.
|
|
__never_wrap_base_attrs__ = set(Base.__dict__) - {"set"} | set(
|
|
__never_wrap_base_attrs__ = set(Base.__dict__) - {"set"} | set(
|
|
@@ -3634,7 +3714,7 @@ class MutableProxy(wrapt.ObjectProxy):
|
|
self,
|
|
self,
|
|
wrapped=None,
|
|
wrapped=None,
|
|
instance=None,
|
|
instance=None,
|
|
- args=tuple(),
|
|
|
|
|
|
+ args=(),
|
|
kwargs=None,
|
|
kwargs=None,
|
|
) -> Any:
|
|
) -> Any:
|
|
"""Mark the state as dirty, then call a wrapped function.
|
|
"""Mark the state as dirty, then call a wrapped function.
|
|
@@ -3890,7 +3970,7 @@ class ImmutableMutableProxy(MutableProxy):
|
|
self,
|
|
self,
|
|
wrapped=None,
|
|
wrapped=None,
|
|
instance=None,
|
|
instance=None,
|
|
- args=tuple(),
|
|
|
|
|
|
+ args=(),
|
|
kwargs=None,
|
|
kwargs=None,
|
|
) -> Any:
|
|
) -> Any:
|
|
"""Raise an exception when an attempt is made to modify the object.
|
|
"""Raise an exception when an attempt is made to modify the object.
|