|
@@ -11,6 +11,7 @@ import inspect
|
|
|
import json
|
|
|
import pickle
|
|
|
import sys
|
|
|
+import time
|
|
|
import typing
|
|
|
import uuid
|
|
|
from abc import ABC, abstractmethod
|
|
@@ -39,6 +40,7 @@ from typing import (
|
|
|
get_type_hints,
|
|
|
)
|
|
|
|
|
|
+from redis.asyncio.client import PubSub
|
|
|
from sqlalchemy.orm import DeclarativeBase
|
|
|
from typing_extensions import Self
|
|
|
|
|
@@ -3479,6 +3481,35 @@ class StateManagerRedis(StateManager):
|
|
|
nx=True, # only set if it doesn't exist
|
|
|
)
|
|
|
|
|
|
+ async def _get_pubsub_message(
|
|
|
+ self, pubsub: PubSub, timeout: float | None = None
|
|
|
+ ) -> None:
|
|
|
+ """Get lock release events from the pubsub.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ pubsub: The pubsub to get a message from.
|
|
|
+ timeout: Remaining time to wait for a message.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The message.
|
|
|
+ """
|
|
|
+ if timeout is None:
|
|
|
+ timeout = self.lock_expiration / 1000.0
|
|
|
+
|
|
|
+ started = time.time()
|
|
|
+ message = await pubsub.get_message(
|
|
|
+ ignore_subscribe_messages=True,
|
|
|
+ timeout=timeout,
|
|
|
+ )
|
|
|
+ if (
|
|
|
+ message is None
|
|
|
+ or message["data"] not in self._redis_keyspace_lock_release_events
|
|
|
+ ):
|
|
|
+ remaining = timeout - (time.time() - started)
|
|
|
+ if remaining <= 0:
|
|
|
+ return
|
|
|
+ await self._get_pubsub_message(pubsub, timeout=remaining)
|
|
|
+
|
|
|
async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
|
|
|
"""Wait for a redis lock to be released via pubsub.
|
|
|
|
|
@@ -3491,7 +3522,6 @@ class StateManagerRedis(StateManager):
|
|
|
Raises:
|
|
|
ResponseError: when the keyspace config cannot be set.
|
|
|
"""
|
|
|
- state_is_locked = False
|
|
|
lock_key_channel = f"__keyspace@0__:{lock_key.decode()}"
|
|
|
# Enable keyspace notifications for the lock key, so we know when it is available.
|
|
|
try:
|
|
@@ -3505,20 +3535,13 @@ class StateManagerRedis(StateManager):
|
|
|
raise
|
|
|
async with self.redis.pubsub() as pubsub:
|
|
|
await pubsub.psubscribe(lock_key_channel)
|
|
|
- while not state_is_locked:
|
|
|
- # wait for the lock to be released
|
|
|
- while True:
|
|
|
- if not await self.redis.exists(lock_key):
|
|
|
- break # key was removed, try to get the lock again
|
|
|
- message = await pubsub.get_message(
|
|
|
- ignore_subscribe_messages=True,
|
|
|
- timeout=self.lock_expiration / 1000.0,
|
|
|
- )
|
|
|
- if message is None:
|
|
|
- continue
|
|
|
- if message["data"] in self._redis_keyspace_lock_release_events:
|
|
|
- break
|
|
|
- state_is_locked = await self._try_get_lock(lock_key, lock_id)
|
|
|
+ # wait for the lock to be released
|
|
|
+ while True:
|
|
|
+ # fast path
|
|
|
+ if await self._try_get_lock(lock_key, lock_id):
|
|
|
+ return
|
|
|
+ # wait for lock events
|
|
|
+ await self._get_pubsub_message(pubsub)
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
|
async def _lock(self, token: str):
|