diff --git a/discord/__init__.py b/discord/__init__.py index bcffff183b..3b2440411b 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -33,6 +33,7 @@ from .automod import * from .bot import * from .channel import * +from .channel.thread import * from .client import * from .cog import * from .collectibles import * @@ -72,7 +73,6 @@ from .sticker import * from .team import * from .template import * -from .threads import * from .user import * from .voice_client import * from .webhook import * diff --git a/discord/abc.py b/discord/abc.py index 3a40dbf4ea..0dbccca5e6 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -26,15 +26,14 @@ from __future__ import annotations import asyncio -import copy import time from typing import ( TYPE_CHECKING, - Any, Callable, Iterable, Protocol, Sequence, + TypeAlias, TypeVar, Union, overload, @@ -43,17 +42,11 @@ from . import utils from .context_managers import Typing -from .enums import ChannelType from .errors import ClientException, InvalidArgument from .file import File, VoiceMessage -from .flags import ChannelFlags, MessageFlags -from .invite import Invite +from .flags import MessageFlags from .iterators import HistoryIterator, MessagePinIterator from .mentions import AllowedMentions -from .partial_emoji import PartialEmoji, _EmojiTag -from .permissions import PermissionOverwrite, Permissions -from .role import Role -from .scheduled_events import ScheduledEvent from .sticker import GuildSticker, StickerItem from .utils.private import warn_deprecated from .voice_client import VoiceClient, VoiceProtocol @@ -62,7 +55,6 @@ "Snowflake", "User", "PrivateChannel", - "GuildChannel", "Messageable", "Connectable", "Mentionable", @@ -73,9 +65,9 @@ if TYPE_CHECKING: from datetime import datetime + from .app.state import ConnectionState from .asset import Asset from .channel import ( - CategoryChannel, DMChannel, GroupChannel, PartialMessageable, @@ -83,17 +75,11 @@ TextChannel, VoiceChannel, ) + from .channel.thread import Thread from .client import Client from .embeds import Embed - from .enums import InviteTarget - from .guild import Guild - from .member import Member from .message import Message, MessageReference, PartialMessage from .poll import Poll - from .state import ConnectionState - from .threads import Thread - from .types.channel import Channel as ChannelPayload - from .types.channel import GuildChannel as GuildChannelPayload from .types.channel import OverwriteType from .types.channel import PermissionOverwrite as PermissionOverwritePayload from .ui.view import View @@ -101,7 +87,7 @@ PartialMessageableChannel = TextChannel | VoiceChannel | StageChannel | Thread | DMChannel | PartialMessageable MessageableChannel = PartialMessageableChannel | GroupChannel - SnowflakeTime = "Snowflake" | datetime +SnowflakeTime: TypeAlias = "Snowflake | datetime" MISSING = utils.MISSING @@ -295,968 +281,6 @@ def is_member(self) -> bool: return self.type == self.MEMBER -GCH = TypeVar("GCH", bound="GuildChannel") - - -class GuildChannel: - """An ABC that details the common operations on a Discord guild channel. - - The following implement this ABC: - - - :class:`~discord.TextChannel` - - :class:`~discord.VoiceChannel` - - :class:`~discord.CategoryChannel` - - :class:`~discord.StageChannel` - - :class:`~discord.ForumChannel` - - This ABC must also implement :class:`~discord.abc.Snowflake`. - - Attributes - ---------- - name: :class:`str` - The channel name. - guild: :class:`~discord.Guild` - The guild the channel belongs to. - position: :class:`int` - The position in the channel list. This is a number that starts at 0. - e.g. the top channel is position 0. - """ - - __slots__ = () - - id: int - name: str - guild: Guild - type: ChannelType - position: int - category_id: int | None - flags: ChannelFlags - _state: ConnectionState - _overwrites: list[_Overwrites] - - if TYPE_CHECKING: - - def __init__(self, *, state: ConnectionState, guild: Guild, data: dict[str, Any]): ... - - def __str__(self) -> str: - return self.name - - @property - def _sorting_bucket(self) -> int: - raise NotImplementedError - - def _update(self, guild: Guild, data: dict[str, Any]) -> None: - raise NotImplementedError - - async def _move( - self, - position: int, - parent_id: Any | None = None, - lock_permissions: bool = False, - *, - reason: str | None, - ) -> None: - if position < 0: - raise InvalidArgument("Channel position cannot be less than 0.") - - http = self._state.http - bucket = self._sorting_bucket - channels: list[GuildChannel] = [c for c in self.guild.channels if c._sorting_bucket == bucket] - - channels.sort(key=lambda c: c.position) - - try: - # remove ourselves from the channel list - channels.remove(self) - except ValueError: - # not there somehow lol - return - else: - index = next( - (i for i, c in enumerate(channels) if c.position >= position), - len(channels), - ) - # add ourselves at our designated position - channels.insert(index, self) - - payload = [] - for index, c in enumerate(channels): - d: dict[str, Any] = {"id": c.id, "position": index} - if parent_id is not MISSING and c.id == self.id: - d.update(parent_id=parent_id, lock_permissions=lock_permissions) - payload.append(d) - - await http.bulk_channel_update(self.guild.id, payload, reason=reason) - - async def _edit(self, options: dict[str, Any], reason: str | None) -> ChannelPayload | None: - try: - parent = options.pop("category") - except KeyError: - parent_id = MISSING - else: - parent_id = parent and parent.id - - try: - options["rate_limit_per_user"] = options.pop("slowmode_delay") - except KeyError: - pass - - try: - options["default_thread_rate_limit_per_user"] = options.pop("default_thread_slowmode_delay") - except KeyError: - pass - - try: - options["flags"] = options.pop("flags").value - except KeyError: - pass - - try: - options["available_tags"] = [tag.to_dict() for tag in options.pop("available_tags")] - except KeyError: - pass - - try: - rtc_region = options.pop("rtc_region") - except KeyError: - pass - else: - options["rtc_region"] = None if rtc_region is None else str(rtc_region) - - try: - video_quality_mode = options.pop("video_quality_mode") - except KeyError: - pass - else: - options["video_quality_mode"] = int(video_quality_mode) - - lock_permissions = options.pop("sync_permissions", False) - - try: - position = options.pop("position") - except KeyError: - if parent_id is not MISSING: - if lock_permissions: - category = self.guild.get_channel(parent_id) - if category: - options["permission_overwrites"] = [c._asdict() for c in category._overwrites] - options["parent_id"] = parent_id - elif lock_permissions and self.category_id is not None: - # if we're syncing permissions on a pre-existing channel category without changing it - # we need to update the permissions to point to the pre-existing category - category = self.guild.get_channel(self.category_id) - if category: - options["permission_overwrites"] = [c._asdict() for c in category._overwrites] - else: - await self._move( - position, - parent_id=parent_id, - lock_permissions=lock_permissions, - reason=reason, - ) - - overwrites = options.get("overwrites") - if overwrites is not None: - perms = [] - for target, perm in overwrites.items(): - if not isinstance(perm, PermissionOverwrite): - raise InvalidArgument(f"Expected PermissionOverwrite received {perm.__class__.__name__}") - - allow, deny = perm.pair() - payload = { - "allow": allow.value, - "deny": deny.value, - "id": target.id, - "type": (_Overwrites.ROLE if isinstance(target, Role) else _Overwrites.MEMBER), - } - - perms.append(payload) - options["permission_overwrites"] = perms - - try: - ch_type = options["type"] - except KeyError: - pass - else: - if not isinstance(ch_type, ChannelType): - raise InvalidArgument("type field must be of type ChannelType") - options["type"] = ch_type.value - - try: - default_reaction_emoji = options["default_reaction_emoji"] - except KeyError: - pass - else: - if isinstance(default_reaction_emoji, _EmojiTag): # GuildEmoji, PartialEmoji - default_reaction_emoji = default_reaction_emoji._to_partial() - elif isinstance(default_reaction_emoji, int): - default_reaction_emoji = PartialEmoji(name=None, id=default_reaction_emoji) - elif isinstance(default_reaction_emoji, str): - default_reaction_emoji = PartialEmoji.from_str(default_reaction_emoji) - elif default_reaction_emoji is None: - pass - else: - raise InvalidArgument("default_reaction_emoji must be of type: GuildEmoji | int | str | None") - - options["default_reaction_emoji"] = ( - default_reaction_emoji._to_forum_reaction_payload() if default_reaction_emoji else None - ) - - if options: - return await self._state.http.edit_channel(self.id, reason=reason, **options) - - def _fill_overwrites(self, data: GuildChannelPayload) -> None: - self._overwrites = [] - everyone_index = 0 - everyone_id = self.guild.id - - for index, overridden in enumerate(data.get("permission_overwrites", [])): - overwrite = _Overwrites(overridden) - self._overwrites.append(overwrite) - - if overwrite.type == _Overwrites.MEMBER: - continue - - if overwrite.id == everyone_id: - # the @everyone role is not guaranteed to be the first one - # in the list of permission overwrites, however the permission - # resolution code kind of requires that it is the first one in - # the list since it is special. So we need the index so we can - # swap it to be the first one. - everyone_index = index - - # do the swap - tmp = self._overwrites - if tmp: - tmp[everyone_index], tmp[0] = tmp[0], tmp[everyone_index] - - @property - def changed_roles(self) -> list[Role]: - """Returns a list of roles that have been overridden from - their default values in the :attr:`~discord.Guild.roles` attribute. - """ - ret = [] - g = self.guild - for overwrite in filter(lambda o: o.is_role(), self._overwrites): - role = g.get_role(overwrite.id) - if role is None: - continue - - role = copy.copy(role) - role.permissions.handle_overwrite(overwrite.allow, overwrite.deny) - ret.append(role) - return ret - - @property - def mention(self) -> str: - """The string that allows you to mention the channel.""" - return f"<#{self.id}>" - - @property - def jump_url(self) -> str: - """Returns a URL that allows the client to jump to the channel. - - .. versionadded:: 2.0 - """ - return f"https://discord.com/channels/{self.guild.id}/{self.id}" - - @property - def created_at(self) -> datetime: - """Returns the channel's creation time in UTC.""" - return utils.snowflake_time(self.id) - - def overwrites_for(self, obj: Role | User) -> PermissionOverwrite: - """Returns the channel-specific overwrites for a member or a role. - - Parameters - ---------- - obj: Union[:class:`~discord.Role`, :class:`~discord.abc.User`] - The role or user denoting - whose overwrite to get. - - Returns - ------- - :class:`~discord.PermissionOverwrite` - The permission overwrites for this object. - """ - - if isinstance(obj, User): - predicate = lambda p: p.is_member() - elif isinstance(obj, Role): - predicate = lambda p: p.is_role() - else: - predicate = lambda p: True - - for overwrite in filter(predicate, self._overwrites): - if overwrite.id == obj.id: - allow = Permissions(overwrite.allow) - deny = Permissions(overwrite.deny) - return PermissionOverwrite.from_pair(allow, deny) - - return PermissionOverwrite() - - @property - def overwrites(self) -> dict[Role | Member, PermissionOverwrite]: - """Returns all of the channel's overwrites. - - This is returned as a dictionary where the key contains the target which - can be either a :class:`~discord.Role` or a :class:`~discord.Member` and the value is the - overwrite as a :class:`~discord.PermissionOverwrite`. - - Returns - ------- - Dict[Union[:class:`~discord.Role`, :class:`~discord.Member`], :class:`~discord.PermissionOverwrite`] - The channel's permission overwrites. - """ - ret = {} - for ow in self._overwrites: - allow = Permissions(ow.allow) - deny = Permissions(ow.deny) - overwrite = PermissionOverwrite.from_pair(allow, deny) - target = None - - if ow.is_role(): - target = self.guild.get_role(ow.id) - elif ow.is_member(): - target = self.guild.get_member(ow.id) - - # TODO: There is potential data loss here in the non-chunked - # case, i.e. target is None because get_member returned nothing. - # This can be fixed with a slight breaking change to the return type, - # i.e. adding discord.Object to the list of it - # However, for now this is an acceptable compromise. - if target is not None: - ret[target] = overwrite - return ret - - @property - def category(self) -> CategoryChannel | None: - """The category this channel belongs to. - - If there is no category then this is ``None``. - """ - return self.guild.get_channel(self.category_id) # type: ignore - - @property - def permissions_synced(self) -> bool: - """Whether the permissions for this channel are synced with the - category it belongs to. - - If there is no category then this is ``False``. - - .. versionadded:: 1.3 - """ - if self.category_id is None: - return False - - category = self.guild.get_channel(self.category_id) - return bool(category and category.overwrites == self.overwrites) - - def permissions_for(self, obj: Member | Role, /) -> Permissions: - """Handles permission resolution for the :class:`~discord.Member` - or :class:`~discord.Role`. - - This function takes into consideration the following cases: - - - Guild owner - - Guild roles - - Channel overrides - - Member overrides - - If a :class:`~discord.Role` is passed, then it checks the permissions - someone with that role would have, which is essentially: - - - The default role permissions - - The permissions of the role used as a parameter - - The default role permission overwrites - - The permission overwrites of the role used as a parameter - - .. versionchanged:: 2.0 - The object passed in can now be a role object. - - Parameters - ---------- - obj: Union[:class:`~discord.Member`, :class:`~discord.Role`] - The object to resolve permissions for. This could be either - a member or a role. If it's a role then member overwrites - are not computed. - - Returns - ------- - :class:`~discord.Permissions` - The resolved permissions for the member or role. - """ - - # The current cases can be explained as: - # Guild owner get all permissions -- no questions asked. Otherwise... - # The @everyone role gets the first application. - # After that, the applied roles that the user has in the channel - # (or otherwise) are then OR'd together. - # After the role permissions are resolved, the member permissions - # have to take into effect. - # After all that is done, you have to do the following: - - # If manage permissions is True, then all permissions are set to True. - - # The operation first takes into consideration the denied - # and then the allowed. - - if self.guild.owner_id == obj.id: - return Permissions.all() - - default = self.guild.default_role - base = Permissions(default.permissions.value if default else 0) - - # Handle the role case first - if isinstance(obj, Role): - base.value |= obj._permissions - - if base.administrator: - return Permissions.all() - - # Apply @everyone allow/deny first since it's special - try: - maybe_everyone = self._overwrites[0] - if maybe_everyone.id == self.guild.id: - base.handle_overwrite(allow=maybe_everyone.allow, deny=maybe_everyone.deny) - except IndexError: - pass - - if obj.is_default(): - return base - - overwrite = utils.find(lambda o: o.type == _Overwrites.ROLE and o.id == obj.id, self._overwrites) - if overwrite is not None: - base.handle_overwrite(overwrite.allow, overwrite.deny) - - return base - - roles = obj._roles - get_role = self.guild.get_role - - # Apply guild roles that the member has. - for role_id in roles: - role = get_role(role_id) - if role is not None: - base.value |= role._permissions - - # Guild-wide Administrator -> True for everything - # Bypass all channel-specific overrides - if base.administrator: - return Permissions.all() - - # Apply @everyone allow/deny first since it's special - try: - maybe_everyone = self._overwrites[0] - if maybe_everyone.id == self.guild.id: - base.handle_overwrite(allow=maybe_everyone.allow, deny=maybe_everyone.deny) - remaining_overwrites = self._overwrites[1:] - else: - remaining_overwrites = self._overwrites - except IndexError: - remaining_overwrites = self._overwrites - - denies = 0 - allows = 0 - - # Apply channel specific role permission overwrites - for overwrite in remaining_overwrites: - if overwrite.is_role() and roles.has(overwrite.id): - denies |= overwrite.deny - allows |= overwrite.allow - - base.handle_overwrite(allow=allows, deny=denies) - - # Apply member specific permission overwrites - for overwrite in remaining_overwrites: - if overwrite.is_member() and overwrite.id == obj.id: - base.handle_overwrite(allow=overwrite.allow, deny=overwrite.deny) - break - - # if you can't send a message in a channel then you can't have certain - # permissions as well - if not base.send_messages: - base.send_tts_messages = False - base.mention_everyone = False - base.embed_links = False - base.attach_files = False - - # if you can't read a channel then you have no permissions there - if not base.read_messages: - denied = Permissions.all_channel() - base.value &= ~denied.value - - return base - - async def delete(self, *, reason: str | None = None) -> None: - """|coro| - - Deletes the channel. - - You must have :attr:`~discord.Permissions.manage_channels` permission to use this. - - Parameters - ---------- - reason: Optional[:class:`str`] - The reason for deleting this channel. - Shows up on the audit log. - - Raises - ------ - ~discord.Forbidden - You do not have proper permissions to delete the channel. - ~discord.NotFound - The channel was not found or was already deleted. - ~discord.HTTPException - Deleting the channel failed. - """ - await self._state.http.delete_channel(self.id, reason=reason) - - @overload - async def set_permissions( - self, - target: Member | Role, - *, - overwrite: PermissionOverwrite | None = ..., - reason: str | None = ..., - ) -> None: ... - - @overload - async def set_permissions( - self, - target: Member | Role, - *, - reason: str | None = ..., - **permissions: bool, - ) -> None: ... - - async def set_permissions(self, target, *, overwrite=MISSING, reason=None, **permissions): - r"""|coro| - - Sets the channel specific permission overwrites for a target in the - channel. - - The ``target`` parameter should either be a :class:`~discord.Member` or a - :class:`~discord.Role` that belongs to guild. - - The ``overwrite`` parameter, if given, must either be ``None`` or - :class:`~discord.PermissionOverwrite`. For convenience, you can pass in - keyword arguments denoting :class:`~discord.Permissions` attributes. If this is - done, then you cannot mix the keyword arguments with the ``overwrite`` - parameter. - - If the ``overwrite`` parameter is ``None``, then the permission - overwrites are deleted. - - You must have the :attr:`~discord.Permissions.manage_roles` permission to use this. - - .. note:: - - This method *replaces* the old overwrites with the ones given. - - Examples - ---------- - - Setting allow and deny: :: - - await message.channel.set_permissions(message.author, read_messages=True, send_messages=False) - - Deleting overwrites :: - - await channel.set_permissions(member, overwrite=None) - - Using :class:`~discord.PermissionOverwrite` :: - - overwrite = discord.PermissionOverwrite() - overwrite.send_messages = False - overwrite.read_messages = True - await channel.set_permissions(member, overwrite=overwrite) - - Parameters - ----------- - target: Union[:class:`~discord.Member`, :class:`~discord.Role`] - The member or role to overwrite permissions for. - overwrite: Optional[:class:`~discord.PermissionOverwrite`] - The permissions to allow and deny to the target, or ``None`` to - delete the overwrite. - \*\*permissions - A keyword argument list of permissions to set for ease of use. - Cannot be mixed with ``overwrite``. - reason: Optional[:class:`str`] - The reason for doing this action. Shows up on the audit log. - - Raises - ------- - ~discord.Forbidden - You do not have permissions to edit channel specific permissions. - ~discord.HTTPException - Editing channel specific permissions failed. - ~discord.NotFound - The role or member being edited is not part of the guild. - ~discord.InvalidArgument - The overwrite parameter invalid or the target type was not - :class:`~discord.Role` or :class:`~discord.Member`. - """ - - http = self._state.http - - if isinstance(target, User): - perm_type = _Overwrites.MEMBER - elif isinstance(target, Role): - perm_type = _Overwrites.ROLE - else: - raise InvalidArgument("target parameter must be either Member or Role") - - if overwrite is MISSING: - if len(permissions) == 0: - raise InvalidArgument("No overwrite provided.") - try: - overwrite = PermissionOverwrite(**permissions) - except (ValueError, TypeError) as e: - raise InvalidArgument("Invalid permissions given to keyword arguments.") from e - elif len(permissions) > 0: - raise InvalidArgument("Cannot mix overwrite and keyword arguments.") - - # TODO: wait for event - - if overwrite is None: - await http.delete_channel_permissions(self.id, target.id, reason=reason) - elif isinstance(overwrite, PermissionOverwrite): - (allow, deny) = overwrite.pair() - await http.edit_channel_permissions(self.id, target.id, allow.value, deny.value, perm_type, reason=reason) - else: - raise InvalidArgument("Invalid overwrite type provided.") - - async def _clone_impl( - self: GCH, - base_attrs: dict[str, Any], - *, - name: str | None = None, - reason: str | None = None, - ) -> GCH: - base_attrs["permission_overwrites"] = [x._asdict() for x in self._overwrites] - base_attrs["parent_id"] = self.category_id - base_attrs["name"] = name or self.name - guild_id = self.guild.id - cls = self.__class__ - data = await self._state.http.create_channel(guild_id, self.type.value, reason=reason, **base_attrs) - obj = cls(state=self._state, guild=self.guild, data=data) - - # temporarily add it to the cache - self.guild._channels[obj.id] = obj # type: ignore - return obj - - async def clone(self: GCH, *, name: str | None = None, reason: str | None = None) -> GCH: - """|coro| - - Clones this channel. This creates a channel with the same properties - as this channel. - - You must have the :attr:`~discord.Permissions.manage_channels` permission to - do this. - - .. versionadded:: 1.1 - - Parameters - ---------- - name: Optional[:class:`str`] - The name of the new channel. If not provided, defaults to this - channel name. - reason: Optional[:class:`str`] - The reason for cloning this channel. Shows up on the audit log. - - Returns - ------- - :class:`.abc.GuildChannel` - The channel that was created. - - Raises - ------ - ~discord.Forbidden - You do not have the proper permissions to create this channel. - ~discord.HTTPException - Creating the channel failed. - """ - raise NotImplementedError - - @overload - async def move( - self, - *, - beginning: bool, - offset: int | utils.Undefined = MISSING, - category: Snowflake | None | utils.Undefined = MISSING, - sync_permissions: bool | utils.Undefined = MISSING, - reason: str | None | utils.Undefined = MISSING, - ) -> None: ... - - @overload - async def move( - self, - *, - end: bool, - offset: int | utils.Undefined = MISSING, - category: Snowflake | None | utils.Undefined = MISSING, - sync_permissions: bool | utils.Undefined = MISSING, - reason: str | utils.Undefined = MISSING, - ) -> None: ... - - @overload - async def move( - self, - *, - before: Snowflake, - offset: int | utils.Undefined = MISSING, - category: Snowflake | None | utils.Undefined = MISSING, - sync_permissions: bool | utils.Undefined = MISSING, - reason: str | utils.Undefined = MISSING, - ) -> None: ... - - @overload - async def move( - self, - *, - after: Snowflake, - offset: int | utils.Undefined = MISSING, - category: Snowflake | None | utils.Undefined = MISSING, - sync_permissions: bool | utils.Undefined = MISSING, - reason: str | utils.Undefined = MISSING, - ) -> None: ... - - async def move(self, **kwargs) -> None: - """|coro| - - A rich interface to help move a channel relative to other channels. - - If exact position movement is required, ``edit`` should be used instead. - - You must have the :attr:`~discord.Permissions.manage_channels` permission to - do this. - - .. note:: - - Voice channels will always be sorted below text channels. - This is a Discord limitation. - - .. versionadded:: 1.7 - - Parameters - ---------- - beginning: :class:`bool` - Whether to move the channel to the beginning of the - channel list (or category if given). - This is mutually exclusive with ``end``, ``before``, and ``after``. - end: :class:`bool` - Whether to move the channel to the end of the - channel list (or category if given). - This is mutually exclusive with ``beginning``, ``before``, and ``after``. - before: :class:`~discord.abc.Snowflake` - The channel that should be before our current channel. - This is mutually exclusive with ``beginning``, ``end``, and ``after``. - after: :class:`~discord.abc.Snowflake` - The channel that should be after our current channel. - This is mutually exclusive with ``beginning``, ``end``, and ``before``. - offset: :class:`int` - The number of channels to offset the move by. For example, - an offset of ``2`` with ``beginning=True`` would move - it 2 after the beginning. A positive number moves it below - while a negative number moves it above. Note that this - number is relative and computed after the ``beginning``, - ``end``, ``before``, and ``after`` parameters. - category: Optional[:class:`~discord.abc.Snowflake`] - The category to move this channel under. - If ``None`` is given then it moves it out of the category. - This parameter is ignored if moving a category channel. - sync_permissions: :class:`bool` - Whether to sync the permissions with the category (if given). - reason: :class:`str` - The reason for the move. - - Raises - ------ - InvalidArgument - An invalid position was given or a bad mix of arguments was passed. - Forbidden - You do not have permissions to move the channel. - HTTPException - Moving the channel failed. - """ - - if not kwargs: - return - - beginning, end = kwargs.get("beginning"), kwargs.get("end") - before, after = kwargs.get("before"), kwargs.get("after") - offset = kwargs.get("offset", 0) - if sum(bool(a) for a in (beginning, end, before, after)) > 1: - raise InvalidArgument("Only one of [before, after, end, beginning] can be used.") - - bucket = self._sorting_bucket - parent_id = kwargs.get("category", MISSING) - channels: list[GuildChannel] - if parent_id not in (MISSING, None): - parent_id = parent_id.id - channels = [ - ch for ch in self.guild.channels if ch._sorting_bucket == bucket and ch.category_id == parent_id - ] - else: - channels = [ - ch for ch in self.guild.channels if ch._sorting_bucket == bucket and ch.category_id == self.category_id - ] - - channels.sort(key=lambda c: (c.position, c.id)) - - try: - # Try to remove ourselves from the channel list - channels.remove(self) - except ValueError: - # If we're not there then it's probably due to not being in the category - pass - - index = None - if beginning: - index = 0 - elif end: - index = len(channels) - elif before: - index = next((i for i, c in enumerate(channels) if c.id == before.id), None) - elif after: - index = next((i + 1 for i, c in enumerate(channels) if c.id == after.id), None) - - if index is None: - raise InvalidArgument("Could not resolve appropriate move position") - - channels.insert(max((index + offset), 0), self) - payload = [] - lock_permissions = kwargs.get("sync_permissions", False) - reason = kwargs.get("reason") - for index, channel in enumerate(channels): - d = {"id": channel.id, "position": index} - if parent_id is not MISSING and channel.id == self.id: - d.update(parent_id=parent_id, lock_permissions=lock_permissions) - payload.append(d) - - await self._state.http.bulk_channel_update(self.guild.id, payload, reason=reason) - - async def create_invite( - self, - *, - reason: str | None = None, - max_age: int = 0, - max_uses: int = 0, - temporary: bool = False, - unique: bool = True, - target_event: ScheduledEvent | None = None, - target_type: InviteTarget | None = None, - target_user: User | None = None, - target_application_id: int | None = None, - ) -> Invite: - """|coro| - - Creates an instant invite from a text or voice channel. - - You must have the :attr:`~discord.Permissions.create_instant_invite` permission to - do this. - - Parameters - ---------- - max_age: :class:`int` - How long the invite should last in seconds. If it's 0 then the invite - doesn't expire. Defaults to ``0``. - max_uses: :class:`int` - How many uses the invite could be used for. If it's 0 then there - are unlimited uses. Defaults to ``0``. - temporary: :class:`bool` - Denotes that the invite grants temporary membership - (i.e. they get kicked after they disconnect). Defaults to ``False``. - unique: :class:`bool` - Indicates if a unique invite URL should be created. Defaults to True. - If this is set to ``False`` then it will return a previously created - invite. - reason: Optional[:class:`str`] - The reason for creating this invite. Shows up on the audit log. - target_type: Optional[:class:`.InviteTarget`] - The type of target for the voice channel invite, if any. - - .. versionadded:: 2.0 - - target_user: Optional[:class:`User`] - The user whose stream to display for this invite, required if `target_type` is `TargetType.stream`. - The user must be streaming in the channel. - - .. versionadded:: 2.0 - - target_application_id: Optional[:class:`int`] - The id of the embedded application for the invite, required if `target_type` is - `TargetType.embedded_application`. - - .. versionadded:: 2.0 - - target_event: Optional[:class:`.ScheduledEvent`] - The scheduled event object to link to the event. - Shortcut to :meth:`.Invite.set_scheduled_event` - - See :meth:`.Invite.set_scheduled_event` for more - info on event invite linking. - - .. versionadded:: 2.0 - - Returns - ------- - :class:`~discord.Invite` - The invite that was created. - - Raises - ------ - ~discord.HTTPException - Invite creation failed. - - ~discord.NotFound - The channel that was passed is a category or an invalid channel. - """ - - data = await self._state.http.create_invite( - self.id, - reason=reason, - max_age=max_age, - max_uses=max_uses, - temporary=temporary, - unique=unique, - target_type=target_type.value if target_type else None, - target_user_id=target_user.id if target_user else None, - target_application_id=target_application_id, - ) - invite = Invite.from_incomplete(data=data, state=self._state) - if target_event: - invite.set_scheduled_event(target_event) - return invite - - async def invites(self) -> list[Invite]: - """|coro| - - Returns a list of all active instant invites from this channel. - - You must have :attr:`~discord.Permissions.manage_channels` to get this information. - - Returns - ------- - List[:class:`~discord.Invite`] - The list of invites that are currently active. - - Raises - ------ - ~discord.Forbidden - You do not have proper permissions to get the information. - ~discord.HTTPException - An error occurred while fetching the information. - """ - - state = self._state - data = await state.http.invites_from_channel(self.id) - guild = self.guild - return [Invite(state=state, data=invite, channel=self, guild=guild) for invite in data] - - class Messageable: """An ABC that details the common operations on a model that can send messages. @@ -1527,7 +551,7 @@ async def send( if reference is not None: try: _reference = reference.to_message_reference_dict() - from .message import MessageReference # noqa: PLC0415 + from .message import MessageReference if not isinstance(reference, MessageReference): warn_deprecated( @@ -1611,7 +635,7 @@ async def send( ret = state.create_message(channel=channel, data=data) if view: if view.is_dispatchable(): - state.store_view(view, ret.id) + await state.store_view(view, ret.id) view.message = ret view.refresh(ret.components) @@ -1961,6 +985,10 @@ async def connect( return voice -class Mentionable: - # TODO: documentation, methods if needed - pass +@runtime_checkable +class Mentionable(Protocol): + """An ABC that details the common operations on an object that can + be mentioned. + """ + + def mention(self) -> str: ... diff --git a/discord/app/cache.py b/discord/app/cache.py new file mode 100644 index 0000000000..7f97b5f837 --- /dev/null +++ b/discord/app/cache.py @@ -0,0 +1,425 @@ +""" +The MIT License (MIT) + +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from collections import OrderedDict, defaultdict, deque +from typing import TYPE_CHECKING, Deque, Protocol, TypeVar + +from discord import utils +from discord.member import Member +from discord.message import Message +from discord.soundboard import SoundboardSound + +from ..channel import DMChannel +from ..emoji import AppEmoji, GuildEmoji +from ..guild import Guild +from ..poll import Poll +from ..sticker import GuildSticker, Sticker +from ..types.channel import DMChannel as DMChannelPayload +from ..types.emoji import Emoji as EmojiPayload +from ..types.message import Message as MessagePayload +from ..types.sticker import GuildSticker as GuildStickerPayload +from ..types.user import User as UserPayload +from ..ui.modal import Modal +from ..ui.view import View +from ..user import User + +if TYPE_CHECKING: + from discord.app.state import ConnectionState + + from ..abc import MessageableChannel, PrivateChannel + +T = TypeVar("T") + + +class Cache(Protocol): + def __init__(self): + self.__state: ConnectionState | None = None + + @property + def _state(self) -> "ConnectionState": + if self.__state is None: + raise RuntimeError("Cache state has not been initialized.") + return self.__state + + @_state.setter + def _state(self, state: "ConnectionState") -> None: + self.__state = state + + # users + async def get_all_users(self) -> list[User]: ... + + async def store_user(self, payload: UserPayload) -> User: ... + + async def delete_user(self, user_id: int) -> None: ... + + async def get_user(self, user_id: int) -> User | None: ... + + # stickers + + async def get_all_stickers(self) -> list[GuildSticker]: ... + + async def get_sticker(self, sticker_id: int) -> GuildSticker: ... + + async def store_sticker(self, guild: Guild, data: GuildStickerPayload) -> GuildSticker: ... + + async def delete_sticker(self, sticker_id: int) -> None: ... + + # interactions + + async def store_view(self, view: View, message_id: int | None) -> None: ... + + async def delete_view_on(self, message_id: int) -> None: ... + + async def get_all_views(self) -> list[View]: ... + + async def store_modal(self, modal: Modal, user_id: int) -> None: ... + + async def delete_modal(self, custom_id: str) -> None: ... + + async def get_all_modals(self) -> list[Modal]: ... + + # guilds + + async def get_all_guilds(self) -> list[Guild]: ... + + async def get_guild(self, id: int) -> Guild | None: ... + + async def add_guild(self, guild: Guild) -> None: ... + + async def delete_guild(self, guild: Guild) -> None: ... + + # emojis + + async def store_guild_emoji(self, guild: Guild, data: EmojiPayload) -> GuildEmoji: ... + + async def store_app_emoji(self, application_id: int, data: EmojiPayload) -> AppEmoji: ... + + async def get_all_emojis(self) -> list[GuildEmoji | AppEmoji]: ... + + async def get_emoji(self, emoji_id: int | None) -> GuildEmoji | AppEmoji | None: ... + + async def delete_emoji(self, emoji: GuildEmoji | AppEmoji) -> None: ... + + # polls + + async def get_all_polls(self) -> list[Poll]: ... + + async def get_poll(self, message_id: int) -> Poll: ... + + async def store_poll(self, poll: Poll, message_id: int) -> None: ... + + # private channels + + async def get_private_channels(self) -> "list[PrivateChannel]": ... + + async def get_private_channel(self, channel_id: int) -> "PrivateChannel": ... + + async def get_private_channel_by_user(self, user_id: int) -> "PrivateChannel | None": ... + + async def store_private_channel(self, channel: "PrivateChannel") -> None: ... + + # messages + + async def store_message(self, message: MessagePayload, channel: "MessageableChannel") -> Message: ... + + async def store_built_message(self, message: Message) -> None: ... + + async def upsert_message(self, message: Message) -> None: ... + + async def delete_message(self, message_id: int) -> None: ... + + async def get_message(self, message_id: int) -> Message | None: ... + + async def get_all_messages(self) -> list[Message]: ... + + # guild members + + async def store_member(self, member: Member) -> None: ... + + async def get_member(self, guild_id: int, user_id: int) -> Member | None: ... + + async def delete_member(self, guild_id: int, user_id: int) -> None: ... + + async def delete_guild_members(self, guild_id: int) -> None: ... + + async def get_guild_members(self, guild_id: int) -> list[Member]: ... + + async def get_all_members(self) -> list[Member]: ... + + async def clear(self, views: bool = True) -> None: ... + + async def store_sound(self, sound: SoundboardSound) -> None: ... + + async def get_sound(self, sound_id: int) -> SoundboardSound | None: ... + + async def get_all_sounds(self) -> list[SoundboardSound]: ... + + async def delete_sound(self, sound_id: int) -> None: ... + + +class MemoryCache(Cache): + def __init__(self, max_messages: int | None = None) -> None: + self.__state: ConnectionState | None = None + self.max_messages = max_messages + self._users: dict[int, User] = {} + self._guilds: dict[int, Guild] = {} + self._polls: dict[int, Poll] = {} + self._stickers: dict[int, list[GuildSticker]] = {} + self._views: dict[str, View] = {} + self._modals: dict[str, Modal] = {} + self._sounds: dict[int, SoundboardSound] = {} + self._messages: Deque[Message] = deque(maxlen=self.max_messages) + + self._emojis: dict[int, list[GuildEmoji | AppEmoji]] = {} + + self._private_channels: OrderedDict[int, PrivateChannel] = OrderedDict() + self._private_channels_by_user: dict[int, DMChannel] = {} + + self._guild_members: dict[int, dict[int, Member]] = defaultdict(dict) + + def _flatten(self, matrix: list[list[T]]) -> list[T]: + return [item for row in matrix for item in row] + + async def clear(self, views: bool = True) -> None: + self._users: dict[int, User] = {} + self._guilds: dict[int, Guild] = {} + self._polls: dict[int, Poll] = {} + self._stickers: dict[int, list[GuildSticker]] = {} + if views: + self._views: dict[str, View] = {} + self._modals: dict[str, Modal] = {} + self._messages: Deque[Message] = deque(maxlen=self.max_messages) + + self._emojis: dict[int, list[GuildEmoji | AppEmoji]] = {} + + self._private_channels: OrderedDict[int, PrivateChannel] = OrderedDict() + self._private_channels_by_user: dict[int, DMChannel] = {} + + self._guild_members: dict[int, dict[int, Member]] = defaultdict(dict) + + # users + async def get_all_users(self) -> list[User]: + return list(self._users.values()) + + async def store_user(self, payload: UserPayload) -> User: + user_id = int(payload["id"]) + try: + return self._users[user_id] + except KeyError: + user = User(state=self._state, data=payload) + if user.discriminator != "0000": + self._users[user_id] = user + user._stored = True + return user + + async def delete_user(self, user_id: int) -> None: + self._users.pop(user_id, None) + + async def get_user(self, user_id: int) -> User | None: + return self._users.get(user_id) + + # stickers + + async def get_all_stickers(self) -> list[GuildSticker]: + return self._flatten(list(self._stickers.values())) + + async def get_sticker(self, sticker_id: int) -> GuildSticker | None: + stickers = self._flatten(list(self._stickers.values())) + for sticker in stickers: + if sticker.id == sticker_id: + return sticker + + async def store_sticker(self, guild: Guild, data: GuildStickerPayload) -> GuildSticker: + sticker = GuildSticker(state=self._state, data=data) + try: + self._stickers[guild.id].append(sticker) + except KeyError: + self._stickers[guild.id] = [sticker] + return sticker + + async def delete_sticker(self, sticker_id: int) -> None: + self._stickers.pop(sticker_id, None) + + # interactions + + async def delete_view_on(self, message_id: int) -> View | None: + for view in await self.get_all_views(): + if view.message and view.message.id == message_id: + return view + + async def store_view(self, view: View, message_id: int) -> None: + self._views[str(message_id or view.id)] = view + + async def get_all_views(self) -> list[View]: + return list(self._views.values()) + + async def store_modal(self, modal: Modal) -> None: + self._modals[modal.custom_id] = modal + + async def get_all_modals(self) -> list[Modal]: + return list(self._modals.values()) + + # guilds + + async def get_all_guilds(self) -> list[Guild]: + return list(self._guilds.values()) + + async def get_guild(self, id: int) -> Guild | None: + return self._guilds.get(id) + + async def add_guild(self, guild: Guild) -> None: + self._guilds[guild.id] = guild + + async def delete_guild(self, guild: Guild) -> None: + self._guilds.pop(guild.id, None) + + # emojis + + async def store_guild_emoji(self, guild: Guild, data: EmojiPayload) -> GuildEmoji: + emoji = GuildEmoji(guild=guild, state=self._state, data=data) + try: + self._emojis[guild.id].append(emoji) + except KeyError: + self._emojis[guild.id] = [emoji] + return emoji + + async def store_app_emoji(self, application_id: int, data: EmojiPayload) -> AppEmoji: + emoji = AppEmoji(application_id=application_id, state=self._state, data=data) + try: + self._emojis[application_id].append(emoji) + except KeyError: + self._emojis[application_id] = [emoji] + return emoji + + async def get_all_emojis(self) -> list[GuildEmoji | AppEmoji]: + return self._flatten(list(self._emojis.values())) + + async def get_emoji(self, emoji_id: int | None) -> GuildEmoji | AppEmoji | None: + emojis = self._flatten(list(self._emojis.values())) + for emoji in emojis: + if emoji.id == emoji_id: + return emoji + + async def delete_emoji(self, emoji: GuildEmoji | AppEmoji) -> None: + if isinstance(emoji, AppEmoji): + self._emojis[emoji.application_id].remove(emoji) + else: + self._emojis[emoji.guild_id].remove(emoji) + + # polls + + async def get_all_polls(self) -> list[Poll]: + return list(self._polls.values()) + + async def get_poll(self, message_id: int) -> Poll | None: + return self._polls.get(message_id) + + async def store_poll(self, poll: Poll, message_id: int) -> None: + self._polls[message_id] = poll + + # private channels + + async def get_private_channels(self) -> "list[PrivateChannel]": + return list(self._private_channels.values()) + + async def get_private_channel(self, channel_id: int) -> "PrivateChannel | None": + try: + channel = self._private_channels[channel_id] + except KeyError: + return None + else: + self._private_channels.move_to_end(channel_id) + return channel + + async def store_private_channel(self, channel: "PrivateChannel") -> None: + channel_id = channel.id + self._private_channels[channel_id] = channel + + if len(self._private_channels) > 128: + _, to_remove = self._private_channels.popitem(last=False) + if isinstance(to_remove, DMChannel) and to_remove.recipient: + self._private_channels_by_user.pop(to_remove.recipient.id, None) + + if isinstance(channel, DMChannel) and channel.recipient: + self._private_channels_by_user[channel.recipient.id] = channel + + async def get_private_channel_by_user(self, user_id: int) -> "PrivateChannel | None": + return self._private_channels_by_user.get(user_id) + + # messages + + async def upsert_message(self, message: Message) -> None: + self._messages.append(message) + + async def store_message(self, message: MessagePayload, channel: "MessageableChannel") -> Message: + msg = await Message._from_data(state=self._state, channel=channel, data=message) + self.store_built_message(msg) + return msg + + async def store_built_message(self, message: Message) -> None: + self._messages.append(message) + + async def delete_message(self, message_id: int) -> None: + self._messages.remove(utils.find(lambda m: m.id == message_id, reversed(self._messages))) + + async def get_message(self, message_id: int) -> Message | None: + return utils.find(lambda m: m.id == message_id, reversed(self._messages)) + + async def get_all_messages(self) -> list[Message]: + return list(self._messages) + + async def delete_modal(self, custom_id: str) -> None: + self._modals.pop(custom_id, None) + + # guild members + + async def store_member(self, member: Member) -> None: + self._guild_members[member.guild.id][member.id] = member + + async def get_member(self, guild_id: int, user_id: int) -> Member | None: + return self._guild_members[guild_id].get(user_id) + + async def delete_member(self, guild_id: int, user_id: int) -> None: + self._guild_members[guild_id].pop(user_id, None) + + async def delete_guild_members(self, guild_id: int) -> None: + self._guild_members.pop(guild_id, None) + + async def get_guild_members(self, guild_id: int) -> list[Member]: + return list(self._guild_members.get(guild_id, {}).values()) + + async def get_all_members(self) -> list[Member]: + return self._flatten([list(members.values()) for members in self._guild_members.values()]) + + async def store_sound(self, sound: SoundboardSound) -> None: + self._sounds[sound.id] = sound + + async def get_sound(self, sound_id: int) -> SoundboardSound | None: + return self._sounds.get(sound_id) + + async def get_all_sounds(self) -> list[SoundboardSound]: + return list(self._sounds.values()) + + async def delete_sound(self, sound_id: int) -> None: + self._sounds.pop(sound_id, None) diff --git a/discord/app/event_emitter.py b/discord/app/event_emitter.py new file mode 100644 index 0000000000..5af1b6a34a --- /dev/null +++ b/discord/app/event_emitter.py @@ -0,0 +1,123 @@ +""" +The MIT License (MIT) + +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import asyncio +from abc import ABC, abstractmethod +from collections import defaultdict +from collections.abc import Awaitable, Coroutine +from typing import TYPE_CHECKING, Any, Callable, Protocol, TypeAlias, TypeVar + +from typing_extensions import Self + +if TYPE_CHECKING: + from .state import ConnectionState + +T = TypeVar("T", bound="Event") + + +class Event(ABC): + __event_name__: str + + @classmethod + @abstractmethod + async def __load__(cls, data: Any, state: "ConnectionState") -> Self | None: ... + + def _populate_from_slots(self, obj: Any) -> None: + """ + Populate this event instance with attributes from another object. + + Handles both __slots__ and __dict__ based objects. + + Parameters + ---------- + obj: Any + The object to copy attributes from. + """ + # Collect all slots from the object's class hierarchy + slots = set() + for klass in type(obj).__mro__: + if hasattr(klass, "__slots__"): + slots.update(klass.__slots__) + + # Copy slot attributes + for slot in slots: + if hasattr(obj, slot): + try: + setattr(self, slot, getattr(obj, slot)) + except AttributeError: + # Some slots might be read-only or not settable + pass + + # Also copy __dict__ if it exists + if hasattr(obj, "__dict__"): + for key, value in obj.__dict__.items(): + try: + setattr(self, key, value) + except AttributeError: + pass + + +ListenerCallback: TypeAlias = Callable[[Event], Any] + + +class EventReciever(Protocol): + def __call__(self, event: Event) -> Awaitable[Any]: ... + + +class EventEmitter: + def __init__(self, state: "ConnectionState") -> None: + self._receivers: list[EventReciever] = [] + self._events: dict[str, list[type[Event]]] = defaultdict(list) + self._state: ConnectionState = state + + from ..events import ALL_EVENTS + + for event_cls in ALL_EVENTS: + self.add_event(event_cls) + + def add_event(self, event: type[Event]) -> None: + self._events[event.__event_name__].append(event) + + def remove_event(self, event: type[Event]) -> list[type[Event]] | None: + return self._events.pop(event.__event_name__, None) + + def add_receiver(self, receiver: EventReciever) -> None: + self._receivers.append(receiver) + + def remove_receiver(self, receiver: EventReciever) -> None: + self._receivers.remove(receiver) + + async def emit(self, event_str: str, data: Any) -> None: + events = self._events.get(event_str, []) + + coros: list[Awaitable[None]] = [] + for event_cls in events: + event = await event_cls.__load__(data=data, state=self._state) + + if event is None: + continue + + coros.extend(receiver(event) for receiver in self._receivers) + + await asyncio.gather(*coros) diff --git a/discord/app/state.py b/discord/app/state.py new file mode 100644 index 0000000000..d62828e729 --- /dev/null +++ b/discord/app/state.py @@ -0,0 +1,760 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import asyncio +import copy +import inspect +import itertools +import logging +import os +from collections import OrderedDict, deque +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Deque, + Sequence, + TypeVar, + Union, + cast, +) + +from discord.soundboard import SoundboardSound + +from .. import utils +from ..activity import BaseActivity +from ..automod import AutoModRule +from ..channel import * +from ..channel import _channel_factory +from ..channel.thread import Thread, ThreadMember +from ..emoji import AppEmoji, GuildEmoji +from ..enums import ChannelType, InteractionType, Status, try_enum +from ..flags import ApplicationFlags, Intents, MemberCacheFlags +from ..guild import Guild +from ..interactions import Interaction +from ..invite import Invite +from ..member import Member +from ..mentions import AllowedMentions +from ..message import Message +from ..monetization import Entitlement, Subscription +from ..object import Object +from ..partial_emoji import PartialEmoji +from ..poll import Poll, PollAnswerCount +from ..raw_models import * +from ..role import Role +from ..sticker import GuildSticker +from ..ui.modal import Modal +from ..ui.view import View +from ..user import ClientUser, User +from ..utils.private import get_as_snowflake, parse_time, sane_wait_for +from .cache import Cache +from .event_emitter import EventEmitter + +if TYPE_CHECKING: + from ..abc import PrivateChannel + from ..client import Client + from ..gateway import DiscordWebSocket + from ..guild import GuildChannel, VocalGuildChannel + from ..http import HTTPClient + from ..message import MessageableChannel + from ..types.activity import Activity as ActivityPayload + from ..types.channel import DMChannel as DMChannelPayload + from ..types.emoji import Emoji as EmojiPayload + from ..types.guild import Guild as GuildPayload + from ..types.message import Message as MessagePayload + from ..types.poll import Poll as PollPayload + from ..types.sticker import GuildSticker as GuildStickerPayload + from ..types.user import User as UserPayload + from ..voice_client import VoiceClient + + T = TypeVar("T") + CS = TypeVar("CS", bound="ConnectionState") + Channel = GuildChannel | VocalGuildChannel | PrivateChannel | PartialMessageable + + +class ChunkRequest: + def __init__( + self, + guild_id: int, + loop: asyncio.AbstractEventLoop, + resolver: Callable[[int], Any], + *, + cache: bool = True, + ) -> None: + self.guild_id: int = guild_id + self.resolver: Callable[[int], Any] = resolver + self.loop: asyncio.AbstractEventLoop = loop + self.cache: bool = cache + self.nonce: str = os.urandom(16).hex() + self.buffer: list[Member] = [] + self.waiters: list[asyncio.Future[list[Member]]] = [] + + async def add_members(self, members: list[Member]) -> None: + self.buffer.extend(members) + if self.cache: + guild = self.resolver(self.guild_id) + if inspect.isawaitable(guild): + guild = await guild + if guild is None: + return + + for member in members: + existing = await guild.get_member(member.id) + if existing is None or existing.joined_at is None: + await guild._add_member(member) + + async def wait(self) -> list[Member]: + future = self.loop.create_future() + self.waiters.append(future) + try: + return await future + finally: + self.waiters.remove(future) + + def get_future(self) -> asyncio.Future[list[Member]]: + future = self.loop.create_future() + self.waiters.append(future) + return future + + def done(self) -> None: + for future in self.waiters: + if not future.done(): + future.set_result(self.buffer) + + +_log = logging.getLogger(__name__) + + +async def logging_coroutine(coroutine: Coroutine[Any, Any, T], *, info: str) -> None: + try: + await coroutine + except Exception: + _log.exception("Exception occurred during %s", info) + + +class ConnectionState: + if TYPE_CHECKING: + _get_websocket: Callable[..., DiscordWebSocket] + _get_client: Callable[..., Client] + _parsers: dict[str, Callable[[dict[str, Any]], None]] + + def __init__( + self, + *, + cache: Cache, + handlers: dict[str, Callable], + hooks: dict[str, Callable], + http: HTTPClient, + loop: asyncio.AbstractEventLoop, + **options: Any, + ) -> None: + self.loop: asyncio.AbstractEventLoop = loop + self.http: HTTPClient = http + self.max_messages: int | None = options.get("max_messages", 1000) + if self.max_messages is not None and self.max_messages <= 0: + self.max_messages = 1000 + + self.handlers: dict[str, Callable] = handlers + self.hooks: dict[str, Callable] = hooks + self.shard_count: int | None = None + self._ready_task: asyncio.Task | None = None + self.application_id: int | None = get_as_snowflake(options, "application_id") + self.application_flags: ApplicationFlags | None = None + self.heartbeat_timeout: float = options.get("heartbeat_timeout", 60.0) + self.guild_ready_timeout: float = options.get("guild_ready_timeout", 2.0) + if self.guild_ready_timeout < 0: + raise ValueError("guild_ready_timeout cannot be negative") + + allowed_mentions = options.get("allowed_mentions") + + if allowed_mentions is not None and not isinstance(allowed_mentions, AllowedMentions): + raise TypeError("allowed_mentions parameter must be AllowedMentions") + + self.allowed_mentions: AllowedMentions | None = allowed_mentions + self._chunk_requests: dict[int | str, ChunkRequest] = {} + + activity = options.get("activity", None) + if activity: + if not isinstance(activity, BaseActivity): + raise TypeError("activity parameter must derive from BaseActivity.") + + activity = activity.to_dict() + + status = options.get("status", None) + if status: + status = "invisible" if status is Status.offline else str(status) + intents = options.get("intents", None) + if intents is None: + intents = Intents.default() + + elif not isinstance(intents, Intents): + raise TypeError(f"intents parameter must be Intent not {type(intents)!r}") + if not intents.guilds: + _log.warning("Guilds intent seems to be disabled. This may cause state related issues.") + + self._chunk_guilds: bool = options.get("chunk_guilds_at_startup", intents.members) + + # Ensure these two are set properly + if not intents.members and self._chunk_guilds: + raise ValueError("Intents.members must be enabled to chunk guilds at startup.") + + cache_flags = options.get("member_cache_flags", None) + if cache_flags is None: + cache_flags = MemberCacheFlags.from_intents(intents) + elif not isinstance(cache_flags, MemberCacheFlags): + raise TypeError(f"member_cache_flags parameter must be MemberCacheFlags not {type(cache_flags)!r}") + + else: + cache_flags._verify_intents(intents) + + self.member_cache_flags: MemberCacheFlags = cache_flags + self._activity: ActivityPayload | None = activity + self._status: str | None = status + self._intents: Intents = intents + self._voice_clients: dict[int, VoiceClient] = {} + + if not intents.members or cache_flags._empty: + self.store_user = self.create_user_async # type: ignore + self.deref_user = self.deref_user_no_intents # type: ignore + + self.cache_app_emojis: bool = options.get("cache_app_emojis", False) + + self.emitter: EventEmitter = EventEmitter(self) + + self.cache: Cache = cache + self.cache._state = self + + async def clear(self, *, views: bool = True) -> None: + self.user: ClientUser | None = None + await self.cache.clear() + self._voice_clients = {} + + async def process_chunk_requests( + self, guild_id: int, nonce: str | None, members: list[Member], complete: bool + ) -> None: + removed = [] + for key, request in self._chunk_requests.items(): + if request.guild_id == guild_id and request.nonce == nonce: + await request.add_members(members) + if complete: + request.done() + removed.append(key) + + for key in removed: + del self._chunk_requests[key] + + def call_handlers(self, key: str, *args: Any, **kwargs: Any) -> None: + try: + func = self.handlers[key] + except KeyError: + pass + else: + func(*args, **kwargs) + + async def call_hooks(self, key: str, *args: Any, **kwargs: Any) -> None: + try: + coro = self.hooks[key] + except KeyError: + pass + else: + await coro(*args, **kwargs) + + @property + def self_id(self) -> int | None: + u = self.user + return u.id if u else None + + @property + def intents(self) -> Intents: + ret = Intents.none() + ret.value = self._intents.value + return ret + + @property + def voice_clients(self) -> list[VoiceClient]: + return list(self._voice_clients.values()) + + def _get_voice_client(self, guild_id: int | None) -> VoiceClient | None: + # the keys of self._voice_clients are ints + return self._voice_clients.get(guild_id) # type: ignore + + def _add_voice_client(self, guild_id: int, voice: VoiceClient) -> None: + self._voice_clients[guild_id] = voice + + def _remove_voice_client(self, guild_id: int) -> None: + self._voice_clients.pop(guild_id, None) + + def _update_references(self, ws: DiscordWebSocket) -> None: + for vc in self.voice_clients: + vc.main_ws = ws # type: ignore + + async def store_user(self, data: UserPayload) -> User: + return await self.cache.store_user(data) + + async def deref_user(self, user_id: int) -> None: + return await self.cache.delete_user(user_id) + + def create_user(self, data: UserPayload) -> User: + return User(state=self, data=data) + + async def create_user_async(self, data: UserPayload) -> User: + return User(state=self, data=data) + + def deref_user_no_intents(self, user_id: int) -> None: + return + + async def get_user(self, id: int | None) -> User | None: + return await self.cache.get_user(cast(int, id)) + + async def store_emoji(self, guild: Guild, data: EmojiPayload) -> GuildEmoji: + return await self.cache.store_guild_emoji(guild, data) + + async def maybe_store_app_emoji(self, application_id: int, data: EmojiPayload) -> AppEmoji: + # the id will be present here + emoji = AppEmoji(application_id=application_id, state=self, data=data) + if self.cache_app_emojis: + await self.cache.store_app_emoji(application_id, data) + return emoji + + async def store_sticker(self, guild: Guild, data: GuildStickerPayload) -> GuildSticker: + return await self.cache.store_sticker(guild, data) + + async def store_view(self, view: View, message_id: int | None = None) -> None: + await self.cache.store_view(view, message_id) + + async def store_modal(self, modal: Modal, user_id: int) -> None: + await self.cache.store_modal(modal, user_id) + + async def prevent_view_updates_for(self, message_id: int) -> View | None: + return await self.cache.delete_view_on(message_id) + + async def get_persistent_views(self) -> Sequence[View]: + views = await self.cache.get_all_views() + persistent_views = {view.id: view for view in views if view.is_persistent()} + return list(persistent_views.values()) + + async def get_guilds(self) -> list[Guild]: + return await self.cache.get_all_guilds() + + async def _get_guild(self, guild_id: int | None) -> Guild | None: + return await self.cache.get_guild(cast(int, guild_id)) + + async def _add_guild(self, guild: Guild) -> None: + await self.cache.add_guild(guild) + + async def _remove_guild(self, guild: Guild) -> None: + await self.cache.delete_guild(guild) + + for emoji in guild.emojis: + await self.cache.delete_emoji(emoji) + + for sticker in guild.stickers: + await self.cache.delete_sticker(sticker.id) + + del guild + + async def _add_default_sounds(self) -> None: + default_sounds = await self.http.get_default_sounds() + for default_sound in default_sounds: + sound = SoundboardSound(state=self, http=self.http, data=default_sound) + await self._add_sound(sound) + + async def _add_sound(self, sound: SoundboardSound) -> None: + await self.cache.store_sound(sound) + + async def _remove_sound(self, sound: SoundboardSound) -> None: + await self.cache.delete_sound(sound.id) + + async def get_sounds(self) -> list[SoundboardSound]: + return list(await self.cache.get_all_sounds()) + + async def get_emojis(self) -> list[GuildEmoji | AppEmoji]: + return await self.cache.get_all_emojis() + + async def get_stickers(self) -> list[GuildSticker]: + return await self.cache.get_all_stickers() + + async def get_emoji(self, emoji_id: int | None) -> GuildEmoji | AppEmoji | None: + return await self.cache.get_emoji(emoji_id) + + async def _remove_emoji(self, emoji: GuildEmoji | AppEmoji) -> None: + await self.cache.delete_emoji(emoji) + + async def get_sticker(self, sticker_id: int | None) -> GuildSticker | None: + return await self.cache.get_sticker(cast(int, sticker_id)) + + async def get_polls(self) -> list[Poll]: + return await self.cache.get_all_polls() + + async def store_poll(self, poll: Poll, message_id: int): + await self.cache.store_poll(poll, message_id) + + async def get_poll(self, message_id: int): + return await self.cache.get_poll(message_id) + + async def get_private_channels(self) -> list[PrivateChannel]: + return await self.cache.get_private_channels() + + async def _get_private_channel(self, channel_id: int | None) -> PrivateChannel | None: + return await self.cache.get_private_channel(cast(int, channel_id)) + + async def _get_private_channel_by_user(self, user_id: int | None) -> DMChannel | None: + return cast(DMChannel | None, await self.cache.get_private_channel_by_user(cast(int, user_id))) + + async def _add_private_channel(self, channel: PrivateChannel) -> None: + await self.cache.store_private_channel(channel) + + async def add_dm_channel(self, data: DMChannelPayload) -> DMChannel: + # self.user is *always* cached when this is called + channel = DMChannel(me=self.user, state=self, data=data) # type: ignore + await channel._load() + await self._add_private_channel(channel) + return channel + + async def _get_message(self, msg_id: int | None) -> Message | None: + return await self.cache.get_message(cast(int, msg_id)) + + def _guild_needs_chunking(self, guild: Guild) -> bool: + # If presences are enabled then we get back the old guild.large behaviour + return self._chunk_guilds and not guild.chunked and not (self._intents.presences and not guild.large) + + async def _get_guild_channel( + self, data: MessagePayload, guild_id: int | None = None + ) -> tuple[Channel | Thread, Guild | None]: + channel_id = int(data["channel_id"]) + try: + # guild_id is in data + guild = await self._get_guild(int(guild_id or data["guild_id"])) # type: ignore + except KeyError: + channel = DMChannel(id=channel_id, state=self) + guild = None + else: + channel = guild and guild._resolve_channel(channel_id) + + return channel or PartialMessageable(state=self, id=channel_id), guild + + async def chunker( + self, + guild_id: int, + query: str = "", + limit: int = 0, + presences: bool = False, + *, + nonce: str | None = None, + ) -> None: + ws = self._get_websocket(guild_id) # This is ignored upstream + await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) + + async def query_members( + self, + guild: Guild, + query: str | None, + limit: int, + user_ids: list[int] | None, + cache: bool, + presences: bool, + ): + guild_id = guild.id + ws = self._get_websocket(guild_id) + if ws is None: + raise RuntimeError("Somehow do not have a websocket for this guild_id") + + request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache) + self._chunk_requests[request.nonce] = request + + try: + # start the query operation + await ws.request_chunks( + guild_id, + query=query, + limit=limit, + user_ids=user_ids, + presences=presences, + nonce=request.nonce, + ) + return await asyncio.wait_for(request.wait(), timeout=30.0) + except asyncio.TimeoutError: + _log.warning( + ("Timed out waiting for chunks with query %r and limit %d for guild_id %d"), + query, + limit, + guild_id, + ) + raise + + async def _get_create_guild(self, data): + if data.get("unavailable") is False: + # GUILD_CREATE with unavailable in the response + # usually means that the guild has become available + # and is therefore in the cache + guild = await self._get_guild(int(data["id"])) + if guild is not None: + guild.unavailable = False + await guild._from_data(data, self) + return guild + + return self._add_guild_from_data(data) + + def is_guild_evicted(self, guild) -> bool: + return guild.id not in self._guilds + + async def chunk_guild(self, guild, *, wait=True, cache=None): + # Note: This method makes an API call without timeout, and should be used in + # conjunction with `asyncio.wait_for(..., timeout=...)`. + cache = cache or self.member_cache_flags.joined + request = self._chunk_requests.get(guild.id) # nosec B113 + if request is None: + self._chunk_requests[guild.id] = request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache) + await self.chunker(guild.id, nonce=request.nonce) + + if wait: + return await request.wait() + return request.get_future() + + async def _chunk_and_dispatch(self, guild, unavailable): + try: + await asyncio.wait_for(self.chunk_guild(guild), timeout=60.0) + except asyncio.TimeoutError: + _log.info("Somehow timed out waiting for chunks.") + + if unavailable is False: + self.dispatch("guild_available", guild) + else: + self.dispatch("guild_join", guild) + + def parse_guild_members_chunk(self, data) -> None: + guild_id = int(data["guild_id"]) + guild = self._get_guild(guild_id) + presences = data.get("presences", []) + + # the guild won't be None here + members = [Member(guild=guild, data=member, state=self) for member in data.get("members", [])] # type: ignore + _log.debug("Processed a chunk for %s members in guild ID %s.", len(members), guild_id) + + if presences: + member_dict = {str(member.id): member for member in members} + for presence in presences: + user = presence["user"] + member_id = user["id"] + member = member_dict.get(member_id) + if member is not None: + member._presence_update(presence, user) + + complete = data.get("chunk_index", 0) + 1 == data.get("chunk_count") + self.process_chunk_requests(guild_id, data.get("nonce"), members, complete) + + async def _get_reaction_user(self, channel: MessageableChannel, user_id: int) -> User | Member | None: + if isinstance(channel, TextChannel): + return await channel.guild.get_member(user_id) + return await self.get_user(user_id) + + async def get_reaction_emoji(self, data) -> GuildEmoji | AppEmoji | PartialEmoji: + emoji_id = get_as_snowflake(data, "id") + + if not emoji_id: + return data["name"] + + try: + return await self.cache.get_emoji(emoji_id) + except KeyError: + return PartialEmoji.with_state( + self, + animated=data.get("animated", False), + id=emoji_id, + name=data["name"], + ) + + async def _upgrade_partial_emoji(self, emoji: PartialEmoji) -> GuildEmoji | AppEmoji | PartialEmoji | str: + emoji_id = emoji.id + if not emoji_id: + return emoji.name + return await self.cache.get_emoji(emoji_id) or emoji + + async def get_channel(self, id: int | None) -> Channel | Thread | None: + if id is None: + return None + + pm = await self._get_private_channel(id) + if pm is not None: + return pm + + for guild in await self.cache.get_all_guilds(): + channel = guild._resolve_channel(id) + if channel is not None: + return channel + + def create_message( + self, + *, + channel: MessageableChannel, + data: MessagePayload, + ) -> Message: + return Message(state=self, channel=channel, data=data) + + +class AutoShardedConnectionState(ConnectionState): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.shard_ids: list[int] | range = [] + self.shards_launched: asyncio.Event = asyncio.Event() + + def _update_message_references(self) -> None: + # self._messages won't be None when this is called + for msg in self._messages: # type: ignore + if not msg.guild: + continue + + new_guild = self._get_guild(msg.guild.id) + if new_guild is not None and new_guild is not msg.guild: + channel_id = msg.channel.id + channel = new_guild._resolve_channel(channel_id) or Object(id=channel_id) + # channel will either be a TextChannel, Thread or Object + msg._rebind_cached_references(new_guild, channel) # type: ignore + + async def chunker( + self, + guild_id: int, + query: str = "", + limit: int = 0, + presences: bool = False, + *, + shard_id: int | None = None, + nonce: str | None = None, + ) -> None: + ws = self._get_websocket(guild_id, shard_id=shard_id) + await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) + + async def _delay_ready(self) -> None: + await self.shards_launched.wait() + processed = [] + max_concurrency = len(self.shard_ids) * 2 + current_bucket = [] + while True: + # this snippet of code is basically waiting N seconds + # until the last GUILD_CREATE was sent + try: + guild = await asyncio.wait_for(self._ready_state.get(), timeout=self.guild_ready_timeout) + except asyncio.TimeoutError: + break + else: + if self._guild_needs_chunking(guild): + _log.debug( + ("Guild ID %d requires chunking, will be done in the background."), + guild.id, + ) + if len(current_bucket) >= max_concurrency: + try: + await sane_wait_for(current_bucket, timeout=max_concurrency * 70.0) + except asyncio.TimeoutError: + fmt = "Shard ID %s failed to wait for chunks from a sub-bucket with length %d" + _log.warning(fmt, guild.shard_id, len(current_bucket)) + finally: + current_bucket = [] + + # Chunk the guild in the background while we wait for GUILD_CREATE streaming + future = asyncio.ensure_future(self.chunk_guild(guild)) + current_bucket.append(future) + else: + await self._add_default_sounds() + future = self.loop.create_future() + future.set_result([]) + + processed.append((guild, future)) + + guilds = sorted(processed, key=lambda g: g[0].shard_id) + for shard_id, info in itertools.groupby(guilds, key=lambda g: g[0].shard_id): + children, futures = zip(*info, strict=True) + # 110 reqs/minute w/ 1 req/guild plus some buffer + timeout = 61 * (len(children) / 110) + try: + await sane_wait_for(futures, timeout=timeout) + except asyncio.TimeoutError: + _log.warning( + ("Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds"), + shard_id, + timeout, + len(guilds), + ) + for guild in children: + if guild.unavailable is False: + self.dispatch("guild_available", guild) + else: + self.dispatch("guild_join", guild) + + self.dispatch("shard_ready", shard_id) + + if self.cache_app_emojis and self.application_id: + data = await self.http.get_all_application_emojis(self.application_id) + for e in data.get("items", []): + await self.maybe_store_app_emoji(self.application_id, e) + + # remove the state + try: + del self._ready_state + except AttributeError: + pass # already been deleted somehow + + # clear the current task + self._ready_task = None + + # dispatch the event + self.call_handlers("ready") + self.dispatch("ready") + + def parse_ready(self, data) -> None: + if not hasattr(self, "_ready_state"): + self._ready_state = asyncio.Queue() + + self.user = user = ClientUser(state=self, data=data["user"]) + # self._users is a list of Users, we're setting a ClientUser + self._users[user.id] = user # type: ignore + + if self.application_id is None: + try: + application = data["application"] + except KeyError: + pass + else: + self.application_id = get_as_snowflake(application, "id") + self.application_flags = ApplicationFlags._from_value(application["flags"]) + + for guild_data in data["guilds"]: + self._add_guild_from_data(guild_data) + + if self._messages: + self._update_message_references() + + self.dispatch("connect") + self.dispatch("shard_connect", data["__shard_id__"]) + + if self._ready_task is None: + self._ready_task = asyncio.create_task(self._delay_ready()) + + def parse_resumed(self, data) -> None: + self.dispatch("resumed") + self.dispatch("shard_resumed", data["__shard_id__"]) diff --git a/discord/appinfo.py b/discord/appinfo.py index bccc9ad883..e6e328d6fb 100644 --- a/discord/appinfo.py +++ b/discord/appinfo.py @@ -33,8 +33,8 @@ from .utils.private import get_as_snowflake, warn_deprecated if TYPE_CHECKING: + from .app.state import ConnectionState from .guild import Guild - from .state import ConnectionState from .types.appinfo import AppInfo as AppInfoPayload from .types.appinfo import AppInstallParams as AppInstallParamsPayload from .types.appinfo import PartialAppInfo as PartialAppInfoPayload @@ -183,7 +183,7 @@ class AppInfo: ) def __init__(self, state: ConnectionState, data: AppInfoPayload): - from .team import Team # noqa: PLC0415 + from .team import Team self._state: ConnectionState = state self.id: int = int(data["id"]) @@ -243,14 +243,13 @@ def cover_image(self) -> Asset | None: return None return Asset._from_cover_image(self._state, self.id, self._cover_image) - @property - def guild(self) -> Guild | None: + async def get_guild(self) -> Guild | None: """If this application is a game sold on Discord, this field will be the guild to which it has been linked. .. versionadded:: 1.3 """ - return self._state._get_guild(self.guild_id) + return await self._state._get_guild(self.guild_id) @property def summary(self) -> str | None: diff --git a/discord/asset.py b/discord/asset.py index afdd47f3aa..684dcd1dd6 100644 --- a/discord/asset.py +++ b/discord/asset.py @@ -34,12 +34,15 @@ from . import utils from .errors import DiscordException, InvalidArgument +if TYPE_CHECKING: + from .app.state import ConnectionState + __all__ = ("Asset",) if TYPE_CHECKING: ValidStaticFormatTypes = Literal["webp", "jpeg", "jpg", "png"] ValidAssetFormatTypes = Literal["webp", "jpeg", "jpg", "png", "gif"] - from .state import ConnectionState + from .app.state import ConnectionState VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"}) @@ -172,7 +175,7 @@ def __init__(self, state, *, url: str, key: str, animated: bool = False): self._key = key @classmethod - def _from_default_avatar(cls, state, index: int) -> Asset: + def _from_default_avatar(cls, state: ConnectionState, index: int) -> Asset: return cls( state, url=f"{cls.BASE}/embed/avatars/{index}.png", @@ -181,7 +184,7 @@ def _from_default_avatar(cls, state, index: int) -> Asset: ) @classmethod - def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset: + def _from_avatar(cls, state: ConnectionState, user_id: int, avatar: str) -> Asset: animated = avatar.startswith("a_") format = "gif" if animated else "png" return cls( @@ -192,7 +195,7 @@ def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset: ) @classmethod - def _from_avatar_decoration(cls, state, user_id: int, avatar_decoration: str) -> Asset: + def _from_avatar_decoration(cls, state: ConnectionState, user_id: int, avatar_decoration: str) -> Asset: animated = avatar_decoration.startswith("a_") endpoint = ( "avatar-decoration-presets" @@ -232,7 +235,7 @@ def _from_user_primary_guild_tag(cls, state: ConnectionState, identity_guild_id: ) @classmethod - def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset: + def _from_guild_avatar(cls, state: ConnectionState, guild_id: int, member_id: int, avatar: str) -> Asset: animated = avatar.startswith("a_") format = "gif" if animated else "png" return cls( @@ -243,7 +246,7 @@ def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) - ) @classmethod - def _from_guild_banner(cls, state, guild_id: int, member_id: int, banner: str) -> Asset: + def _from_guild_banner(cls, state: ConnectionState, guild_id: int, member_id: int, banner: str) -> Asset: animated = banner.startswith("a_") format = "gif" if animated else "png" return cls( @@ -254,7 +257,7 @@ def _from_guild_banner(cls, state, guild_id: int, member_id: int, banner: str) - ) @classmethod - def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset: + def _from_icon(cls, state: ConnectionState, object_id: int, icon_hash: str, path: str) -> Asset: return cls( state, url=f"{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024", @@ -263,7 +266,7 @@ def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset: ) @classmethod - def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset: + def _from_cover_image(cls, state: ConnectionState, object_id: int, cover_image_hash: str) -> Asset: return cls( state, url=f"{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024", @@ -282,7 +285,7 @@ def _from_collectible(cls, state: ConnectionState, asset: str, animated: bool = ) @classmethod - def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset: + def _from_guild_image(cls, state: ConnectionState, guild_id: int, image: str, path: str) -> Asset: animated = False format = "png" if path == "banners": @@ -297,7 +300,7 @@ def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset ) @classmethod - def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset: + def _from_guild_icon(cls, state: ConnectionState, guild_id: int, icon_hash: str) -> Asset: animated = icon_hash.startswith("a_") format = "gif" if animated else "png" return cls( @@ -308,7 +311,7 @@ def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset: ) @classmethod - def _from_sticker_banner(cls, state, banner: int) -> Asset: + def _from_sticker_banner(cls, state: ConnectionState, banner: int) -> Asset: return cls( state, url=f"{cls.BASE}/app-assets/710982414301790216/store/{banner}.png", @@ -317,7 +320,7 @@ def _from_sticker_banner(cls, state, banner: int) -> Asset: ) @classmethod - def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset: + def _from_user_banner(cls, state: ConnectionState, user_id: int, banner_hash: str) -> Asset: animated = banner_hash.startswith("a_") format = "gif" if animated else "png" return cls( @@ -328,7 +331,7 @@ def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset: ) @classmethod - def _from_scheduled_event_image(cls, state, event_id: int, cover_hash: str) -> Asset: + def _from_scheduled_event_image(cls, state: ConnectionState, event_id: int, cover_hash: str) -> Asset: return cls( state, url=f"{cls.BASE}/guild-events/{event_id}/{cover_hash}.png", @@ -337,7 +340,7 @@ def _from_scheduled_event_image(cls, state, event_id: int, cover_hash: str) -> A ) @classmethod - def _from_soundboard_sound(cls, state, sound_id: int) -> Asset: + def _from_soundboard_sound(cls, state: ConnectionState, sound_id: int) -> Asset: return cls( state, url=f"{cls.BASE}/soundboard-sounds/{sound_id}", diff --git a/discord/audit_logs.py b/discord/audit_logs.py index 097685565f..a59efed1e9 100644 --- a/discord/audit_logs.py +++ b/discord/audit_logs.py @@ -27,6 +27,7 @@ import datetime from functools import cached_property +from inspect import isawaitable from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generator, TypeVar from . import enums, utils @@ -47,16 +48,16 @@ if TYPE_CHECKING: - from . import abc + from .app.state import ConnectionState + from .channel.base import GuildChannel + from .channel.thread import Thread from .emoji import GuildEmoji from .guild import Guild from .member import Member from .role import Role from .scheduled_events import ScheduledEvent from .stage_instance import StageInstance - from .state import ConnectionState from .sticker import GuildSticker - from .threads import Thread from .types.audit_log import AuditLogChange as AuditLogChangePayload from .types.audit_log import AuditLogEntry as AuditLogEntryPayload from .types.automod import AutoModAction as AutoModActionPayload @@ -79,13 +80,13 @@ def _transform_snowflake(entry: AuditLogEntry, data: Snowflake) -> int: return int(data) -def _transform_channel(entry: AuditLogEntry, data: Snowflake | None) -> abc.GuildChannel | Object | None: +def _transform_channel(entry: AuditLogEntry, data: Snowflake | None) -> GuildChannel | Object | None: if data is None: return None return entry.guild.get_channel(int(data)) or Object(id=data) -def _transform_channels(entry: AuditLogEntry, data: list[Snowflake] | None) -> list[abc.GuildChannel | Object] | None: +def _transform_channels(entry: AuditLogEntry, data: list[Snowflake] | None) -> list[GuildChannel | Object] | None: if data is None: return None return [_transform_channel(entry, channel) for channel in data] @@ -103,10 +104,10 @@ def _transform_member_id(entry: AuditLogEntry, data: Snowflake | None) -> Member return entry._get_member(int(data)) -def _transform_guild_id(entry: AuditLogEntry, data: Snowflake | None) -> Guild | None: +async def _transform_guild_id(entry: AuditLogEntry, data: Snowflake | None) -> Guild | None: if data is None: return None - return entry._state._get_guild(data) + return await entry._state._get_guild(data) def _transform_overwrites( @@ -277,7 +278,14 @@ class AuditLogChanges: "communication_disabled_until": (None, _transform_communication_disabled_until), } - def __init__( + @staticmethod + async def _maybe_await(func: Any) -> Any: + if isawaitable(func): + return await func + else: + return func + + async def _from_data( self, entry: AuditLogEntry, data: list[AuditLogChangePayload], @@ -340,10 +348,10 @@ def __init__( before = None else: if transformer: - before = transformer(entry, before) + before = await self._maybe_await(transformer(entry, before)) if attr == "location" and hasattr(self.before, "location_type"): - from .scheduled_events import ScheduledEventLocation # noqa: PLC0415 + from .scheduled_events import ScheduledEventLocation if self.before.location_type is enums.ScheduledEventLocationType.external: before = ScheduledEventLocation(state=state, value=before) @@ -358,10 +366,10 @@ def __init__( after = None else: if transformer: - after = transformer(entry, after) + after = await self._maybe_await(transformer(entry, after)) if attr == "location" and hasattr(self.after, "location_type"): - from .scheduled_events import ScheduledEventLocation # noqa: PLC0415 + from .scheduled_events import ScheduledEventLocation if self.after.location_type is enums.ScheduledEventLocationType.external: after = ScheduledEventLocation(state=state, value=after) @@ -430,7 +438,7 @@ class _AuditLogProxyMemberPrune: class _AuditLogProxyMemberMoveOrMessageDelete: - channel: abc.GuildChannel + channel: GuildChannel count: int @@ -439,12 +447,12 @@ class _AuditLogProxyMemberDisconnect: class _AuditLogProxyPinAction: - channel: abc.GuildChannel + channel: GuildChannel message_id: int class _AuditLogProxyStageInstanceAction: - channel: abc.GuildChannel + channel: GuildChannel class AuditLogEntry(Hashable): @@ -570,8 +578,8 @@ def _from_data(self, data: AuditLogEntryPayload) -> None: self.user = self._get_member(get_as_snowflake(data, "user_id")) # type: ignore self._target_id = get_as_snowflake(data, "target_id") - def _get_member(self, user_id: int) -> Member | User | None: - return self.guild.get_member(user_id) or self._users.get(user_id) + async def _get_member(self, user_id: int) -> Member | User | None: + return await self.guild.get_member(user_id) or self._users.get(user_id) def __repr__(self) -> str: return f"" @@ -581,12 +589,11 @@ def created_at(self) -> datetime.datetime: """Returns the entry's creation time in UTC.""" return utils.snowflake_time(self.id) - @cached_property - def target( + async def get_target( self, ) -> ( Guild - | abc.GuildChannel + | GuildChannel | Member | User | Role @@ -603,17 +610,19 @@ def target( except AttributeError: return Object(id=self._target_id) else: - return converter(self._target_id) + r = converter(self._target_id) + if isawaitable(r): + r = await r + return r @property def category(self) -> enums.AuditLogActionCategory: """The category of the action, if applicable.""" return self.action.category - @cached_property - def changes(self) -> AuditLogChanges: + async def changes(self) -> AuditLogChanges: """The list of changes this entry has.""" - obj = AuditLogChanges(self, self._changes, state=self._state) + obj = AuditLogChanges().from_data(self, self._changes, state=self._state) del self._changes return obj @@ -630,11 +639,11 @@ def after(self) -> AuditLogDiff: def _convert_target_guild(self, target_id: int) -> Guild: return self.guild - def _convert_target_channel(self, target_id: int) -> abc.GuildChannel | Object: + def _convert_target_channel(self, target_id: int) -> GuildChannel | Object: return self.guild.get_channel(target_id) or Object(id=target_id) - def _convert_target_user(self, target_id: int) -> Member | User | None: - return self._get_member(target_id) + async def _convert_target_user(self, target_id: int) -> Member | User | None: + return await self._get_member(target_id) def _convert_target_role(self, target_id: int) -> Role | Object: return self.guild.get_role(target_id) or Object(id=target_id) @@ -659,17 +668,17 @@ def _convert_target_invite(self, target_id: int) -> Invite: pass return obj - def _convert_target_emoji(self, target_id: int) -> GuildEmoji | Object: - return self._state.get_emoji(target_id) or Object(id=target_id) + async def _convert_target_emoji(self, target_id: int) -> GuildEmoji | Object: + return (await self._state.get_emoji(target_id)) or Object(id=target_id) - def _convert_target_message(self, target_id: int) -> Member | User | None: - return self._get_member(target_id) + async def _convert_target_message(self, target_id: int) -> Member | User | None: + return await self._get_member(target_id) def _convert_target_stage_instance(self, target_id: int) -> StageInstance | Object: return self.guild.get_stage_instance(target_id) or Object(id=target_id) - def _convert_target_sticker(self, target_id: int) -> GuildSticker | Object: - return self._state.get_sticker(target_id) or Object(id=target_id) + async def _convert_target_sticker(self, target_id: int) -> GuildSticker | Object: + return (await self._state.get_sticker(target_id)) or Object(id=target_id) def _convert_target_thread(self, target_id: int) -> Thread | Object: return self.guild.get_thread(target_id) or Object(id=target_id) diff --git a/discord/automod.py b/discord/automod.py index 9f990d8bbc..2bf7869351 100644 --- a/discord/automod.py +++ b/discord/automod.py @@ -48,11 +48,11 @@ if TYPE_CHECKING: from .abc import Snowflake + from .app.state import ConnectionState from .channel import ForumChannel, TextChannel, VoiceChannel from .guild import Guild from .member import Member from .role import Role - from .state import ConnectionState from .types.automod import AutoModAction as AutoModActionPayload from .types.automod import AutoModActionMetadata as AutoModActionMetadataPayload from .types.automod import AutoModRule as AutoModRulePayload @@ -406,17 +406,15 @@ def __repr__(self) -> str: def __str__(self) -> str: return self.name - @property - def guild(self) -> Guild | None: + async def get_guild(self) -> Guild | None: """The guild this rule belongs to.""" - return self._state._get_guild(self.guild_id) + return await self._state._get_guild(self.guild_id) - @property - def creator(self) -> Member | None: + async def get_creator(self) -> Member | None: """The member who created this rule.""" if self.guild is None: return None - return self.guild.get_member(self.creator_id) + return await self.guild.get_member(self.creator_id) @cached_property def exempt_roles(self) -> list[Role | Object]: diff --git a/discord/bot.py b/discord/bot.py index 41b126674a..bc9c683a43 100644 --- a/discord/bot.py +++ b/discord/bot.py @@ -46,7 +46,6 @@ ) from .client import Client -from .cog import CogMixin from .commands import ( ApplicationCommand, ApplicationContext, @@ -59,6 +58,7 @@ ) from .enums import IntegrationType, InteractionContextType, InteractionType from .errors import CheckFailure, DiscordException +from .events import InteractionCreate from .interactions import Interaction from .shard import AutoShardedClient from .types import interactions @@ -1082,7 +1082,7 @@ async def invoke_application_command(self, ctx: ApplicationContext) -> None: ctx: :class:`.ApplicationCommand` The invocation context to invoke. """ - self._bot.dispatch("application_command", ctx) + # self._bot.dispatch("application_command", ctx) # TODO: Remove when moving away from ApplicationContext try: if await self._bot.can_run(ctx, call_once=True): await ctx.command.invoke(ctx) @@ -1091,14 +1091,15 @@ async def invoke_application_command(self, ctx: ApplicationContext) -> None: except DiscordException as exc: await ctx.command.dispatch_error(ctx, exc) else: - self._bot.dispatch("application_command_completion", ctx) + # self._bot.dispatch("application_command_completion", ctx) # TODO: Remove when moving away from ApplicationContext + pass @property @abstractmethod def _bot(self) -> Bot | AutoShardedBot: ... -class BotBase(ApplicationCommandMixin, CogMixin, ABC): +class BotBase(ApplicationCommandMixin, ABC): _supports_prefixed_commands = False def __init__(self, description=None, *args, **options): @@ -1152,11 +1153,13 @@ def __init__(self, description=None, *args, **options): self._before_invoke = None self._after_invoke = None + self._bot.add_listener(self.on_interaction, event=InteractionCreate) + async def on_connect(self): if self.auto_sync_commands: await self.sync_commands() - async def on_interaction(self, interaction): + async def on_interaction(self, interaction: InteractionCreate): await self.process_application_commands(interaction) async def on_application_command_error(self, context: ApplicationContext, exception: DiscordException) -> None: diff --git a/discord/channel/__init__.py b/discord/channel/__init__.py new file mode 100644 index 0000000000..9ef560c65c --- /dev/null +++ b/discord/channel/__init__.py @@ -0,0 +1,119 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from ..enums import ChannelType, try_enum +from .base import ( + BaseChannel, + GuildChannel, + GuildMessageableChannel, + GuildPostableChannel, + GuildThreadableChannel, + GuildTopLevelChannel, +) +from .category import CategoryChannel +from .dm import DMChannel +from .dm import GroupDMChannel as GroupChannel +from .forum import ForumChannel +from .media import MediaChannel +from .news import NewsChannel +from .partial import PartialMessageable +from .stage import StageChannel +from .text import TextChannel +from .thread import Thread +from .voice import VoiceChannel + +__all__ = ( + "BaseChannel", + "CategoryChannel", + "DMChannel", + "ForumChannel", + "GroupChannel", + "GuildChannel", + "GuildMessageableChannel", + "GuildPostableChannel", + "GuildThreadableChannel", + "GuildTopLevelChannel", + "MediaChannel", + "NewsChannel", + "PartialMessageable", + "StageChannel", + "TextChannel", + "Thread", + "VoiceChannel", +) + + +def _guild_channel_factory(channel_type: int): + value = try_enum(ChannelType, channel_type) + if value is ChannelType.text: + return TextChannel, value + elif value is ChannelType.voice: + return VoiceChannel, value + elif value is ChannelType.category: + return CategoryChannel, value + elif value is ChannelType.news: + return NewsChannel, value + elif value is ChannelType.stage_voice: + return StageChannel, value + elif value is ChannelType.directory: + return None, value # todo: Add DirectoryChannel when applicable + elif value is ChannelType.forum: + return ForumChannel, value + elif value is ChannelType.media: + return MediaChannel, value + else: + return None, value + + +def _channel_factory(channel_type: int): + cls, value = _guild_channel_factory(channel_type) + if value is ChannelType.private: + return DMChannel, value + elif value is ChannelType.group: + return GroupChannel, value + else: + return cls, value + + +def _threaded_channel_factory(channel_type: int): + cls, value = _channel_factory(channel_type) + if value in ( + ChannelType.private_thread, + ChannelType.public_thread, + ChannelType.news_thread, + ): + return Thread, value + return cls, value + + +def _threaded_guild_channel_factory(channel_type: int): + cls, value = _guild_channel_factory(channel_type) + if value in ( + ChannelType.private_thread, + ChannelType.public_thread, + ChannelType.news_thread, + ): + return Thread, value + return cls, value diff --git a/discord/channel/base.py b/discord/channel/base.py new file mode 100644 index 0000000000..f446833eee --- /dev/null +++ b/discord/channel/base.py @@ -0,0 +1,2025 @@ +""" +The MIT License (MIT) + +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +import copy +import datetime +import logging +from abc import ABC, abstractmethod +from collections.abc import Collection, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Generic, cast, overload + +from typing_extensions import Self, TypeVar, override + +from ..abc import Messageable, Snowflake, SnowflakeTime, User, _Overwrites, _purge_messages_helper +from ..emoji import GuildEmoji, PartialEmoji +from ..enums import ChannelType, InviteTarget, SortOrder, try_enum +from ..errors import ClientException +from ..flags import ChannelFlags, MessageFlags +from ..iterators import ArchivedThreadIterator +from ..mixins import Hashable +from ..utils import MISSING, Undefined, find, snowflake_time +from ..utils.private import SnowflakeList, bytes_to_base64_data, copy_doc, get_as_snowflake + +if TYPE_CHECKING: + from ..embeds import Embed + from ..errors import InvalidArgument + from ..file import File + from ..guild import Guild + from ..invite import Invite + from ..member import Member + from ..mentions import AllowedMentions + from ..message import EmojiInputType, Message, PartialMessage + from ..object import Object + from ..partial_emoji import _EmojiTag + from ..permissions import PermissionOverwrite, Permissions + from ..role import Role + from ..scheduled_events import ScheduledEvent + from ..sticker import GuildSticker, StickerItem + from ..types.channel import CategoryChannel as CategoryChannelPayload + from ..types.channel import Channel as ChannelPayload + from ..types.channel import ForumChannel as ForumChannelPayload + from ..types.channel import ForumTag as ForumTagPayload + from ..types.channel import GuildChannel as GuildChannelPayload + from ..types.channel import MediaChannel as MediaChannelPayload + from ..types.channel import NewsChannel as NewsChannelPayload + from ..types.channel import StageChannel as StageChannelPayload + from ..types.channel import TextChannel as TextChannelPayload + from ..types.channel import VoiceChannel as VoiceChannelPayload + from ..types.guild import ChannelPositionUpdate as ChannelPositionUpdatePayload + from ..ui.view import View + from ..webhook import Webhook + from .category import CategoryChannel + from .channel import ForumTag + from .text import TextChannel + from .thread import Thread + +_log = logging.getLogger(__name__) + +if TYPE_CHECKING: + from ..app.state import ConnectionState + + +P = TypeVar("P", bound="ChannelPayload") + + +class BaseChannel(ABC, Generic[P]): + __slots__: tuple[str, ...] = ("id", "_type", "_state", "_data") # pyright: ignore [reportIncompatibleUnannotatedOverride] + + def __init__(self, id: int, state: ConnectionState): + self.id: int = id + self._state: ConnectionState = state + self._data: P = {} # type: ignore + + async def _update(self, data: P) -> None: + self._type: int = data["type"] + self._data = self._data | data # type: ignore + + @classmethod + async def _from_data(cls, *, data: P, state: ConnectionState, **kwargs) -> Self: + if kwargs: + _log.warning("Unexpected keyword arguments passed to %s._from_data: %r", cls.__name__, kwargs) + self = cls(int(data["id"]), state) + await self._update(data) + return self + + @property + def type(self) -> ChannelType: + """The channel's Discord channel type.""" + return try_enum(ChannelType, self._type) + + async def _get_channel(self) -> Self: + return self + + @property + def created_at(self) -> datetime.datetime: + """The channel's creation time in UTC.""" + return snowflake_time(self.id) + + @abstractmethod + @override + def __repr__(self) -> str: ... + + @property + @abstractmethod + def jump_url(self) -> str: ... + + +P_guild = TypeVar( + "P_guild", + bound="TextChannelPayload | NewsChannelPayload | VoiceChannelPayload | CategoryChannelPayload | StageChannelPayload | ForumChannelPayload", + default="TextChannelPayload | NewsChannelPayload | VoiceChannelPayload | CategoryChannelPayload | StageChannelPayload | ForumChannelPayload", +) + + +class GuildChannel(BaseChannel[P_guild], ABC, Generic[P_guild]): + """Represents a Discord guild channel.""" + + """An ABC that details the common operations on a Discord guild channel. + + The following implement this ABC: + + - :class:`~discord.TextChannel` + - :class:`~discord.VoiceChannel` + - :class:`~discord.CategoryChannel` + - :class:`~discord.StageChannel` + - :class:`~discord.ForumChannel` + + This ABC must also implement :class:`~discord.abc.Snowflake`. + + Attributes + ---------- + name: :class:`str` + The channel name. + guild: :class:`~discord.Guild` + The guild the channel belongs to. + position: :class:`int` + The position in the channel list. This is a number that starts at 0. + e.g. the top channel is position 0. + """ + + __slots__: tuple[str, ...] = ("name", "guild", "category_id", "flags", "_overwrites") + + @override + def __init__(self, id: int, *, guild: Guild, state: ConnectionState) -> None: + self.guild: Guild = guild + super().__init__(id, state) + + @classmethod + @override + async def _from_data(cls, *, data: P_guild, state: ConnectionState, guild: Guild, **kwargs) -> Self: + if kwargs: + _log.warning("Unexpected keyword arguments passed to %s._from_data: %r", cls.__name__, kwargs) + self = cls(int(data["id"]), guild=guild, state=state) + await self._update(data) + return self + + @override + async def _update(self, data: P_guild) -> None: + await super()._update(data) + self.name: str = data["name"] + self.category_id: int | None = get_as_snowflake(data, "parent_id") or getattr(self, "category_id", None) + if flags_value := data.get("flags", 0): + self.flags: ChannelFlags = ChannelFlags._from_value(flags_value) + self._fill_overwrites(data) + + @override + def __str__(self) -> str: + return self.name + + async def _edit(self, options: dict[str, Any], reason: str | None) -> ChannelPayload | None: + try: + parent = options.pop("category") + except KeyError: + parent_id = MISSING + else: + parent_id = parent and parent.id + + try: + options["rate_limit_per_user"] = options.pop("slowmode_delay") + except KeyError: + pass + + try: + options["default_thread_rate_limit_per_user"] = options.pop("default_thread_slowmode_delay") + except KeyError: + pass + + try: + options["flags"] = options.pop("flags").value + except KeyError: + pass + + try: + options["available_tags"] = [tag.to_dict() for tag in options.pop("available_tags")] + except KeyError: + pass + + try: + rtc_region = options.pop("rtc_region") + except KeyError: + pass + else: + options["rtc_region"] = None if rtc_region is None else str(rtc_region) + + try: + video_quality_mode = options.pop("video_quality_mode") + except KeyError: + pass + else: + options["video_quality_mode"] = int(video_quality_mode) + + lock_permissions = options.pop("sync_permissions", False) + + try: + position = options.pop("position") + except KeyError: + if parent_id is not MISSING: + if lock_permissions: + category = self.guild.get_channel(parent_id) + if category: + options["permission_overwrites"] = [c._asdict() for c in category._overwrites] + options["parent_id"] = parent_id + elif lock_permissions and self.category_id is not None: + # if we're syncing permissions on a pre-existing channel category without changing it + # we need to update the permissions to point to the pre-existing category + category = self.guild.get_channel(self.category_id) + if category: + options["permission_overwrites"] = [c._asdict() for c in category._overwrites] + else: + await self._move( + position, + parent_id=parent_id, + lock_permissions=lock_permissions, + reason=reason, + ) + + overwrites = options.get("overwrites") + if overwrites is not None: + perms = [] + for target, perm in overwrites.items(): + if not isinstance(perm, PermissionOverwrite): + raise InvalidArgument(f"Expected PermissionOverwrite received {perm.__class__.__name__}") + + allow, deny = perm.pair() + payload = { + "allow": allow.value, + "deny": deny.value, + "id": target.id, + "type": (_Overwrites.ROLE if isinstance(target, Role) else _Overwrites.MEMBER), + } + + perms.append(payload) + options["permission_overwrites"] = perms + + try: + ch_type = options["type"] + except KeyError: + pass + else: + if not isinstance(ch_type, ChannelType): + raise InvalidArgument("type field must be of type ChannelType") + options["type"] = ch_type.value + + try: + default_reaction_emoji = options["default_reaction_emoji"] + except KeyError: + pass + else: + if isinstance(default_reaction_emoji, _EmojiTag): # GuildEmoji, PartialEmoji + default_reaction_emoji = default_reaction_emoji._to_partial() + elif isinstance(default_reaction_emoji, int): + default_reaction_emoji = PartialEmoji(name=None, id=default_reaction_emoji) + elif isinstance(default_reaction_emoji, str): + default_reaction_emoji = PartialEmoji.from_str(default_reaction_emoji) + elif default_reaction_emoji is None: + pass + else: + raise InvalidArgument("default_reaction_emoji must be of type: GuildEmoji | int | str | None") + + options["default_reaction_emoji"] = ( + default_reaction_emoji._to_forum_reaction_payload() if default_reaction_emoji else None + ) + + if options: + return await self._state.http.edit_channel(self.id, reason=reason, **options) + + def _fill_overwrites(self, data: GuildChannelPayload) -> None: + self._overwrites: list[_Overwrites] = [] + everyone_index = 0 + everyone_id = self.guild.id + + for index, overridden in enumerate(data.get("permission_overwrites", [])): + overwrite = _Overwrites(overridden) + self._overwrites.append(overwrite) + + if overwrite.type == _Overwrites.MEMBER: + continue + + if overwrite.id == everyone_id: + # the @everyone role is not guaranteed to be the first one + # in the list of permission overwrites, however the permission + # resolution code kind of requires that it is the first one in + # the list since it is special. So we need the index so we can + # swap it to be the first one. + everyone_index = index + + # do the swap + tmp = self._overwrites + if tmp: + tmp[everyone_index], tmp[0] = tmp[0], tmp[everyone_index] + + @property + def changed_roles(self) -> list[Role]: + """Returns a list of roles that have been overridden from + their default values in the :attr:`~discord.Guild.roles` attribute. + """ + ret = [] + g = self.guild + for overwrite in filter(lambda o: o.is_role(), self._overwrites): + role = g.get_role(overwrite.id) + if role is None: + continue + + role = copy.copy(role) + role.permissions.handle_overwrite(overwrite.allow, overwrite.deny) + ret.append(role) + return ret + + @property + def mention(self) -> str: + """The string that allows you to mention the channel.""" + return f"<#{self.id}>" + + @property + @override + def jump_url(self) -> str: + """Returns a URL that allows the client to jump to the channel. + + .. versionadded:: 2.0 + """ + return f"https://discord.com/channels/{self.guild.id}/{self.id}" + + def overwrites_for(self, obj: Role | User) -> PermissionOverwrite: + """Returns the channel-specific overwrites for a member or a role. + + Parameters + ---------- + obj: Union[:class:`~discord.Role`, :class:`~discord.abc.User`] + The role or user denoting + whose overwrite to get. + + Returns + ------- + :class:`~discord.PermissionOverwrite` + The permission overwrites for this object. + """ + + if isinstance(obj, User): + predicate: Callable[[Any], bool] = lambda p: p.is_member() + elif isinstance(obj, Role): + predicate = lambda p: p.is_role() + else: + predicate = lambda p: True + + for overwrite in filter(predicate, self._overwrites): + if overwrite.id == obj.id: + allow = Permissions(overwrite.allow) + deny = Permissions(overwrite.deny) + return PermissionOverwrite.from_pair(allow, deny) + + return PermissionOverwrite() + + async def get_overwrites(self) -> dict[Role | Member | Object, PermissionOverwrite]: + """Returns all of the channel's overwrites. + + This is returned as a dictionary where the key contains the target which + can be either a :class:`~discord.Role` or a :class:`~discord.Member` and the value is the + overwrite as a :class:`~discord.PermissionOverwrite`. + + Returns + ------- + Dict[Union[:class:`~discord.Role`, :class:`~discord.Member`, :class:`~discord.Object`], :class:`~discord.PermissionOverwrite`] + The channel's permission overwrites. + """ + ret: dict[Role | Member | Object, PermissionOverwrite] = {} + for ow in self._overwrites: + allow = Permissions(ow.allow) + deny = Permissions(ow.deny) + overwrite = PermissionOverwrite.from_pair(allow, deny) + target = None + + if ow.is_role(): + target = self.guild.get_role(ow.id) + elif ow.is_member(): + target = await self.guild.get_member(ow.id) + + if target is not None: + ret[target] = overwrite + else: + ret[Object(id=ow.id)] = overwrite + return ret + + @property + def category(self) -> CategoryChannel | None: + """The category this channel belongs to. + + If there is no category then this is ``None``. + """ + return cast("CategoryChannel | None", self.guild.get_channel(self.category_id)) if self.category_id else None + + @property + def members(self) -> Collection[Member]: + """Returns all members that can view this channel. + + This is calculated based on the channel's permission overwrites and + the members' roles. + + Returns + ------- + Collection[:class:`Member`] + All members who have permission to view this channel. + """ + return [m for m in self.guild.members if self.permissions_for(m).read_messages] + + async def permissions_are_synced(self) -> bool: + """Whether the permissions for this channel are synced with the + category it belongs to. + + If there is no category then this is ``False``. + + .. versionadded:: 3.0 + """ + if self.category_id is None: + return False + + category: CategoryChannel | None = cast("CategoryChannel | None", self.guild.get_channel(self.category_id)) + return bool(category and await category.get_overwrites() == await self.get_overwrites()) + + def permissions_for(self, obj: Member | Role, /) -> Permissions: + """Handles permission resolution for the :class:`~discord.Member` + or :class:`~discord.Role`. + + This function takes into consideration the following cases: + + - Guild owner + - Guild roles + - Channel overrides + - Member overrides + + If a :class:`~discord.Role` is passed, then it checks the permissions + someone with that role would have, which is essentially: + + - The default role permissions + - The permissions of the role used as a parameter + - The default role permission overwrites + - The permission overwrites of the role used as a parameter + + .. versionchanged:: 2.0 + The object passed in can now be a role object. + + Parameters + ---------- + obj: Union[:class:`~discord.Member`, :class:`~discord.Role`] + The object to resolve permissions for. This could be either + a member or a role. If it's a role then member overwrites + are not computed. + + Returns + ------- + :class:`~discord.Permissions` + The resolved permissions for the member or role. + """ + + # The current cases can be explained as: + # Guild owner get all permissions -- no questions asked. Otherwise... + # The @everyone role gets the first application. + # After that, the applied roles that the user has in the channel + # (or otherwise) are then OR'd together. + # After the role permissions are resolved, the member permissions + # have to take into effect. + # After all that is done, you have to do the following: + + # If manage permissions is True, then all permissions are set to True. + + # The operation first takes into consideration the denied + # and then the allowed. + + if self.guild.owner_id == obj.id: + return Permissions.all() + + default = self.guild.default_role + base = Permissions(default.permissions.value if default else 0) + + # Handle the role case first + if isinstance(obj, Role): + base.value |= obj._permissions + + if base.administrator: + return Permissions.all() + + # Apply @everyone allow/deny first since it's special + try: + maybe_everyone = self._overwrites[0] + if maybe_everyone.id == self.guild.id: + base.handle_overwrite(allow=maybe_everyone.allow, deny=maybe_everyone.deny) + except IndexError: + pass + + if obj.is_default(): + return base + + overwrite = find(lambda o: o.type == _Overwrites.ROLE and o.id == obj.id, self._overwrites) + if overwrite is not None: + base.handle_overwrite(overwrite.allow, overwrite.deny) + + return base + + roles = obj._roles + get_role = self.guild.get_role + + # Apply guild roles that the member has. + for role_id in roles: + role = get_role(role_id) + if role is not None: + base.value |= role._permissions + + # Guild-wide Administrator -> True for everything + # Bypass all channel-specific overrides + if base.administrator: + return Permissions.all() + + # Apply @everyone allow/deny first since it's special + try: + maybe_everyone = self._overwrites[0] + if maybe_everyone.id == self.guild.id: + base.handle_overwrite(allow=maybe_everyone.allow, deny=maybe_everyone.deny) + remaining_overwrites = self._overwrites[1:] + else: + remaining_overwrites = self._overwrites + except IndexError: + remaining_overwrites = self._overwrites + + denies = 0 + allows = 0 + + # Apply channel specific role permission overwrites + for overwrite in remaining_overwrites: + if overwrite.is_role() and roles.has(overwrite.id): + denies |= overwrite.deny + allows |= overwrite.allow + + base.handle_overwrite(allow=allows, deny=denies) + + # Apply member specific permission overwrites + for overwrite in remaining_overwrites: + if overwrite.is_member() and overwrite.id == obj.id: + base.handle_overwrite(allow=overwrite.allow, deny=overwrite.deny) + break + + # if you can't send a message in a channel then you can't have certain + # permissions as well + if not base.send_messages: + base.send_tts_messages = False + base.mention_everyone = False + base.embed_links = False + base.attach_files = False + + # if you can't read a channel then you have no permissions there + if not base.read_messages: + denied = Permissions.all_channel() + base.value &= ~denied.value + + return base + + async def delete(self, *, reason: str | None = None) -> None: + """|coro| + + Deletes the channel. + + You must have :attr:`~discord.Permissions.manage_channels` permission to use this. + + Parameters + ---------- + reason: Optional[:class:`str`] + The reason for deleting this channel. + Shows up on the audit log. + + Raises + ------ + ~discord.Forbidden + You do not have proper permissions to delete the channel. + ~discord.NotFound + The channel was not found or was already deleted. + ~discord.HTTPException + Deleting the channel failed. + """ + await self._state.http.delete_channel(self.id, reason=reason) + + @overload + async def set_permissions( + self, + target: Member | Role, + *, + overwrite: PermissionOverwrite | None = ..., + reason: str | None = ..., + ) -> None: ... + + @overload + async def set_permissions( + self, + target: Member | Role, + *, + overwrite: Undefined = MISSING, + reason: str | None = ..., + **permissions: bool, + ) -> None: ... + + async def set_permissions( + self, + target: Member | Role, + *, + overwrite: PermissionOverwrite | None | Undefined = MISSING, + reason: str | None = None, + **permissions: bool, + ) -> None: + r"""|coro| + + Sets the channel specific permission overwrites for a target in the + channel. + + The ``target`` parameter should either be a :class:`~discord.Member` or a + :class:`~discord.Role` that belongs to guild. + + The ``overwrite`` parameter, if given, must either be ``None`` or + :class:`~discord.PermissionOverwrite`. For convenience, you can pass in + keyword arguments denoting :class:`~discord.Permissions` attributes. If this is + done, then you cannot mix the keyword arguments with the ``overwrite`` + parameter. + + If the ``overwrite`` parameter is ``None``, then the permission + overwrites are deleted. + + You must have the :attr:`~discord.Permissions.manage_roles` permission to use this. + + .. note:: + + This method *replaces* the old overwrites with the ones given. + + Examples + ---------- + + Setting allow and deny: :: + + await message.channel.set_permissions(message.author, read_messages=True, send_messages=False) + + Deleting overwrites :: + + await channel.set_permissions(member, overwrite=None) + + Using :class:`~discord.PermissionOverwrite` :: + + overwrite = discord.PermissionOverwrite() + overwrite.send_messages = False + overwrite.read_messages = True + await channel.set_permissions(member, overwrite=overwrite) + + Parameters + ----------- + target: Union[:class:`~discord.Member`, :class:`~discord.Role`] + The member or role to overwrite permissions for. + overwrite: Optional[:class:`~discord.PermissionOverwrite`] + The permissions to allow and deny to the target, or ``None`` to + delete the overwrite. + \*\*permissions + A keyword argument list of permissions to set for ease of use. + Cannot be mixed with ``overwrite``. + reason: Optional[:class:`str`] + The reason for doing this action. Shows up on the audit log. + + Raises + ------- + ~discord.Forbidden + You do not have permissions to edit channel specific permissions. + ~discord.HTTPException + Editing channel specific permissions failed. + ~discord.NotFound + The role or member being edited is not part of the guild. + ~discord.InvalidArgument + The overwrite parameter invalid or the target type was not + :class:`~discord.Role` or :class:`~discord.Member`. + """ + + http = self._state.http + + if isinstance(target, User): + perm_type = _Overwrites.MEMBER + elif isinstance(target, Role): + perm_type = _Overwrites.ROLE + else: + raise InvalidArgument("target parameter must be either Member or Role") + + if overwrite is MISSING: + if len(permissions) == 0: + raise InvalidArgument("No overwrite provided.") + try: + overwrite = PermissionOverwrite(**permissions) + except (ValueError, TypeError) as e: + raise InvalidArgument("Invalid permissions given to keyword arguments.") from e + elif len(permissions) > 0: + raise InvalidArgument("Cannot mix overwrite and keyword arguments.") + + # TODO: wait for event + + if overwrite is None: + await http.delete_channel_permissions(self.id, target.id, reason=reason) + elif isinstance(overwrite, PermissionOverwrite): + (allow, deny) = overwrite.pair() + await http.edit_channel_permissions(self.id, target.id, allow.value, deny.value, perm_type, reason=reason) + else: + raise InvalidArgument("Invalid overwrite type provided.") + + async def _clone_impl( + self, + base_attrs: dict[str, Any], + *, + name: str | None = None, + reason: str | None = None, + ) -> Self: + base_attrs["permission_overwrites"] = [x._asdict() for x in self._overwrites] + base_attrs["parent_id"] = self.category_id + base_attrs["name"] = name or self.name + guild_id = self.guild.id + cls = self.__class__ + data: P_guild = cast( + "P_guild", await self._state.http.create_channel(guild_id, self.type.value, reason=reason, **base_attrs) + ) + clone = cls(id=int(data["id"]), guild=self.guild, state=self._state) + await clone._update(data) + + self.guild._channels[clone.id] = clone + return clone + + async def clone(self, *, name: str | None = None, reason: str | None = None) -> Self: + """|coro| + + Clones this channel. This creates a channel with the same properties + as this channel. + + You must have the :attr:`~discord.Permissions.manage_channels` permission to + do this. + + .. versionadded:: 1.1 + + Parameters + ---------- + name: Optional[:class:`str`] + The name of the new channel. If not provided, defaults to this + channel name. + reason: Optional[:class:`str`] + The reason for cloning this channel. Shows up on the audit log. + + Returns + ------- + :class:`.abc.GuildChannel` + The channel that was created. + + Raises + ------ + ~discord.Forbidden + You do not have the proper permissions to create this channel. + ~discord.HTTPException + Creating the channel failed. + """ + raise NotImplementedError + + async def create_invite( + self, + *, + reason: str | None = None, + max_age: int = 0, + max_uses: int = 0, + temporary: bool = False, + unique: bool = True, + target_event: ScheduledEvent | None = None, + target_type: InviteTarget | None = None, + target_user: User | None = None, + target_application_id: int | None = None, + ) -> Invite: + """|coro| + + Creates an instant invite from a text or voice channel. + + You must have the :attr:`~discord.Permissions.create_instant_invite` permission to + do this. + + Parameters + ---------- + max_age: :class:`int` + How long the invite should last in seconds. If it's 0 then the invite + doesn't expire. Defaults to ``0``. + max_uses: :class:`int` + How many uses the invite could be used for. If it's 0 then there + are unlimited uses. Defaults to ``0``. + temporary: :class:`bool` + Denotes that the invite grants temporary membership + (i.e. they get kicked after they disconnect). Defaults to ``False``. + unique: :class:`bool` + Indicates if a unique invite URL should be created. Defaults to True. + If this is set to ``False`` then it will return a previously created + invite. + reason: Optional[:class:`str`] + The reason for creating this invite. Shows up on the audit log. + target_type: Optional[:class:`.InviteTarget`] + The type of target for the voice channel invite, if any. + + .. versionadded:: 2.0 + + target_user: Optional[:class:`User`] + The user whose stream to display for this invite, required if `target_type` is `TargetType.stream`. + The user must be streaming in the channel. + + .. versionadded:: 2.0 + + target_application_id: Optional[:class:`int`] + The id of the embedded application for the invite, required if `target_type` is + `TargetType.embedded_application`. + + .. versionadded:: 2.0 + + target_event: Optional[:class:`.ScheduledEvent`] + The scheduled event object to link to the event. + Shortcut to :meth:`.Invite.set_scheduled_event` + + See :meth:`.Invite.set_scheduled_event` for more + info on event invite linking. + + .. versionadded:: 2.0 + + Returns + ------- + :class:`~discord.Invite` + The invite that was created. + + Raises + ------ + ~discord.HTTPException + Invite creation failed. + + ~discord.NotFound + The channel that was passed is a category or an invalid channel. + """ + if target_type is InviteTarget.unknown: + raise TypeError("target_type cannot be unknown") + + data = await self._state.http.create_invite( + self.id, + reason=reason, + max_age=max_age, + max_uses=max_uses, + temporary=temporary, + unique=unique, + target_type=target_type.value if target_type else None, + target_user_id=target_user.id if target_user else None, + target_application_id=target_application_id, + ) + invite = await Invite.from_incomplete(data=data, state=self._state) + if target_event: + invite.set_scheduled_event(target_event) + return invite + + async def invites(self) -> list[Invite]: + """|coro| + + Returns a list of all active instant invites from this channel. + + You must have :attr:`~discord.Permissions.manage_channels` to get this information. + + Returns + ------- + List[:class:`~discord.Invite`] + The list of invites that are currently active. + + Raises + ------ + ~discord.Forbidden + You do not have proper permissions to get the information. + ~discord.HTTPException + An error occurred while fetching the information. + """ + + data = await self._state.http.invites_from_channel(self.id) + guild = self.guild + return [Invite(state=self._state, data=invite, channel=self, guild=guild) for invite in data] + + +P_guild_top_level = TypeVar( + "P_guild_top_level", + bound="TextChannelPayload | NewsChannelPayload | VoiceChannelPayload | CategoryChannelPayload | StageChannelPayload | ForumChannelPayload", + default="TextChannelPayload | NewsChannelPayload | VoiceChannelPayload | CategoryChannelPayload | StageChannelPayload | ForumChannelPayload", +) + + +class GuildTopLevelChannel(GuildChannel[P_guild_top_level], ABC, Generic[P_guild_top_level]): + """An ABC for guild channels that can be positioned in the channel list. + + This includes categories and all channels that appear in the channel sidebar + (text, voice, news, stage, forum, media channels). Threads do not inherit from + this class as they are not positioned in the main channel list. + + .. versionadded:: 3.0 + + Attributes + ---------- + position: int + The position in the channel list. This is a number that starts at 0. + e.g. the top channel is position 0. + """ + + __slots__: tuple[str, ...] = ("position",) + + @override + async def _update(self, data: P_guild_top_level) -> None: + await super()._update(data) + self.position: int = data.get("position", 0) + + @property + @abstractmethod + def _sorting_bucket(self) -> int: + """Returns the bucket for sorting channels by type.""" + raise NotImplementedError + + async def _move( + self, + position: int, + parent_id: Any | None = None, + lock_permissions: bool = False, + *, + reason: str | None, + ) -> None: + """Internal method to move a channel to a specific position. + + Parameters + ---------- + position: int + The new position for the channel. + parent_id: Any | None + The parent category ID, if moving to a category. + lock_permissions: bool + Whether to sync permissions with the category. + reason: str | None + The reason for moving the channel. + + Raises + ------ + InvalidArgument + The position is less than 0. + """ + if position < 0: + raise InvalidArgument("Channel position cannot be less than 0.") + + bucket = self._sorting_bucket + channels: list[Self] = [c for c in self.guild.channels if c._sorting_bucket == bucket] + + channels.sort(key=lambda c: c.position) + + try: + # remove ourselves from the channel list + channels.remove(self) + except ValueError: + # not there somehow lol + return + else: + index = next( + (i for i, c in enumerate(channels) if c.position >= position), + len(channels), + ) + # add ourselves at our designated position + channels.insert(index, self) + + payload: list[ChannelPositionUpdatePayload] = [] + for index, c in enumerate(channels): + d: ChannelPositionUpdatePayload = {"id": c.id, "position": index} + if parent_id is not MISSING and c.id == self.id: + d.update(parent_id=parent_id, lock_permissions=lock_permissions) + payload.append(d) + + await self._state.http.bulk_channel_update(self.guild.id, payload, reason=reason) + + @overload + async def move( + self, + *, + beginning: bool, + offset: int | Undefined = MISSING, + category: Snowflake | None | Undefined = MISSING, + sync_permissions: bool | Undefined = MISSING, + reason: str | None | Undefined = MISSING, + ) -> None: ... + + @overload + async def move( + self, + *, + end: bool, + offset: int | Undefined = MISSING, + category: Snowflake | None | Undefined = MISSING, + sync_permissions: bool | Undefined = MISSING, + reason: str | Undefined = MISSING, + ) -> None: ... + + @overload + async def move( + self, + *, + before: Snowflake, + offset: int | Undefined = MISSING, + category: Snowflake | None | Undefined = MISSING, + sync_permissions: bool | Undefined = MISSING, + reason: str | Undefined = MISSING, + ) -> None: ... + + @overload + async def move( + self, + *, + after: Snowflake, + offset: int | Undefined = MISSING, + category: Snowflake | None | Undefined = MISSING, + sync_permissions: bool | Undefined = MISSING, + reason: str | Undefined = MISSING, + ) -> None: ... + + async def move(self, **kwargs: Any) -> None: + """|coro| + + A rich interface to help move a channel relative to other channels. + + If exact position movement is required, ``edit`` should be used instead. + + You must have :attr:`~discord.Permissions.manage_channels` permission to + do this. + + .. note:: + + Voice channels will always be sorted below text channels. + This is a Discord limitation. + + .. versionadded:: 1.7 + + Parameters + ---------- + beginning: bool + Whether to move the channel to the beginning of the + channel list (or category if given). + This is mutually exclusive with ``end``, ``before``, and ``after``. + end: bool + Whether to move the channel to the end of the + channel list (or category if given). + This is mutually exclusive with ``beginning``, ``before``, and ``after``. + before: ~discord.abc.Snowflake + The channel that should be before our current channel. + This is mutually exclusive with ``beginning``, ``end``, and ``after``. + after: ~discord.abc.Snowflake + The channel that should be after our current channel. + This is mutually exclusive with ``beginning``, ``end``, and ``before``. + offset: int + The number of channels to offset the move by. For example, + an offset of ``2`` with ``beginning=True`` would move + it 2 after the beginning. A positive number moves it below + while a negative number moves it above. Note that this + number is relative and computed after the ``beginning``, + ``end``, ``before``, and ``after`` parameters. + category: ~discord.abc.Snowflake | None + The category to move this channel under. + If ``None`` is given then it moves it out of the category. + This parameter is ignored if moving a category channel. + sync_permissions: bool + Whether to sync the permissions with the category (if given). + reason: str | None + The reason for the move. + + Raises + ------ + InvalidArgument + An invalid position was given or a bad mix of arguments was passed. + Forbidden + You do not have permissions to move the channel. + HTTPException + Moving the channel failed. + """ + + if not kwargs: + return + + beginning, end = kwargs.get("beginning"), kwargs.get("end") + before, after = kwargs.get("before"), kwargs.get("after") + offset = kwargs.get("offset", 0) + if sum(bool(a) for a in (beginning, end, before, after)) > 1: + raise InvalidArgument("Only one of [before, after, end, beginning] can be used.") + + bucket = self._sorting_bucket + parent_id = kwargs.get("category", MISSING) + channels: list[GuildChannel] + if parent_id not in (MISSING, None): + parent_id = parent_id.id + channels = [ + ch for ch in self.guild.channels if ch._sorting_bucket == bucket and ch.category_id == parent_id + ] + else: + channels = [ + ch for ch in self.guild.channels if ch._sorting_bucket == bucket and ch.category_id == self.category_id + ] + + channels.sort(key=lambda c: (c.position, c.id)) + + try: + # Try to remove ourselves from the channel list + channels.remove(self) + except ValueError: + # If we're not there then it's probably due to not being in the category + pass + + index = None + if beginning: + index = 0 + elif end: + index = len(channels) + elif before: + index = next((i for i, c in enumerate(channels) if c.id == before.id), None) + elif after: + index = next((i + 1 for i, c in enumerate(channels) if c.id == after.id), None) + + if index is None: + raise InvalidArgument("Could not resolve appropriate move position") + # TODO: This could use self._move to avoid code duplication + channels.insert(max((index + offset), 0), self) + payload: list[ChannelPositionUpdatePayload] = [] + lock_permissions = kwargs.get("sync_permissions", False) + reason = kwargs.get("reason") + for index, channel in enumerate(channels): + d: ChannelPositionUpdatePayload = {"id": channel.id, "position": index} # pyright: ignore[reportAssignmentType] + if parent_id is not MISSING and channel.id == self.id: + d.update(parent_id=parent_id, lock_permissions=lock_permissions) + payload.append(d) + + await self._state.http.bulk_channel_update(self.guild.id, payload, reason=reason) + + +P_guild_threadable = TypeVar( + "P_guild_threadable", + bound="TextChannelPayload | NewsChannelPayload | ForumChannelPayload | MediaChannelPayload", + default="TextChannelPayload | NewsChannelPayload | ForumChannelPayload | MediaChannelPayload", +) + + +class GuildThreadableChannel: + """An ABC for guild channels that support thread creation. + + This includes text, news, forum, and media channels. + Voice, stage, and category channels do not support threads. + + This is a mixin class that adds threading capabilities to guild channels. + + .. versionadded:: 3.0 + + Attributes + ---------- + default_auto_archive_duration: int + The default auto archive duration in minutes for threads created in this channel. + default_thread_slowmode_delay: int | None + The initial slowmode delay to set on newly created threads in this channel. + """ + + __slots__ = () # Mixin class - slots defined in concrete classes + + # Type hints for attributes that this mixin expects from the inheriting class + if TYPE_CHECKING: + id: int + guild: Guild + default_auto_archive_duration: int + default_thread_slowmode_delay: int | None + + async def _update(self, data) -> None: + """Update threadable channel attributes.""" + await super()._update(data) # Call next in MRO + self.default_auto_archive_duration: int = data.get("default_auto_archive_duration", 1440) + self.default_thread_slowmode_delay: int | None = data.get("default_thread_rate_limit_per_user") + + @property + def threads(self) -> list[Thread]: + """Returns all the threads that you can see in this channel. + + .. versionadded:: 2.0 + + Returns + ------- + list[:class:`Thread`] + All active threads in this channel. + """ + return [thread for thread in self.guild._threads.values() if thread.parent_id == self.id] + + def get_thread(self, thread_id: int, /) -> Thread | None: + """Returns a thread with the given ID. + + .. versionadded:: 2.0 + + Parameters + ---------- + thread_id: int + The ID to search for. + + Returns + ------- + Thread | None + The returned thread or ``None`` if not found. + """ + return self.guild.get_thread(thread_id) + + def archived_threads( + self, + *, + private: bool = False, + joined: bool = False, + limit: int | None = 50, + before: Snowflake | datetime.datetime | None = None, + ) -> ArchivedThreadIterator: + """Returns an iterator that iterates over all archived threads in the channel. + + You must have :attr:`~Permissions.read_message_history` to use this. If iterating over private threads + then :attr:`~Permissions.manage_threads` is also required. + + .. versionadded:: 2.0 + + Parameters + ---------- + limit: int | None + The number of threads to retrieve. + If ``None``, retrieves every archived thread in the channel. Note, however, + that this would make it a slow operation. + before: Snowflake | datetime.datetime | None + Retrieve archived channels before the given date or ID. + private: bool + Whether to retrieve private archived threads. + joined: bool + Whether to retrieve private archived threads that you've joined. + You cannot set ``joined`` to ``True`` and ``private`` to ``False``. + + Yields + ------ + :class:`Thread` + The archived threads. + + Raises + ------ + Forbidden + You do not have permissions to get archived threads. + HTTPException + The request to get the archived threads failed. + """ + return ArchivedThreadIterator( + self.id, + self.guild, + limit=limit, + joined=joined, + private=private, + before=before, + ) + + +P_guild_postable = TypeVar( + "P_guild_postable", + bound="ForumChannelPayload | MediaChannelPayload", + default="ForumChannelPayload | MediaChannelPayload", +) + + +class ForumTag(Hashable): + """Represents a forum tag that can be added to a thread inside a :class:`ForumChannel` + . + .. versionadded:: 2.3 + + .. container:: operations + + .. describe:: x == y + + Checks if two forum tags are equal. + + .. describe:: x != y + + Checks if two forum tags are not equal. + + .. describe:: hash(x) + + Returns the forum tag's hash. + + .. describe:: str(x) + + Returns the forum tag's name. + + Attributes + ---------- + id: :class:`int` + The tag ID. + Note that if the object was created manually then this will be ``0``. + name: :class:`str` + The name of the tag. Can only be up to 20 characters. + moderated: :class:`bool` + Whether this tag can only be added or removed by a moderator with + the :attr:`~Permissions.manage_threads` permission. + emoji: :class:`PartialEmoji` + The emoji that is used to represent this tag. + Note that if the emoji is a custom emoji, it will *not* have name information. + """ + + __slots__ = ("name", "id", "moderated", "emoji") + + def __init__(self, *, name: str, emoji: EmojiInputType, moderated: bool = False) -> None: + self.name: str = name + self.id: int = 0 + self.moderated: bool = moderated + self.emoji: PartialEmoji + if isinstance(emoji, _EmojiTag): + self.emoji = emoji._to_partial() + elif isinstance(emoji, str): + self.emoji = PartialEmoji.from_str(emoji) + else: + raise TypeError(f"emoji must be a GuildEmoji, PartialEmoji, or str and not {emoji.__class__!r}") + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return self.name + + @classmethod + def from_data(cls, *, state: ConnectionState, data: ForumTagPayload) -> ForumTag: + self = cls.__new__(cls) + self.name = data["name"] + self.id = int(data["id"]) + self.moderated = data.get("moderated", False) + + emoji_name = data["emoji_name"] or "" + emoji_id = get_as_snowflake(data, "emoji_id") or None + self.emoji = PartialEmoji.with_state(state=state, name=emoji_name, id=emoji_id) + return self + + def to_dict(self) -> dict[str, Any]: + payload: dict[str, Any] = { + "name": self.name, + "moderated": self.moderated, + } | self.emoji._to_forum_reaction_payload() + + if self.id: + payload["id"] = self.id + + return payload + + +class GuildPostableChannel( + GuildTopLevelChannel[P_guild_postable], GuildThreadableChannel, ABC, Generic[P_guild_postable] +): + """An ABC for guild channels that support posts (threads with tags). + + This is a common base for forum and media channels. These channels don't support + direct messaging, but users create posts (which are threads) with associated tags. + + .. versionadded:: 3.0 + + Attributes + ---------- + topic: str | None + The channel's topic/guidelines. ``None`` if it doesn't exist. + nsfw: bool + Whether the channel is marked as NSFW. + slowmode_delay: int + The number of seconds a member must wait between creating posts + in this channel. A value of ``0`` denotes that it is disabled. + last_message_id: int | None + The ID of the last message sent in this channel. It may not always point to an existing or valid message. + available_tags: list[ForumTag] + The set of tags that can be used in this channel. + default_sort_order: SortOrder | None + The default sort order type used to order posts in this channel. + default_reaction_emoji: str | GuildEmoji | None + The default reaction emoji for posts in this channel. + """ + + __slots__: tuple[str, ...] = ( + "topic", + "nsfw", + "slowmode_delay", + "last_message_id", + "default_auto_archive_duration", + "default_thread_slowmode_delay", + "available_tags", + "default_sort_order", + "default_reaction_emoji", + ) + + @override + async def _update(self, data: P_guild_postable) -> None: + await super()._update(data) + if not data.pop("_invoke_flag", False): + self.topic: str | None = data.get("topic") + self.nsfw: bool = data.get("nsfw", False) + self.slowmode_delay: int = data.get("rate_limit_per_user", 0) + self.last_message_id: int | None = get_as_snowflake(data, "last_message_id") + + self.available_tags: list[ForumTag] = [ + ForumTag.from_data(state=self._state, data=tag) for tag in (data.get("available_tags") or []) + ] + self.default_sort_order: SortOrder | None = data.get("default_sort_order", None) + if self.default_sort_order is not None: + self.default_sort_order = try_enum(SortOrder, self.default_sort_order) + + self.default_reaction_emoji = None + reaction_emoji_ctx: dict = data.get("default_reaction_emoji") + if reaction_emoji_ctx is not None: + emoji_name = reaction_emoji_ctx.get("emoji_name") + if emoji_name is not None: + self.default_reaction_emoji = reaction_emoji_ctx["emoji_name"] + else: + emoji_id = get_as_snowflake(reaction_emoji_ctx, "emoji_id") + if emoji_id: + self.default_reaction_emoji = await self._state.get_emoji(emoji_id) + + @property + def guidelines(self) -> str | None: + """The channel's guidelines. An alias of :attr:`topic`.""" + return self.topic + + @property + def requires_tag(self) -> bool: + """Whether a tag is required to be specified when creating a post in this channel. + + .. versionadded:: 2.3 + """ + return self.flags.require_tag + + def get_tag(self, id: int, /) -> ForumTag | None: + """Returns the :class:`ForumTag` from this channel with the given ID, if any. + + .. versionadded:: 2.3 + """ + return find(lambda t: t.id == id, self.available_tags) + + async def create_thread( + self, + name: str, + content: str | None = None, + *, + embed: Embed | None = None, + embeds: list[Embed] | None = None, + file: File | None = None, + files: list[File] | None = None, + stickers: Sequence[GuildSticker | StickerItem] | None = None, + delete_message_after: float | None = None, + nonce: int | str | None = None, + allowed_mentions: AllowedMentions | None = None, + view: View | None = None, + applied_tags: list[ForumTag] | None = None, + suppress: bool = False, + silent: bool = False, + auto_archive_duration: int | Undefined = MISSING, + slowmode_delay: int | Undefined = MISSING, + reason: str | None = None, + ) -> Thread: + """|coro| + + Creates a post (thread with initial message) in this forum or media channel. + + To create a post, you must have :attr:`~discord.Permissions.create_public_threads` or + :attr:`~discord.Permissions.send_messages` permission. + + .. versionadded:: 2.0 + + Parameters + ---------- + name: :class:`str` + The name of the post/thread. + content: :class:`str` + The content of the initial message. + embed: :class:`~discord.Embed` + The rich embed for the content. + embeds: list[:class:`~discord.Embed`] + A list of embeds to upload. Must be a maximum of 10. + file: :class:`~discord.File` + The file to upload. + files: list[:class:`~discord.File`] + A list of files to upload. Must be a maximum of 10. + stickers: Sequence[:class:`~discord.GuildSticker` | :class:`~discord.StickerItem`] + A list of stickers to upload. Must be a maximum of 3. + delete_message_after: :class:`float` + The time in seconds to wait before deleting the initial message. + nonce: :class:`str` | :class:`int` + The nonce to use for sending this message. + allowed_mentions: :class:`~discord.AllowedMentions` + Controls the mentions being processed in this message. + view: :class:`discord.ui.View` + A Discord UI View to add to the message. + applied_tags: list[:class:`ForumTag`] + A list of tags to apply to the new post. + suppress: :class:`bool` + Whether to suppress embeds in the initial message. + silent: :class:`bool` + Whether to send the message without triggering a notification. + auto_archive_duration: :class:`int` + The duration in minutes before the post is automatically archived for inactivity. + If not provided, the channel's default auto archive duration is used. + slowmode_delay: :class:`int` + The number of seconds a member must wait between sending messages in the new post. + If not provided, the channel's default slowmode is used. + reason: :class:`str` + The reason for creating the post. Shows up on the audit log. + + Returns + ------- + :class:`Thread` + The created post/thread. + + Raises + ------ + Forbidden + You do not have permissions to create a post. + HTTPException + Creating the post failed. + InvalidArgument + You provided invalid arguments. + """ + from ..errors import InvalidArgument + from ..file import File + from ..flags import MessageFlags + + state = self._state + message_content = str(content) if content is not None else None + + if embed is not None and embeds is not None: + raise InvalidArgument("cannot pass both embed and embeds parameter to create_thread()") + + if embed is not None: + embed = embed.to_dict() + + elif embeds is not None: + if len(embeds) > 10: + raise InvalidArgument("embeds parameter must be a list of up to 10 elements") + embeds = [e.to_dict() for e in embeds] + + if stickers is not None: + stickers = [sticker.id for sticker in stickers] + + if allowed_mentions is None: + allowed_mentions = state.allowed_mentions and state.allowed_mentions.to_dict() + elif state.allowed_mentions is not None: + allowed_mentions = state.allowed_mentions.merge(allowed_mentions).to_dict() + else: + allowed_mentions = allowed_mentions.to_dict() + + flags = MessageFlags( + suppress_embeds=bool(suppress), + suppress_notifications=bool(silent), + ) + + if view: + if not hasattr(view, "__discord_ui_view__"): + raise InvalidArgument(f"view parameter must be View not {view.__class__!r}") + + components = view.to_components() + if view.is_components_v2(): + if embeds or content: + raise TypeError("cannot send embeds or content with a view using v2 component logic") + flags.is_components_v2 = True + else: + components = None + + if applied_tags is not None: + applied_tags = [str(tag.id) for tag in applied_tags] + + if file is not None and files is not None: + raise InvalidArgument("cannot pass both file and files parameter to create_thread()") + + if files is not None: + if len(files) > 10: + raise InvalidArgument("files parameter must be a list of up to 10 elements") + elif not all(isinstance(f, File) for f in files): + raise InvalidArgument("files parameter must be a list of File") + + if file is not None: + if not isinstance(file, File): + raise InvalidArgument("file parameter must be File") + files = [file] + + try: + data = await state.http.start_forum_thread( + self.id, + content=message_content, + name=name, + files=files, + embed=embed, + embeds=embeds, + nonce=nonce, + allowed_mentions=allowed_mentions, + stickers=stickers, + components=components, + auto_archive_duration=auto_archive_duration + if auto_archive_duration is not MISSING + else self.default_auto_archive_duration, + rate_limit_per_user=slowmode_delay + if slowmode_delay is not MISSING + else self.default_thread_slowmode_delay, + applied_tags=applied_tags, + flags=flags.value, + reason=reason, + ) + finally: + if files is not None: + for f in files: + f.close() + + from .thread import Thread + + ret = Thread(guild=self.guild, state=self._state, data=data) + msg = ret.get_partial_message(int(data["last_message_id"])) + if view and view.is_dispatchable(): + await state.store_view(view, msg.id) + + if delete_message_after is not None: + await msg.delete(delay=delete_message_after) + return ret + + +P_guild_messageable = TypeVar( + "P_guild_messageable", + bound="TextChannelPayload | NewsChannelPayload | VoiceChannelPayload | StageChannelPayload | ForumChannelPayload", + default="TextChannelPayload | NewsChannelPayload | VoiceChannelPayload | StageChannelPayload | ForumChannelPayload", +) + + +class GuildMessageableChannel(Messageable, ABC): + """An ABC mixin for guild channels that support messaging. + + This includes text and news channels, as well as threads. Voice and stage channels + do not support direct messaging (though they can have threads). + + This is a mixin class that adds messaging capabilities to guild channels. + + .. versionadded:: 3.0 + + Attributes + ---------- + topic: str | None + The channel's topic. ``None`` if it doesn't exist. + nsfw: bool + Whether the channel is marked as NSFW. + slowmode_delay: int + The number of seconds a member must wait between sending messages + in this channel. A value of ``0`` denotes that it is disabled. + Bots and users with :attr:`~Permissions.manage_channels` or + :attr:`~Permissions.manage_messages` bypass slowmode. + last_message_id: int | None + The ID of the last message sent in this channel. It may not always point to an existing or valid message. + """ + + __slots__ = () # Mixin class - slots defined in concrete classes + + # Attributes expected from inheriting classes + id: int + guild: Guild + _state: ConnectionState + topic: str | None + nsfw: bool + slowmode_delay: int + last_message_id: int | None + + async def _update(self, data) -> None: + """Update mutable attributes from API payload.""" + await super()._update(data) + # This data may be missing depending on how this object is being created/updated + if not data.pop("_invoke_flag", False): + self.topic = data.get("topic") + self.nsfw = data.get("nsfw", False) + # Does this need coercion into `int`? No idea yet. + self.slowmode_delay = data.get("rate_limit_per_user", 0) + self.last_message_id = get_as_snowflake(data, "last_message_id") + + @copy_doc(GuildChannel.permissions_for) + @override + def permissions_for(self, obj: Member | Role, /) -> Permissions: + base = super().permissions_for(obj) + + # text channels do not have voice related permissions + denied = Permissions.voice() + base.value &= ~denied.value + return base + + async def get_members(self) -> list[Member]: + """Returns all members that can see this channel.""" + return [m for m in await self.guild.get_members() if self.permissions_for(m).read_messages] + + async def get_last_message(self) -> Message | None: + """Fetches the last message from this channel in cache. + + The message might not be valid or point to an existing message. + + .. admonition:: Reliable Fetching + :class: helpful + + For a slightly more reliable method of fetching the + last message, consider using either :meth:`history` + or :meth:`fetch_message` with the :attr:`last_message_id` + attribute. + + Returns + ------- + Optional[:class:`Message`] + The last message in this channel or ``None`` if not found. + """ + return await self._state._get_message(self.last_message_id) if self.last_message_id else None + + async def edit(self, **options) -> Self: + """Edits the channel.""" + raise NotImplementedError + + @copy_doc(GuildChannel.clone) + @override + async def clone(self, *, name: str | None = None, reason: str | None = None) -> Self: + return await self._clone_impl( + { + "topic": self.topic, + "nsfw": self.nsfw, + "rate_limit_per_user": self.slowmode_delay, + }, + name=name, + reason=reason, + ) + + async def delete_messages(self, messages: Iterable[Snowflake], *, reason: str | None = None) -> None: + """|coro| + + Deletes a list of messages. This is similar to :meth:`Message.delete` + except it bulk deletes multiple messages. + + As a special case, if the number of messages is 0, then nothing + is done. If the number of messages is 1 then single message + delete is done. If it's more than two, then bulk delete is used. + + You cannot bulk delete more than 100 messages or messages that + are older than 14 days old. + + You must have the :attr:`~Permissions.manage_messages` permission to + use this. + + Parameters + ---------- + messages: Iterable[:class:`abc.Snowflake`] + An iterable of messages denoting which ones to bulk delete. + reason: Optional[:class:`str`] + The reason for deleting the messages. Shows up on the audit log. + + Raises + ------ + ClientException + The number of messages to delete was more than 100. + Forbidden + You do not have proper permissions to delete the messages. + NotFound + If single delete, then the message was already deleted. + HTTPException + Deleting the messages failed. + """ + if not isinstance(messages, (list, tuple)): + messages = list(messages) + + if len(messages) == 0: + return # do nothing + + if len(messages) == 1: + message_id: int = messages[0].id + await self._state.http.delete_message(self.id, message_id, reason=reason) + return + + if len(messages) > 100: + raise ClientException("Can only bulk delete messages up to 100 messages") + + message_ids: SnowflakeList = [m.id for m in messages] + await self._state.http.delete_messages(self.id, message_ids, reason=reason) + + async def purge( + self, + *, + limit: int | None = 100, + check: Callable[[Message], bool] | Undefined = MISSING, + before: SnowflakeTime | None = None, + after: SnowflakeTime | None = None, + around: SnowflakeTime | None = None, + oldest_first: bool | None = False, + bulk: bool = True, + reason: str | None = None, + ) -> list[Message]: + """|coro| + + Purges a list of messages that meet the criteria given by the predicate + ``check``. If a ``check`` is not provided then all messages are deleted + without discrimination. + + You must have the :attr:`~Permissions.manage_messages` permission to + delete messages even if they are your own. + The :attr:`~Permissions.read_message_history` permission is + also needed to retrieve message history. + + Parameters + ---------- + limit: Optional[:class:`int`] + The number of messages to search through. This is not the number + of messages that will be deleted, though it can be. + check: Callable[[:class:`Message`], :class:`bool`] + The function used to check if a message should be deleted. + It must take a :class:`Message` as its sole parameter. + before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Same as ``before`` in :meth:`history`. + after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Same as ``after`` in :meth:`history`. + around: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Same as ``around`` in :meth:`history`. + oldest_first: Optional[:class:`bool`] + Same as ``oldest_first`` in :meth:`history`. + bulk: :class:`bool` + If ``True``, use bulk delete. Setting this to ``False`` is useful for mass-deleting + a bot's own messages without :attr:`Permissions.manage_messages`. When ``True``, will + fall back to single delete if messages are older than two weeks. + reason: Optional[:class:`str`] + The reason for deleting the messages. Shows up on the audit log. + + Returns + ------- + List[:class:`.Message`] + The list of messages that were deleted. + + Raises + ------ + Forbidden + You do not have proper permissions to do the actions required. + HTTPException + Purging the messages failed. + + Examples + -------- + + Deleting bot's messages :: + + def is_me(m): + return m.author == client.user + + + deleted = await channel.purge(limit=100, check=is_me) + await channel.send(f"Deleted {len(deleted)} message(s)") + """ + return await _purge_messages_helper( + self, + limit=limit, + check=check, + before=before, + after=after, + around=around, + oldest_first=oldest_first, + bulk=bulk, + reason=reason, + ) + + async def webhooks(self) -> list[Webhook]: + """|coro| + + Gets the list of webhooks from this channel. + + Requires :attr:`~.Permissions.manage_webhooks` permissions. + + Returns + ------- + List[:class:`Webhook`] + The webhooks for this channel. + + Raises + ------ + Forbidden + You don't have permissions to get the webhooks. + """ + + from .webhook import Webhook + + data = await self._state.http.channel_webhooks(self.id) + return [Webhook.from_state(d, state=self._state) for d in data] + + async def create_webhook(self, *, name: str, avatar: bytes | None = None, reason: str | None = None) -> Webhook: + """|coro| + + Creates a webhook for this channel. + + Requires :attr:`~.Permissions.manage_webhooks` permissions. + + .. versionchanged:: 1.1 + Added the ``reason`` keyword-only parameter. + + Parameters + ---------- + name: :class:`str` + The webhook's name. + avatar: Optional[:class:`bytes`] + A :term:`py:bytes-like object` representing the webhook's default avatar. + This operates similarly to :meth:`~ClientUser.edit`. + reason: Optional[:class:`str`] + The reason for creating this webhook. Shows up in the audit logs. + + Returns + ------- + :class:`Webhook` + The created webhook. + + Raises + ------ + HTTPException + Creating the webhook failed. + Forbidden + You do not have permissions to create a webhook. + """ + + from .webhook import Webhook + + if avatar is not None: + avatar = bytes_to_base64_data(avatar) # type: ignore + + data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason) + return Webhook.from_state(data, state=self._state) + + async def follow(self, *, destination: TextChannel, reason: str | None = None) -> Webhook: + """ + Follows a channel using a webhook. + + Only news channels can be followed. + + .. note:: + + The webhook returned will not provide a token to do webhook + actions, as Discord does not provide it. + + .. versionadded:: 1.3 + + Parameters + ---------- + destination: :class:`TextChannel` + The channel you would like to follow from. + reason: Optional[:class:`str`] + The reason for following the channel. Shows up on the destination guild's audit log. + + .. versionadded:: 1.4 + + Returns + ------- + :class:`Webhook` + The created webhook. + + Raises + ------ + HTTPException + Following the channel failed. + Forbidden + You do not have the permissions to create a webhook. + """ + + from .news import NewsChannel + from .text import TextChannel + + if not isinstance(self, NewsChannel): + raise ClientException("The channel must be a news channel.") + + if not isinstance(destination, TextChannel): + raise InvalidArgument(f"Expected TextChannel received {destination.__class__.__name__}") + + from .webhook import Webhook + + data = await self._state.http.follow_webhook(self.id, webhook_channel_id=destination.id, reason=reason) + return Webhook._as_follower(data, channel=destination, user=self._state.user) + + def get_partial_message(self, message_id: int, /) -> PartialMessage: + """Creates a :class:`PartialMessage` from the message ID. + + This is useful if you want to work with a message and only have its ID without + doing an unnecessary API call. + + .. versionadded:: 1.6 + + Parameters + ---------- + message_id: :class:`int` + The message ID to create a partial message for. + + Returns + ------- + :class:`PartialMessage` + The partial message. + """ + + from .message import PartialMessage + + return PartialMessage(channel=self, id=message_id) diff --git a/discord/channel/category.py b/discord/channel/category.py new file mode 100644 index 0000000000..be1da44f10 --- /dev/null +++ b/discord/channel/category.py @@ -0,0 +1,273 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, overload + +from typing_extensions import override + +if TYPE_CHECKING: + from collections.abc import Mapping + + from ..app.state import ConnectionState + from ..guild import Guild + from ..member import Member + from ..permissions import PermissionOverwrite + from ..role import Role + from . import ForumChannel, StageChannel, TextChannel, VoiceChannel + +from ..enums import ChannelType, try_enum +from ..flags import ChannelFlags +from ..types.channel import CategoryChannel as CategoryChannelPayload +from ..utils.private import copy_doc +from .base import GuildChannel, GuildTopLevelChannel + + +def comparator(channel: GuildChannel): + # Sorts channels so voice channels (VoiceChannel, StageChannel) appear below non-voice channels + return isinstance(channel, (VoiceChannel, StageChannel)), (channel.position or -1) + + +class CategoryChannel(GuildTopLevelChannel[CategoryChannelPayload]): + """Represents a Discord channel category. + + These are useful to group channels to logical compartments. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the category's hash. + + .. describe:: str(x) + + Returns the category's name. + + Attributes + ---------- + name: str + The category name. + guild: Guild + The guild the category belongs to. + id: int + The category channel ID. + position: int + The position in the category list. This is a number that starts at 0. e.g. the + top category is position 0. + flags: ChannelFlags + Extra features of the channel. + + .. versionadded:: 2.0 + """ + + __slots__: tuple[str, ...] = () + + @override + def __repr__(self) -> str: + return f"" + + @property + @override + def _sorting_bucket(self) -> int: + return ChannelType.category.value + + @property + def type(self) -> ChannelType: + """The channel's Discord type.""" + return try_enum(ChannelType, self._type) + + @copy_doc(GuildChannel.clone) + async def clone(self, *, name: str | None = None, reason: str | None = None) -> CategoryChannel: + return await self._clone_impl({}, name=name, reason=reason) + + @overload + async def edit( + self, + *, + name: str = ..., + position: int = ..., + overwrites: Mapping[Role | Member, PermissionOverwrite] = ..., + reason: str | None = ..., + ) -> CategoryChannel | None: ... + + @overload + async def edit(self) -> CategoryChannel | None: ... + + async def edit(self, *, reason=None, **options): + """|coro| + + Edits the channel. + + You must have the :attr:`~Permissions.manage_channels` permission to + use this. + + .. versionchanged:: 1.3 + The ``overwrites`` keyword-only parameter was added. + + .. versionchanged:: 2.0 + Edits are no longer in-place, the newly edited channel is returned instead. + + Parameters + ---------- + name: :class:`str` + The new category's name. + position: :class:`int` + The new category's position. + reason: Optional[:class:`str`] + The reason for editing this category. Shows up on the audit log. + overwrites: Dict[Union[:class:`Role`, :class:`Member`, :class:`~discord.abc.Snowflake`], :class:`PermissionOverwrite`] + The overwrites to apply to channel permissions. Useful for creating secret channels. + + Returns + ------- + Optional[:class:`.CategoryChannel`] + The newly edited category channel. If the edit was only positional + then ``None`` is returned instead. + + Raises + ------ + InvalidArgument + If position is less than 0 or greater than the number of categories. + Forbidden + You do not have permissions to edit the category. + HTTPException + Editing the category failed. + """ + + payload = await self._edit(options, reason=reason) + if payload is not None: + # the payload will always be the proper channel payload + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore + + @copy_doc(GuildTopLevelChannel.move) + async def move(self, **kwargs): + kwargs.pop("category", None) + await super().move(**kwargs) + + @property + def channels(self) -> list[GuildTopLevelChannel]: + """Returns the channels that are under this category. + + These are sorted by the official Discord UI, which places voice channels below the text channels. + """ + + ret = [c for c in self.guild.channels if c.category_id == self.id] + ret.sort(key=comparator) + return ret + + @property + def text_channels(self) -> list[TextChannel]: + """Returns the text channels that are under this category.""" + ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, TextChannel)] + ret.sort(key=lambda c: (c.position or -1, c.id)) + return ret + + @property + def voice_channels(self) -> list[VoiceChannel]: + """Returns the voice channels that are under this category.""" + ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, VoiceChannel)] + ret.sort(key=lambda c: (c.position or -1, c.id)) + return ret + + @property + def stage_channels(self) -> list[StageChannel]: + """Returns the stage channels that are under this category. + + .. versionadded:: 1.7 + """ + ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, StageChannel)] + ret.sort(key=lambda c: (c.position or -1, c.id)) + return ret + + @property + def forum_channels(self) -> list[ForumChannel]: + """Returns the forum channels that are under this category. + + .. versionadded:: 2.0 + """ + ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, ForumChannel)] + ret.sort(key=lambda c: (c.position or -1, c.id)) + return ret + + async def create_text_channel(self, name: str, **options: Any) -> TextChannel: + """|coro| + + A shortcut method to :meth:`Guild.create_text_channel` to create a :class:`TextChannel` in the category. + + Returns + ------- + :class:`TextChannel` + The channel that was just created. + """ + return await self.guild.create_text_channel(name, category=self, **options) + + async def create_voice_channel(self, name: str, **options: Any) -> VoiceChannel: + """|coro| + + A shortcut method to :meth:`Guild.create_voice_channel` to create a :class:`VoiceChannel` in the category. + + Returns + ------- + :class:`VoiceChannel` + The channel that was just created. + """ + return await self.guild.create_voice_channel(name, category=self, **options) + + async def create_stage_channel(self, name: str, **options: Any) -> StageChannel: + """|coro| + + A shortcut method to :meth:`Guild.create_stage_channel` to create a :class:`StageChannel` in the category. + + .. versionadded:: 1.7 + + Returns + ------- + :class:`StageChannel` + The channel that was just created. + """ + return await self.guild.create_stage_channel(name, category=self, **options) + + async def create_forum_channel(self, name: str, **options: Any) -> ForumChannel: + """|coro| + + A shortcut method to :meth:`Guild.create_forum_channel` to create a :class:`ForumChannel` in the category. + + .. versionadded:: 2.0 + + Returns + ------- + :class:`ForumChannel` + The channel that was just created. + """ + return await self.guild.create_forum_channel(name, category=self, **options) diff --git a/discord/channel.py b/discord/channel/channel.py.old similarity index 91% rename from discord/channel.py rename to discord/channel/channel.py.old index fb8b196265..00b149b748 100644 --- a/discord/channel.py +++ b/discord/channel/channel.py.old @@ -25,6 +25,7 @@ from __future__ import annotations +import asyncio import datetime from typing import ( TYPE_CHECKING, @@ -38,12 +39,14 @@ overload, ) +from typing_extensions import override + import discord.abc -from . import utils -from .asset import Asset -from .emoji import GuildEmoji -from .enums import ( +from .. import utils +from ..asset import Asset +from ..emoji import GuildEmoji +from ..enums import ( ChannelType, EmbeddedActivity, InviteTarget, @@ -54,21 +57,21 @@ VoiceRegion, try_enum, ) -from .enums import ThreadArchiveDuration as ThreadArchiveDurationEnum -from .errors import ClientException, InvalidArgument -from .file import File -from .flags import ChannelFlags, MessageFlags -from .invite import Invite -from .iterators import ArchivedThreadIterator -from .mixins import Hashable -from .object import Object -from .partial_emoji import PartialEmoji, _EmojiTag -from .permissions import PermissionOverwrite, Permissions -from .soundboard import PartialSoundboardSound, SoundboardSound -from .stage_instance import StageInstance -from .threads import Thread -from .utils import MISSING -from .utils.private import bytes_to_base64_data, copy_doc, get_as_snowflake +from ..enums import ThreadArchiveDuration as ThreadArchiveDurationEnum +from ..errors import ClientException, InvalidArgument +from ..file import File +from ..flags import ChannelFlags, MessageFlags +from ..invite import Invite +from ..iterators import ArchivedThreadIterator +from ..mixins import Hashable +from ..object import Object +from ..partial_emoji import PartialEmoji, _EmojiTag +from ..permissions import PermissionOverwrite, Permissions +from ..soundboard import PartialSoundboardSound, SoundboardSound +from ..stage_instance import StageInstance +from .thread import Thread +from ..utils import MISSING +from ..utils.private import bytes_to_base64_data, copy_doc, get_as_snowflake __all__ = ( "TextChannel", @@ -85,112 +88,30 @@ ) if TYPE_CHECKING: - from .abc import Snowflake, SnowflakeTime - from .embeds import Embed - from .guild import Guild - from .guild import GuildChannel as GuildChannelType - from .member import Member, VoiceState - from .mentions import AllowedMentions - from .message import EmojiInputType, Message, PartialMessage - from .role import Role - from .state import ConnectionState - from .sticker import GuildSticker, StickerItem - from .types.channel import CategoryChannel as CategoryChannelPayload - from .types.channel import DMChannel as DMChannelPayload - from .types.channel import ForumChannel as ForumChannelPayload - from .types.channel import ForumTag as ForumTagPayload - from .types.channel import GroupDMChannel as GroupChannelPayload - from .types.channel import StageChannel as StageChannelPayload - from .types.channel import TextChannel as TextChannelPayload - from .types.channel import VoiceChannel as VoiceChannelPayload - from .types.channel import VoiceChannelEffectSendEvent as VoiceChannelEffectSend - from .types.snowflake import SnowflakeList - from .types.threads import ThreadArchiveDuration - from .ui.view import View - from .user import BaseUser, ClientUser, User - from .webhook import Webhook - - -class ForumTag(Hashable): - """Represents a forum tag that can be added to a thread inside a :class:`ForumChannel` - . - .. versionadded:: 2.3 - - .. container:: operations - - .. describe:: x == y - - Checks if two forum tags are equal. - - .. describe:: x != y - - Checks if two forum tags are not equal. - - .. describe:: hash(x) - - Returns the forum tag's hash. - - .. describe:: str(x) - - Returns the forum tag's name. - - Attributes - ---------- - id: :class:`int` - The tag ID. - Note that if the object was created manually then this will be ``0``. - name: :class:`str` - The name of the tag. Can only be up to 20 characters. - moderated: :class:`bool` - Whether this tag can only be added or removed by a moderator with - the :attr:`~Permissions.manage_threads` permission. - emoji: :class:`PartialEmoji` - The emoji that is used to represent this tag. - Note that if the emoji is a custom emoji, it will *not* have name information. - """ - - __slots__ = ("name", "id", "moderated", "emoji") - - def __init__(self, *, name: str, emoji: EmojiInputType, moderated: bool = False) -> None: - self.name: str = name - self.id: int = 0 - self.moderated: bool = moderated - self.emoji: PartialEmoji - if isinstance(emoji, _EmojiTag): - self.emoji = emoji._to_partial() - elif isinstance(emoji, str): - self.emoji = PartialEmoji.from_str(emoji) - else: - raise TypeError(f"emoji must be a GuildEmoji, PartialEmoji, or str and not {emoji.__class__!r}") - - def __repr__(self) -> str: - return f"" - - def __str__(self) -> str: - return self.name - - @classmethod - def from_data(cls, *, state: ConnectionState, data: ForumTagPayload) -> ForumTag: - self = cls.__new__(cls) - self.name = data["name"] - self.id = int(data["id"]) - self.moderated = data.get("moderated", False) - - emoji_name = data["emoji_name"] or "" - emoji_id = get_as_snowflake(data, "emoji_id") or None - self.emoji = PartialEmoji.with_state(state=state, name=emoji_name, id=emoji_id) - return self - - def to_dict(self) -> dict[str, Any]: - payload: dict[str, Any] = { - "name": self.name, - "moderated": self.moderated, - } | self.emoji._to_forum_reaction_payload() - - if self.id: - payload["id"] = self.id - - return payload + from ..abc import Snowflake, SnowflakeTime + from ..app.state import ConnectionState + from ..embeds import Embed + from ..guild import Guild + from ..guild import GuildChannel as GuildChannelType + from ..member import Member, VoiceState + from ..mentions import AllowedMentions + from ..message import EmojiInputType, Message, PartialMessage + from ..role import Role + from ..sticker import GuildSticker, StickerItem + from ..types.channel import CategoryChannel as CategoryChannelPayload + from ..types.channel import DMChannel as DMChannelPayload + from ..types.channel import ForumChannel as ForumChannelPayload + from ..types.channel import ForumTag as ForumTagPayload + from ..types.channel import GroupDMChannel as GroupChannelPayload + from ..types.channel import StageChannel as StageChannelPayload + from ..types.channel import TextChannel as TextChannelPayload + from ..types.channel import VoiceChannel as VoiceChannelPayload + from ..types.channel import VoiceChannelEffectSendEvent as VoiceChannelEffectSend + from ..types.snowflake import SnowflakeList + from ..types.threads import ThreadArchiveDuration + from ..ui.view import View + from ..user import BaseUser, ClientUser, User + from ..webhook import Webhook class _TextChannel(discord.abc.GuildChannel, Hashable): @@ -218,13 +139,31 @@ class _TextChannel(discord.abc.GuildChannel, Hashable): def __init__( self, *, - state: ConnectionState, + id: int, guild: Guild, - data: TextChannelPayload | ForumChannelPayload, + state: ConnectionState, ): + """Initialize with permanent attributes only.""" self._state: ConnectionState = state - self.id: int = int(data["id"]) - self._update(guild, data) + self.id: int = id + self.guild: Guild = guild + + @classmethod + async def _from_data( + cls, + *, + data: TextChannelPayload | ForumChannelPayload, + state: ConnectionState, + guild: Guild, + ): + """Create channel instance from API payload.""" + self = cls( + id=int(data["id"]), + guild=guild, + state=state, + ) + await self._update(data) + return self @property def _repr_attrs(self) -> tuple[str, ...]: @@ -235,9 +174,9 @@ def __repr__(self) -> str: joined = " ".join("%s=%r" % t for t in attrs) return f"<{self.__class__.__name__} {joined}>" - def _update(self, guild: Guild, data: TextChannelPayload | ForumChannelPayload) -> None: + async def _update(self, data: TextChannelPayload | ForumChannelPayload) -> None: + """Update mutable attributes from API payload.""" # This data will always exist - self.guild: Guild = guild self.name: str = data["name"] self.category_id: int | None = get_as_snowflake(data, "parent_id") self._type: int = data["type"] @@ -289,8 +228,7 @@ def is_nsfw(self) -> bool: """Checks if the channel is NSFW.""" return self.nsfw - @property - def last_message(self) -> Message | None: + async def get_last_message(self) -> Message | None: """Fetches the last message from this channel in cache. The message might not be valid or point to an existing message. @@ -308,7 +246,7 @@ def last_message(self) -> Message | None: Optional[:class:`Message`] The last message in this channel or ``None`` if not found. """ - return self._state._get_message(self.last_message_id) if self.last_message_id else None + return await self._state._get_message(self.last_message_id) if self.last_message_id else None async def edit(self, **options) -> _TextChannel: """Edits the channel.""" @@ -660,7 +598,7 @@ def archived_threads( ) -class TextChannel(discord.abc.Messageable, _TextChannel): +class TextChannel(discord.abc.Messageable, ForumChannel): """Represents a Discord text channel. .. container:: operations @@ -724,15 +662,33 @@ class TextChannel(discord.abc.Messageable, _TextChannel): .. versionadded:: 2.3 """ - def __init__(self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload): - super().__init__(state=state, guild=guild, data=data) + def __init__(self, *, id: int, guild: Guild, state: ConnectionState): + """Initialize with permanent attributes only.""" + super().__init__(id=id, guild=guild, state=state) + + @classmethod + async def _from_data( + cls, + *, + data: TextChannelPayload, + state: ConnectionState, + guild: Guild, + ): + """Create channel instance from API payload.""" + self = cls( + id=int(data["id"]), + guild=guild, + state=state, + ) + await self._update(data) + return self @property def _repr_attrs(self) -> tuple[str, ...]: return super()._repr_attrs + ("news",) - def _update(self, guild: Guild, data: TextChannelPayload) -> None: - super()._update(guild, data) + async def _update(self, data: TextChannelPayload) -> None: + await super()._update(data) async def _get_channel(self) -> TextChannel: return self @@ -838,7 +794,7 @@ async def edit(self, *, reason=None, **options): payload = await self._edit(options, reason=reason) if payload is not None: # the payload will always be the proper channel payload - return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore async def create_thread( self, @@ -1003,11 +959,31 @@ class ForumChannel(_TextChannel): .. versionadded:: 2.5 """ - def __init__(self, *, state: ConnectionState, guild: Guild, data: ForumChannelPayload): - super().__init__(state=state, guild=guild, data=data) + def __init__(self, *, id: int, guild: Guild, state: ConnectionState): + """Initialize with permanent attributes only.""" + super().__init__(id=id, guild=guild, state=state) + + @classmethod + @override + async def _from_data( + cls, + *, + data: ForumChannelPayload, + state: ConnectionState, + guild: Guild, + ): + """Create channel instance from API payload.""" + self = cls( + id=int(data["id"]), + guild=guild, + state=state, + ) + await self._update(data) + return self - def _update(self, guild: Guild, data: ForumChannelPayload) -> None: - super()._update(guild, data) + @override + async def _update(self, data: ForumChannelPayload) -> None: + await super()._update(data) self.available_tags: list[ForumTag] = [ ForumTag.from_data(state=self._state, data=tag) for tag in (data.get("available_tags") or []) ] @@ -1023,7 +999,9 @@ def _update(self, guild: Guild, data: ForumChannelPayload) -> None: if emoji_name is not None: self.default_reaction_emoji = reaction_emoji_ctx["emoji_name"] else: - self.default_reaction_emoji = self._state.get_emoji(get_as_snowflake(reaction_emoji_ctx, "emoji_id")) + self.default_reaction_emoji = await self._state.get_emoji( + get_as_snowflake(reaction_emoji_ctx, "emoji_id") + ) @property def guidelines(self) -> str | None: @@ -1153,7 +1131,7 @@ async def edit(self, *, reason=None, **options): payload = await self._edit(options, reason=reason) if payload is not None: # the payload will always be the proper channel payload - return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore async def create_thread( self, @@ -1325,7 +1303,7 @@ async def create_thread( ret = Thread(guild=self.guild, state=self._state, data=data) msg = ret.get_partial_message(int(data["last_message_id"])) if view and view.is_dispatchable(): - state.store_view(view, msg.id) + await state.store_view(view, msg.id) if delete_message_after is not None: await msg.delete(delay=delete_message_after) @@ -1519,7 +1497,7 @@ async def edit(self, *, reason=None, **options): payload = await self._edit(options, reason=reason) if payload is not None: # the payload will always be the proper channel payload - return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): @@ -1544,13 +1522,34 @@ class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hasha def __init__( self, *, - state: ConnectionState, + id: int, guild: Guild, - data: VoiceChannelPayload | StageChannelPayload, + state: ConnectionState, + type: int | ChannelType, ): + """Initialize with permanent attributes only.""" self._state: ConnectionState = state - self.id: int = int(data["id"]) - self._update(guild, data) + self.id: int = id + self.guild = guild + self._type: int = int(type) + + @classmethod + async def _from_data( + cls, + *, + data: VoiceChannelPayload | StageChannelPayload, + state: ConnectionState, + guild: Guild, + ): + """Create channel instance from API payload.""" + self = cls( + id=int(data["id"]), + guild=guild, + state=state, + type=data["type"], + ) + await self._update(data) + return self def _get_voice_client_key(self) -> tuple[int, str]: return self.guild.id, "guild_id" @@ -1558,9 +1557,8 @@ def _get_voice_client_key(self) -> tuple[int, str]: def _get_voice_state_pair(self) -> tuple[int, int]: return self.guild.id, self.id - def _update(self, guild: Guild, data: VoiceChannelPayload | StageChannelPayload) -> None: + async def _update(self, data: VoiceChannelPayload | StageChannelPayload) -> None: # This data will always exist - self.guild = guild self.name: str = data["name"] self.category_id: int | None = get_as_snowflake(data, "parent_id") @@ -1582,13 +1580,12 @@ def _update(self, guild: Guild, data: VoiceChannelPayload | StageChannelPayload) def _sorting_bucket(self) -> int: return ChannelType.voice.value - @property - def members(self) -> list[Member]: + async def get_members(self) -> list[Member]: """Returns all members that are currently inside this voice channel.""" ret = [] for user_id, state in self.guild._voice_states.items(): if state.channel and state.channel.id == self.id: - member = self.guild.get_member(user_id) + member = await self.guild.get_member(user_id) if member is not None: ret.append(member) return ret @@ -1704,15 +1701,35 @@ class VoiceChannel(discord.abc.Messageable, VocalGuildChannel): def __init__( self, *, - state: ConnectionState, + id: int, guild: Guild, - data: VoiceChannelPayload, + state: ConnectionState, + type: int | ChannelType, ): + """Initialize with permanent attributes only.""" + super().__init__(id=id, guild=guild, state=state, type=type) self.status: str | None = None - super().__init__(state=state, guild=guild, data=data) - def _update(self, guild: Guild, data: VoiceChannelPayload): - super()._update(guild, data) + @classmethod + async def _from_data( + cls, + *, + data: VoiceChannelPayload, + state: ConnectionState, + guild: Guild, + ): + """Create channel instance from API payload.""" + self = cls( + id=int(data["id"]), + guild=guild, + state=state, + type=data["type"], + ) + await self._update(data) + return self + + async def _update(self, data: VoiceChannelPayload): + await super()._update(data) if data.get("status"): self.status = data.get("status") @@ -1738,8 +1755,7 @@ def is_nsfw(self) -> bool: """Checks if the channel is NSFW.""" return self.nsfw - @property - def last_message(self) -> Message | None: + async def get_last_message(self) -> Message | None: """Fetches the last message from this channel in cache. The message might not be valid or point to an existing message. @@ -1757,7 +1773,7 @@ def last_message(self) -> Message | None: Optional[:class:`Message`] The last message in this channel or ``None`` if not found. """ - return self._state._get_message(self.last_message_id) if self.last_message_id else None + return await self._state._get_message(self.last_message_id) if self.last_message_id else None def get_partial_message(self, message_id: int, /) -> PartialMessage: """Creates a :class:`PartialMessage` from the message ID. @@ -2085,7 +2101,7 @@ async def edit(self, *, reason=None, **options): payload = await self._edit(options, reason=reason) if payload is not None: # the payload will always be the proper channel payload - return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore async def create_activity_invite(self, activity: EmbeddedActivity | int, **kwargs) -> Invite: """|coro| @@ -2252,8 +2268,8 @@ class StageChannel(discord.abc.Messageable, VocalGuildChannel): __slots__ = ("topic",) - def _update(self, guild: Guild, data: StageChannelPayload) -> None: - super()._update(guild, data) + async def _update(self, data: StageChannelPayload) -> None: + await super()._update(data) self.topic = data.get("topic") def __repr__(self) -> str: @@ -2303,8 +2319,7 @@ def is_nsfw(self) -> bool: """Checks if the channel is NSFW.""" return self.nsfw - @property - def last_message(self) -> Message | None: + async def get_last_message(self) -> Message | None: """Fetches the last message from this channel in cache. The message might not be valid or point to an existing message. @@ -2322,7 +2337,7 @@ def last_message(self) -> Message | None: Optional[:class:`Message`] The last message in this channel or ``None`` if not found. """ - return self._state._get_message(self.last_message_id) if self.last_message_id else None + return await self._state._get_message(self.last_message_id) if self.last_message_id else None def get_partial_message(self, message_id: int, /) -> PartialMessage: """Creates a :class:`PartialMessage` from the message ID. @@ -2736,7 +2751,7 @@ async def edit(self, *, reason=None, **options): payload = await self._edit(options, reason=reason) if payload is not None: # the payload will always be the proper channel payload - return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore class CategoryChannel(discord.abc.GuildChannel, Hashable): @@ -2780,29 +2795,38 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): .. versionadded:: 2.0 """ - __slots__ = ( - "name", - "id", - "guild", - "_state", - "position", - "_overwrites", - "category_id", - "flags", - ) + __slots__ = ("name", "id", "guild", "_state", "position", "_overwrites", "category_id", "flags", "_type") - def __init__(self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload): + def __init__(self, *, id: int, guild: Guild, state: ConnectionState) -> None: + """Initialize with permanent attributes only.""" self._state: ConnectionState = state - self.id: int = int(data["id"]) - self._update(guild, data) + self.id: int = id + self.guild = guild + + @classmethod + async def _from_data( + cls, + *, + data: CategoryChannelPayload, + state: ConnectionState, + guild: Guild, + ): + """Create channel instance from API payload.""" + self = cls( + id=int(data["id"]), + guild=guild, + state=state, + ) + await self._update(data) + return self def __repr__(self) -> str: return f"" - def _update(self, guild: Guild, data: CategoryChannelPayload) -> None: + async def _update(self, data: CategoryChannelPayload) -> None: # This data will always exist - self.guild: Guild = guild self.name: str = data["name"] + self._type: int = data["type"] self.category_id: int | None = get_as_snowflake(data, "parent_id") # This data may be missing depending on how this object is being created/updated @@ -2818,7 +2842,7 @@ def _sorting_bucket(self) -> int: @property def type(self) -> ChannelType: """The channel's Discord type.""" - return ChannelType.category + return try_enum(ChannelType, self._type) @copy_doc(discord.abc.GuildChannel.clone) async def clone(self, *, name: str | None = None, reason: str | None = None) -> CategoryChannel: @@ -2881,7 +2905,7 @@ async def edit(self, *, reason=None, **options): payload = await self._edit(options, reason=reason) if payload is not None: # the payload will always be the proper channel payload - return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore @copy_doc(discord.abc.GuildChannel.move) async def move(self, **kwargs): @@ -3025,15 +3049,28 @@ class DMChannel(discord.abc.Messageable, Hashable): The direct message channel ID. """ - __slots__ = ("id", "recipient", "me", "_state") + __slots__ = ("id", "recipient", "me", "_state", "_type") - def __init__(self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload): + def __init__(self, *, me: ClientUser, state: ConnectionState, id: int) -> None: + """Initialize with permanent attributes only.""" self._state: ConnectionState = state self.recipient: User | None = None - if r := data.get("recipients"): - self.recipient = state.store_user(r[0]) self.me: ClientUser = me - self.id: int = int(data["id"]) + self.id: int = id + + @classmethod + async def _from_data(cls, *, data: DMChannelPayload, state: ConnectionState, me: ClientUser) -> DMChannel: + """Create channel instance from API payload.""" + self = cls(me=me, state=state, id=int(data["id"])) + await self._update(data) + return self + + async def _update(self, data: DMChannelPayload) -> None: + """Update mutable attributes from API payload.""" + recipients = data.get("recipients", []) + self._type = data["type"] + if recipients: + self.recipient = await self._state.store_user(recipients[0]) async def _get_channel(self): return self @@ -3059,7 +3096,7 @@ def _from_message(cls: type[DMC], state: ConnectionState, channel_id: int) -> DM @property def type(self) -> ChannelType: """The channel's Discord type.""" - return ChannelType.private + return try_enum(ChannelType, self._type) @property def jump_url(self) -> str: @@ -3176,19 +3213,22 @@ class GroupChannel(discord.abc.Messageable, Hashable): "name", "me", "_state", + "_data", ) def __init__(self, *, me: ClientUser, state: ConnectionState, data: GroupChannelPayload): self._state: ConnectionState = state + self._data = data self.id: int = int(data["id"]) self.me: ClientUser = me - self._update_group(data) - def _update_group(self, data: GroupChannelPayload) -> None: - self.owner_id: int | None = get_as_snowflake(data, "owner_id") - self._icon: str | None = data.get("icon") - self.name: str | None = data.get("name") - self.recipients: list[User] = [self._state.store_user(u) for u in data.get("recipients", [])] + async def _update_group(self, data: dict[str, Any] | None = None) -> None: + if data: + self._data = data + self.owner_id: int | None = get_as_snowflake(self._data, "owner_id") + self._icon: str | None = self._data.get("icon") + self.name: str | None = self._data.get("name") + self.recipients: list[User] = [await self._state.store_user(u) for u in self._data.get("recipients", [])] self.owner: BaseUser | None if self.owner_id == self.me.id: @@ -3320,7 +3360,7 @@ class PartialMessageable(discord.abc.Messageable, Hashable): The channel type associated with this partial messageable, if given. """ - def __init__(self, state: ConnectionState, id: int, type: ChannelType | None = None): + def __init__(self, state: ConnectionState, id: int): self._state: ConnectionState = state self._channel: Object = Object(id=id) self.id: int = id @@ -3430,57 +3470,3 @@ def __init__( else None ) self.data = data - - -def _guild_channel_factory(channel_type: int): - value = try_enum(ChannelType, channel_type) - if value is ChannelType.text: - return TextChannel, value - elif value is ChannelType.voice: - return VoiceChannel, value - elif value is ChannelType.category: - return CategoryChannel, value - elif value is ChannelType.news: - return TextChannel, value - elif value is ChannelType.stage_voice: - return StageChannel, value - elif value is ChannelType.directory: - return None, value # todo: Add DirectoryChannel when applicable - elif value is ChannelType.forum: - return ForumChannel, value - elif value is ChannelType.media: - return MediaChannel, value - else: - return None, value - - -def _channel_factory(channel_type: int): - cls, value = _guild_channel_factory(channel_type) - if value is ChannelType.private: - return DMChannel, value - elif value is ChannelType.group: - return GroupChannel, value - else: - return cls, value - - -def _threaded_channel_factory(channel_type: int): - cls, value = _channel_factory(channel_type) - if value in ( - ChannelType.private_thread, - ChannelType.public_thread, - ChannelType.news_thread, - ): - return Thread, value - return cls, value - - -def _threaded_guild_channel_factory(channel_type: int): - cls, value = _guild_channel_factory(channel_type) - if value in ( - ChannelType.private_thread, - ChannelType.public_thread, - ChannelType.news_thread, - ): - return Thread, value - return cls, value diff --git a/discord/channel/dm.py b/discord/channel/dm.py new file mode 100644 index 0000000000..626ad439f2 --- /dev/null +++ b/discord/channel/dm.py @@ -0,0 +1,106 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from collections.abc import Collection +from typing import TYPE_CHECKING + +from typing_extensions import override + +from ..abc import Messageable, Snowflake +from ..asset import Asset +from ..permissions import Permissions +from ..types.channel import DMChannel as DMChannelPayload +from ..types.channel import GroupDMChannel as GroupDMChannelPayload +from .base import BaseChannel, P + +if TYPE_CHECKING: + from ..app.state import ConnectionState + from ..message import Message + from ..user import User + + +class DMChannel(BaseChannel[DMChannelPayload], Messageable): + __slots__: tuple[str, ...] = ("last_message", "recipient") + + def __init__(self, id: int, state: "ConnectionState") -> None: + super().__init__(id, state) + self.recipient: User | None = None + self.last_message: Message | None = None + + @override + async def _update(self, data: DMChannelPayload) -> None: + await super()._update(data) + if last_message_id := data.get("last_message_id", None): + self.last_message = await self._state.cache.get_message(int(last_message_id)) + if recipients := data.get("recipients"): + self.recipient = await self._state.cache.store_user(recipients[0]) + + @override + def __repr__(self) -> str: + return f"" + + @property + @override + def jump_url(self) -> str: + """Returns a URL that allows the client to jump to the channel.""" + return f"https://discord.com/channels/@me/{self.id}" + + +class GroupDMChannel(BaseChannel[GroupDMChannelPayload], Messageable): + __slots__: tuple[str, ...] = ("recipients", "icon_hash", "owner", "name") + + def __init__(self, id: int, state: "ConnectionState") -> None: + super().__init__(id, state) + self.recipients: Collection[User] = set() + self.icon_hash: str | None = None + self.owner: User | None = None + + @override + async def _update(self, data: GroupDMChannelPayload) -> None: + await super()._update(data) + self.name: str = data["name"] + if recipients := data.get("recipients"): + self.recipients = {await self._state.cache.store_user(recipient_data) for recipient_data in recipients} + if icon_hash := data.get("icon"): + self.icon_hash = icon_hash + if owner_id := data.get("owner_id"): + self.owner = await self._state.cache.get_user(int(owner_id)) + + @override + def __repr__(self) -> str: + return f"" + + @property + @override + def jump_url(self) -> str: + """Returns a URL that allows the client to jump to the channel.""" + return f"https://discord.com/channels/@me/{self.id}" + + @property + def icon(self) -> Asset | None: + """Returns the channel's icon asset if available.""" + if self.icon_hash is None: + return None + return Asset._from_icon(self._state, self.id, self.icon_hash, path="channel") diff --git a/discord/channel/forum.py b/discord/channel/forum.py new file mode 100644 index 0000000000..39fad7f36d --- /dev/null +++ b/discord/channel/forum.py @@ -0,0 +1,210 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Mapping, overload + +from typing_extensions import Self, override + +from ..enums import ChannelType, SortOrder +from ..flags import ChannelFlags +from ..utils import MISSING, Undefined +from .base import GuildPostableChannel + +if TYPE_CHECKING: + from ..abc import Snowflake + from ..emoji import GuildEmoji + from ..member import Member + from ..permissions import PermissionOverwrite + from ..role import Role + from ..types.channel import ForumChannel as ForumChannelPayload + from .category import CategoryChannel + from .channel import ForumTag + +__all__ = ("ForumChannel",) + + +class ForumChannel(GuildPostableChannel["ForumChannelPayload"]): + """Represents a Discord forum channel. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns the channel's name. + + Attributes + ---------- + id: :class:`int` + The channel's ID. + name: :class:`str` + The channel's name. + guild: :class:`Guild` + The guild the channel belongs to. + topic: :class:`str` | None + The channel's topic/guidelines. ``None`` if it doesn't exist. + category_id: :class:`int` | None + The category channel ID this channel belongs to, if applicable. + position: :class:`int` + The position in the channel list. This is a number that starts at 0. + nsfw: :class:`bool` + Whether the channel is marked as NSFW. + slowmode_delay: :class:`int` + The number of seconds a member must wait between creating posts + in this channel. A value of `0` denotes that it is disabled. + last_message_id: :class:`int` | None + The last message ID sent to this channel. It may not point to an existing or valid message. + default_auto_archive_duration: :class:`int` + The default auto archive duration in minutes for posts created in this channel. + default_thread_slowmode_delay: :class:`int` | None + The initial slowmode delay to set on newly created posts in this channel. + available_tags: list[:class:`ForumTag`] + The set of tags that can be used in this forum channel. + default_sort_order: :class:`SortOrder` | None + The default sort order type used to order posts in this channel. + default_reaction_emoji: :class:`str` | :class:`GuildEmoji` | None + The default forum reaction emoji. + + .. versionadded:: 3.0 + """ + + __slots__: tuple[str, ...] = () + + @property + @override + def _sorting_bucket(self) -> int: + return ChannelType.forum.value + + def __repr__(self) -> str: + attrs = [ + ("id", self.id), + ("name", self.name), + ("position", self.position), + ("nsfw", self.nsfw), + ("category_id", self.category_id), + ] + joined = " ".join(f"{k}={v!r}" for k, v in attrs) + return f"" + + @overload + async def edit( + self, + *, + name: str | Undefined = MISSING, + topic: str | Undefined = MISSING, + position: int | Undefined = MISSING, + nsfw: bool | Undefined = MISSING, + sync_permissions: bool | Undefined = MISSING, + category: CategoryChannel | None | Undefined = MISSING, + slowmode_delay: int | Undefined = MISSING, + default_auto_archive_duration: int | Undefined = MISSING, + default_thread_slowmode_delay: int | Undefined = MISSING, + default_sort_order: SortOrder | Undefined = MISSING, + default_reaction_emoji: GuildEmoji | int | str | None | Undefined = MISSING, + available_tags: list[ForumTag] | Undefined = MISSING, + require_tag: bool | Undefined = MISSING, + overwrites: Mapping[Role | Member | Snowflake, PermissionOverwrite] | Undefined = MISSING, + reason: str | None = None, + ) -> Self: ... + + @overload + async def edit(self) -> Self: ... + + async def edit(self, *, reason: str | None = None, **options) -> Self: + """|coro| + + Edits the forum channel. + + You must have :attr:`~Permissions.manage_channels` permission to use this. + + Parameters + ---------- + name: :class:`str` + The new channel name. + topic: :class:`str` + The new channel's topic/guidelines. + position: :class:`int` + The new channel's position. + nsfw: :class:`bool` + Whether the channel should be marked as NSFW. + sync_permissions: :class:`bool` + Whether to sync permissions with the channel's new or pre-existing category. + category: :class:`CategoryChannel` | None + The new category for this channel. Can be ``None`` to remove the category. + slowmode_delay: :class:`int` + Specifies the slowmode rate limit for users in this channel, in seconds. + A value of ``0`` disables slowmode. The maximum value possible is ``21600``. + default_auto_archive_duration: :class:`int` + The new default auto archive duration in minutes for posts created in this channel. + Must be one of ``60``, ``1440``, ``4320``, or ``10080``. + default_thread_slowmode_delay: :class:`int` + The new default slowmode delay in seconds for posts created in this channel. + default_sort_order: :class:`SortOrder` + The default sort order type to use to order posts in this channel. + default_reaction_emoji: :class:`GuildEmoji` | :class:`int` | :class:`str` | None + The default reaction emoji for posts. + Can be a unicode emoji or a custom emoji. + available_tags: list[:class:`ForumTag`] + The set of tags that can be used in this channel. Must be less than ``20``. + require_tag: :class:`bool` + Whether a tag should be required to be specified when creating a post in this channel. + overwrites: Mapping[:class:`Role` | :class:`Member` | :class:`~discord.abc.Snowflake`, :class:`PermissionOverwrite`] + The overwrites to apply to channel permissions. + reason: :class:`str` | None + The reason for editing this channel. Shows up on the audit log. + + Returns + ------- + :class:`.ForumChannel` + The newly edited forum channel. + + Raises + ------ + Forbidden + You do not have permissions to edit the channel. + HTTPException + Editing the channel failed. + """ + if "require_tag" in options: + options["flags"] = ChannelFlags._from_value(self.flags.value) + options["flags"].require_tag = options.pop("require_tag") + + payload = await self._edit(options, reason=reason) + if payload is not None: + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore + return self diff --git a/discord/channel/media.py b/discord/channel/media.py new file mode 100644 index 0000000000..64b4ea5620 --- /dev/null +++ b/discord/channel/media.py @@ -0,0 +1,227 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Mapping, overload + +from typing_extensions import Self, override + +from ..enums import ChannelType, SortOrder +from ..flags import ChannelFlags +from ..utils import MISSING, Undefined +from .base import GuildPostableChannel + +if TYPE_CHECKING: + from ..abc import Snowflake + from ..emoji import GuildEmoji + from ..member import Member + from ..permissions import PermissionOverwrite + from ..role import Role + from ..types.channel import MediaChannel as MediaChannelPayload + from .category import CategoryChannel + from .channel import ForumTag + +__all__ = ("MediaChannel",) + + +class MediaChannel(GuildPostableChannel["MediaChannelPayload"]): + """Represents a Discord media channel. + + .. versionadded:: 2.7 + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns the channel's name. + + Attributes + ---------- + id: :class:`int` + The channel's ID. + name: :class:`str` + The channel's name. + guild: :class:`Guild` + The guild the channel belongs to. + topic: :class:`str` | None + The channel's topic/guidelines. ``None`` if it doesn't exist. + category_id: :class:`int` | None + The category channel ID this channel belongs to, if applicable. + position: :class:`int` + The position in the channel list. This is a number that starts at 0. + nsfw: :class:`bool` + Whether the channel is marked as NSFW. + slowmode_delay: :class:`int` + The number of seconds a member must wait between creating posts + in this channel. A value of `0` denotes that it is disabled. + last_message_id: :class:`int` | None + The last message ID sent to this channel. It may not point to an existing or valid message. + default_auto_archive_duration: :class:`int` + The default auto archive duration in minutes for posts created in this channel. + default_thread_slowmode_delay: :class:`int` | None + The initial slowmode delay to set on newly created posts in this channel. + available_tags: list[:class:`ForumTag`] + The set of tags that can be used in this media channel. + default_sort_order: :class:`SortOrder` | None + The default sort order type used to order posts in this channel. + default_reaction_emoji: :class:`str` | :class:`GuildEmoji` | None + The default reaction emoji. + + .. versionadded:: 3.0 + """ + + __slots__: tuple[str, ...] = () + + @property + @override + def _sorting_bucket(self) -> int: + return ChannelType.media.value + + @property + def media_download_options_hidden(self) -> bool: + """Whether media download options are hidden in this media channel. + + .. versionadded:: 2.7 + """ + return self.flags.hide_media_download_options + + def __repr__(self) -> str: + attrs = [ + ("id", self.id), + ("name", self.name), + ("position", self.position), + ("nsfw", self.nsfw), + ("category_id", self.category_id), + ] + joined = " ".join(f"{k}={v!r}" for k, v in attrs) + return f"" + + @overload + async def edit( + self, + *, + name: str | Undefined = MISSING, + topic: str | Undefined = MISSING, + position: int | Undefined = MISSING, + nsfw: bool | Undefined = MISSING, + sync_permissions: bool | Undefined = MISSING, + category: CategoryChannel | None | Undefined = MISSING, + slowmode_delay: int | Undefined = MISSING, + default_auto_archive_duration: int | Undefined = MISSING, + default_thread_slowmode_delay: int | Undefined = MISSING, + default_sort_order: SortOrder | Undefined = MISSING, + default_reaction_emoji: GuildEmoji | int | str | None | Undefined = MISSING, + available_tags: list[ForumTag] | Undefined = MISSING, + require_tag: bool | Undefined = MISSING, + hide_media_download_options: bool | Undefined = MISSING, + overwrites: Mapping[Role | Member | Snowflake, PermissionOverwrite] | Undefined = MISSING, + reason: str | None = None, + ) -> Self: ... + + @overload + async def edit(self) -> Self: ... + + async def edit(self, *, reason: str | None = None, **options) -> Self: + """|coro| + + Edits the media channel. + + You must have :attr:`~Permissions.manage_channels` permission to use this. + + Parameters + ---------- + name: :class:`str` + The new channel name. + topic: :class:`str` + The new channel's topic/guidelines. + position: :class:`int` + The new channel's position. + nsfw: :class:`bool` + Whether the channel should be marked as NSFW. + sync_permissions: :class:`bool` + Whether to sync permissions with the channel's new or pre-existing category. + category: :class:`CategoryChannel` | None + The new category for this channel. Can be ``None`` to remove the category. + slowmode_delay: :class:`int` + Specifies the slowmode rate limit for users in this channel, in seconds. + A value of ``0`` disables slowmode. The maximum value possible is ``21600``. + default_auto_archive_duration: :class:`int` + The new default auto archive duration in minutes for posts created in this channel. + Must be one of ``60``, ``1440``, ``4320``, or ``10080``. + default_thread_slowmode_delay: :class:`int` + The new default slowmode delay in seconds for posts created in this channel. + default_sort_order: :class:`SortOrder` + The default sort order type to use to order posts in this channel. + default_reaction_emoji: :class:`GuildEmoji` | :class:`int` | :class:`str` | None + The default reaction emoji for posts. + Can be a unicode emoji or a custom emoji. + available_tags: list[:class:`ForumTag`] + The set of tags that can be used in this channel. Must be less than ``20``. + require_tag: :class:`bool` + Whether a tag should be required to be specified when creating a post in this channel. + hide_media_download_options: :class:`bool` + Whether to hide the media download options in this media channel. + overwrites: Mapping[:class:`Role` | :class:`Member` | :class:`~discord.abc.Snowflake`, :class:`PermissionOverwrite`] + The overwrites to apply to channel permissions. + reason: :class:`str` | None + The reason for editing this channel. Shows up on the audit log. + + Returns + ------- + :class:`.MediaChannel` + The newly edited media channel. + + Raises + ------ + Forbidden + You do not have permissions to edit the channel. + HTTPException + Editing the channel failed. + """ + # Handle require_tag flag + if "require_tag" in options or "hide_media_download_options" in options: + options["flags"] = ChannelFlags._from_value(self.flags.value) + if "require_tag" in options: + options["flags"].require_tag = options.pop("require_tag") + if "hide_media_download_options" in options: + options["flags"].hide_media_download_options = options.pop("hide_media_download_options") + + payload = await self._edit(options, reason=reason) + if payload is not None: + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore + return self diff --git a/discord/channel/news.py b/discord/channel/news.py new file mode 100644 index 0000000000..a9dbc299d2 --- /dev/null +++ b/discord/channel/news.py @@ -0,0 +1,283 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Mapping + +from typing_extensions import Self, override + +from ..enums import ChannelType +from ..utils import MISSING, Undefined +from .base import GuildMessageableChannel, GuildThreadableChannel, GuildTopLevelChannel + +if TYPE_CHECKING: + from ..abc import Snowflake + from ..member import Member + from ..permissions import PermissionOverwrite + from ..role import Role + from ..types.channel import NewsChannel as NewsChannelPayload + from ..types.channel import TextChannel as TextChannelPayload + from .category import CategoryChannel + from .text import TextChannel + from .thread import Thread + +__all__ = ("NewsChannel",) + + +class NewsChannel( + GuildTopLevelChannel["NewsChannelPayload"], + GuildMessageableChannel, + GuildThreadableChannel, +): + """Represents a Discord guild news/announcement channel. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns the channel's name. + + Attributes + ---------- + id: :class:`int` + The channel's ID. + name: :class:`str` + The channel's name. + guild: :class:`Guild` + The guild the channel belongs to. + topic: :class:`str` | None + The channel's topic. ``None`` if it isn't set. + category_id: :class:`int` | None + The category channel ID this channel belongs to, if applicable. + position: :class:`int` + The position in the channel list. This is a number that starts at 0. + nsfw: :class:`bool` + Whether the channel is marked as NSFW. + slowmode_delay: :class:`int` + The number of seconds a member must wait between sending messages + in this channel. A value of `0` denotes that it is disabled. + last_message_id: :class:`int` | None + The last message ID of the message sent to this channel. It may + *not* point to an existing or valid message. + default_auto_archive_duration: :class:`int` + The default auto archive duration in minutes for threads created in this channel. + + .. versionadded:: 3.0 + """ + + __slots__: tuple[str, ...] = ( + "topic", + "nsfw", + "slowmode_delay", + "last_message_id", + "default_auto_archive_duration", + "default_thread_slowmode_delay", + ) + + @property + @override + def _sorting_bucket(self) -> int: + return ChannelType.news.value + + def __repr__(self) -> str: + attrs = [ + ("id", self.id), + ("name", self.name), + ("position", self.position), + ("nsfw", self.nsfw), + ("category_id", self.category_id), + ] + joined = " ".join(f"{k}={v!r}" for k, v in attrs) + return f"" + + async def edit( + self, + *, + name: str | Undefined = MISSING, + topic: str | Undefined = MISSING, + position: int | Undefined = MISSING, + nsfw: bool | Undefined = MISSING, + sync_permissions: bool | Undefined = MISSING, + category: CategoryChannel | None | Undefined = MISSING, + slowmode_delay: int | Undefined = MISSING, + default_auto_archive_duration: int | Undefined = MISSING, + default_thread_slowmode_delay: int | Undefined = MISSING, + type: ChannelType | Undefined = MISSING, + overwrites: Mapping[Role | Member | Snowflake, PermissionOverwrite] | Undefined = MISSING, + reason: str | None = None, + ) -> Self | TextChannel: + """|coro| + + Edits the channel. + + You must have :attr:`~Permissions.manage_channels` permission to use this. + + Parameters + ---------- + name: :class:`str` + The new channel name. + topic: :class:`str` + The new channel's topic. + position: :class:`int` + The new channel's position. + nsfw: :class:`bool` + Whether the channel is marked as NSFW. + sync_permissions: :class:`bool` + Whether to sync permissions with the channel's new or pre-existing category. + category: :class:`CategoryChannel` | None + The new category for this channel. Can be ``None`` to remove the category. + slowmode_delay: :class:`int` + Specifies the slowmode rate limit for user in this channel, in seconds. + default_auto_archive_duration: :class:`int` + The new default auto archive duration in minutes for threads created in this channel. + default_thread_slowmode_delay: :class:`int` + The new default slowmode delay in seconds for threads created in this channel. + type: :class:`ChannelType` + Change the type of this news channel. Only conversion between text and news is supported. + overwrites: Mapping[:class:`Role` | :class:`Member` | :class:`~discord.abc.Snowflake`, :class:`PermissionOverwrite`] + The overwrites to apply to channel permissions. + reason: :class:`str` | None + The reason for editing this channel. Shows up on the audit log. + + Returns + ------- + :class:`.NewsChannel` | :class:`.TextChannel` + The newly edited channel. If type was changed, the appropriate channel type is returned. + + Raises + ------ + Forbidden + You do not have permissions to edit the channel. + HTTPException + Editing the channel failed. + """ + options = {} + if name is not MISSING: + options["name"] = name + if topic is not MISSING: + options["topic"] = topic + if position is not MISSING: + options["position"] = position + if nsfw is not MISSING: + options["nsfw"] = nsfw + if sync_permissions is not MISSING: + options["sync_permissions"] = sync_permissions + if category is not MISSING: + options["category"] = category + if slowmode_delay is not MISSING: + options["slowmode_delay"] = slowmode_delay + if default_auto_archive_duration is not MISSING: + options["default_auto_archive_duration"] = default_auto_archive_duration + if default_thread_slowmode_delay is not MISSING: + options["default_thread_slowmode_delay"] = default_thread_slowmode_delay + if type is not MISSING: + options["type"] = type + if overwrites is not MISSING: + options["overwrites"] = overwrites + + payload = await self._edit(options, reason=reason) + if payload is not None: + if payload.get("type") == ChannelType.text.value: + from .text import TextChannel + + return await TextChannel._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore + + async def create_thread( + self, + *, + name: str, + message: Snowflake | None = None, + auto_archive_duration: int | Undefined = MISSING, + type: ChannelType | None = None, + slowmode_delay: int | None = None, + invitable: bool | None = None, + reason: str | None = None, + ) -> Thread: + """|coro| + + Creates a thread in this news channel. + + Parameters + ---------- + name: :class:`str` + The name of the thread. + message: :class:`abc.Snowflake` | None + A snowflake representing the message to create the thread with. + auto_archive_duration: :class:`int` + The duration in minutes before a thread is automatically archived for inactivity. + type: :class:`ChannelType` | None + The type of thread to create. + slowmode_delay: :class:`int` | None + Specifies the slowmode rate limit for users in this thread, in seconds. + invitable: :class:`bool` | None + Whether non-moderators can add other non-moderators to this thread. + reason: :class:`str` | None + The reason for creating a new thread. + + Returns + ------- + :class:`Thread` + The created thread + """ + from .thread import Thread + + if type is None: + type = ChannelType.private_thread + + if message is None: + data = await self._state.http.start_thread_without_message( + self.id, + name=name, + auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration, + type=type.value, + rate_limit_per_user=slowmode_delay or 0, + invitable=invitable, + reason=reason, + ) + else: + data = await self._state.http.start_thread_with_message( + self.id, + message.id, + name=name, + auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration, + rate_limit_per_user=slowmode_delay or 0, + reason=reason, + ) + + return Thread(guild=self.guild, state=self._state, data=data) diff --git a/discord/channel/partial.py b/discord/channel/partial.py new file mode 100644 index 0000000000..81058cd40e --- /dev/null +++ b/discord/channel/partial.py @@ -0,0 +1,104 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..abc import Messageable +from ..enums import ChannelType +from ..mixins import Hashable +from ..object import Object + +if TYPE_CHECKING: + from ..message import PartialMessage + from ..state import ConnectionState + +__all__ = ("PartialMessageable",) + + +class PartialMessageable(Messageable, Hashable): + """Represents a partial messageable to aid with working messageable channels when + only a channel ID are present. + + The only way to construct this class is through :meth:`Client.get_partial_messageable`. + + Note that this class is trimmed down and has no rich attributes. + + .. versionadded:: 2.0 + + .. container:: operations + + .. describe:: x == y + + Checks if two partial messageables are equal. + + .. describe:: x != y + + Checks if two partial messageables are not equal. + + .. describe:: hash(x) + + Returns the partial messageable's hash. + + Attributes + ---------- + id: :class:`int` + The channel ID associated with this partial messageable. + type: Optional[:class:`ChannelType`] + The channel type associated with this partial messageable, if given. + """ + + def __init__(self, state: ConnectionState, id: int, type: ChannelType | None = None): + self._state: ConnectionState = state + self._channel: Object = Object(id=id) + self.id: int = id + self.type: ChannelType | None = type + + async def _get_channel(self) -> Object: + return self._channel + + def get_partial_message(self, message_id: int, /) -> PartialMessage: + """Creates a :class:`PartialMessage` from the message ID. + + This is useful if you want to work with a message and only have its ID without + doing an unnecessary API call. + + Parameters + ---------- + message_id: :class:`int` + The message ID to create a partial message for. + + Returns + ------- + :class:`PartialMessage` + The partial message. + """ + from ..message import PartialMessage + + return PartialMessage(channel=self, id=message_id) + + def __repr__(self) -> str: + return f"" diff --git a/discord/channel/stage.py b/discord/channel/stage.py new file mode 100644 index 0000000000..8c96a666b8 --- /dev/null +++ b/discord/channel/stage.py @@ -0,0 +1,345 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Mapping + +from typing_extensions import Self, override + +from ..abc import Connectable +from ..enums import ChannelType, StagePrivacyLevel, VideoQualityMode, VoiceRegion, try_enum +from ..utils import MISSING, Undefined +from .base import GuildMessageableChannel, GuildTopLevelChannel + +if TYPE_CHECKING: + from ..abc import Snowflake + from ..member import Member + from ..permissions import PermissionOverwrite + from ..role import Role + from ..stage_instance import StageInstance + from ..types.channel import StageChannel as StageChannelPayload + from .category import CategoryChannel + +__all__ = ("StageChannel",) + + +class StageChannel( + GuildTopLevelChannel["StageChannelPayload"], + GuildMessageableChannel, + Connectable, +): + """Represents a Discord guild stage channel. + + .. versionadded:: 1.7 + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns the channel's name. + + Attributes + ---------- + id: :class:`int` + The channel's ID. + name: :class:`str` + The channel's name. + guild: :class:`Guild` + The guild the channel belongs to. + topic: :class:`str` | None + The channel's topic. ``None`` if it isn't set. + category_id: :class:`int` | None + The category channel ID this channel belongs to, if applicable. + position: :class:`int` + The position in the channel list. This is a number that starts at 0. + bitrate: :class:`int` + The channel's preferred audio bitrate in bits per second. + user_limit: :class:`int` + The channel's limit for number of members that can be in a stage channel. + A value of ``0`` indicates no limit. + rtc_region: :class:`VoiceRegion` | None + The region for the stage channel's voice communication. + A value of ``None`` indicates automatic voice region detection. + video_quality_mode: :class:`VideoQualityMode` + The camera video quality for the stage channel's participants. + last_message_id: :class:`int` | None + The ID of the last message sent to this channel. It may not always point to an existing or valid message. + slowmode_delay: :class:`int` + Specifies the slowmode rate limit for users in this channel, in seconds. + nsfw: :class:`bool` + Whether the channel is marked as NSFW. + + .. versionadded:: 3.0 + """ + + __slots__: tuple[str, ...] = ( + "topic", + "nsfw", + "slowmode_delay", + "last_message_id", + "bitrate", + "user_limit", + "rtc_region", + "video_quality_mode", + ) + + @override + async def _update(self, data: StageChannelPayload) -> None: + await super()._update(data) + self.bitrate: int = data.get("bitrate", 64000) + self.user_limit: int = data.get("user_limit", 0) + rtc = data.get("rtc_region") + self.rtc_region: VoiceRegion | None = try_enum(VoiceRegion, rtc) if rtc is not None else None + self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get("video_quality_mode", 1)) + + @property + @override + def _sorting_bucket(self) -> int: + return ChannelType.stage_voice.value + + @property + def requesting_to_speak(self) -> list[Member]: + """A list of members who are requesting to speak in the stage channel.""" + return [member for member in self.members if member.voice and member.voice.requested_to_speak_at is not None] + + @property + def speakers(self) -> list[Member]: + """A list of members who have been permitted to speak in the stage channel. + + .. versionadded:: 2.0 + """ + return [member for member in self.members if member.voice and not member.voice.suppress] + + @property + def listeners(self) -> list[Member]: + """A list of members who are listening in the stage channel. + + .. versionadded:: 2.0 + """ + return [member for member in self.members if member.voice and member.voice.suppress] + + def __repr__(self) -> str: + attrs = [ + ("id", self.id), + ("name", self.name), + ("topic", self.topic), + ("rtc_region", self.rtc_region), + ("position", self.position), + ("bitrate", self.bitrate), + ("video_quality_mode", self.video_quality_mode), + ("user_limit", self.user_limit), + ("category_id", self.category_id), + ] + joined = " ".join(f"{k}={v!r}" for k, v in attrs) + return f"" + + @property + def instance(self) -> StageInstance | None: + """Returns the currently running stage instance if any. + + .. versionadded:: 2.0 + + Returns + ------- + :class:`StageInstance` | None + The stage instance or ``None`` if not active. + """ + return self.guild.get_stage_instance(self.id) + + @property + def moderators(self) -> list[Member]: + """Returns a list of members who have stage moderator permissions. + + .. versionadded:: 2.0 + + Returns + ------- + list[:class:`Member`] + The members with stage moderator permissions. + """ + from ..permissions import Permissions + + required = Permissions.stage_moderator() + return [m for m in self.members if (self.permissions_for(m) & required) == required] + + async def edit( + self, + *, + name: str | Undefined = MISSING, + topic: str | Undefined = MISSING, + position: int | Undefined = MISSING, + sync_permissions: bool | Undefined = MISSING, + category: CategoryChannel | None | Undefined = MISSING, + overwrites: Mapping[Role | Member | Snowflake, PermissionOverwrite] | Undefined = MISSING, + rtc_region: VoiceRegion | None | Undefined = MISSING, + video_quality_mode: VideoQualityMode | Undefined = MISSING, + reason: str | None = None, + ) -> Self: + """|coro| + + Edits the stage channel. + + You must have :attr:`~Permissions.manage_channels` permission to use this. + + Parameters + ---------- + name: :class:`str` + The new channel's name. + topic: :class:`str` + The new channel's topic. + position: :class:`int` + The new channel's position. + sync_permissions: :class:`bool` + Whether to sync permissions with the channel's new or pre-existing category. + category: :class:`CategoryChannel` | None + The new category for this channel. Can be ``None`` to remove the category. + overwrites: Mapping[:class:`Role` | :class:`Member` | :class:`~discord.abc.Snowflake`, :class:`PermissionOverwrite`] + The overwrites to apply to channel permissions. + rtc_region: :class:`VoiceRegion` | None + The new region for the stage channel's voice communication. + video_quality_mode: :class:`VideoQualityMode` + The camera video quality for the stage channel's participants. + reason: :class:`str` | None + The reason for editing this channel. Shows up on the audit log. + + Returns + ------- + :class:`.StageChannel` + The newly edited stage channel. + + Raises + ------ + Forbidden + You do not have permissions to edit the channel. + HTTPException + Editing the channel failed. + """ + options = {} + if name is not MISSING: + options["name"] = name + if topic is not MISSING: + options["topic"] = topic + if position is not MISSING: + options["position"] = position + if sync_permissions is not MISSING: + options["sync_permissions"] = sync_permissions + if category is not MISSING: + options["category"] = category + if overwrites is not MISSING: + options["overwrites"] = overwrites + if rtc_region is not MISSING: + options["rtc_region"] = rtc_region + if video_quality_mode is not MISSING: + options["video_quality_mode"] = video_quality_mode + + payload = await self._edit(options, reason=reason) + if payload is not None: + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore + + async def create_instance( + self, + *, + topic: str, + privacy_level: StagePrivacyLevel = StagePrivacyLevel.guild_only, + reason: str | None = None, + send_notification: bool = False, + ) -> StageInstance: + """|coro| + + Creates a stage instance. + + You must have :attr:`~Permissions.manage_channels` permission to do this. + + Parameters + ---------- + topic: :class:`str` + The stage instance's topic. + privacy_level: :class:`StagePrivacyLevel` + The stage instance's privacy level. + send_notification: :class:`bool` + Whether to send a notification to everyone in the server that the stage is starting. + reason: :class:`str` | None + The reason for creating the stage instance. Shows up on the audit log. + + Returns + ------- + :class:`StageInstance` + The created stage instance. + + Raises + ------ + Forbidden + You do not have permissions to create a stage instance. + HTTPException + Creating the stage instance failed. + """ + from ..stage_instance import StageInstance + + payload = await self._state.http.create_stage_instance( + self.id, + topic=topic, + privacy_level=int(privacy_level), + send_start_notification=send_notification, + reason=reason, + ) + return StageInstance(guild=self.guild, state=self._state, data=payload) + + async def fetch_instance(self) -> StageInstance | None: + """|coro| + + Fetches the currently running stage instance. + + Returns + ------- + :class:`StageInstance` | None + The stage instance or ``None`` if not active. + + Raises + ------ + NotFound + The stage instance is not active or was deleted. + HTTPException + Fetching the stage instance failed. + """ + from ..stage_instance import StageInstance + + try: + payload = await self._state.http.get_stage_instance(self.id) + return StageInstance(guild=self.guild, state=self._state, data=payload) + except Exception: + return None diff --git a/discord/channel/text.py b/discord/channel/text.py new file mode 100644 index 0000000000..4c92925d3b --- /dev/null +++ b/discord/channel/text.py @@ -0,0 +1,327 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Mapping + +from typing_extensions import Self, override + +from ..enums import ChannelType +from ..utils import MISSING, Undefined +from .base import GuildMessageableChannel, GuildThreadableChannel, GuildTopLevelChannel + +if TYPE_CHECKING: + from ..abc import Snowflake + from ..member import Member + from ..permissions import PermissionOverwrite + from ..role import Role + from ..types.channel import NewsChannel as NewsChannelPayload + from ..types.channel import TextChannel as TextChannelPayload + from .category import CategoryChannel + from .news import NewsChannel + from .thread import Thread + +__all__ = ("TextChannel",) + + +class TextChannel( + GuildTopLevelChannel["TextChannelPayload"], + GuildMessageableChannel, + GuildThreadableChannel, +): + """Represents a Discord guild text channel. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns the channel's name. + + Attributes + ---------- + id: :class:`int` + The channel's ID. + name: :class:`str` + The channel's name. + guild: :class:`Guild` + The guild the channel belongs to. + topic: :class:`str` | None + The channel's topic. ``None`` if it isn't set. + category_id: :class:`int` | None + The category channel ID this channel belongs to, if applicable. + position: :class:`int` + The position in the channel list. This is a number that starts at 0. + nsfw: :class:`bool` + Whether the channel is marked as NSFW. + slowmode_delay: :class:`int` + The number of seconds a member must wait between sending messages + in this channel. A value of `0` denotes that it is disabled. + last_message_id: :class:`int` | None + The last message ID of the message sent to this channel. It may + *not* point to an existing or valid message. + default_auto_archive_duration: :class:`int` + The default auto archive duration in minutes for threads created in this channel. + + .. versionadded:: 3.0 + """ + + __slots__: tuple[str, ...] = ( + "topic", + "nsfw", + "slowmode_delay", + "last_message_id", + "default_auto_archive_duration", + "default_thread_slowmode_delay", + ) + + @property + @override + def _sorting_bucket(self) -> int: + return ChannelType.text.value + + def __repr__(self) -> str: + attrs = [ + ("id", self.id), + ("name", self.name), + ("position", self.position), + ("nsfw", self.nsfw), + ("category_id", self.category_id), + ] + joined = " ".join(f"{k}={v!r}" for k, v in attrs) + return f"" + + async def edit( + self, + *, + name: str | Undefined = MISSING, + topic: str | Undefined = MISSING, + position: int | Undefined = MISSING, + nsfw: bool | Undefined = MISSING, + sync_permissions: bool | Undefined = MISSING, + category: CategoryChannel | None | Undefined = MISSING, + slowmode_delay: int | Undefined = MISSING, + default_auto_archive_duration: int | Undefined = MISSING, + default_thread_slowmode_delay: int | Undefined = MISSING, + type: ChannelType | Undefined = MISSING, + overwrites: Mapping[Role | Member | Snowflake, PermissionOverwrite] | Undefined = MISSING, + reason: str | None = None, + ) -> Self | NewsChannel: + """|coro| + + Edits the channel. + + You must have :attr:`~Permissions.manage_channels` permission to + use this. + + .. versionchanged:: 1.3 + The ``overwrites`` keyword-only parameter was added. + + .. versionchanged:: 1.4 + The ``type`` keyword-only parameter was added. + + .. versionchanged:: 2.0 + Edits are no longer in-place, the newly edited channel is returned instead. + + .. versionchanged:: 3.0 + The ``default_thread_slowmode_delay`` keyword-only parameter was added. + + Parameters + ---------- + name: :class:`str` + The new channel name. + topic: :class:`str` + The new channel's topic. + position: :class:`int` + The new channel's position. + nsfw: :class:`bool` + Whether the channel is marked as NSFW. + sync_permissions: :class:`bool` + Whether to sync permissions with the channel's new or pre-existing + category. Defaults to ``False``. + category: :class:`CategoryChannel` | None + The new category for this channel. Can be ``None`` to remove the + category. + slowmode_delay: :class:`int` + Specifies the slowmode rate limit for user in this channel, in seconds. + A value of ``0`` disables slowmode. The maximum value possible is ``21600``. + default_auto_archive_duration: :class:`int` + The new default auto archive duration in minutes for threads created in this channel. + Must be one of ``60``, ``1440``, ``4320``, or ``10080``. + default_thread_slowmode_delay: :class:`int` + The new default slowmode delay in seconds for threads created in this channel. + type: :class:`ChannelType` + Change the type of this text channel. Currently, only conversion between + :attr:`ChannelType.text` and :attr:`ChannelType.news` is supported. This + is only available to guilds that contain ``NEWS`` in :attr:`Guild.features`. + overwrites: Mapping[:class:`Role` | :class:`Member` | :class:`~discord.abc.Snowflake`, :class:`PermissionOverwrite`] + The overwrites to apply to channel permissions. Useful for creating secret channels. + reason: :class:`str` | None + The reason for editing this channel. Shows up on the audit log. + + Returns + ------- + :class:`.TextChannel` | :class:`.NewsChannel` + The newly edited channel. If the edit was only positional + then ``None`` is returned instead. If the type was changed, + the appropriate channel type is returned. + + Raises + ------ + InvalidArgument + If position is less than 0 or greater than the number of channels, or if + the permission overwrite information is not in proper form. + Forbidden + You do not have permissions to edit the channel. + HTTPException + Editing the channel failed. + """ + options = {} + if name is not MISSING: + options["name"] = name + if topic is not MISSING: + options["topic"] = topic + if position is not MISSING: + options["position"] = position + if nsfw is not MISSING: + options["nsfw"] = nsfw + if sync_permissions is not MISSING: + options["sync_permissions"] = sync_permissions + if category is not MISSING: + options["category"] = category + if slowmode_delay is not MISSING: + options["slowmode_delay"] = slowmode_delay + if default_auto_archive_duration is not MISSING: + options["default_auto_archive_duration"] = default_auto_archive_duration + if default_thread_slowmode_delay is not MISSING: + options["default_thread_slowmode_delay"] = default_thread_slowmode_delay + if type is not MISSING: + options["type"] = type + if overwrites is not MISSING: + options["overwrites"] = overwrites + + payload = await self._edit(options, reason=reason) + if payload is not None: + # Check if type was changed to news + if payload.get("type") == ChannelType.news.value: + from .news import NewsChannel + + return await NewsChannel._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore + + async def create_thread( + self, + *, + name: str, + message: Snowflake | None = None, + auto_archive_duration: int | Undefined = MISSING, + type: ChannelType | None = None, + slowmode_delay: int | None = None, + invitable: bool | None = None, + reason: str | None = None, + ) -> Thread: + """|coro| + + Creates a thread in this text channel. + + To create a public thread, you must have :attr:`~discord.Permissions.create_public_threads`. + For a private thread, :attr:`~discord.Permissions.create_private_threads` is needed instead. + + .. versionadded:: 2.0 + + Parameters + ---------- + name: :class:`str` + The name of the thread. + message: :class:`abc.Snowflake` | None + A snowflake representing the message to create the thread with. + If ``None`` is passed then a private thread is created. + Defaults to ``None``. + auto_archive_duration: :class:`int` + The duration in minutes before a thread is automatically archived for inactivity. + If not provided, the channel's default auto archive duration is used. + type: :class:`ChannelType` | None + The type of thread to create. If a ``message`` is passed then this parameter + is ignored, as a thread created with a message is always a public thread. + By default, this creates a private thread if this is ``None``. + slowmode_delay: :class:`int` | None + Specifies the slowmode rate limit for users in this thread, in seconds. + A value of ``0`` disables slowmode. The maximum value possible is ``21600``. + invitable: :class:`bool` | None + Whether non-moderators can add other non-moderators to this thread. + Only available for private threads, where it defaults to True. + reason: :class:`str` | None + The reason for creating a new thread. Shows up on the audit log. + + Returns + ------- + :class:`Thread` + The created thread + + Raises + ------ + Forbidden + You do not have permissions to create a thread. + HTTPException + Starting the thread failed. + """ + from .thread import Thread + + if type is None: + type = ChannelType.private_thread + + if message is None: + data = await self._state.http.start_thread_without_message( + self.id, + name=name, + auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration, + type=type.value, + rate_limit_per_user=slowmode_delay or 0, + invitable=invitable, + reason=reason, + ) + else: + data = await self._state.http.start_thread_with_message( + self.id, + message.id, + name=name, + auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration, + rate_limit_per_user=slowmode_delay or 0, + reason=reason, + ) + + return Thread(guild=self.guild, state=self._state, data=data) diff --git a/discord/threads.py b/discord/channel/thread.py similarity index 83% rename from discord/threads.py rename to discord/channel/thread.py index 3a8e9e884b..3a942dff86 100644 --- a/discord/threads.py +++ b/discord/channel/thread.py @@ -27,19 +27,23 @@ from typing import TYPE_CHECKING, Callable, Iterable +from typing_extensions import override + from discord import utils -from .abc import Messageable, _purge_messages_helper -from .enums import ( +from ..abc import Messageable, _purge_messages_helper +from ..enums import ( ChannelType, try_enum, ) -from .enums import ThreadArchiveDuration as ThreadArchiveDurationEnum -from .errors import ClientException -from .flags import ChannelFlags -from .mixins import Hashable -from .utils import MISSING -from .utils.private import get_as_snowflake, parse_time +from ..enums import ThreadArchiveDuration as ThreadArchiveDurationEnum +from ..errors import ClientException +from ..flags import ChannelFlags +from ..mixins import Hashable +from ..types.threads import Thread as ThreadPayload +from ..utils import MISSING +from ..utils.private import get_as_snowflake, parse_time +from .base import BaseChannel, GuildMessageableChannel __all__ = ( "Thread", @@ -47,21 +51,20 @@ ) if TYPE_CHECKING: - from .abc import Snowflake, SnowflakeTime - from .channel import CategoryChannel, ForumChannel, ForumTag, TextChannel - from .guild import Guild - from .member import Member - from .message import Message, PartialMessage - from .permissions import Permissions - from .role import Role - from .state import ConnectionState - from .types.snowflake import SnowflakeList - from .types.threads import Thread as ThreadPayload - from .types.threads import ThreadArchiveDuration, ThreadMetadata - from .types.threads import ThreadMember as ThreadMemberPayload - - -class Thread(Messageable, Hashable): + from ..abc import Snowflake, SnowflakeTime + from ..app.state import ConnectionState + from ..guild import Guild + from ..member import Member + from ..message import Message, PartialMessage + from ..permissions import Permissions + from ..role import Role + from ..types.snowflake import SnowflakeList + from ..types.threads import ThreadArchiveDuration, ThreadMetadata + from ..types.threads import ThreadMember as ThreadMemberPayload + from . import CategoryChannel, ForumChannel, ForumTag, TextChannel + + +class Thread(BaseChannel[ThreadPayload], GuildMessageableChannel): """Represents a Discord thread. .. container:: operations @@ -86,55 +89,55 @@ class Thread(Messageable, Hashable): Attributes ---------- - name: :class:`str` + name: str The thread name. - guild: :class:`Guild` + guild: Guild The guild the thread belongs to. - id: :class:`int` + id: int The thread ID. .. note:: This ID is the same as the thread starting message ID. - parent_id: :class:`int` + parent_id: int The parent :class:`TextChannel` ID this thread belongs to. - owner_id: :class:`int` + owner_id: int The user's ID that created this thread. - last_message_id: Optional[:class:`int`] + last_message_id: int | None The last message ID of the message sent to this thread. It may *not* point to an existing or valid message. - slowmode_delay: :class:`int` + slowmode_delay: int The number of seconds a member must wait between sending messages in this thread. A value of `0` denotes that it is disabled. Bots and users with :attr:`~Permissions.manage_channels` or :attr:`~Permissions.manage_messages` bypass slowmode. - message_count: :class:`int` + message_count: int An approximate number of messages in this thread. This caps at 50. - member_count: :class:`int` + member_count: int An approximate number of members in this thread. This caps at 50. - me: Optional[:class:`ThreadMember`] + me: ThreadMember | None A thread member representing yourself, if you've joined the thread. This could not be available. - archived: :class:`bool` + archived: bool Whether the thread is archived. - locked: :class:`bool` + locked: bool Whether the thread is locked. - invitable: :class:`bool` + invitable: bool Whether non-moderators can add other non-moderators to this thread. This is always ``True`` for public threads. - auto_archive_duration: :class:`int` + auto_archive_duration: int The duration in minutes until the thread is automatically archived due to inactivity. Usually a value of 60, 1440, 4320 and 10080. - archive_timestamp: :class:`datetime.datetime` + archive_timestamp: datetime.datetime An aware timestamp of when the thread's archived status was last updated in UTC. - created_at: Optional[:class:`datetime.datetime`] + created_at: datetime.datetime | None An aware timestamp of when the thread was created. Only available for threads created after 2022-01-09. - flags: :class:`ChannelFlags` + flags: ChannelFlags Extra features of the thread. .. versionadded:: 2.0 - total_message_sent: :class:`int` + total_message_sent: int Number of messages ever sent in a thread. It's similar to message_count on message creation, but will not decrement the number when a message is deleted. @@ -142,20 +145,14 @@ class Thread(Messageable, Hashable): .. versionadded:: 2.3 """ - __slots__ = ( - "name", - "id", + __slots__: tuple[str, ...] = ( "guild", - "_type", - "_state", "_members", "_applied_tags", "owner_id", "parent_id", - "last_message_id", "message_count", "member_count", - "slowmode_delay", "me", "locked", "archived", @@ -163,86 +160,83 @@ class Thread(Messageable, Hashable): "auto_archive_duration", "archive_timestamp", "created_at", - "flags", "total_message_sent", ) - def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload): - self._state: ConnectionState = state - self.guild = guild + @override + def __init__(self, *, id: int, guild: Guild, state: ConnectionState): + super().__init__(id, state) + self.guild: Guild = guild self._members: dict[int, ThreadMember] = {} - self._from_data(data) + + @classmethod + @override + async def _from_data( + cls, + *, + data: ThreadPayload, + state: ConnectionState, + guild: Guild, + ) -> Thread: + """Create thread instance from API payload.""" + self = cls( + id=int(data["id"]), + guild=guild, + state=state, + ) + await self._update(data) + return self + + @override + async def _update(self, data: ThreadPayload) -> None: + """Update mutable attributes from API payload.""" + await super()._update(data) + + # Thread-specific attributes + self.parent_id: int = int(data.get("parent_id", self.parent_id if hasattr(self, "parent_id") else 0)) + self.owner_id: int | None = int(data["owner_id"]) if data.get("owner_id") is not None else None + self.message_count: int | None = data.get("message_count") + self.member_count: int | None = data.get("member_count") + self.total_message_sent: int | None = data.get("total_message_sent") + self._applied_tags: list[int] = [int(tag_id) for tag_id in data.get("applied_tags", [])] + + # Handle thread metadata + if "thread_metadata" in data: + metadata = data["thread_metadata"] + self.archived: bool = metadata["archived"] + self.auto_archive_duration: int = metadata["auto_archive_duration"] + self.archive_timestamp = parse_time(metadata["archive_timestamp"]) + self.locked: bool = metadata["locked"] + self.invitable: bool = metadata.get("invitable", True) + self.created_at = parse_time(metadata.get("create_timestamp")) + + # Handle thread member data + if "member" in data: + self.me: ThreadMember | None = ThreadMember(self, data["member"]) + elif not hasattr(self, "me"): + self.me = None async def _get_channel(self): return self + @override def __repr__(self) -> str: return ( f"" ) - def __str__(self) -> str: - return self.name - - def _from_data(self, data: ThreadPayload): - # This data will always exist - self.id = int(data["id"]) - self.parent_id = int(data["parent_id"]) - self.name = data["name"] - self._type = try_enum(ChannelType, data["type"]) - - # This data may be missing depending on how this object is being created - self.owner_id = int(data.get("owner_id")) if data.get("owner_id", None) is not None else None - self.last_message_id = get_as_snowflake(data, "last_message_id") - self.slowmode_delay = data.get("rate_limit_per_user", 0) - self.message_count = data.get("message_count", None) - self.member_count = data.get("member_count", None) - self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0)) - self.total_message_sent = data.get("total_message_sent", None) - self._applied_tags: list[int] = [int(tag_id) for tag_id in data.get("applied_tags", [])] - - # Here, we try to fill in potentially missing data - if thread := self.guild.get_thread(self.id) and data.pop("_invoke_flag", False): - self.owner_id = thread.owner_id if self.owner_id is None else self.owner_id - self.last_message_id = thread.last_message_id if self.last_message_id is None else self.last_message_id - self.message_count = thread.message_count if self.message_count is None else self.message_count - self.total_message_sent = ( - thread.total_message_sent if self.total_message_sent is None else self.total_message_sent - ) - self.member_count = thread.member_count if self.member_count is None else self.member_count - - self._unroll_metadata(data["thread_metadata"]) - - try: - member = data["member"] - except KeyError: - self.me = None - else: - self.me = ThreadMember(self, member) - - def _unroll_metadata(self, data: ThreadMetadata): - self.archived = data["archived"] - self.auto_archive_duration = data["auto_archive_duration"] - self.archive_timestamp = parse_time(data["archive_timestamp"]) - self.locked = data["locked"] - self.invitable = data.get("invitable", True) - self.created_at = parse_time(data.get("create_timestamp", None)) - - def _update(self, data): - try: - self.name = data["name"] - except KeyError: - pass - - self._applied_tags: list[int] = [int(tag_id) for tag_id in data.get("applied_tags", [])] - self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0)) - self.slowmode_delay = data.get("rate_limit_per_user", 0) + @property + def topic(self) -> None: + """Threads don't have topics. Always returns None.""" + return None - try: - self._unroll_metadata(data["thread_metadata"]) - except KeyError: - pass + @property + @override + def nsfw(self) -> bool: + """Whether the thread is NSFW. Inherited from parent channel.""" + parent = self.parent + return parent.nsfw if parent else False @property def type(self) -> ChannelType: @@ -254,10 +248,9 @@ def parent(self) -> TextChannel | ForumChannel | None: """The parent channel this thread belongs to.""" return self.guild.get_channel(self.parent_id) # type: ignore - @property - def owner(self) -> Member | None: + async def get_owner(self) -> Member | None: """The member this thread belongs to.""" - return self.guild.get_member(self.owner_id) + return await self.guild.get_member(self.owner_id) @property def mention(self) -> str: @@ -288,14 +281,13 @@ def applied_tags(self) -> list[ForumTag]: This is only available for threads in forum or media channels. """ - from .channel import ForumChannel # noqa: PLC0415 # to prevent circular import + from .channel import ForumChannel # to prevent circular import if isinstance(self.parent, ForumChannel): return [tag for tag_id in self._applied_tags if (tag := self.parent.get_tag(tag_id)) is not None] return [] - @property - def last_message(self) -> Message | None: + async def get_last_message(self) -> Message | None: """Returns the last message from this thread in cache. The message might not be valid or point to an existing message. @@ -313,7 +305,7 @@ def last_message(self) -> Message | None: Optional[:class:`Message`] The last message in this channel or ``None`` if not found. """ - return self._state._get_message(self.last_message_id) if self.last_message_id else None + return await self._state._get_message(self.last_message_id) if self.last_message_id else None @property def category(self) -> CategoryChannel | None: @@ -355,8 +347,7 @@ def category_id(self) -> int | None: raise ClientException("Parent channel not found") return parent.category_id - @property - def starting_message(self) -> Message | None: + async def get_starting_message(self) -> Message | None: """Returns the message that started this thread. The message might not be valid or point to an existing message. @@ -369,7 +360,7 @@ def starting_message(self) -> Message | None: Optional[:class:`Message`] The message that started this thread or ``None`` if not found in the cache. """ - return self._state._get_message(self.id) + return await self._state._get_message(self.id) def is_pinned(self) -> bool: """Whether the thread is pinned to the top of its parent forum or media channel. @@ -657,7 +648,7 @@ async def edit( data = await self._state.http.edit_channel(self.id, **payload, reason=reason) # The data payload will always be a Thread payload - return Thread(data=data, state=self._state, guild=self.guild) # type: ignore + return await Thread._from_data(data=data, state=self._state, guild=self.guild) # type: ignore async def archive(self, locked: bool | utils.Undefined = MISSING) -> Thread: """|coro| @@ -825,7 +816,7 @@ def get_partial_message(self, message_id: int, /) -> PartialMessage: The partial message. """ - from .message import PartialMessage # noqa: PLC0415 + from .message import PartialMessage return PartialMessage(channel=self, id=message_id) diff --git a/discord/channel/voice.py b/discord/channel/voice.py new file mode 100644 index 0000000000..5b0c932d80 --- /dev/null +++ b/discord/channel/voice.py @@ -0,0 +1,328 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-2021 Rapptz +Copyright (c) 2021-present Pycord Development + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Mapping + +from typing_extensions import Self, override + +from ..abc import Connectable +from ..enums import ChannelType, InviteTarget, VideoQualityMode, VoiceRegion, try_enum +from ..utils import MISSING, Undefined +from .base import GuildMessageableChannel, GuildTopLevelChannel + +if TYPE_CHECKING: + from ..abc import Snowflake + from ..enums import EmbeddedActivity + from ..invite import Invite + from ..member import Member + from ..permissions import PermissionOverwrite + from ..role import Role + from ..soundboard import PartialSoundboardSound + from ..types.channel import VoiceChannel as VoiceChannelPayload + from .category import CategoryChannel + +__all__ = ("VoiceChannel",) + + +class VoiceChannel( + GuildTopLevelChannel["VoiceChannelPayload"], + GuildMessageableChannel, + Connectable, +): + """Represents a Discord guild voice channel. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns the channel's name. + + Attributes + ---------- + id: :class:`int` + The channel's ID. + name: :class:`str` + The channel's name. + guild: :class:`Guild` + The guild the channel belongs to. + category_id: :class:`int` | None + The category channel ID this channel belongs to, if applicable. + position: :class:`int` + The position in the channel list. This is a number that starts at 0. + bitrate: :class:`int` + The channel's preferred audio bitrate in bits per second. + user_limit: :class:`int` + The channel's limit for number of members that can be in a voice channel. + A value of ``0`` indicates no limit. + rtc_region: :class:`VoiceRegion` | None + The region for the voice channel's voice communication. + A value of ``None`` indicates automatic voice region detection. + video_quality_mode: :class:`VideoQualityMode` + The camera video quality for the voice channel's participants. + last_message_id: :class:`int` | None + The ID of the last message sent to this channel. It may not always point to an existing or valid message. + slowmode_delay: :class:`int` + The number of seconds a member must wait between sending messages + in this channel. A value of `0` denotes that it is disabled. + status: :class:`str` | None + The channel's status, if set. + nsfw: :class:`bool` + Whether the channel is marked as NSFW. + + .. versionadded:: 3.0 + """ + + __slots__: tuple[str, ...] = ( + "topic", + "nsfw", + "slowmode_delay", + "last_message_id", + "bitrate", + "user_limit", + "rtc_region", + "video_quality_mode", + "status", + ) + + @override + async def _update(self, data: VoiceChannelPayload) -> None: + await super()._update(data) + self.bitrate: int = data.get("bitrate", 64000) + self.user_limit: int = data.get("user_limit", 0) + rtc = data.get("rtc_region") + self.rtc_region: VoiceRegion | None = try_enum(VoiceRegion, rtc) if rtc is not None else None + self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get("video_quality_mode", 1)) + self.status: str | None = data.get("status") + + @property + @override + def _sorting_bucket(self) -> int: + return ChannelType.voice.value + + def __repr__(self) -> str: + attrs = [ + ("id", self.id), + ("name", self.name), + ("status", self.status), + ("rtc_region", self.rtc_region), + ("position", self.position), + ("bitrate", self.bitrate), + ("video_quality_mode", self.video_quality_mode), + ("user_limit", self.user_limit), + ("category_id", self.category_id), + ] + joined = " ".join(f"{k}={v!r}" for k, v in attrs) + return f"" + + async def edit( + self, + *, + name: str | Undefined = MISSING, + bitrate: int | Undefined = MISSING, + user_limit: int | Undefined = MISSING, + position: int | Undefined = MISSING, + sync_permissions: bool | Undefined = MISSING, + category: CategoryChannel | None | Undefined = MISSING, + overwrites: Mapping[Role | Member | Snowflake, PermissionOverwrite] | Undefined = MISSING, + rtc_region: VoiceRegion | None | Undefined = MISSING, + video_quality_mode: VideoQualityMode | Undefined = MISSING, + slowmode_delay: int | Undefined = MISSING, + nsfw: bool | Undefined = MISSING, + reason: str | None = None, + ) -> Self: + """|coro| + + Edits the voice channel. + + You must have :attr:`~Permissions.manage_channels` permission to use this. + + Parameters + ---------- + name: :class:`str` + The new channel's name. + bitrate: :class:`int` + The new channel's bitrate. + user_limit: :class:`int` + The new channel's user limit. + position: :class:`int` + The new channel's position. + sync_permissions: :class:`bool` + Whether to sync permissions with the channel's new or pre-existing category. + category: :class:`CategoryChannel` | None + The new category for this channel. Can be ``None`` to remove the category. + overwrites: Mapping[:class:`Role` | :class:`Member` | :class:`~discord.abc.Snowflake`, :class:`PermissionOverwrite`] + The overwrites to apply to channel permissions. + rtc_region: :class:`VoiceRegion` | None + The new region for the voice channel's voice communication. + A value of ``None`` indicates automatic voice region detection. + video_quality_mode: :class:`VideoQualityMode` + The camera video quality for the voice channel's participants. + slowmode_delay: :class:`int` + Specifies the slowmode rate limit for user in this channel, in seconds. + nsfw: :class:`bool` + Whether the channel is marked as NSFW. + reason: :class:`str` | None + The reason for editing this channel. Shows up on the audit log. + + Returns + ------- + :class:`.VoiceChannel` + The newly edited voice channel. If the edit was only positional then ``None`` is returned. + + Raises + ------ + Forbidden + You do not have permissions to edit the channel. + HTTPException + Editing the channel failed. + """ + options = {} + if name is not MISSING: + options["name"] = name + if bitrate is not MISSING: + options["bitrate"] = bitrate + if user_limit is not MISSING: + options["user_limit"] = user_limit + if position is not MISSING: + options["position"] = position + if sync_permissions is not MISSING: + options["sync_permissions"] = sync_permissions + if category is not MISSING: + options["category"] = category + if overwrites is not MISSING: + options["overwrites"] = overwrites + if rtc_region is not MISSING: + options["rtc_region"] = rtc_region + if video_quality_mode is not MISSING: + options["video_quality_mode"] = video_quality_mode + if slowmode_delay is not MISSING: + options["slowmode_delay"] = slowmode_delay + if nsfw is not MISSING: + options["nsfw"] = nsfw + + payload = await self._edit(options, reason=reason) + if payload is not None: + return await self.__class__._from_data(data=payload, state=self._state, guild=self.guild) # type: ignore + + async def create_activity_invite(self, activity: EmbeddedActivity | int, **kwargs) -> Invite: + """|coro| + + A shortcut method that creates an instant activity invite. + + You must have :attr:`~discord.Permissions.start_embedded_activities` permission to do this. + + Parameters + ---------- + activity: :class:`EmbeddedActivity` | :class:`int` + The embedded activity to create an invite for. Can be an :class:`EmbeddedActivity` enum member + or the application ID as an integer. + max_age: :class:`int` + How long the invite should last in seconds. If it's 0 then the invite doesn't expire. + max_uses: :class:`int` + How many uses the invite could be used for. If it's 0 then there are unlimited uses. + temporary: :class:`bool` + Denotes that the invite grants temporary membership. + unique: :class:`bool` + Indicates if a unique invite URL should be created. + reason: :class:`str` | None + The reason for creating this invite. Shows up on the audit log. + + Returns + ------- + :class:`~discord.Invite` + The invite that was created. + + Raises + ------ + HTTPException + Invite creation failed. + """ + from ..enums import EmbeddedActivity + + if isinstance(activity, EmbeddedActivity): + activity = activity.value + + return await self.create_invite( + target_type=InviteTarget.embedded_application, + target_application_id=activity, + **kwargs, + ) + + async def set_status(self, status: str | None, *, reason: str | None = None) -> None: + """|coro| + + Sets the voice channel status. + + You must have :attr:`~discord.Permissions.manage_channels` and + :attr:`~discord.Permissions.connect` permissions to do this. + + Parameters + ---------- + status: :class:`str` | None + The new voice channel status. Set to ``None`` to remove the status. + reason: :class:`str` | None + The reason for setting the voice channel status. Shows up on the audit log. + + Raises + ------ + Forbidden + You do not have permissions to set the voice channel status. + HTTPException + Setting the voice channel status failed. + """ + await self._state.http.edit_voice_channel_status(self.id, status, reason=reason) + + async def send_soundboard_sound(self, sound: PartialSoundboardSound) -> None: + """|coro| + + Sends a soundboard sound to the voice channel. + + Parameters + ---------- + sound: :class:`PartialSoundboardSound` + The soundboard sound to send. + + Raises + ------ + Forbidden + You do not have proper permissions to send the soundboard sound. + HTTPException + Sending the soundboard sound failed. + """ + await self._state.http.send_soundboard_sound(self.id, sound) diff --git a/discord/client.py b/discord/client.py index 225904bd50..27667955e6 100644 --- a/discord/client.py +++ b/discord/client.py @@ -30,8 +30,9 @@ import signal import sys import traceback +from collections.abc import Awaitable from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Generator, Sequence, TypeVar +from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Coroutine, Generator, Sequence, TypeVar import aiohttp @@ -39,15 +40,20 @@ from . import utils from .activity import ActivityTypes, BaseActivity, create_activity +from .app.cache import Cache, MemoryCache +from .app.event_emitter import Event +from .app.state import ConnectionState from .appinfo import AppInfo, PartialAppInfo from .application_role_connection import ApplicationRoleConnectionMetadata from .backoff import ExponentialBackoff from .channel import PartialMessageable, _threaded_channel_factory +from .channel.thread import Thread from .emoji import AppEmoji, GuildEmoji from .enums import ChannelType, Status from .errors import * from .flags import ApplicationFlags, Intents from .gateway import * +from .gears import Gear from .guild import Guild from .http import HTTPClient from .invite import Invite @@ -57,16 +63,15 @@ from .object import Object from .soundboard import SoundboardSound from .stage_instance import StageInstance -from .state import ConnectionState from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory from .template import Template -from .threads import Thread from .ui.view import View from .user import ClientUser, User from .utils import MISSING from .utils.private import ( SequenceProxy, bytes_to_base64_data, + copy_doc, resolve_invite, resolve_template, ) @@ -75,8 +80,8 @@ from .widget import Widget if TYPE_CHECKING: - from .abc import GuildChannel, PrivateChannel, Snowflake, SnowflakeTime - from .channel import DMChannel + from .abc import PrivateChannel, Snowflake, SnowflakeTime + from .channel import DMChannel, GuildChannel from .interactions import Interaction from .member import Member from .message import Message @@ -242,7 +247,6 @@ def __init__( # self.ws is set in the connect method self.ws: DiscordWebSocket = None # type: ignore self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop - self._listeners: dict[str, list[tuple[asyncio.Future, Callable[..., bool]]]] = {} self.shard_id: int | None = options.get("shard_id") self.shard_count: int | None = options.get("shard_count") @@ -263,7 +267,14 @@ def __init__( self._hooks: dict[str, Callable] = {"before_identify": self._call_before_identify_hook} self._enable_debug_events: bool = options.pop("enable_debug_events", False) - self._connection: ConnectionState = self._get_state(**options) + self._connection: ConnectionState = ConnectionState( + handlers=self._handlers, + hooks=self._hooks, + http=self.http, + loop=self.loop, + cache=MemoryCache(), + **options, + ) self._connection.shard_count = self.shard_count self._closed: bool = False self._ready: asyncio.Event = asyncio.Event() @@ -271,6 +282,10 @@ def __init__( self._connection._get_client = lambda: self self._event_handlers: dict[str, list[Coro]] = {} + self._main_gear: Gear = Gear() + + self._connection.emitter.add_receiver(self._handle_event) + if VoiceClient.warn_nacl: VoiceClient.warn_nacl = False _log.warning("PyNaCl is not installed, voice will NOT be supported") @@ -278,6 +293,9 @@ def __init__( # Used to hard-reference tasks so they don't get garbage collected (discarded with done_callbacks) self._tasks = set() + async def _handle_event(self, event: Event) -> None: + await asyncio.gather(*self._main_gear._handle_event(event)) + async def __aenter__(self) -> Client: loop = asyncio.get_running_loop() self.loop = loop @@ -297,21 +315,47 @@ async def __aexit__( if not self.is_closed(): await self.close() + # Gear methods + + @copy_doc(Gear.attach_gear) + def attach_gear(self, gear: Gear) -> None: + return self._main_gear.attach_gear(gear) + + @copy_doc(Gear.detach_gear) + def detach_gear(self, gear: Gear) -> None: + return self._main_gear.detach_gear(gear) + + @copy_doc(Gear.add_listener) + def add_listener( + self, + callback: Callable[[Event], Awaitable[None]], + *, + event: type[Event] | Undefined = MISSING, + is_instance_function: bool = False, + once: bool = False, + ) -> None: + return self._main_gear.add_listener(callback, event=event, is_instance_function=is_instance_function, once=once) + + @copy_doc(Gear.remove_listener) + def remove_listener( + self, + callback: Callable[[Event], Awaitable[None]], + event: type[Event] | Undefined = MISSING, + is_instance_function: bool = False, + ) -> None: + return self._main_gear.remove_listener(callback, event=event, is_instance_function=is_instance_function) + + @copy_doc(Gear.listen) + def listen( + self, event: type[Event] | Undefined = MISSING, once: bool = False + ) -> Callable[[Callable[[Event], Awaitable[None]]], Callable[[Event], Awaitable[None]]]: + return self._main_gear.listen(event=event, once=once) + # internals def _get_websocket(self, guild_id: int | None = None, *, shard_id: int | None = None) -> DiscordWebSocket: return self.ws - def _get_state(self, **options: Any) -> ConnectionState: - return ConnectionState( - dispatch=self.dispatch, - handlers=self._handlers, - hooks=self._hooks, - http=self.http, - loop=self.loop, - **options, - ) - def _handle_ready(self) -> None: self._ready.set() @@ -342,62 +386,54 @@ def user(self) -> ClientUser | None: """Represents the connected client. ``None`` if not logged in.""" return self._connection.user - @property - def guilds(self) -> list[Guild]: + async def get_guilds(self) -> list[Guild]: """The guilds that the connected client is a member of.""" - return self._connection.guilds + return await self._connection.get_guilds() - @property - def emojis(self) -> list[GuildEmoji | AppEmoji]: + async def get_emojis(self) -> list[GuildEmoji | AppEmoji]: """The emojis that the connected client has. .. note:: This only includes the application's emojis if `cache_app_emojis` is ``True``. """ - return self._connection.emojis + return await self._connection.get_emojis() - @property - def guild_emojis(self) -> list[GuildEmoji]: + async def get_guild_emojis(self) -> list[GuildEmoji]: """The :class:`~discord.GuildEmoji` that the connected client has.""" - return [e for e in self.emojis if isinstance(e, GuildEmoji)] + return [e for e in await self.get_emojis() if isinstance(e, GuildEmoji)] - @property - def app_emojis(self) -> list[AppEmoji]: + async def get_app_emojis(self) -> list[AppEmoji]: """The :class:`~discord.AppEmoji` that the connected client has. .. note:: This is only available if `cache_app_emojis` is ``True``. """ - return [e for e in self.emojis if isinstance(e, AppEmoji)] + return [e for e in await self.get_emojis() if isinstance(e, AppEmoji)] - @property - def stickers(self) -> list[GuildSticker]: + async def get_stickers(self) -> list[GuildSticker]: """The stickers that the connected client has. .. versionadded:: 2.0 """ - return self._connection.stickers + return await self._connection.get_stickers() - @property - def polls(self) -> list[Poll]: + async def get_polls(self) -> list[Poll]: """The polls that the connected client has. .. versionadded:: 2.6 """ - return self._connection.polls + return await self._connection.get_polls() - @property - def cached_messages(self) -> Sequence[Message]: + async def get_cached_messages(self) -> Sequence[Message]: """Read-only list of messages the connected client has cached. .. versionadded:: 1.1 """ - return SequenceProxy(self._connection._messages or []) + return SequenceProxy(await self._connection.cache.get_all_messages()) - @property - def private_channels(self) -> list[PrivateChannel]: + async def get_private_channels(self) -> list[PrivateChannel]: """The private channels that the connected client is participating on. .. note:: @@ -405,7 +441,7 @@ def private_channels(self) -> list[PrivateChannel]: This returns only up to 128 most recent private channels due to an internal working on how Discord deals with private channels. """ - return self._connection.private_channels + return await self._connection.get_private_channels() @property def voice_clients(self) -> list[VoiceProtocol]: @@ -471,71 +507,6 @@ def _schedule_event( task.add_done_callback(self._tasks.discard) return task - def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None: - _log.debug("Dispatching event %s", event) - method = f"on_{event}" - - listeners = self._listeners.get(event) - if listeners: - removed = [] - for i, (future, condition) in enumerate(listeners): - if future.cancelled(): - removed.append(i) - continue - - try: - result = condition(*args) - except Exception as exc: - future.set_exception(exc) - removed.append(i) - else: - if result: - if len(args) == 0: - future.set_result(None) - elif len(args) == 1: - future.set_result(args[0]) - else: - future.set_result(args) - removed.append(i) - - if len(removed) == len(listeners): - self._listeners.pop(event) - else: - for idx in reversed(removed): - del listeners[idx] - - # Schedule the main handler registered with @event - try: - coro = getattr(self, method) - except AttributeError: - pass - else: - self._schedule_event(coro, method, *args, **kwargs) - - # collect the once listeners as removing them from the list - # while iterating over it causes issues - once_listeners = [] - - # Schedule additional handlers registered with @listen - for coro in self._event_handlers.get(method, []): - self._schedule_event(coro, method, *args, **kwargs) - - try: - if coro._once: # added using @listen() - once_listeners.append(coro) - - except AttributeError: # added using @Cog.add_listener() - # https://github.com/Pycord-Development/pycord/pull/1989 - # Although methods are similar to functions, attributes can't be added to them. - # This means that we can't add the `_once` attribute in the `add_listener` method - # and can only be added using the `@listen` decorator. - - continue - - # remove the once listeners - for coro in once_listeners: - self._event_handlers[method].remove(coro) - async def on_error(self, event_method: str, *args: Any, **kwargs: Any) -> None: """|coro| @@ -697,7 +668,7 @@ async def connect(self, *, reconnect: bool = True) -> None: await self.ws.poll_event() except ReconnectWebSocket as e: _log.info("Got a request to %s the websocket.", e.op) - self.dispatch("disconnect") + # self.dispatch("disconnect") # TODO: dispatch event ws_params.update( sequence=self.ws.sequence, resume=e.resume, @@ -932,10 +903,9 @@ def intents(self) -> Intents: # helpers/getters - @property - def users(self) -> list[User]: + async def get_users(self) -> list[User]: """Returns a list of all the users the bot can see.""" - return list(self._connection._users.values()) + return await self._connection.cache.get_all_users() async def fetch_application(self, application_id: int, /) -> PartialAppInfo: """|coro| @@ -961,7 +931,7 @@ async def fetch_application(self, application_id: int, /) -> PartialAppInfo: data = await self.http.get_application(application_id) return PartialAppInfo(state=self._connection, data=data) - def get_channel(self, id: int, /) -> GuildChannel | Thread | PrivateChannel | None: + async def get_channel(self, id: int, /) -> GuildChannel | Thread | PrivateChannel | None: """Returns a channel or thread with the given ID. Parameters @@ -974,9 +944,9 @@ def get_channel(self, id: int, /) -> GuildChannel | Thread | PrivateChannel | No Optional[Union[:class:`.abc.GuildChannel`, :class:`.Thread`, :class:`.abc.PrivateChannel`]] The returned channel or ``None`` if not found. """ - return self._connection.get_channel(id) + return await self._connection.get_channel(id) - def get_message(self, id: int, /) -> Message | None: + async def get_message(self, id: int, /) -> Message | None: """Returns a message the given ID. This is useful if you have a message_id but don't want to do an API call @@ -992,7 +962,7 @@ def get_message(self, id: int, /) -> Message | None: Optional[:class:`.Message`] The returned message or ``None`` if not found. """ - return self._connection._get_message(id) + return await self._connection._get_message(id) def get_partial_messageable(self, id: int, *, type: ChannelType | None = None) -> PartialMessageable: """Returns a partial messageable with the given channel ID. @@ -1016,7 +986,7 @@ def get_partial_messageable(self, id: int, *, type: ChannelType | None = None) - """ return PartialMessageable(state=self._connection, id=id, type=type) - def get_stage_instance(self, id: int, /) -> StageInstance | None: + async def get_stage_instance(self, id: int, /) -> StageInstance | None: """Returns a stage instance with the given stage channel ID. .. versionadded:: 2.0 @@ -1031,14 +1001,14 @@ def get_stage_instance(self, id: int, /) -> StageInstance | None: Optional[:class:`.StageInstance`] The stage instance or ``None`` if not found. """ - from .channel import StageChannel # noqa: PLC0415 + from .channel import StageChannel - channel = self._connection.get_channel(id) + channel = await self._connection.get_channel(id) if isinstance(channel, StageChannel): return channel.instance - def get_guild(self, id: int, /) -> Guild | None: + async def get_guild(self, id: int, /) -> Guild | None: """Returns a guild with the given ID. Parameters @@ -1051,9 +1021,9 @@ def get_guild(self, id: int, /) -> Guild | None: Optional[:class:`.Guild`] The guild or ``None`` if not found. """ - return self._connection._get_guild(id) + return await self._connection._get_guild(id) - def get_user(self, id: int, /) -> User | None: + async def get_user(self, id: int, /) -> User | None: """Returns a user with the given ID. Parameters @@ -1066,9 +1036,9 @@ def get_user(self, id: int, /) -> User | None: Optional[:class:`~discord.User`] The user or ``None`` if not found. """ - return self._connection.get_user(id) + return await self._connection.get_user(id) - def get_emoji(self, id: int, /) -> GuildEmoji | AppEmoji | None: + async def get_emoji(self, id: int, /) -> GuildEmoji | AppEmoji | None: """Returns an emoji with the given ID. Parameters @@ -1081,9 +1051,9 @@ def get_emoji(self, id: int, /) -> GuildEmoji | AppEmoji | None: Optional[:class:`.GuildEmoji` | :class:`.AppEmoji`] The custom emoji or ``None`` if not found. """ - return self._connection.get_emoji(id) + return await self._connection.get_emoji(id) - def get_sticker(self, id: int, /) -> GuildSticker | None: + async def get_sticker(self, id: int, /) -> GuildSticker | None: """Returns a guild sticker with the given ID. .. versionadded:: 2.0 @@ -1098,9 +1068,9 @@ def get_sticker(self, id: int, /) -> GuildSticker | None: Optional[:class:`.GuildSticker`] The sticker or ``None`` if not found. """ - return self._connection.get_sticker(id) + return await self._connection.get_sticker(id) - def get_poll(self, id: int, /) -> Poll | None: + async def get_poll(self, id: int, /) -> Poll | None: """Returns a poll attached to the given message ID. Parameters @@ -1113,14 +1083,14 @@ def get_poll(self, id: int, /) -> Poll | None: Optional[:class:`.Poll`] The poll or ``None`` if not found. """ - return self._connection.get_poll(id) + return await self._connection.get_poll(id) - def get_all_channels(self) -> Generator[GuildChannel]: + async def get_all_channels(self) -> AsyncGenerator[GuildChannel]: """A generator that retrieves every :class:`.abc.GuildChannel` the client can 'access'. This is equivalent to: :: - for guild in client.guilds: + for guild in await client.get_guilds(): for channel in guild.channels: yield channel @@ -1136,15 +1106,16 @@ def get_all_channels(self) -> Generator[GuildChannel]: A channel the client can 'access'. """ - for guild in self.guilds: - yield from guild.channels + for guild in await self.get_guilds(): + for channel in guild.channels: + yield channel - def get_all_members(self) -> Generator[Member]: + async def get_all_members(self) -> AsyncGenerator[Member]: """Returns a generator with every :class:`.Member` the client can see. This is equivalent to: :: - for guild in client.guilds: + for guild in await client.get_guilds(): for member in guild.members: yield member @@ -1153,10 +1124,9 @@ def get_all_members(self) -> Generator[Member]: :class:`.Member` A member the client can see. """ - for guild in self.guilds: - yield from guild.members - - # listeners/waiters + for guild in await self.get_guilds(): + for member in guild.members: + yield member async def wait_until_ready(self) -> None: """|coro| @@ -1165,275 +1135,6 @@ async def wait_until_ready(self) -> None: """ await self._ready.wait() - def wait_for( - self, - event: str, - *, - check: Callable[..., bool] | None = None, - timeout: float | None = None, - ) -> Any: - """|coro| - - Waits for a WebSocket event to be dispatched. - - This could be used to wait for a user to reply to a message, - or to react to a message, or to edit a message in a self-contained - way. - - The ``timeout`` parameter is passed onto :func:`asyncio.wait_for`. By default, - it does not timeout. Note that this does propagate the - :exc:`asyncio.TimeoutError` for you in case of timeout and is provided for - ease of use. - - In case the event returns multiple arguments, a :class:`tuple` containing those - arguments is returned instead. Please check the - :ref:`documentation ` for a list of events and their - parameters. - - This function returns the **first event that meets the requirements**. - - Parameters - ---------- - event: :class:`str` - The event name, similar to the :ref:`event reference `, - but without the ``on_`` prefix, to wait for. - check: Optional[Callable[..., :class:`bool`]] - A predicate to check what to wait for. The arguments must meet the - parameters of the event being waited for. - timeout: Optional[:class:`float`] - The number of seconds to wait before timing out and raising - :exc:`asyncio.TimeoutError`. - - Returns - ------- - Any - Returns no arguments, a single argument, or a :class:`tuple` of multiple - arguments that mirrors the parameters passed in the - :ref:`event reference `. - - Raises - ------ - asyncio.TimeoutError - Raised if a timeout is provided and reached. - - Examples - -------- - - Waiting for a user reply: :: - - @client.event - async def on_message(message): - if message.content.startswith("$greet"): - channel = message.channel - await channel.send("Say hello!") - - def check(m): - return m.content == "hello" and m.channel == channel - - msg = await client.wait_for("message", check=check) - await channel.send(f"Hello {msg.author}!") - - Waiting for a thumbs up reaction from the message author: :: - - @client.event - async def on_message(message): - if message.content.startswith("$thumb"): - channel = message.channel - await channel.send("Send me that \N{THUMBS UP SIGN} reaction, mate") - - def check(reaction, user): - return user == message.author and str(reaction.emoji) == "\N{THUMBS UP SIGN}" - - try: - reaction, user = await client.wait_for("reaction_add", timeout=60.0, check=check) - except asyncio.TimeoutError: - await channel.send("\N{THUMBS DOWN SIGN}") - else: - await channel.send("\N{THUMBS UP SIGN}") - """ - - future = self.loop.create_future() - if check is None: - - def _check(*args): - return True - - check = _check - - ev = event.lower() - try: - listeners = self._listeners[ev] - except KeyError: - listeners = [] - self._listeners[ev] = listeners - - listeners.append((future, check)) - return asyncio.wait_for(future, timeout) - - # event registration - def add_listener(self, func: Coro, name: str | utils.Undefined = MISSING) -> None: - """The non decorator alternative to :meth:`.listen`. - - Parameters - ---------- - func: :ref:`coroutine ` - The function to call. - name: :class:`str` - The name of the event to listen for. Defaults to ``func.__name__``. - - Raises - ------ - TypeError - The ``func`` parameter is not a coroutine function. - ValueError - The ``name`` (event name) does not start with ``on_``. - - Example - ------- - - .. code-block:: python3 - - async def on_ready(): - pass - - - async def my_message(message): - pass - - - client.add_listener(on_ready) - client.add_listener(my_message, "on_message") - """ - name = func.__name__ if name is MISSING else name - - if not name.startswith("on_"): - raise ValueError("The 'name' parameter must start with 'on_'") - - if not asyncio.iscoroutinefunction(func): - raise TypeError("Listeners must be coroutines") - - if name in self._event_handlers: - self._event_handlers[name].append(func) - else: - self._event_handlers[name] = [func] - - _log.debug( - "%s has successfully been registered as a handler for event %s", - func.__name__, - name, - ) - - def remove_listener(self, func: Coro, name: str | utils.Undefined = MISSING) -> None: - """Removes a listener from the pool of listeners. - - Parameters - ---------- - func - The function that was used as a listener to remove. - name: :class:`str` - The name of the event we want to remove. Defaults to - ``func.__name__``. - """ - - name = func.__name__ if name is MISSING else name - - if name in self._event_handlers: - try: - self._event_handlers[name].remove(func) - except ValueError: - pass - - def listen(self, name: str | utils.Undefined = MISSING, once: bool = False) -> Callable[[Coro], Coro]: - """A decorator that registers another function as an external - event listener. Basically this allows you to listen to multiple - events from different places e.g. such as :func:`.on_ready` - - The functions being listened to must be a :ref:`coroutine `. - - Raises - ------ - TypeError - The function being listened to is not a coroutine. - ValueError - The ``name`` (event name) does not start with ``on_``. - - Example - ------- - - .. code-block:: python3 - - @client.listen() - async def on_message(message): - print("one") - - - # in some other file... - - - @client.listen("on_message") - async def my_message(message): - print("two") - - - # listen to the first event only - @client.listen("on_ready", once=True) - async def on_ready(): - print("ready!") - - Would print one and two in an unspecified order. - """ - - def decorator(func: Coro) -> Coro: - # Special case, where default should be overwritten - if name == "on_application_command_error": - return self.event(func) - - func._once = once - self.add_listener(func, name) - return func - - if asyncio.iscoroutinefunction(name): - coro = name - name = coro.__name__ - return decorator(coro) - - return decorator - - def event(self, coro: Coro) -> Coro: - """A decorator that registers an event to listen to. - - You can find more info about the events on the :ref:`documentation below `. - - The events must be a :ref:`coroutine `, if not, :exc:`TypeError` is raised. - - .. note:: - - This replaces any default handlers. - Developers are encouraged to use :py:meth:`~discord.Client.listen` for adding additional handlers - instead of :py:meth:`~discord.Client.event` unless default method replacement is intended. - - Raises - ------ - TypeError - The coroutine passed is not actually a coroutine. - - Example - ------- - - .. code-block:: python3 - - @client.event - async def on_ready(): - print("Ready!") - """ - - if not asyncio.iscoroutinefunction(coro): - raise TypeError("event registered must be a coroutine function") - - setattr(self, coro.__name__, coro) - _log.debug("%s has successfully been registered as an event", coro.__name__) - return coro - async def change_presence( self, *, @@ -1480,7 +1181,7 @@ async def change_presence( await self.ws.change_presence(activity=activity, status=status_str) - for guild in self._connection.guilds: + for guild in await self._connection.get_guilds(): me = guild.me if me is None: continue @@ -1581,7 +1282,7 @@ async def fetch_template(self, code: Template | str) -> Template: """ code = resolve_template(code) data = await self.http.get_template(code) - return Template(data=data, state=self._connection) # type: ignore + return await Template.from_data(data=data, state=self._connection) # type: ignore async def fetch_guild(self, guild_id: int, /, *, with_counts=True) -> Guild: """|coro| @@ -1622,7 +1323,7 @@ async def fetch_guild(self, guild_id: int, /, *, with_counts=True) -> Guild: Getting the guild failed. """ data = await self.http.get_guild(guild_id, with_counts=with_counts) - return Guild(data=data, state=self._connection) + return await Guild._from_data(data=data, state=self._connection) async def create_guild( self, @@ -1671,7 +1372,7 @@ async def create_guild( data = await self.http.create_from_template(code, name, icon_base64) else: data = await self.http.create_guild(name, icon_base64) - return Guild(data=data, state=self._connection) + return await Guild._from_data(data=data, state=self._connection) async def fetch_stage_instance(self, channel_id: int, /) -> StageInstance: """|coro| @@ -1762,7 +1463,7 @@ async def fetch_invite( with_expiration=with_expiration, guild_scheduled_event_id=event_id, ) - return Invite.from_incomplete(state=self._connection, data=data) + return await Invite.from_incomplete(state=self._connection, data=data) async def delete_invite(self, invite: Invite | str) -> None: """|coro| @@ -2002,14 +1703,14 @@ async def create_dm(self, user: Snowflake) -> DMChannel: The channel that was created. """ state = self._connection - found = state._get_private_channel_by_user(user.id) + found = await state._get_private_channel_by_user(user.id) if found: return found data = await state.http.start_private_message(user.id) - return state.add_dm_channel(data) + return await state.add_dm_channel(data) - def add_view(self, view: View, *, message_id: int | None = None) -> None: + async def add_view(self, view: View, *, message_id: int | None = None) -> None: """Registers a :class:`~discord.ui.View` for persistent listening. This method should be used for when a view is comprised of components @@ -2041,15 +1742,14 @@ def add_view(self, view: View, *, message_id: int | None = None) -> None: if not view.is_persistent(): raise ValueError("View is not persistent. Items need to have a custom_id set and View must have no timeout") - self._connection.store_view(view, message_id) + await self._connection.store_view(view, message_id) - @property - def persistent_views(self) -> Sequence[View]: + async def get_persistent_views(self) -> Sequence[View]: """A sequence of persistent views added to the client. .. versionadded:: 2.0 """ - return self._connection.persistent_views + return await self._connection.get_persistent_views() async def fetch_role_connection_metadata_records( self, @@ -2205,7 +1905,7 @@ async def fetch_emojis(self) -> list[AppEmoji]: The retrieved emojis. """ data = await self._connection.http.get_all_application_emojis(self.application_id) - return [self._connection.maybe_store_app_emoji(self.application_id, d) for d in data["items"]] + return [await self._connection.maybe_store_app_emoji(self.application_id, d) for d in data["items"]] async def fetch_emoji(self, emoji_id: int, /) -> AppEmoji: """|coro| @@ -2230,7 +1930,7 @@ async def fetch_emoji(self, emoji_id: int, /) -> AppEmoji: An error occurred fetching the emoji. """ data = await self._connection.http.get_application_emoji(self.application_id, emoji_id) - return self._connection.maybe_store_app_emoji(self.application_id, data) + return await self._connection.maybe_store_app_emoji(self.application_id, data) async def create_emoji( self, @@ -2265,7 +1965,7 @@ async def create_emoji( img = bytes_to_base64_data(image) data = await self._connection.http.create_application_emoji(self.application_id, name, img) - return self._connection.maybe_store_app_emoji(self.application_id, data) + return await self._connection.maybe_store_app_emoji(self.application_id, data) async def delete_emoji(self, emoji: Snowflake) -> None: """|coro| @@ -2284,8 +1984,8 @@ async def delete_emoji(self, emoji: Snowflake) -> None: """ await self._connection.http.delete_application_emoji(self.application_id, emoji.id) - if self._connection.cache_app_emojis and self._connection.get_emoji(emoji.id): - self._connection.remove_emoji(emoji) + if self._connection.cache_app_emojis and await self._connection.get_emoji(emoji.id): + await self._connection._remove_emoji(emoji) def get_sound(self, sound_id: int) -> SoundboardSound | None: """Gets a :class:`.Sound` from the bot's sound cache. diff --git a/discord/commands/context.py b/discord/commands/context.py index d7d1500ab9..7d5810e2bd 100644 --- a/discord/commands/context.py +++ b/discord/commands/context.py @@ -39,6 +39,7 @@ import discord from .. import Bot + from ..app.state import ConnectionState from ..client import ClientUser from ..cog import Cog from ..guild import Guild @@ -46,7 +47,6 @@ from ..member import Member from ..message import Message from ..permissions import Permissions - from ..state import ConnectionState from ..user import User from ..voice_client import VoiceClient from ..webhook import WebhookMessage diff --git a/discord/commands/core.py b/discord/commands/core.py index 0351513601..acbc2a757f 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -45,7 +45,10 @@ Union, ) +from discord.interactions import AutocompleteInteraction, Interaction + from ..channel import PartialMessageable, _threaded_guild_channel_factory +from ..channel.thread import Thread from ..enums import Enum as DiscordEnum from ..enums import ( IntegrationType, @@ -66,7 +69,6 @@ from ..message import Attachment, Message from ..object import Object from ..role import Role -from ..threads import Thread from ..user import User from ..utils import MISSING, find, utcnow from ..utils.private import async_all, maybe_awaitable, warn_deprecated @@ -111,7 +113,7 @@ def wrap_callback(coro): - from ..ext.commands.errors import CommandError # noqa: PLC0415 + from ..ext.commands.errors import CommandError @functools.wraps(coro) async def wrapped(*args, **kwargs): @@ -131,7 +133,7 @@ async def wrapped(*args, **kwargs): def hooked_wrapped_callback(command, ctx, coro): - from ..ext.commands.errors import CommandError # noqa: PLC0415 + from ..ext.commands.errors import CommandError @functools.wraps(coro) async def wrapped(arg): @@ -188,7 +190,7 @@ class ApplicationCommand(_BaseCommand, Generic[CogT, P, T]): cog = None def __init__(self, func: Callable, **kwargs) -> None: - from ..ext.commands.cooldowns import BucketType, CooldownMapping, MaxConcurrency # noqa: PLC0415 + from ..ext.commands.cooldowns import BucketType, CooldownMapping, MaxConcurrency cooldown = getattr(func, "__commands_cooldown__", kwargs.get("cooldown")) @@ -330,7 +332,7 @@ def _prepare_cooldowns(self, ctx: ApplicationContext): retry_after = bucket.update_rate_limit(current) if retry_after: - from ..ext.commands.errors import CommandOnCooldown # noqa: PLC0415 + from ..ext.commands.errors import CommandOnCooldown raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore @@ -464,7 +466,9 @@ async def dispatch_error(self, ctx: ApplicationContext, error: Exception) -> Non wrapped = wrap_callback(local) await wrapped(ctx, error) finally: - ctx.bot.dispatch("application_command_error", ctx, error) + ctx.bot.dispatch( + "application_command_error", ctx, error + ) # TODO: Remove this when migrating away from ApplicationContext def _get_signature_parameters(self): return OrderedDict(inspect.signature(self.callback).parameters) @@ -963,7 +967,7 @@ async def _invoke(self, ctx: ApplicationContext) -> None: # We resolved the user from the user id _data["user"] = _user_data cache_flag = ctx.interaction._state.member_cache_flags.interaction - arg = ctx.guild._get_and_update_member(_data, int(arg), cache_flag) + arg = await ctx.guild._get_and_update_member(_data, int(arg), cache_flag) elif op.input_type is SlashCommandOptionType.mentionable: if (_data := resolved.get("users", {}).get(arg)) is not None: arg = User(state=ctx.interaction._state, data=_data) @@ -1003,7 +1007,7 @@ async def _invoke(self, ctx: ApplicationContext) -> None: arg = Object(id=int(arg)) elif op.input_type == SlashCommandOptionType.string and (converter := op.converter) is not None: - from discord.ext.commands import Converter # noqa: PLC0415 + from discord.ext.commands import Converter if isinstance(converter, Converter): if isinstance(converter, type): @@ -1042,30 +1046,14 @@ async def _invoke(self, ctx: ApplicationContext) -> None: else: await self.callback(ctx, **kwargs) - async def invoke_autocomplete_callback(self, ctx: AutocompleteContext): - values = {i.name: i.default for i in self.options} - - for op in ctx.interaction.data.get("options", []): - if op.get("focused", False): - # op_name is used because loop variables leak in surrounding scope - option = find(lambda o, op_name=op["name"]: o.name == op_name, self.options) - values.update({i["name"]: i["value"] for i in ctx.interaction.data["options"]}) - ctx.command = self - ctx.focused = option - ctx.value = op.get("value") - ctx.options = values - - if option.autocomplete._is_instance_method: - instance = getattr(option.autocomplete, "__self__", ctx.cog) - result = option.autocomplete(instance, ctx) - else: - result = option.autocomplete(ctx) - - if inspect.isawaitable(result): - result = await result + async def invoke_autocomplete_callback(self, interaction: AutocompleteInteraction) -> None: + option = find(lambda o: o.name == interaction.name, self.options) + if not option.autocomplete: + raise ClientException(f"Option {interaction.name} is not an autocomplete option.") + result = await option.autocomplete(interaction) - choices = [o if isinstance(o, OptionChoice) else OptionChoice(o) for o in result][:25] - return await ctx.interaction.response.send_autocomplete_result(choices=choices) + choices = [o if isinstance(o, OptionChoice) else OptionChoice(o) for o in result][:25] + return await interaction.response.send_autocomplete_result(choices=choices) def copy(self): """Creates a copy of this command. @@ -1230,7 +1218,7 @@ def __init__( self.description_localizations: dict[str, str] = kwargs.get("description_localizations", MISSING) # similar to ApplicationCommand - from ..ext.commands.cooldowns import BucketType, CooldownMapping, MaxConcurrency # noqa: PLC0415 + from ..ext.commands.cooldowns import BucketType, CooldownMapping, MaxConcurrency # no need to getattr, since slash cmds groups cant be created using a decorator @@ -1416,11 +1404,10 @@ async def _invoke(self, ctx: ApplicationContext) -> None: ctx.interaction.data = option await command.invoke(ctx) - async def invoke_autocomplete_callback(self, ctx: AutocompleteContext) -> None: - option = ctx.interaction.data["options"][0] + async def invoke_autocomplete_callback(self, interaction: AutocompleteInteraction) -> None: + option = interaction.data["options"][0] command = find(lambda x: x.name == option["name"], self.subcommands) - ctx.interaction.data = option - await command.invoke_autocomplete_callback(ctx) + await command.invoke_autocomplete_callback(interaction) async def call_before_hooks(self, ctx: ApplicationContext) -> None: # only call local hooks @@ -1708,7 +1695,7 @@ async def _invoke(self, ctx: ApplicationContext) -> None: user = v member["user"] = user cache_flag = ctx.interaction._state.member_cache_flags.interaction - target = ctx.guild._get_and_update_member(member, user["id"], cache_flag) + target = await ctx.guild._get_and_update_member(member, user["id"], cache_flag) if self.cog is not None: await self.callback(self.cog, ctx, target) else: @@ -1815,7 +1802,7 @@ async def _invoke(self, ctx: ApplicationContext): # we got weird stuff going on, make up a channel channel = PartialMessageable(state=ctx.interaction._state, id=int(message["channel_id"])) - target = Message(state=ctx.interaction._state, channel=channel, data=message) + target = Message._from_data(state=ctx.interaction._state, channel=channel, data=message) if self.cog is not None: await self.callback(self.cog, ctx, target) diff --git a/discord/commands/options.py b/discord/commands/options.py index a055022830..0bb5a30f95 100644 --- a/discord/commands/options.py +++ b/discord/commands/options.py @@ -33,24 +33,34 @@ from typing import ( TYPE_CHECKING, Any, + Generic, Literal, Optional, + Sequence, Type, - TypeVar, Union, get_args, + overload, ) +from typing_extensions import TypeAlias, TypeVar, override + +from discord.interactions import AutocompleteInteraction, Interaction + +from ..utils.private import maybe_awaitable + if sys.version_info >= (3, 12): from typing import TypeAliasType else: from typing_extensions import TypeAliasType -from ..abc import GuildChannel, Mentionable +from ..abc import Mentionable from ..channel import ( + BaseChannel, CategoryChannel, DMChannel, ForumChannel, + GuildChannel, MediaChannel, StageChannel, TextChannel, @@ -71,36 +81,26 @@ from ..user import User InputType = ( - Type[str] - | Type[bool] - | Type[int] - | Type[float] - | Type[GuildChannel] - | Type[Thread] - | Type[Member] - | Type[User] - | Type[Attachment] - | Type[Role] - | Type[Mentionable] + type[ + str | bool | int | float | GuildChannel | Thread | Member | User | Attachment | Role | Mentionable + # | Converter + ] | SlashCommandOptionType - | Converter - | Type[Converter] - | Type[Enum] - | Type[DiscordEnum] + # | Converter ) AutocompleteReturnType = Iterable["OptionChoice"] | Iterable[str] | Iterable[int] | Iterable[float] - T = TypeVar("T", bound=AutocompleteReturnType) - MaybeAwaitable = T | Awaitable[T] - AutocompleteFunction = ( - Callable[[AutocompleteContext], MaybeAwaitable[AutocompleteReturnType]] - | Callable[[Cog, AutocompleteContext], MaybeAwaitable[AutocompleteReturnType]] + AR_T = TypeVar("AR_T", bound=AutocompleteReturnType) + MaybeAwaitable = AR_T | Awaitable[AR_T] + AutocompleteFunction: TypeAlias = ( + Callable[[AutocompleteInteraction], MaybeAwaitable[AutocompleteReturnType]] + | Callable[[Any, AutocompleteInteraction], MaybeAwaitable[AutocompleteReturnType]] | Callable[ - [AutocompleteContext, Any], # pyright: ignore [reportExplicitAny] + [AutocompleteInteraction, Any], MaybeAwaitable[AutocompleteReturnType], ] | Callable[ - [Cog, AutocompleteContext, Any], # pyright: ignore [reportExplicitAny] + [Any, AutocompleteInteraction, Any], MaybeAwaitable[AutocompleteReturnType], ] ) @@ -110,7 +110,6 @@ "ThreadOption", "Option", "OptionChoice", - "option", ) CHANNEL_TYPE_MAP = { @@ -147,7 +146,21 @@ def __init__(self, thread_type: Literal["public", "private", "news"]): self._type = type_map[thread_type] -class Option: +T = TypeVar("T", bound="str | int | float", default="str") + + +class ApplicationCommandOptionAutocomplete: + def __init__(self, autocomplete_function: AutocompleteFunction) -> None: + self.autocomplete_function: AutocompleteFunction = autocomplete_function + self.self: Any | None = None + + async def __call__(self, interaction: AutocompleteInteraction) -> AutocompleteReturnType: + if self.self is not None: + return await maybe_awaitable(self.autocomplete_function(self.self, interaction)) + return await maybe_awaitable(self.autocomplete_function(interaction)) + + +class Option(Generic[T]): # TODO: Update docstring @Paillat-dev """Represents a selectable option for a slash command. Attributes @@ -211,198 +224,356 @@ async def hello( .. versionadded:: 2.0 """ - input_type: SlashCommandOptionType - converter: Converter | type[Converter] | None = None - - def __init__(self, input_type: InputType = str, /, description: str | None = None, **kwargs) -> None: - self.name: str | None = kwargs.pop("name", None) - if self.name is not None: - self.name = str(self.name) - self._parameter_name = self.name # default - input_type = self._parse_type_alias(input_type) - input_type = self._strip_none_type(input_type) - self._raw_type: InputType | tuple = input_type - - enum_choices = [] - input_type_is_class = isinstance(input_type, type) - if input_type_is_class and issubclass(input_type, (Enum, DiscordEnum)): - if description is None and input_type.__doc__ is not None: - description = inspect.cleandoc(input_type.__doc__) - if description and len(description) > 100: - description = description[:97] + "..." - _log.warning( - "Option %s's description was truncated due to Enum %s's docstring exceeding 100 characters.", - self.name, - input_type, - ) - enum_choices = [OptionChoice(e.name, e.value) for e in input_type] - value_class = enum_choices[0].value.__class__ - if value_class in SlashCommandOptionType.__members__ and all( - isinstance(elem.value, value_class) for elem in enum_choices - ): - input_type = SlashCommandOptionType.from_datatype(enum_choices[0].value.__class__) - else: - enum_choices = [OptionChoice(e.name, str(e.value)) for e in input_type] - input_type = SlashCommandOptionType.string - - self.description = description or "No description provided" - self.channel_types: list[ChannelType] = kwargs.pop("channel_types", []) - - if self.channel_types: - self.input_type = SlashCommandOptionType.channel - elif isinstance(input_type, SlashCommandOptionType): - self.input_type = input_type - else: - from ..ext.commands import Converter # noqa: PLC0415 - - if isinstance(input_type, tuple) and any(issubclass(op, ApplicationContext) for op in input_type): - input_type = next(op for op in input_type if issubclass(op, ApplicationContext)) - - if isinstance(input_type, Converter) or input_type_is_class and issubclass(input_type, Converter): - self.converter = input_type - self._raw_type = str - self.input_type = SlashCommandOptionType.string - else: - try: - self.input_type = SlashCommandOptionType.from_datatype(input_type) - except TypeError as exc: - from ..ext.commands.converter import CONVERTER_MAPPING # noqa: PLC0415 - - if input_type not in CONVERTER_MAPPING: - raise exc - self.converter = CONVERTER_MAPPING[input_type] - self._raw_type = str - self.input_type = SlashCommandOptionType.string - else: - if self.input_type == SlashCommandOptionType.channel: - if not isinstance(self._raw_type, tuple): - if hasattr(input_type, "__args__"): - self._raw_type = input_type.__args__ # type: ignore # Union.__args__ - else: - self._raw_type = (input_type,) - if not self.channel_types: - self.channel_types = [CHANNEL_TYPE_MAP[t] for t in self._raw_type if t is not GuildChannel] - self.required: bool = kwargs.pop("required", True) if "default" not in kwargs else False - self.default = kwargs.pop("default", None) - - self._autocomplete: AutocompleteFunction | None = None - self.autocomplete = kwargs.pop("autocomplete", None) - if len(enum_choices) > 25: - self.choices: list[OptionChoice] = [] - for e in enum_choices: - e.value = str(e.value) - self.autocomplete = basic_autocomplete(enum_choices) - self.input_type = SlashCommandOptionType.string - else: - self.choices: list[OptionChoice] = enum_choices or [ - o if isinstance(o, OptionChoice) else OptionChoice(o) for o in kwargs.pop("choices", []) - ] - - if self.input_type == SlashCommandOptionType.integer: - minmax_types = (int, type(None)) - minmax_typehint = Optional[int] # noqa: UP045 - elif self.input_type == SlashCommandOptionType.number: - minmax_types = (int, float, type(None)) - minmax_typehint = Optional[int | float] # noqa: UP045 - else: - minmax_types = (type(None),) - minmax_typehint = type(None) - - if self.input_type == SlashCommandOptionType.string: - minmax_length_types = (int, type(None)) - minmax_length_typehint = Optional[int] # noqa: UP045 - else: - minmax_length_types = (type(None),) - minmax_length_typehint = type(None) - - self.min_value: int | float | None = kwargs.pop("min_value", None) - self.max_value: int | float | None = kwargs.pop("max_value", None) - self.min_length: int | None = kwargs.pop("min_length", None) - self.max_length: int | None = kwargs.pop("max_length", None) - - if ( - self.input_type != SlashCommandOptionType.integer - and self.input_type != SlashCommandOptionType.number - and (self.min_value or self.max_value) - ): - raise AttributeError( - "Option does not take min_value or max_value if not of type " - "SlashCommandOptionType.integer or SlashCommandOptionType.number" - ) - if self.input_type != SlashCommandOptionType.string and (self.min_length or self.max_length): - raise AttributeError("Option does not take min_length or max_length if not of type str") + # Overload for options with choices (str, int, or float types) + @overload + def __init__( + self, + name: str, + input_type: type[T] = str, + *, + choices: Sequence[OptionChoice[T]], + description: str | None = None, + channel_types: None = None, + required: bool = ..., + default: Any | Undefined = ..., + min_value: None = None, + max_value: None = None, + min_length: None = None, + max_length: None = None, + name_localizations: dict[str, str] | None = None, + description_localizations: dict[str, str] | None = None, + autocomplete: None = None, + ) -> None: ... + + # Overload for channel options with optional channel_types filter + @overload + def __init__( + self, + name: str, + input_type: type[GuildChannel | Thread] + | Literal[SlashCommandOptionType.channel] = SlashCommandOptionType.channel, + *, + choices: None = None, + description: str | None = None, + channel_types: Sequence[ChannelType] | None = None, + required: bool = ..., + default: Any | Undefined = ..., + min_value: None = None, + max_value: None = None, + min_length: None = None, + max_length: None = None, + name_localizations: dict[str, str] | None = None, + description_localizations: dict[str, str] | None = None, + autocomplete: None = None, + ) -> None: ... + + # Overload for required string options with min_length/max_length constraints + @overload + def __init__( + self, + name: str, + input_type: type[str] | Literal[SlashCommandOptionType.string] = str, + *, + description: str | None = None, + choices: None = None, + channel_types: None = None, + required: Literal[True], + default: Undefined = MISSING, + min_length: int | None = None, + max_length: int | None = None, + min_value: None = None, + max_value: None = None, + name_localizations: dict[str, str] | None = None, + description_localizations: dict[str, str] | None = None, + autocomplete: None = None, + ) -> None: ... + + # Overload for optional string options with default value and min_length/max_length constraints + @overload + def __init__( + self, + name: str, + input_type: type[str] | Literal[SlashCommandOptionType.string] = str, + *, + description: str | None = None, + choices: None = None, + channel_types: None = None, + required: bool = False, + default: Any, + min_length: int | None = None, + max_length: int | None = None, + min_value: None = None, + max_value: None = None, + name_localizations: dict[str, str] | None = None, + description_localizations: dict[str, str] | None = None, + autocomplete: None = None, + ) -> None: ... + + # Overload for required integer options with min_value/max_value constraints (integers only) + @overload + def __init__( + self, + name: str, + input_type: type[int] | Literal[SlashCommandOptionType.integer], + *, + description: str | None = None, + choices: None = None, + channel_types: None = None, + required: Literal[True], + default: Undefined = MISSING, + min_value: int | None = None, + max_value: int | None = None, + min_length: None = None, + max_length: None = None, + name_localizations: dict[str, str] | None = None, + description_localizations: dict[str, str] | None = None, + autocomplete: None = None, + ) -> None: ... + + # Overload for optional integer options with default value and min_value/max_value constraints (integers only) + @overload + def __init__( + self, + name: str, + input_type: type[int] | Literal[SlashCommandOptionType.integer], + *, + description: str | None = None, + choices: None = None, + channel_types: None = None, + required: bool = False, + default: Any, + min_value: int | None = None, + max_value: int | None = None, + min_length: None = None, + max_length: None = None, + name_localizations: dict[str, str] | None = None, + description_localizations: dict[str, str] | None = None, + autocomplete: None = None, + ) -> None: ... + + # Overload for required float options with min_value/max_value constraints (integers or floats) + @overload + def __init__( + self, + name: str, + input_type: type[float] | Literal[SlashCommandOptionType.number], + *, + description: str | None = None, + choices: None = None, + channel_types: None = None, + required: Literal[True], + default: Undefined = MISSING, + min_value: int | float | None = None, + max_value: int | float | None = None, + min_length: None = None, + max_length: None = None, + name_localizations: dict[str, str] | None = None, + description_localizations: dict[str, str] | None = None, + autocomplete: None = None, + ) -> None: ... + + # Overload for optional float options with default value and min_value/max_value constraints (integers or floats) + @overload + def __init__( + self, + name: str, + input_type: type[float] | Literal[SlashCommandOptionType.number], + *, + description: str | None = None, + choices: None = None, + channel_types: None = None, + required: bool = False, + default: Any, + min_value: int | float | None = None, + max_value: int | float | None = None, + min_length: None = None, + max_length: None = None, + name_localizations: dict[str, str] | None = None, + description_localizations: dict[str, str] | None = None, + autocomplete: None = None, + ) -> None: ... + + # Overload for required options with autocomplete (no choices or min/max constraints allowed) + @overload + def __init__( + self, + name: str, + input_type: type[str | int | float] = str, + *, + description: str | None = None, + choices: None = None, + channel_types: None = None, + required: Literal[True], + default: Undefined = MISSING, + min_value: None = None, + max_value: None = None, + min_length: None = None, + max_length: None = None, + autocomplete: ApplicationCommandOptionAutocomplete, + name_localizations: dict[str, str] | None = None, + description_localizations: dict[str, str] | None = None, + ) -> None: ... + + # Overload for optional options with autocomplete and default value (no choices or min/max constraints allowed) + @overload + def __init__( + self, + name: str, + input_type: type[str | int | float] = str, + *, + description: str | None = None, + choices: None = None, + channel_types: None = None, + required: bool = False, + default: Any, + min_value: None = None, + max_value: None = None, + min_length: None = None, + max_length: None = None, + autocomplete: ApplicationCommandOptionAutocomplete, + name_localizations: dict[str, str] | None = None, + description_localizations: dict[str, str] | None = None, + ) -> None: ... + + # Overload for required options of other types (bool, User, Member, Role, Attachment, Mentionable, etc.) + @overload + def __init__( + self, + name: str, + input_type: type[T] = str, + *, + description: str | None = None, + choices: None = None, + channel_types: None = None, + required: Literal[True], + default: Undefined = MISSING, + min_value: None = None, + max_value: None = None, + min_length: None = None, + max_length: None = None, + name_localizations: dict[str, str] | None = None, + description_localizations: dict[str, str] | None = None, + autocomplete: None = None, + ) -> None: ... + + # Overload for optional options of other types with default value (bool, User, Member, Role, Attachment, Mentionable, etc.) + @overload + def __init__( + self, + name: str, + input_type: type[T] = str, + *, + description: str | None = None, + choices: None = None, + channel_types: None = None, + required: bool = False, + default: Any, + min_value: None = None, + max_value: None = None, + min_length: None = None, + max_length: None = None, + name_localizations: dict[str, str] | None = None, + description_localizations: dict[str, str] | None = None, + autocomplete: None = None, + ) -> None: ... - if self.min_value is not None and not isinstance(self.min_value, minmax_types): - raise TypeError(f'Expected {minmax_typehint} for min_value, got "{type(self.min_value).__name__}"') - if self.max_value is not None and not isinstance(self.max_value, minmax_types): - raise TypeError(f'Expected {minmax_typehint} for max_value, got "{type(self.max_value).__name__}"') + def __init__( + self, + name: str, + input_type: InputType | type[T] = str, + *, + description: str | None = None, + choices: Sequence[OptionChoice[T]] | None = None, + channel_types: Sequence[ChannelType] | None = None, + required: bool = True, + default: Any | Undefined = MISSING, + min_value: int | float | None = None, + max_value: int | float | None = None, + min_length: int | None = None, + max_length: int | None = None, + name_localizations: dict[str, str] | None = None, + description_localizations: dict[str, str] | None = None, + autocomplete: ApplicationCommandOptionAutocomplete | None = None, + ) -> None: + self.name: str = name + + self.description: str | None = description + + self.choices: list[OptionChoice[T]] | None = list(choices) if choices is not None else None + if self.choices is not None: + if len(self.choices) > 25: + raise ValueError("Option choices cannot exceed 25 items.") + if not issubclass(input_type, str | int | float): + raise TypeError("Option choices can only be used with str, int, or float input types.") + + self.channel_types: list[ChannelType] | None = list(channel_types) if channel_types is not None else None + + self.input_type: SlashCommandOptionType - if self.min_length is not None: - if not isinstance(self.min_length, minmax_length_types): - raise TypeError( - f'Expected {minmax_length_typehint} for min_length, got "{type(self.min_length).__name__}"' - ) - if self.min_length < 0 or self.min_length > 6000: - raise AttributeError("min_length must be between 0 and 6000 (inclusive)") - if self.max_length is not None: - if not isinstance(self.max_length, minmax_length_types): - raise TypeError( - f'Expected {minmax_length_typehint} for max_length, got "{type(self.max_length).__name__}"' - ) - if self.max_length < 1 or self.max_length > 6000: - raise AttributeError("max_length must between 1 and 6000 (inclusive)") - - self.name_localizations = kwargs.pop("name_localizations", MISSING) - self.description_localizations = kwargs.pop("description_localizations", MISSING) - - if input_type is None: - raise TypeError("input_type cannot be NoneType.") - - @staticmethod - def _parse_type_alias(input_type: InputType) -> InputType: - if isinstance(input_type, TypeAliasType): - return input_type.__value__ - return input_type - - @staticmethod - def _strip_none_type(input_type): if isinstance(input_type, SlashCommandOptionType): - return input_type + self.input_type = input_type + elif issubclass(input_type, str): + self.input_type = SlashCommandOptionType.string + elif issubclass(input_type, bool): + self.input_type = SlashCommandOptionType.boolean + elif issubclass(input_type, int): + self.input_type = SlashCommandOptionType.integer + elif issubclass(input_type, float): + self.input_type = SlashCommandOptionType.number + elif issubclass(input_type, Attachment): + self.input_type = SlashCommandOptionType.attachment + elif issubclass(input_type, User | Member): + self.input_type = SlashCommandOptionType.user + elif issubclass(input_type, Role): + self.input_type = SlashCommandOptionType.role + elif issubclass(input_type, GuildChannel | Thread): + self.input_type = SlashCommandOptionType.channel + elif issubclass(input_type, Mentionable): + self.input_type = SlashCommandOptionType.mentionable - if input_type is type(None): - raise TypeError("Option type cannot be only NoneType") + self.required: bool = required if default is MISSING else False + self.default: Any | Undefined = default - args = () - if isinstance(input_type, types.UnionType): - args = get_args(input_type) - elif getattr(input_type, "__origin__", None) is Union: - args = get_args(input_type) - elif isinstance(input_type, tuple): - args = input_type + self.autocomplete: ApplicationCommandOptionAutocomplete | None = autocomplete - if args: - filtered = tuple(t for t in args if t is not type(None)) - if not filtered: - raise TypeError("Option type cannot be only NoneType") - if len(filtered) == 1: - return filtered[0] + self.min_value: int | float | None = min_value + self.max_value: int | float | None = max_value + if self.input_type not in (SlashCommandOptionType.integer, SlashCommandOptionType.number) and ( + self.min_value is not None or self.max_value is not None + ): + raise TypeError( + f"min_value and max_value can only be used with int or float input types, not {self.input_type.name}" + ) + if self.input_type is not SlashCommandOptionType.integer and ( + isinstance(self.min_value, float) or isinstance(self.max_value, float) + ): + raise TypeError("min_value and max_value must be integers when input_type is integer") - return filtered + self.min_length: int | None = min_length + self.max_length: int | None = max_length + if self.input_type is not SlashCommandOptionType.string and ( + self.min_length is not None or self.max_length is not None + ): + raise TypeError( + f"min_length and max_length can only be used with str input type, not {self.input_type.name}" + ) - return input_type + self.name_localizations: dict[str, str] | None = name_localizations + self.description_localizations: dict[str, str] | None = description_localizations - def to_dict(self) -> dict: - as_dict = { + def to_dict(self) -> dict[str, Any]: + as_dict: dict[str, Any] = { "name": self.name, "description": self.description, "type": self.input_type.value, "required": self.required, - "choices": [c.to_dict() for c in self.choices], "autocomplete": bool(self.autocomplete), } - if self.name_localizations is not MISSING: + if self.choices: + as_dict["choices"] = [choice.to_dict() for choice in self.choices] + if self.name_localizations: as_dict["name_localizations"] = self.name_localizations - if self.description_localizations is not MISSING: + if self.description_localizations: as_dict["description_localizations"] = self.description_localizations if self.channel_types: as_dict["channel_types"] = [t.value for t in self.channel_types] @@ -417,46 +588,12 @@ def to_dict(self) -> dict: return as_dict + @override def __repr__(self): - return f"" - - @property - def autocomplete(self) -> AutocompleteFunction | None: - """ - The autocomplete handler for the option. Accepts a callable (sync or async) - that takes a single required argument of :class:`AutocompleteContext` or two arguments - of :class:`discord.Cog` (being the command's cog) and :class:`AutocompleteContext`. - The callable must return an iterable of :class:`str` or :class:`OptionChoice`. - Alternatively, :func:`discord.utils.basic_autocomplete` may be used in place of the callable. - - Returns - ------- - Optional[AutocompleteFunction] - - .. versionchanged:: 2.7 - - .. note:: - Does not validate the input value against the autocomplete results. - """ - return self._autocomplete - - @autocomplete.setter - def autocomplete(self, value: AutocompleteFunction | None) -> None: - self._autocomplete = value - # this is done here so it does not have to be computed every time the autocomplete is invoked - if self._autocomplete is not None: - self._autocomplete._is_instance_method = ( # pyright: ignore [reportFunctionMemberAccess] - sum( - 1 - for param in inspect.signature(self._autocomplete).parameters.values() - if param.default == param.empty # pyright: ignore[reportAny] - and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) - ) - == 2 - ) + return f"