From 71cd360ed31915b873420f8bb23111acbe3b9711 Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Fri, 15 Aug 2025 11:04:16 -0700 Subject: [PATCH] feat: add TTL removal support for pinning checkpoints (#66) Add support for removing TTL from Redis checkpoints to make them persistent. This enables "pinning" specific threads that should never expire while allowing others to be cleaned up automatically. Changes: - Add support for `ttl_minutes=-1` parameter to trigger Redis PERSIST command - Implement TTL removal in both sync and async checkpoint savers - Apply PERSIST to main key and all related keys (blobs, writes) - Add comprehensive test coverage for TTL removal functionality - Update README with documentation for the pinning feature --- README.md | 31 +++- langgraph/checkpoint/redis/aio.py | 27 ++++ langgraph/checkpoint/redis/base.py | 30 ++++ tests/test_ttl_removal.py | 234 +++++++++++++++++++++++++++++ 4 files changed, 320 insertions(+), 2 deletions(-) create mode 100644 tests/test_ttl_removal.py diff --git a/README.md b/README.md index 045cba3..6f37934 100644 --- a/README.md +++ b/README.md @@ -249,7 +249,32 @@ with RedisSaver.from_conn_string("redis://localhost:6379", ttl=ttl_config) as ch # Use the checkpointer... ``` -This makes it easy to manage storage and ensure ephemeral data is automatically cleaned up. +### Removing TTL (Pinning Threads) + +You can make specific checkpoints persistent by removing their TTL. This is useful for "pinning" important threads that should never expire: + +```python +from langgraph.checkpoint.redis import RedisSaver + +# Create saver with default TTL +saver = RedisSaver.from_conn_string("redis://localhost:6379", ttl={"default_ttl": 60}) +saver.setup() + +# Save a checkpoint +config = {"configurable": {"thread_id": "important-thread", "checkpoint_ns": ""}} +saved_config = saver.put(config, checkpoint, metadata, {}) + +# Remove TTL from the checkpoint to make it persistent +checkpoint_id = saved_config["configurable"]["checkpoint_id"] +checkpoint_key = f"checkpoint:important-thread:__empty__:{checkpoint_id}" +saver._apply_ttl_to_keys(checkpoint_key, ttl_minutes=-1) + +# The checkpoint is now persistent and won't expire +``` + +When no TTL configuration is provided, checkpoints are persistent by default (no expiration). + +This makes it easy to manage storage and ensure ephemeral data is automatically cleaned up while keeping important data persistent. ## Redis Stores @@ -370,11 +395,13 @@ For Redis Stores with vector search: Both Redis checkpoint savers and stores leverage Redis's native key expiration: -- **Native Redis TTL**: Uses Redis's built-in `EXPIRE` command +- **Native Redis TTL**: Uses Redis's built-in `EXPIRE` command for setting TTL +- **TTL Removal**: Uses Redis's `PERSIST` command to remove TTL (with `ttl_minutes=-1`) - **Automatic Cleanup**: Redis automatically removes expired keys - **Configurable Default TTL**: Set a default TTL for all keys in minutes - **TTL Refresh on Read**: Optionally refresh TTL when keys are accessed - **Applied to All Related Keys**: TTL is applied to all related keys (checkpoint, blobs, writes) +- **Persistent by Default**: When no TTL is configured, keys are persistent (no expiration) ## Contributing diff --git a/langgraph/checkpoint/redis/aio.py b/langgraph/checkpoint/redis/aio.py index 20a6555..4aba524 100644 --- a/langgraph/checkpoint/redis/aio.py +++ b/langgraph/checkpoint/redis/aio.py @@ -295,6 +295,7 @@ async def _apply_ttl_to_keys( main_key: The primary Redis key related_keys: Additional Redis keys that should expire at the same time ttl_minutes: Time-to-live in minutes, overrides default_ttl if provided + Use -1 to remove TTL (make keys persistent) Returns: Result of the Redis operation @@ -305,6 +306,32 @@ async def _apply_ttl_to_keys( ttl_minutes = self.ttl_config.get("default_ttl") if ttl_minutes is not None: + # Special case: -1 means remove TTL (make persistent) + if ttl_minutes == -1: + if self.cluster_mode: + # For cluster mode, execute PERSIST operations individually + await self._redis.persist(main_key) + + if related_keys: + for key in related_keys: + await self._redis.persist(key) + + return True + else: + # For non-cluster mode, use pipeline for efficiency + pipeline = self._redis.pipeline() + + # Remove TTL for main key + pipeline.persist(main_key) + + # Remove TTL for related keys + if related_keys: + for key in related_keys: + pipeline.persist(key) + + return await pipeline.execute() + + # Regular TTL setting ttl_seconds = int(ttl_minutes * 60) if self.cluster_mode: diff --git a/langgraph/checkpoint/redis/base.py b/langgraph/checkpoint/redis/base.py index 55a54f9..ae00384 100644 --- a/langgraph/checkpoint/redis/base.py +++ b/langgraph/checkpoint/redis/base.py @@ -238,6 +238,7 @@ def _apply_ttl_to_keys( main_key: The primary Redis key related_keys: Additional Redis keys that should expire at the same time ttl_minutes: Time-to-live in minutes, overrides default_ttl if provided + Use -1 to remove TTL (make keys persistent) Returns: Result of the Redis operation @@ -248,6 +249,35 @@ def _apply_ttl_to_keys( ttl_minutes = self.ttl_config.get("default_ttl") if ttl_minutes is not None: + # Special case: -1 means remove TTL (make persistent) + if ttl_minutes == -1: + # Check if cluster mode is detected (for sync checkpoint savers) + cluster_mode = getattr(self, "cluster_mode", False) + + if cluster_mode: + # For cluster mode, execute PERSIST operations individually + self._redis.persist(main_key) + + if related_keys: + for key in related_keys: + self._redis.persist(key) + + return True + else: + # For non-cluster mode, use pipeline for efficiency + pipeline = self._redis.pipeline() + + # Remove TTL for main key + pipeline.persist(main_key) + + # Remove TTL for related keys + if related_keys: + for key in related_keys: + pipeline.persist(key) + + return pipeline.execute() + + # Regular TTL setting ttl_seconds = int(ttl_minutes * 60) # Check if cluster mode is detected (for sync checkpoint savers) diff --git a/tests/test_ttl_removal.py b/tests/test_ttl_removal.py new file mode 100644 index 0000000..696b0ed --- /dev/null +++ b/tests/test_ttl_removal.py @@ -0,0 +1,234 @@ +"""Tests for TTL removal feature (issue #66).""" + +import time +from uuid import uuid4 + +import pytest +from langgraph.checkpoint.base import create_checkpoint, empty_checkpoint + +from langgraph.checkpoint.redis import AsyncRedisSaver, RedisSaver + + +def test_ttl_removal_with_negative_one(redis_url: str) -> None: + """Test that ttl_minutes=-1 removes TTL from keys.""" + saver = RedisSaver(redis_url, ttl={"default_ttl": 1}) # 1 minute default TTL + saver.setup() + + thread_id = str(uuid4()) + checkpoint = create_checkpoint( + checkpoint=empty_checkpoint(), channels={"messages": ["test"]}, step=1 + ) + checkpoint["channel_values"]["messages"] = ["test"] + + config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} + + # Save checkpoint (will have TTL) + saved_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {}) + + checkpoint_key = f"checkpoint:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}" + + # Verify TTL is set + ttl = saver._redis.ttl(checkpoint_key) + assert 50 <= ttl <= 60, f"TTL should be around 60 seconds, got {ttl}" + + # Remove TTL using -1 + saver._apply_ttl_to_keys(checkpoint_key, ttl_minutes=-1) + + # Verify TTL is removed + ttl_after = saver._redis.ttl(checkpoint_key) + assert ttl_after == -1, "Key should be persistent after setting ttl_minutes=-1" + + +def test_ttl_removal_with_related_keys(redis_url: str) -> None: + """Test that TTL removal works for main key and related keys.""" + saver = RedisSaver(redis_url, ttl={"default_ttl": 1}) + saver.setup() + + thread_id = str(uuid4()) + + # Create a checkpoint with writes (to have related keys) + checkpoint = create_checkpoint( + checkpoint=empty_checkpoint(), channels={"messages": ["test"]}, step=1 + ) + checkpoint["channel_values"]["messages"] = ["test"] + + config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": "", + "checkpoint_id": "test-checkpoint", + } + } + + # Save checkpoint and writes + saved_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {}) + saver.put_writes( + saved_config, [("channel1", "value1"), ("channel2", "value2")], "task-1" + ) + + # Get the keys + checkpoint_key = f"checkpoint:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}" + write_key1 = f"checkpoint_write:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}:task-1:0" + write_key2 = f"checkpoint_write:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}:task-1:1" + + # All keys should have TTL + assert 50 <= saver._redis.ttl(checkpoint_key) <= 60 + assert 50 <= saver._redis.ttl(write_key1) <= 60 + assert 50 <= saver._redis.ttl(write_key2) <= 60 + + # Remove TTL from all keys + saver._apply_ttl_to_keys(checkpoint_key, [write_key1, write_key2], ttl_minutes=-1) + + # All keys should be persistent + assert saver._redis.ttl(checkpoint_key) == -1 + assert saver._redis.ttl(write_key1) == -1 + assert saver._redis.ttl(write_key2) == -1 + + +def test_no_ttl_means_persistent(redis_url: str) -> None: + """Test that no TTL configuration means keys are persistent.""" + # Create saver with no TTL config + saver = RedisSaver(redis_url) # No TTL config + saver.setup() + + thread_id = str(uuid4()) + checkpoint = create_checkpoint( + checkpoint=empty_checkpoint(), channels={"messages": ["test"]}, step=1 + ) + checkpoint["channel_values"]["messages"] = ["test"] + + config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} + + # Save checkpoint + saved_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {}) + + # Check TTL + checkpoint_key = f"checkpoint:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}" + ttl = saver._redis.ttl(checkpoint_key) + + # Should be -1 (persistent) when no TTL configured + assert ttl == -1, "Key should be persistent when no TTL configured" + + +def test_ttl_removal_preserves_data(redis_url: str) -> None: + """Test that removing TTL doesn't affect the data.""" + saver = RedisSaver(redis_url, ttl={"default_ttl": 1}) + saver.setup() + + thread_id = str(uuid4()) + checkpoint = create_checkpoint( + checkpoint=empty_checkpoint(), channels={"messages": ["original data"]}, step=1 + ) + checkpoint["channel_values"]["messages"] = ["original data"] + + config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} + + # Save checkpoint + saved_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {}) + + # Load data before TTL removal + loaded_before = saver.get_tuple(saved_config) + assert loaded_before.checkpoint["channel_values"]["messages"] == ["original data"] + + # Remove TTL + checkpoint_key = f"checkpoint:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}" + saver._apply_ttl_to_keys(checkpoint_key, ttl_minutes=-1) + + # Load data after TTL removal + loaded_after = saver.get_tuple(saved_config) + assert loaded_after.checkpoint["channel_values"]["messages"] == ["original data"] + + # Verify TTL is removed + assert saver._redis.ttl(checkpoint_key) == -1 + + +@pytest.mark.asyncio +async def test_async_ttl_removal(redis_url: str) -> None: + """Test TTL removal with async saver.""" + async with AsyncRedisSaver.from_conn_string( + redis_url, ttl={"default_ttl": 1} + ) as saver: + thread_id = str(uuid4()) + checkpoint = create_checkpoint( + checkpoint=empty_checkpoint(), channels={"messages": ["async test"]}, step=1 + ) + checkpoint["channel_values"]["messages"] = ["async test"] + + config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} + + # Save checkpoint + saved_config = await saver.aput( + config, checkpoint, {"source": "test", "step": 1}, {} + ) + + checkpoint_key = f"checkpoint:{thread_id}:__empty__:{saved_config['configurable']['checkpoint_id']}" + + # Verify TTL is set + ttl = await saver._redis.ttl(checkpoint_key) + assert 50 <= ttl <= 60, f"TTL should be around 60 seconds, got {ttl}" + + # Remove TTL using -1 + await saver._apply_ttl_to_keys(checkpoint_key, ttl_minutes=-1) + + # Verify TTL is removed + ttl_after = await saver._redis.ttl(checkpoint_key) + assert ttl_after == -1, "Key should be persistent after setting ttl_minutes=-1" + + +def test_pin_thread_use_case(redis_url: str) -> None: + """Test the 'pin thread' use case from issue #66. + + This simulates pinning a specific thread by removing its TTL, + making it persistent while other threads expire. + """ + saver = RedisSaver( + redis_url, ttl={"default_ttl": 0.1} + ) # 6 seconds TTL for quick test + saver.setup() + + # Create two threads + thread_to_pin = str(uuid4()) + thread_to_expire = str(uuid4()) + + # Store checkpoint IDs to avoid using wildcards (more efficient and precise) + checkpoint_ids = {} + + for thread_id in [thread_to_pin, thread_to_expire]: + checkpoint = create_checkpoint( + checkpoint=empty_checkpoint(), + channels={"messages": [f"Thread {thread_id}"]}, + step=1, + ) + checkpoint["channel_values"]["messages"] = [f"Thread {thread_id}"] + + config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} + + saved_config = saver.put(config, checkpoint, {"source": "test", "step": 1}, {}) + checkpoint_ids[thread_id] = saved_config["configurable"]["checkpoint_id"] + + # Pin the first thread by removing its TTL using exact key + pinned_checkpoint_key = ( + f"checkpoint:{thread_to_pin}:__empty__:{checkpoint_ids[thread_to_pin]}" + ) + saver._apply_ttl_to_keys(pinned_checkpoint_key, ttl_minutes=-1) + + # Verify pinned thread has no TTL + assert saver._redis.exists(pinned_checkpoint_key) == 1 + assert saver._redis.ttl(pinned_checkpoint_key) == -1 + + # Verify other thread still has TTL + expiring_checkpoint_key = ( + f"checkpoint:{thread_to_expire}:__empty__:{checkpoint_ids[thread_to_expire]}" + ) + assert saver._redis.exists(expiring_checkpoint_key) == 1 + ttl = saver._redis.ttl(expiring_checkpoint_key) + assert 0 < ttl <= 6 + + # Wait for expiring thread to expire + time.sleep(7) + + # Pinned thread should still exist + assert saver._redis.exists(pinned_checkpoint_key) == 1 + + # Expiring thread should be gone + assert saver._redis.exists(expiring_checkpoint_key) == 0