Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
fail-fast: false
matrix:
python-version: [3.9, '3.10', 3.11, 3.12, 3.13]
redis-version: ['6.2.6-v9', 'latest', '8.0-M03']
redis-version: ['6.2.6-v9', 'latest', '8.0.2']

steps:
- name: Check out repository
Expand All @@ -49,7 +49,7 @@ jobs:

- name: Set Redis image name
run: |
if [[ "${{ matrix.redis-version }}" == "8.0-M03" ]]; then
if [[ "${{ matrix.redis-version }}" == "8.0.2" ]]; then
echo "REDIS_IMAGE=redis:${{ matrix.redis-version }}" >> $GITHUB_ENV
else
echo "REDIS_IMAGE=redis/redis-stack-server:${{ matrix.redis-version }}" >> $GITHUB_ENV
Expand Down
2 changes: 1 addition & 1 deletion langgraph/checkpoint/redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def put(
# store at top-level for filters in list()
if all(key in metadata for key in ["source", "step"]):
checkpoint_data["source"] = metadata["source"]
checkpoint_data["step"] = metadata["step"] # type: ignore
checkpoint_data["step"] = metadata["step"]

# Create the checkpoint key
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
Expand Down
55 changes: 23 additions & 32 deletions langgraph/checkpoint/redis/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import logging
import os
from contextlib import asynccontextmanager
from functools import partial
from types import TracebackType
from typing import (
Any,
Expand Down Expand Up @@ -39,7 +38,6 @@
from redisvl.index import AsyncSearchIndex
from redisvl.query import FilterQuery
from redisvl.query.filter import Num, Tag
from redisvl.redis.connection import RedisConnectionFactory

from langgraph.checkpoint.redis.base import BaseRedisSaver
from langgraph.checkpoint.redis.util import (
Expand All @@ -63,14 +61,14 @@ async def _write_obj_tx(
exists: int = await pipe.exists(key)
if upsert_case:
if exists:
await pipe.json().set(key, "$.channel", write_obj["channel"])
await pipe.json().set(key, "$.type", write_obj["type"])
await pipe.json().set(key, "$.blob", write_obj["blob"])
pipe.json().set(key, "$.channel", write_obj["channel"])
pipe.json().set(key, "$.type", write_obj["type"])
pipe.json().set(key, "$.blob", write_obj["blob"])
else:
await pipe.json().set(key, "$", write_obj)
pipe.json().set(key, "$", write_obj)
else:
if not exists:
await pipe.json().set(key, "$", write_obj)
pipe.json().set(key, "$", write_obj)


class AsyncRedisSaver(
Expand Down Expand Up @@ -568,7 +566,7 @@ async def aput(
# store at top-level for filters in list()
if all(key in metadata for key in ["source", "step"]):
checkpoint_data["source"] = metadata["source"]
checkpoint_data["step"] = metadata["step"] # type: ignore
checkpoint_data["step"] = metadata["step"]

# Prepare checkpoint key
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
Expand All @@ -587,11 +585,11 @@ async def aput(

if self.cluster_mode:
# For cluster mode, execute operations individually
await self._redis.json().set(checkpoint_key, "$", checkpoint_data)
await self._redis.json().set(checkpoint_key, "$", checkpoint_data) # type: ignore[misc]

if blobs:
for key, data in blobs:
await self._redis.json().set(key, "$", data)
await self._redis.json().set(key, "$", data) # type: ignore[misc]

# Apply TTL if configured
if self.ttl_config and "default_ttl" in self.ttl_config:
Expand All @@ -604,12 +602,12 @@ async def aput(
pipeline = self._redis.pipeline(transaction=True)

# Add checkpoint data to pipeline
await pipeline.json().set(checkpoint_key, "$", checkpoint_data)
pipeline.json().set(checkpoint_key, "$", checkpoint_data)

if blobs:
# Add all blob operations to the pipeline
for key, data in blobs:
await pipeline.json().set(key, "$", data)
pipeline.json().set(key, "$", data)

# Execute all operations atomically
await pipeline.execute()
Expand Down Expand Up @@ -654,13 +652,13 @@ async def aput(

if self.cluster_mode:
# For cluster mode, execute operation directly
await self._redis.json().set(
await self._redis.json().set( # type: ignore[misc]
checkpoint_key, "$", checkpoint_data
)
else:
# For non-cluster mode, use pipeline
pipeline = self._redis.pipeline(transaction=True)
await pipeline.json().set(checkpoint_key, "$", checkpoint_data)
pipeline.json().set(checkpoint_key, "$", checkpoint_data)
await pipeline.execute()
except Exception:
# If this also fails, we just propagate the original cancellation
Expand Down Expand Up @@ -739,24 +737,19 @@ async def aput_writes(
exists = await self._redis.exists(key)
if exists:
# Update existing key
await self._redis.json().set(
key, "$.channel", write_obj["channel"]
)
await self._redis.json().set(
key, "$.type", write_obj["type"]
)
await self._redis.json().set(
key, "$.blob", write_obj["blob"]
)
pipeline = self._redis.pipeline(transaction=True)
pipeline.json().set(key, "$.channel", write_obj["channel"]) # type: ignore[arg-type]
pipeline.json().set(key, "$.type", write_obj["type"]) # type: ignore[arg-type]
pipeline.json().set(key, "$.blob", write_obj["blob"]) # type: ignore[arg-type]
else:
# Create new key
await self._redis.json().set(key, "$", write_obj)
pipeline.json().set(key, "$", write_obj)
created_keys.append(key)
else:
# For non-upsert case, only set if key doesn't exist
exists = await self._redis.exists(key)
if not exists:
await self._redis.json().set(key, "$", write_obj)
pipeline.json().set(key, "$", write_obj)
created_keys.append(key)

# Apply TTL to newly created keys
Expand Down Expand Up @@ -788,20 +781,18 @@ async def aput_writes(
exists = await self._redis.exists(key)
if exists:
# Update existing key
await pipeline.json().set(
key, "$.channel", write_obj["channel"]
)
await pipeline.json().set(key, "$.type", write_obj["type"])
await pipeline.json().set(key, "$.blob", write_obj["blob"])
pipeline.json().set(key, "$.channel", write_obj["channel"]) # type: ignore[arg-type]
pipeline.json().set(key, "$.type", write_obj["type"]) # type: ignore[arg-type]
pipeline.json().set(key, "$.blob", write_obj["blob"]) # type: ignore[arg-type]
else:
# Create new key
await pipeline.json().set(key, "$", write_obj)
pipeline.json().set(key, "$", write_obj)
created_keys.append(key)
else:
# For non-upsert case, only set if key doesn't exist
exists = await self._redis.exists(key)
if not exists:
await pipeline.json().set(key, "$", write_obj)
pipeline.json().set(key, "$", write_obj)
created_keys.append(key)

# Execute all operations atomically
Expand Down
32 changes: 16 additions & 16 deletions langgraph/checkpoint/redis/ashallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@
async def _write_obj_tx(pipe: Pipeline, key: str, write_obj: dict[str, Any]) -> None:
exists: int = await pipe.exists(key)
if exists:
await pipe.json().set(key, "$.channel", write_obj["channel"])
await pipe.json().set(key, "$.type", write_obj["type"])
await pipe.json().set(key, "$.blob", write_obj["blob"])
pipe.json().set(key, "$.channel", write_obj["channel"])
pipe.json().set(key, "$.type", write_obj["type"])
pipe.json().set(key, "$.blob", write_obj["blob"])
else:
await pipe.json().set(key, "$", write_obj)
pipe.json().set(key, "$", write_obj)


class AsyncShallowRedisSaver(BaseRedisSaver[AsyncRedis, AsyncSearchIndex]):
Expand Down Expand Up @@ -240,7 +240,7 @@ async def aput(
)

# Add checkpoint data to pipeline
await pipeline.json().set(checkpoint_key, "$", checkpoint_data)
pipeline.json().set(checkpoint_key, "$", checkpoint_data)

# Before storing the new blobs, clean up old ones that won't be needed
# - Get a list of all blob keys for this thread_id and checkpoint_ns
Expand Down Expand Up @@ -274,7 +274,7 @@ async def aput(
continue
else:
# This is an old version, delete it
await pipeline.delete(blob_key)
pipeline.delete(blob_key)

# Store the new blob values
blobs = self._dump_blobs(
Expand All @@ -287,7 +287,7 @@ async def aput(
if blobs:
# Add all blob data to pipeline
for key, data in blobs:
await pipeline.json().set(key, "$", data)
pipeline.json().set(key, "$", data)

# Execute all operations atomically
await pipeline.execute()
Expand Down Expand Up @@ -571,7 +571,7 @@ async def aput_writes(

# If the write is for a different checkpoint_id, delete it
if key_checkpoint_id != checkpoint_id:
await pipeline.delete(write_key)
pipeline.delete(write_key)

# Add new writes to the pipeline
upsert_case = all(w[0] in WRITES_IDX_MAP for w in writes)
Expand All @@ -589,17 +589,15 @@ async def aput_writes(
exists = await self._redis.exists(key)
if exists:
# Update existing key
await pipeline.json().set(
key, "$.channel", write_obj["channel"]
)
await pipeline.json().set(key, "$.type", write_obj["type"])
await pipeline.json().set(key, "$.blob", write_obj["blob"])
pipeline.json().set(key, "$.channel", write_obj["channel"])
pipeline.json().set(key, "$.type", write_obj["type"])
pipeline.json().set(key, "$.blob", write_obj["blob"])
else:
# Create new key
await pipeline.json().set(key, "$", write_obj)
pipeline.json().set(key, "$", write_obj)
else:
# For shallow implementation, always set the full object
await pipeline.json().set(key, "$", write_obj)
pipeline.json().set(key, "$", write_obj)

# Execute all operations atomically
await pipeline.execute()
Expand Down Expand Up @@ -722,7 +720,9 @@ async def _aload_pending_writes(
(
parsed_key["task_id"],
parsed_key["idx"],
): await self._redis.json().get(key)
): await self._redis.json().get(
key
) # type: ignore[misc]
for key, parsed_key in sorted(
zip(matching_keys, parsed_keys), key=lambda x: x[1]["idx"]
)
Expand Down
4 changes: 2 additions & 2 deletions langgraph/checkpoint/redis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def _dump_metadata(self, metadata: CheckpointMetadata) -> str:
# NOTE: we're using JSON serializer (not msgpack), so we need to remove null characters before writing
return serialized_metadata.decode().replace("\\u0000", "")

def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str:
def get_next_version(self, current: Optional[str], channel: None = None) -> str:
"""Generate next version number."""
if current is None:
current_v = 0
Expand Down Expand Up @@ -420,7 +420,7 @@ def _load_writes_from_redis(self, write_key: str) -> List[Tuple[str, str, Any]]:
return []

writes = []
for write in result["writes"]:
for write in result["writes"]: # type: ignore[call-overload]
writes.append(
(
write["task_id"],
Expand Down
14 changes: 5 additions & 9 deletions langgraph/store/redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def _batch_search_ops(
if not isinstance(store_doc, dict):
try:
store_doc = json.loads(
store_doc
store_doc # type: ignore[arg-type]
) # Attempt to parse if it's a JSON string
except (json.JSONDecodeError, TypeError):
logger.error(f"Failed to parse store_doc: {store_doc}")
Expand Down Expand Up @@ -578,16 +578,14 @@ def _batch_search_ops(
if self.cluster_mode:
for key in refresh_keys:
ttl = self._redis.ttl(key)
if ttl > 0: # type: ignore
if ttl > 0:
self._redis.expire(key, ttl_seconds)
else:
pipeline = self._redis.pipeline(transaction=True)
for key in refresh_keys:
# Only refresh TTL if the key exists and has a TTL
ttl = self._redis.ttl(key)
if (
ttl > 0
): # Only refresh if key exists and has TTL # type: ignore
if ttl > 0: # Only refresh if key exists and has TTL
pipeline.expire(key, ttl_seconds)
if pipeline.command_stack:
pipeline.execute()
Expand Down Expand Up @@ -645,16 +643,14 @@ def _batch_search_ops(
if self.cluster_mode:
for key in refresh_keys:
ttl = self._redis.ttl(key)
if ttl > 0: # type: ignore
if ttl > 0:
self._redis.expire(key, ttl_seconds)
else:
pipeline = self._redis.pipeline(transaction=True)
for key in refresh_keys:
# Only refresh TTL if the key exists and has a TTL
ttl = self._redis.ttl(key)
if (
ttl > 0
): # Only refresh if key exists and has TTL # type: ignore
if ttl > 0: # Only refresh if key exists and has TTL
pipeline.expire(key, ttl_seconds)
if pipeline.command_stack:
pipeline.execute()
Expand Down
4 changes: 2 additions & 2 deletions langgraph/store/redis/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(

# Set up store configuration
self.index_config = index
self.ttl_config = ttl # type: ignore
self.ttl_config = ttl

if self.index_config:
self.index_config = self.index_config.copy()
Expand Down Expand Up @@ -744,7 +744,7 @@ async def _batch_search_ops(
store_key = f"{STORE_PREFIX}{REDIS_KEY_SEPARATOR}{doc_uuid}"
result_map[store_key] = doc
# Fetch individually in cluster mode
store_doc_item = await self._redis.json().get(store_key)
store_doc_item = await self._redis.json().get(store_key) # type: ignore
store_docs.append(store_doc_item)
store_docs_raw = store_docs
else:
Expand Down
Loading