diff --git a/examples/scan_keys.rs b/examples/scan_keys.rs index f1d84911..92e28369 100644 --- a/examples/scan_keys.rs +++ b/examples/scan_keys.rs @@ -1,7 +1,18 @@ +// This example shows the usage of the scan functionality of the Rust Redis Module API Wrapper. +// +// The example implements three commands: +// +// 1. `scan_keys` - scans all keys in the database and returns their names as an array of RedisString. +// 2. `scan_key ` - scans all fields by using a closure and a while loop, thus allowing an early stop. Don't use the early stop but collects all the field/value pairs as an array of RedisString. +// 3. `scan_key_for_each ` - scans all fields and values in a hash key using a closure that stores the field/value pairs as an array of RedisString. + use redis_module::{ - key::RedisKey, redis_module, Context, KeysCursor, RedisResult, RedisString, RedisValue, + key::{KeyFlags, RedisKey}, + redis_module, Context, KeysCursor, RedisError, RedisResult, RedisString, RedisValue, + ScanKeyCursor, }; +/// Scans all keys in the database and returns their names as an array of RedisString. fn scan_keys(ctx: &Context, _args: Vec) -> RedisResult { let cursor = KeysCursor::new(); let mut res = Vec::new(); @@ -16,6 +27,54 @@ fn scan_keys(ctx: &Context, _args: Vec) -> RedisResult { Ok(RedisValue::Array(res)) } +fn scan_key(ctx: &Context, args: Vec) -> RedisResult { + // only argument is the key name + if args.len() != 2 { + return Err(RedisError::WrongArity); + } + + let key_name = &args[1]; + let key = ctx.open_key_with_flags( + key_name, + KeyFlags::NOEFFECTS | KeyFlags::NOEXPIRE | KeyFlags::ACCESS_EXPIRED, + ); + let cursor = ScanKeyCursor::new(key); + + let mut res = Vec::new(); + while cursor.scan(|_key, field, value| { + res.push(RedisValue::BulkRedisString(field.clone())); + res.push(RedisValue::BulkRedisString(value.clone())); + }) { + // here we could do something between scans if needed, like an early stop + } + + Ok(RedisValue::Array(res)) +} + +/// Scans all fields and values in a hash key and returns them as an array of RedisString. +/// The command takes one argument: the name of the hash key to scan. +fn scan_key_for_each(ctx: &Context, args: Vec) -> RedisResult { + // only argument is the key name + if args.len() != 2 { + return Err(RedisError::WrongArity); + } + + let key_name = &args[1]; + let key = ctx.open_key_with_flags( + key_name, + KeyFlags::NOEFFECTS | KeyFlags::NOEXPIRE | KeyFlags::ACCESS_EXPIRED, + ); + let cursor = ScanKeyCursor::new(key); + + let mut res = Vec::new(); + cursor.for_each(|_key, field, value| { + res.push(RedisValue::BulkRedisString(field.clone())); + res.push(RedisValue::BulkRedisString(value.clone())); + }); + + Ok(RedisValue::Array(res)) +} + ////////////////////////////////////////////////////// redis_module! { @@ -25,5 +84,7 @@ redis_module! { data_types: [], commands: [ ["scan_keys", scan_keys, "readonly", 0, 0, 0, ""], + ["scan_key", scan_key, "readonly", 0, 0, 0, ""], + ["scan_key_for_each", scan_key_for_each, "readonly", 0, 0, 0, ""], ], } diff --git a/src/context/call_reply.rs b/src/context/call_reply.rs index 6c739b6f..8dc46044 100644 --- a/src/context/call_reply.rs +++ b/src/context/call_reply.rs @@ -31,6 +31,11 @@ impl<'root> StringCallReply<'root> { }; unsafe { slice::from_raw_parts(reply_string, len) } } + + /// Return the raw pointer to the underlying [RedisModuleCallReply]. + pub fn get_raw(&self) -> *mut RedisModuleCallReply { + self.reply.as_ptr() + } } impl<'root> Drop for StringCallReply<'root> { @@ -77,6 +82,11 @@ impl<'root> ErrorCallReply<'root> { }; unsafe { slice::from_raw_parts(reply_string, len) } } + + /// Return the raw pointer to the underlying [RedisModuleCallReply]. + pub fn get_raw(&self) -> *mut RedisModuleCallReply { + self.reply.as_ptr() + } } impl<'root> Drop for ErrorCallReply<'root> { @@ -150,6 +160,11 @@ impl<'root> I64CallReply<'root> { pub fn to_i64(&self) -> i64 { call_reply_integer(self.reply.as_ptr()) } + + /// Return the raw pointer to the underlying [RedisModuleCallReply]. + pub fn get_raw(&self) -> *mut RedisModuleCallReply { + self.reply.as_ptr() + } } impl<'root> Drop for I64CallReply<'root> { @@ -204,6 +219,11 @@ impl<'root> ArrayCallReply<'root> { pub fn len(&self) -> usize { call_reply_length(self.reply.as_ptr()) } + + /// Return the raw pointer to the underlying [RedisModuleCallReply]. + pub fn get_raw(&self) -> *mut RedisModuleCallReply { + self.reply.as_ptr() + } } pub struct ArrayCallReplyIterator<'root, 'curr> { @@ -254,6 +274,13 @@ pub struct NullCallReply<'root> { _dummy: PhantomData<&'root ()>, } +impl<'root> NullCallReply<'root> { + /// Return the raw pointer to the underlying [RedisModuleCallReply]. + pub fn get_raw(&self) -> *mut RedisModuleCallReply { + self.reply.as_ptr() + } +} + impl<'root> Drop for NullCallReply<'root> { fn drop(&mut self) { free_call_reply(self.reply.as_ptr()); @@ -303,6 +330,11 @@ impl<'root> MapCallReply<'root> { pub fn len(&self) -> usize { call_reply_length(self.reply.as_ptr()) } + + /// Return the raw pointer to the underlying [RedisModuleCallReply]. + pub fn get_raw(&self) -> *mut RedisModuleCallReply { + self.reply.as_ptr() + } } pub struct MapCallReplyIterator<'root, 'curr> { @@ -384,6 +416,11 @@ impl<'root> SetCallReply<'root> { pub fn len(&self) -> usize { call_reply_length(self.reply.as_ptr()) } + + /// Return the raw pointer to the underlying [RedisModuleCallReply]. + pub fn get_raw(&self) -> *mut RedisModuleCallReply { + self.reply.as_ptr() + } } pub struct SetCallReplyIterator<'root, 'curr> { @@ -445,6 +482,11 @@ impl<'root> BoolCallReply<'root> { pub fn to_bool(&self) -> bool { call_reply_bool(self.reply.as_ptr()) } + + /// Return the raw pointer to the underlying [RedisModuleCallReply]. + pub fn get_raw(&self) -> *mut RedisModuleCallReply { + self.reply.as_ptr() + } } impl<'root> Drop for BoolCallReply<'root> { @@ -478,6 +520,11 @@ impl<'root> DoubleCallReply<'root> { pub fn to_double(&self) -> f64 { call_reply_double(self.reply.as_ptr()) } + + /// Return the raw pointer to the underlying [RedisModuleCallReply]. + pub fn get_raw(&self) -> *mut RedisModuleCallReply { + self.reply.as_ptr() + } } impl<'root> Drop for DoubleCallReply<'root> { @@ -512,6 +559,11 @@ impl<'root> BigNumberCallReply<'root> { pub fn to_string(&self) -> Option { call_reply_big_number(self.reply.as_ptr()) } + + /// Return the raw pointer to the underlying [RedisModuleCallReply]. + pub fn get_raw(&self) -> *mut RedisModuleCallReply { + self.reply.as_ptr() + } } impl<'root> Drop for BigNumberCallReply<'root> { @@ -540,6 +592,13 @@ pub struct VerbatimStringCallReply<'root> { _dummy: PhantomData<&'root ()>, } +impl<'root> VerbatimStringCallReply<'root> { + /// Return the raw pointer to the underlying [RedisModuleCallReply]. + pub fn get_raw(&self) -> *mut RedisModuleCallReply { + self.reply.as_ptr() + } +} + /// RESP3 state that the verbatim string format must be of length 3. const VERBATIM_FORMAT_LENGTH: usize = 3; /// The string format of a verbatim string ([VerbatimStringCallReply]). @@ -639,6 +698,25 @@ pub enum CallReply<'root> { VerbatimString(VerbatimStringCallReply<'root>), } +impl<'root> CallReply<'root> { + /// Return the raw pointer to the underlying [RedisModuleCallReply], or `None` if this is the `Unknown` variant. + pub fn get_raw(&self) -> Option<*mut RedisModuleCallReply> { + match self { + CallReply::Unknown => None, + CallReply::I64(inner) => Some(inner.get_raw()), + CallReply::String(inner) => Some(inner.get_raw()), + CallReply::Array(inner) => Some(inner.get_raw()), + CallReply::Null(inner) => Some(inner.get_raw()), + CallReply::Map(inner) => Some(inner.get_raw()), + CallReply::Set(inner) => Some(inner.get_raw()), + CallReply::Bool(inner) => Some(inner.get_raw()), + CallReply::Double(inner) => Some(inner.get_raw()), + CallReply::BigNumber(inner) => Some(inner.get_raw()), + CallReply::VerbatimString(inner) => Some(inner.get_raw()), + } + } +} + /// Send implementation to [CallReply]. /// We need to implements this trait because [CallReply] hold /// raw pointers to C data which does not auto implement the [Send] trait. diff --git a/src/context/key_cursor.rs b/src/context/key_cursor.rs new file mode 100644 index 00000000..3e0684aa --- /dev/null +++ b/src/context/key_cursor.rs @@ -0,0 +1,131 @@ +use std::{ + ffi::c_void, + ptr::{self}, +}; + +use crate::{key::RedisKey, raw, RedisString}; + +/// A cursor to scan field/value pairs of a (hash) key. +/// +/// It provides access via a closure given to [`ScanKeyCursor::for_each`] or if you need more control, you can use [`ScanKeyCursor::scan`] +/// and implement your own loop, e.g. to allow an early stop. +/// +/// ## Example usage +/// +/// Here we show how to extract values to communicate them back to the Redis client. We assume that the following hash key is setup before: +/// +/// ```text +/// HSET user:123 name Alice age 29 location Austin +/// ``` +/// +/// The following example command implementation scans all fields and values in the hash key and returns them as an array of RedisString. +/// +/// ```ignore +/// fn example_scan_key_for_each(ctx: &Context) -> RedisResult { +/// let key = ctx.open_key_with_flags("user:123", KeyFlags::NOEFFECTS | KeyFlags::NOEXPIRE | KeyFlags::ACCESS_EXPIRED ); +/// let cursor = ScanKeyCursor::new(key); +/// +/// let res = RefCell::new(Vec::new()); +/// cursor.for_each(|_key, field, value| { +/// let mut res = res.borrow_mut(); +/// res.push(RedisValue::BulkRedisString(field.clone())); +/// res.push(RedisValue::BulkRedisString(value.clone())); +/// }); +/// +/// Ok(RedisValue::Array(res.take())) +/// } +/// ``` +/// +/// The method will produce the following output: +/// +/// ```text +/// 1) "name" +/// 2) "Alice" +/// 3) "age" +/// 4) "29" +/// 5) "location" +/// 6) "Austin" +/// ``` +pub struct ScanKeyCursor { + key: RedisKey, + inner_cursor: *mut raw::RedisModuleScanCursor, +} + +impl ScanKeyCursor { + /// Creates a new scan cursor for the given key. + pub fn new(key: RedisKey) -> Self { + let inner_cursor = unsafe { raw::RedisModule_ScanCursorCreate.unwrap()() }; + Self { key, inner_cursor } + } + + /// Restarts the cursor from the beginning. + pub fn restart(&self) { + unsafe { raw::RedisModule_ScanCursorRestart.unwrap()(self.inner_cursor) }; + } + + /// Implements a call to `RedisModule_ScanKey` and calls the given closure for each callback invocation by ScanKey. + /// Returns `true` if there are more fields to scan, `false` otherwise. + /// + /// The callback may be called multiple times per `RedisModule_ScanKey` invocation. + /// + /// ## Example + /// + /// ```ignore + /// while cursor.scan(|_key, field, value| { + /// // do something with field and value + /// }) { + /// // do something between scans if needed, like an early stop + /// } + pub fn scan(&self, f: F) -> bool { + unsafe extern "C" fn scan_callback( + key: *mut raw::RedisModuleKey, + field: *mut raw::RedisModuleString, + value: *mut raw::RedisModuleString, + data: *mut c_void, + ) { + let ctx = ptr::null_mut(); + let key = RedisKey::from_raw_parts(ctx, key); + + let field = RedisString::from_redis_module_string(ctx, field); + let value = RedisString::from_redis_module_string(ctx, value); + + let callback = unsafe { &mut *(data.cast::()) }; + callback(&key, &field, &value); + + // we're not the owner of field and value strings + field.take(); + value.take(); + + key.take(); // we're not the owner of the key either + } + + // Safety: The c-side initialized the function ptr and it is is never changed, + // i.e. after module initialization the function pointers stay valid till the end of the program. + let scan_key = unsafe { raw::RedisModule_ScanKey.unwrap() }; + + let res = unsafe { + scan_key( + self.key.key_inner, + self.inner_cursor, + Some(scan_callback::), + &f as *const F as *mut c_void, + ) + }; + + res != 0 + } + + /// Implements a callback based for_each loop over all fields and values in the hash key. + /// If you need more control, e.g. stopping after a scan invocation, then use [`ScanKeyCursor::scan`] directly. + pub fn for_each(&self, mut f: F) { + while self.scan(&mut f) { + // do nothing, the callback does the work + } + } +} + +impl Drop for ScanKeyCursor { + fn drop(&mut self) { + unsafe { raw::RedisModule_ScanCursorDestroy.unwrap()(self.inner_cursor) }; + } +} diff --git a/src/context/mod.rs b/src/context/mod.rs index 00371d3f..c1818171 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -32,6 +32,7 @@ pub mod call_reply; pub mod commands; pub mod defrag; pub mod info; +pub mod key_cursor; pub mod keys_cursor; pub mod server_events; pub mod thread_safe; diff --git a/src/key.rs b/src/key.rs index 9e363a11..c16f3d0c 100644 --- a/src/key.rs +++ b/src/key.rs @@ -70,7 +70,7 @@ impl RedisKey { Self { ctx, key_inner } } - pub(crate) fn open_with_flags( + pub fn open_with_flags( ctx: *mut raw::RedisModuleCtx, key: &RedisString, flags: KeyFlags, @@ -80,7 +80,7 @@ impl RedisKey { Self { ctx, key_inner } } - pub(crate) const fn from_raw_parts( + pub const fn from_raw_parts( ctx: *mut raw::RedisModuleCtx, key_inner: *mut raw::RedisModuleKey, ) -> Self { @@ -206,7 +206,7 @@ impl RedisKeyWritable { Self { ctx, key_inner } } - pub(crate) fn open_with_flags( + pub fn open_with_flags( ctx: *mut raw::RedisModuleCtx, key: &RedisString, flags: KeyFlags, diff --git a/src/lib.rs b/src/lib.rs index b4195b78..21e4c17c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,6 +31,7 @@ pub use crate::context::call_reply::FutureCallReply; pub use crate::context::call_reply::{CallReply, CallResult, ErrorReply, PromiseCallReply}; pub use crate::context::commands; pub use crate::context::defrag; +pub use crate::context::key_cursor::ScanKeyCursor; pub use crate::context::keys_cursor::KeysCursor; pub use crate::context::server_events; pub use crate::context::AclCategory; diff --git a/src/redismodule.rs b/src/redismodule.rs index 46ed3ec1..b23cb1b7 100644 --- a/src/redismodule.rs +++ b/src/redismodule.rs @@ -161,6 +161,24 @@ impl RedisString { Self { ctx, inner } } + /// Create a RedisString from a raw C string and length. The provided C String will be copied. + /// + /// # Safety + /// The caller must ensure that the provided pointer is valid and points to a memory region + /// that is at least `len` bytes long. + #[allow(clippy::not_unsafe_ptr_arg_deref)] + pub unsafe fn from_raw_parts( + ctx: Option>, + s: *const c_char, + len: libc::size_t, + ) -> Self { + let ctx = ctx.map_or(std::ptr::null_mut(), |v| v.as_ptr()); + + let inner = unsafe { raw::RedisModule_CreateString.unwrap()(ctx, s, len) }; + + Self { ctx, inner } + } + #[allow(clippy::not_unsafe_ptr_arg_deref)] pub fn create_from_slice(ctx: *mut raw::RedisModuleCtx, s: &[u8]) -> Self { let inner = unsafe { @@ -201,6 +219,13 @@ impl RedisString { len == 0 } + #[must_use] + pub fn as_cstr_ptr_and_len(&self) -> (*const c_char, usize) { + let mut len: usize = 0; + let ptr = raw::string_ptr_len(self.inner, &mut len); + (ptr, len) + } + pub fn try_as_str<'a>(&self) -> Result<&'a str, RedisError> { Self::from_ptr(self.inner).map_err(|_| RedisError::Str("Couldn't parse as UTF-8 string")) } diff --git a/tests/integration.rs b/tests/integration.rs index f5a1795d..793e2405 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -200,6 +200,42 @@ fn test_scan() -> Result<()> { Ok(()) } +#[test] +fn test_scan_key() -> Result<()> { + let mut con = TestConnection::new("scan_keys"); + redis::cmd("hset") + .arg(&[ + "user:123", "name", "Alice", "age", "29", "location", "Austin", + ]) + .query::<()>(&mut con) + .with_context(|| "failed to hset")?; + + let res: Vec = redis::cmd("scan_key") + .arg(&["user:123"]) + .query(&mut con) + .with_context(|| "failed scan_key")?; + assert_eq!(&res, &["name", "Alice", "age", "29", "location", "Austin"]); + Ok(()) +} + +#[test] +fn test_scan_key_for_each() -> Result<()> { + let mut con = TestConnection::new("scan_keys"); + redis::cmd("hset") + .arg(&[ + "user:123", "name", "Alice", "age", "29", "location", "Austin", + ]) + .query::<()>(&mut con) + .with_context(|| "failed to hset")?; + + let res: Vec = redis::cmd("scan_key_for_each") + .arg(&["user:123"]) + .query(&mut con) + .with_context(|| "failed scan_key_for_each")?; + assert_eq!(&res, &["name", "Alice", "age", "29", "location", "Austin"]); + Ok(()) +} + #[test] fn test_stream_reader() -> Result<()> { let mut con = TestConnection::new("stream");