Skip to content

Commit 670f9a4

Browse files
MOD-8036 Add tests for 393 and 394 (#395)
* add test * add comment * comments from Meir * add else * tests * formatting * implement `enum AclCategory` * using `TestConnection` API
1 parent 08281fe commit 670f9a4

File tree

7 files changed

+212
-28
lines changed

7 files changed

+212
-28
lines changed

examples/acl.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use redis_module::{
2-
redis_module, AclPermissions, Context, NextArg, RedisError, RedisResult, RedisString,
3-
RedisValue,
2+
redis_module, AclCategory, AclPermissions, Context, NextArg, RedisError, RedisResult,
3+
RedisString, RedisValue,
44
};
55

66
fn verify_key_access_for_user(ctx: &Context, args: Vec<RedisString>) -> RedisResult {
@@ -25,8 +25,9 @@ redis_module! {
2525
version: 1,
2626
allocator: (redis_module::alloc::RedisAlloc, redis_module::alloc::RedisAlloc),
2727
data_types: [],
28+
acl_categories: [AclCategory::from("acl"), ],
2829
commands: [
29-
["verify_key_access_for_user", verify_key_access_for_user, "", 0, 0, 0, ""],
30-
["get_current_user", get_current_user, "", 0, 0, 0, ""],
30+
["verify_key_access_for_user", verify_key_access_for_user, "", 0, 0, 0, AclCategory::Read, AclCategory::from("acl")],
31+
["get_current_user", get_current_user, "", 0, 0, 0, vec![AclCategory::Read, AclCategory::Fast], AclCategory::from("acl")],
3132
],
3233
}

src/context/mod.rs

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,113 @@ bitflags! {
926926
}
927927
}
928928

929+
#[derive(Debug, Clone, PartialEq, Eq, Default)]
930+
pub enum AclCategory {
931+
#[default]
932+
None,
933+
Keyspace,
934+
Read,
935+
Write,
936+
Set,
937+
SortedSet,
938+
List,
939+
Hash,
940+
String,
941+
Bitmap,
942+
HyperLogLog,
943+
Geo,
944+
Stream,
945+
PubSub,
946+
Admin,
947+
Fast,
948+
Slow,
949+
Blocking,
950+
Dangerous,
951+
Connection,
952+
Transaction,
953+
Scripting,
954+
Single(String),
955+
Multi(Vec<AclCategory>),
956+
}
957+
958+
impl From<Vec<AclCategory>> for AclCategory {
959+
fn from(value: Vec<AclCategory>) -> Self {
960+
AclCategory::Multi(value)
961+
}
962+
}
963+
964+
impl From<&str> for AclCategory {
965+
fn from(value: &str) -> Self {
966+
match value {
967+
"" => AclCategory::None,
968+
"keyspace" => AclCategory::Keyspace,
969+
"read" => AclCategory::Read,
970+
"write" => AclCategory::Write,
971+
"set" => AclCategory::Set,
972+
"sortedset" => AclCategory::SortedSet,
973+
"list" => AclCategory::List,
974+
"hash" => AclCategory::Hash,
975+
"string" => AclCategory::String,
976+
"bitmap" => AclCategory::Bitmap,
977+
"hyperloglog" => AclCategory::HyperLogLog,
978+
"geo" => AclCategory::Geo,
979+
"stream" => AclCategory::Stream,
980+
"pubsub" => AclCategory::PubSub,
981+
"admin" => AclCategory::Admin,
982+
"fast" => AclCategory::Fast,
983+
"slow" => AclCategory::Slow,
984+
"blocking" => AclCategory::Blocking,
985+
"dangerous" => AclCategory::Dangerous,
986+
"connection" => AclCategory::Connection,
987+
"transaction" => AclCategory::Transaction,
988+
"scripting" => AclCategory::Scripting,
989+
_ if !value.contains(" ") => AclCategory::Single(value.to_string()),
990+
_ => AclCategory::Multi(value.split_whitespace().map(AclCategory::from).collect()),
991+
}
992+
}
993+
}
994+
995+
impl From<AclCategory> for String {
996+
fn from(value: AclCategory) -> Self {
997+
match value {
998+
AclCategory::None => "".to_string(),
999+
AclCategory::Keyspace => "keyspace".to_string(),
1000+
AclCategory::Read => "read".to_string(),
1001+
AclCategory::Write => "write".to_string(),
1002+
AclCategory::Set => "set".to_string(),
1003+
AclCategory::SortedSet => "sortedset".to_string(),
1004+
AclCategory::List => "list".to_string(),
1005+
AclCategory::Hash => "hash".to_string(),
1006+
AclCategory::String => "string".to_string(),
1007+
AclCategory::Bitmap => "bitmap".to_string(),
1008+
AclCategory::HyperLogLog => "hyperloglog".to_string(),
1009+
AclCategory::Geo => "geo".to_string(),
1010+
AclCategory::Stream => "stream".to_string(),
1011+
AclCategory::PubSub => "pubsub".to_string(),
1012+
AclCategory::Admin => "admin".to_string(),
1013+
AclCategory::Fast => "fast".to_string(),
1014+
AclCategory::Slow => "slow".to_string(),
1015+
AclCategory::Blocking => "blocking".to_string(),
1016+
AclCategory::Dangerous => "dangerous".to_string(),
1017+
AclCategory::Connection => "connection".to_string(),
1018+
AclCategory::Transaction => "transaction".to_string(),
1019+
AclCategory::Scripting => "scripting".to_string(),
1020+
AclCategory::Single(s) => s,
1021+
AclCategory::Multi(v) => v
1022+
.into_iter()
1023+
.map(String::from)
1024+
.collect::<Vec<_>>()
1025+
.join(" "),
1026+
}
1027+
}
1028+
}
1029+
1030+
impl std::fmt::Display for AclCategory {
1031+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1032+
write!(f, "{}", String::from(self.clone()))
1033+
}
1034+
}
1035+
9291036
/// The values allowed in the "info" sections and dictionaries.
9301037
#[derive(Debug, Clone)]
9311038
pub enum InfoContextBuilderFieldBottomLevelValue {

src/include/redismodule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,6 +1326,7 @@ static void RedisModule_InitAPI(RedisModuleCtx *ctx) {
13261326
REDISMODULE_GET_API(CreateSubcommand);
13271327
REDISMODULE_GET_API(SetCommandInfo);
13281328
REDISMODULE_GET_API(SetCommandACLCategories);
1329+
REDISMODULE_GET_API(AddACLCategory);
13291330
REDISMODULE_GET_API(SetModuleAttribs);
13301331
REDISMODULE_GET_API(IsModuleNameBusy);
13311332
REDISMODULE_GET_API(WrongArity);

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ pub use crate::context::commands;
3333
pub use crate::context::defrag;
3434
pub use crate::context::keys_cursor::KeysCursor;
3535
pub use crate::context::server_events;
36+
pub use crate::context::AclCategory;
3637
pub use crate::context::AclPermissions;
3738
#[cfg(any(
3839
feature = "min-redis-compatibility-version-7-4",

src/macros.rs

Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@ macro_rules! redis_command {
77
$firstkey:expr,
88
$lastkey:expr,
99
$keystep:expr,
10-
$acl_categories:expr) => {{
10+
$mandatory_acl_categories:expr
11+
$(, $optional_acl_categories:expr)?
12+
) => {{
13+
use redis_module::AclCategory;
14+
1115
let name = CString::new($command_name).unwrap();
1216
let flags = CString::new($command_flags).unwrap();
1317

@@ -37,34 +41,67 @@ macro_rules! redis_command {
3741
)
3842
} == $crate::raw::Status::Err as c_int
3943
{
44+
$crate::raw::redis_log(
45+
$ctx,
46+
&format!("Error: failed to create command {}", $command_name),
47+
);
4048
return $crate::raw::Status::Err as c_int;
4149
}
4250

43-
if $acl_categories != "" {
44-
let acl_categories = CString::new($acl_categories).unwrap();
51+
let command =
52+
unsafe { $crate::raw::RedisModule_GetCommand.unwrap()($ctx, name.as_ptr()) };
53+
if command.is_null() {
54+
$crate::raw::redis_log(
55+
$ctx,
56+
&format!("Error: failed to get command {}", $command_name),
57+
);
58+
return $crate::raw::Status::Err as c_int;
59+
}
4560

46-
let command =
47-
unsafe { $crate::raw::RedisModule_GetCommand.unwrap()($ctx, name.as_ptr()) };
48-
if command.is_null() {
49-
return $crate::raw::Status::Err as c_int;
50-
}
61+
let mandatory_acl_categories = AclCategory::from($mandatory_acl_categories);
62+
if let Some(RM_SetCommandACLCategories) = $crate::raw::RedisModule_SetCommandACLCategories {
63+
let mut acl_categories = CString::default();
64+
$(
65+
let optional_acl_categories = AclCategory::from($optional_acl_categories);
66+
if mandatory_acl_categories != AclCategory::None && optional_acl_categories != AclCategory::None {
67+
acl_categories = CString::new(format!("{} {}", mandatory_acl_categories, optional_acl_categories)).unwrap();
68+
} else if optional_acl_categories != AclCategory::None {
69+
acl_categories = CString::new(format!("{}", $optional_acl_categories)).unwrap();
70+
}
71+
// Warn if optional ACL categories are not set, but don't fail.
72+
if RM_SetCommandACLCategories(command, acl_categories.as_ptr()) == $crate::raw::Status::Err as c_int {
73+
$crate::raw::redis_log(
74+
$ctx,
75+
&format!(
76+
"Warning: failed to set command `{}` ACL categories `{}`",
77+
$command_name, acl_categories.to_str().unwrap()
78+
),
79+
);
80+
} else
81+
)?
82+
if mandatory_acl_categories != AclCategory::None {
83+
acl_categories = CString::new(format!("{}", mandatory_acl_categories)).unwrap();
5184

52-
if let Some(RM_SetCommandACLCategories) =
53-
$crate::raw::RedisModule_SetCommandACLCategories
54-
{
85+
// Fail if mandatory ACL categories are not set.
5586
if RM_SetCommandACLCategories(command, acl_categories.as_ptr())
5687
== $crate::raw::Status::Err as c_int
5788
{
5889
$crate::raw::redis_log(
5990
$ctx,
6091
&format!(
61-
"Error: failed to set command {} ACL categories {}",
62-
$command_name, $acl_categories
92+
"Error: failed to set command `{}` mandatory ACL categories `{}`",
93+
$command_name, mandatory_acl_categories
6394
),
6495
);
6596
return $crate::raw::Status::Err as c_int;
6697
}
6798
}
99+
} else if mandatory_acl_categories != AclCategory::None {
100+
$crate::raw::redis_log(
101+
$ctx,
102+
"Error: Redis version does not support ACL categories",
103+
);
104+
return $crate::raw::Status::Err as c_int;
68105
}
69106
}};
70107
}
@@ -134,7 +171,11 @@ macro_rules! redis_module {
134171
data_types: [
135172
$($data_type:ident),* $(,)*
136173
],
137-
$(acl_category: $acl_category:expr,)* $(,)*
174+
// eg: `acl_category: [ "name_of_module_acl_category", ],`
175+
// This will add the specified (optional) ACL categories.
176+
$(acl_categories: [
177+
$($module_acl_category:expr,)*
178+
],)?
138179
$(init: $init_func:ident,)* $(,)*
139180
$(deinit: $deinit_func:ident,)* $(,)*
140181
$(info: $info_func:ident,)?
@@ -146,7 +187,8 @@ macro_rules! redis_module {
146187
$firstkey:expr,
147188
$lastkey:expr,
148189
$keystep:expr,
149-
$acl_categories:expr
190+
$mandatory_command_acl_categories:expr
191+
$(, $optional_command_acl_categories:expr)?
150192
]),* $(,)*
151193
] $(,)*
152194
$(event_handlers: [
@@ -271,17 +313,24 @@ macro_rules! redis_module {
271313
)*
272314

273315
$(
274-
let category = CString::new($acl_category).unwrap();
275-
if let Some(RM_AddACLCategory) = raw::RedisModule_AddACLCategory {
276-
if RM_AddACLCategory(ctx, category.as_ptr()) == raw::Status::Err as c_int {
277-
raw::redis_log(ctx, &format!("Error: failed to add ACL category {}", $acl_category));
278-
return raw::Status::Err as c_int;
316+
$(
317+
if let Some(RM_AddACLCategory) = raw::RedisModule_AddACLCategory {
318+
let module_acl_category = AclCategory::from($module_acl_category);
319+
if module_acl_category != AclCategory::None {
320+
let category = CString::new(format!("{}", $module_acl_category)).unwrap();
321+
if RM_AddACLCategory(ctx, category.as_ptr()) == raw::Status::Err as c_int {
322+
raw::redis_log(ctx, &format!("Error: failed to add ACL category `{}`", $module_acl_category));
323+
return raw::Status::Err as c_int;
324+
}
325+
}
326+
} else {
327+
raw::redis_log(ctx, "Warning: Redis version does not support adding new ACL categories");
279328
}
280-
}
281-
)*
329+
)*
330+
)?
282331

283332
$(
284-
$crate::redis_command!(ctx, $name, $command, $flags, $firstkey, $lastkey, $keystep, $acl_categories);
333+
$crate::redis_command!(ctx, $name, $command, $flags, $firstkey, $lastkey, $keystep, $mandatory_command_acl_categories $(, $optional_command_acl_categories)?);
285334
)*
286335

287336
if $crate::commands::register_commands(&context) == raw::Status::Err {

test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
#!/usr/bin/env sh
2-
cargo test --all --all-targets --no-default-features
2+
cargo test --all --all-targets --no-default-features --features min-redis-compatibility-version-7-4

tests/integration.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,31 @@ fn test_get_current_user() -> Result<()> {
275275
Ok(())
276276
}
277277

278+
#[test]
279+
#[cfg(feature = "min-redis-compatibility-version-7-4")]
280+
fn test_set_acl_categories() -> Result<()> {
281+
let mut con = TestConnection::new("acl");
282+
283+
let res: Vec<String> = redis::cmd("ACL").arg("CAT").query(&mut con)?;
284+
assert!(res.contains(&"acl".to_owned()));
285+
286+
Ok(())
287+
}
288+
289+
#[test]
290+
#[cfg(feature = "min-redis-compatibility-version-8-0")]
291+
fn test_set_acl_categories_commands() -> Result<()> {
292+
let mut con = TestConnection::new("acl");
293+
294+
let res: Vec<String> = redis::cmd("ACL").arg("CAT").arg("acl").query(&mut con)?;
295+
assert!(
296+
res.contains(&"verify_key_access_for_user".to_owned())
297+
&& res.contains(&"get_current_user".to_owned())
298+
);
299+
300+
Ok(())
301+
}
302+
278303
#[test]
279304
fn test_verify_acl_on_user() -> Result<()> {
280305
let mut con = TestConnection::new("acl");

0 commit comments

Comments
 (0)