Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions interactions/api/events/discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ class AutoModExec(BaseEvent):

@attrs.define(eq=False, order=False, hash=False, kw_only=False)
class AutoModCreated(BaseEvent):
"""Dispatched when an auto mod rule is created"""

guild: "Guild" = attrs.field(repr=False, metadata=docs("The guild the rule was modified in"))
rule: "AutoModRule" = attrs.field(repr=False, metadata=docs("The rule that was modified"))

Expand Down
98 changes: 64 additions & 34 deletions interactions/models/discord/auto_mod.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
from typing import Any, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional, Union

import attrs

from interactions.client.const import get_logger, MISSING, Absent
from interactions.client.const import MISSING, Absent, get_logger
from interactions.client.mixins.serialization import DictSerializationMixin
from interactions.client.utils import list_converter, optional
from interactions.client.utils.attr_utils import docs
from interactions.models.discord.base import ClientObject, DiscordObject
from interactions.models.discord.enums import (
AutoModTriggerType,
AutoModAction,
AutoModEvent,
AutoModLanuguageType,
AutoModTriggerType,
)
from interactions.models.discord.snowflake import to_snowflake_list, to_snowflake
from interactions.models.discord.snowflake import to_snowflake, to_snowflake_list

if TYPE_CHECKING:
from interactions import Snowflake_Type, Guild, GuildText, Message, Client, Member, User
from interactions import (
Client,
Guild,
GuildText,
Member,
Message,
Snowflake_Type,
User,
)

__all__ = ("AutoModerationAction", "AutoModRule")

Expand Down Expand Up @@ -71,7 +79,7 @@ def _process_dict(cls, data: dict[str, Any]) -> dict[str, Any]:
return data

@classmethod
def from_dict_factory(cls, data: dict) -> "BaseAction":
def from_dict_factory(cls, data: dict) -> "TYPE_ALL_TRIGGER":
trigger_class = TRIGGER_MAPPING.get(data.get("trigger_type"))
meta = data.get("trigger_metadata", {})
if not trigger_class:
Expand All @@ -97,16 +105,22 @@ def _keyword_converter(filter: str | list[str]) -> list[str]:
class KeywordTrigger(BaseTrigger):
"""A trigger that checks if content contains words from a user defined list of keywords"""

type: AutoModTriggerType = attrs.field(
default=AutoModTriggerType.KEYWORD,
converter=AutoModTriggerType,
keyword_filter: list[str] = attrs.field(
factory=list,
repr=True,
metadata=docs("The type of trigger"),
metadata=docs("Substrings which will be searched for in content"),
converter=_keyword_converter,
)
keyword_filter: str | list[str] = attrs.field(
regex_patterns: list[str] = attrs.field(
factory=list,
repr=True,
metadata=docs("Regular expression patterns which will be matched against content"),
converter=_keyword_converter,
)
allow_list: list[str] = attrs.field(
factory=list,
repr=True,
metadata=docs("What words will trigger this"),
metadata=docs("Substrings which should not trigger the rule"),
converter=_keyword_converter,
)

Expand All @@ -127,63 +141,74 @@ class HarmfulLinkFilter(BaseTrigger):
class KeywordPresetTrigger(BaseTrigger):
"""A trigger that checks if content contains words from internal pre-defined wordsets"""

type: AutoModTriggerType = attrs.field(
default=AutoModTriggerType.KEYWORD_PRESET,
converter=AutoModTriggerType,
repr=True,
metadata=docs("The type of trigger"),
)
keyword_lists: list[AutoModLanuguageType] = attrs.field(
factory=list,
converter=list_converter(AutoModLanuguageType),
repr=True,
metadata=docs("The preset list of keywords that will trigger this"),
metadata=docs("The internally pre-defined wordsets which will be searched for in content"),
)


@attrs.define(eq=False, order=False, hash=False, kw_only=True)
class MentionSpamTrigger(BaseTrigger):
"""A trigger that checks if content contains more mentions than allowed"""
"""A trigger that checks if content contains more unique mentions than allowed"""

mention_total_limit: int = attrs.field(
default=3, repr=True, metadata=docs("The maximum number of mentions allowed")
)
mention_raid_protection_enabled: bool = attrs.field(
repr=True, metadata=docs("Whether to automatically detect mention raids")
)


@attrs.define(eq=False, order=False, hash=False, kw_only=True)
class MemberProfileTrigger(BaseTrigger):
"""A trigger that checks if member profile contains words from a user defined list of keywords"""

regex_patterns: list[str] = attrs.field(
factory=list, repr=True, metadata=docs("The regex patterns to check against")
factory=list,
repr=True,
metadata=docs("Regular expression patterns which will be matched against content"),
converter=_keyword_converter,
)
keyword_filter: str | list[str] = attrs.field(
factory=list, repr=True, metadata=docs("The keywords to check against")
keyword_filter: list[str] = attrs.field(
factory=list,
repr=True,
metadata=docs("Substrings which will be searched for in content"),
converter=_keyword_converter,
)
allow_list: list["Snowflake_Type"] = attrs.field(
factory=list, repr=True, metadata=docs("The roles exempt from this rule")
allow_list: list[str] = attrs.field(
factory=list,
repr=True,
metadata=docs("Substrings which should not trigger the rule"),
converter=_keyword_converter,
)


@attrs.define(eq=False, order=False, hash=False, kw_only=True)
class SpamTrigger(BaseTrigger):
"""A trigger that checks if content represents generic spam"""


@attrs.define(eq=False, order=False, hash=False, kw_only=True)
class BlockMessage(BaseAction):
"""blocks the content of a message according to the rule"""
"""Blocks the content of a message according to the rule"""

type: AutoModAction = attrs.field(repr=False, default=AutoModAction.BLOCK_MESSAGE, converter=AutoModAction)
custom_message: Optional[str] = attrs.field(repr=True, default=None)


@attrs.define(eq=False, order=False, hash=False, kw_only=True)
class AlertMessage(BaseAction):
"""logs user content to a specified channel"""

channel_id: "Snowflake_Type" = attrs.field(repr=True)
type: AutoModAction = attrs.field(repr=False, default=AutoModAction.ALERT_MESSAGE, converter=AutoModAction)


@attrs.define(eq=False, order=False, hash=False, kw_only=False)
class TimeoutUser(BaseAction):
"""timeout user for a specified duration"""

duration_seconds: int = attrs.field(repr=True, default=60)
type: AutoModAction = attrs.field(repr=False, default=AutoModAction.TIMEOUT_USER, converter=AutoModAction)


@attrs.define(eq=False, order=False, hash=False, kw_only=False)
Expand All @@ -204,13 +229,13 @@ class AutoModRule(DiscordObject):
enabled: bool = attrs.field(repr=False, default=False)
"""whether the rule is enabled"""

actions: list[BaseAction] = attrs.field(repr=False, factory=list)
actions: list["TYPE_ALL_ACTION"] = attrs.field(repr=False, factory=list)
"""the actions which will execute when the rule is triggered"""
event_type: AutoModEvent = attrs.field(
repr=False,
)
"""the rule event type"""
trigger: BaseTrigger = attrs.field(
trigger: "TYPE_ALL_TRIGGER" = attrs.field(
repr=False,
)
"""The trigger for this rule"""
Expand Down Expand Up @@ -262,10 +287,10 @@ async def modify(
self,
*,
name: Absent[str] = MISSING,
trigger: Absent[BaseTrigger] = MISSING,
trigger: Absent["TYPE_ALL_TRIGGER"] = MISSING,
trigger_type: Absent[AutoModTriggerType] = MISSING,
trigger_metadata: Absent[dict] = MISSING,
actions: Absent[list[BaseAction]] = MISSING,
actions: Absent[list["TYPE_ALL_ACTION"]] = MISSING,
exempt_channels: Absent[list["Snowflake_Type"]] = MISSING,
exempt_roles: Absent[list["Snowflake_Type"]] = MISSING,
event_type: Absent[AutoModEvent] = MISSING,
Expand Down Expand Up @@ -318,7 +343,7 @@ class AutoModerationAction(ClientObject):
repr=False,
)

action: BaseAction = attrs.field(default=MISSING, repr=True)
action: "TYPE_ALL_ACTION" = attrs.field(default=MISSING, repr=True)

matched_keyword: str = attrs.field(repr=True)
matched_content: Optional[str] = attrs.field(repr=False, default=None)
Expand Down Expand Up @@ -369,7 +394,12 @@ def member(self) -> "Optional[Member]":
TRIGGER_MAPPING = {
AutoModTriggerType.KEYWORD: KeywordTrigger,
AutoModTriggerType.HARMFUL_LINK: HarmfulLinkFilter,
AutoModTriggerType.SPAM: SpamTrigger,
AutoModTriggerType.KEYWORD_PRESET: KeywordPresetTrigger,
AutoModTriggerType.MENTION_SPAM: MentionSpamTrigger,
AutoModTriggerType.MEMBER_PROFILE: MemberProfileTrigger,
}

TYPE_ALL_TRIGGER = Union[KeywordTrigger, SpamTrigger, KeywordPresetTrigger, MentionSpamTrigger, MemberProfileTrigger]

TYPE_ALL_ACTION = Union[BlockMessage, AlertMessage, TimeoutUser, BlockMemberInteraction]