Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ These changes are available on the `master` branch, but have not yet been releas
([#2915](https://github.com/Pycord-Development/pycord/pull/2915))
- `View.message` being `None` when it had not been interacted with yet.
([#2916](https://github.com/Pycord-Development/pycord/pull/2916))
- Commands not properly parsing string-like / forward references annotations.
([#2919](https://github.com/Pycord-Development/pycord/pull/2919))

### Removed

Expand Down
48 changes: 33 additions & 15 deletions discord/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import sys
import types
from collections import OrderedDict
from collections.abc import Iterator
from enum import Enum
from typing import (
TYPE_CHECKING,
Expand All @@ -50,9 +51,7 @@
from ..enums import (
IntegrationType,
InteractionContextType,
MessageType,
SlashCommandOptionType,
try_enum,
)
from ..errors import (
ApplicationCommandError,
Expand All @@ -62,13 +61,20 @@
InvalidArgument,
ValidationError,
)
from ..member import Member
from ..message import Attachment, Message
from ..object import Object
from ..role import Role
from ..threads import Thread
from ..user import User
from ..utils import MISSING, async_all, find, maybe_coroutine, utcnow, warn_deprecated
from ..utils import (
MISSING,
async_all,
find,
maybe_coroutine,
resolve_annotation,
utcnow,
warn_deprecated,
)
from .context import ApplicationContext, AutocompleteContext
from .options import Option, OptionChoice

Expand Down Expand Up @@ -101,7 +107,7 @@

T = TypeVar("T")
CogT = TypeVar("CogT", bound="Cog")
Coro = TypeVar("Coro", bound=Callable[..., Coroutine[Any, Any, Any]])
Coro = Coroutine[Any, Any, T]

if TYPE_CHECKING:
P = ParamSpec("P")
Expand Down Expand Up @@ -190,7 +196,7 @@ class ApplicationCommand(_BaseCommand, Generic[CogT, P, T]):
cog = None

def __init__(self, func: Callable, **kwargs) -> None:
from ..ext.commands.cooldowns import BucketType, CooldownMapping, MaxConcurrency
from ..ext.commands.cooldowns import BucketType, CooldownMapping

cooldown = getattr(func, "__commands_cooldown__", kwargs.get("cooldown"))

Expand Down Expand Up @@ -314,7 +320,11 @@ def guild_only(self) -> bool:
"2.6",
reference="https://discord.com/developers/docs/change-log#userinstallable-apps-preview",
)
return InteractionContextType.guild in self.contexts and len(self.contexts) == 1
return (
self.contexts is not None
and InteractionContextType.guild in self.contexts
and len(self.contexts) == 1
)

@guild_only.setter
def guild_only(self, value: bool) -> None:
Expand Down Expand Up @@ -779,33 +789,41 @@ def _validate_parameters(self):
else:
self.options = self._parse_options(params)

def _check_required_params(self, params):
params = iter(params.items())
def _check_required_params(
self, params: OrderedDict[str, inspect.Parameter]
) -> Iterator[tuple[str, inspect.Parameter]]:
params_iter = iter(params.items())
required_params = (
["self", "context"] if self.attached_to_group or self.cog else ["context"]
)
for p in required_params:
try:
next(params)
next(params_iter)
except StopIteration:
raise ClientException(
f'Callback for {self.name} command is missing "{p}" parameter.'
)

return params
return params_iter

def _parse_options(self, params, *, check_params: bool = True) -> list[Option]:
def _parse_options(
self, params: OrderedDict[str, inspect.Parameter], *, check_params: bool = True
) -> list[Option]:
if check_params:
params = self._check_required_params(params)
params_iter = self._check_required_params(params)
else:
params = iter(params.items())
params_iter = iter(params.items())

final_options = []
for p_name, p_obj in params:
cache = {}
for p_name, p_obj in params_iter:
option = p_obj.annotation
if option == inspect.Parameter.empty:
option = str

if isinstance(option, str):
option = resolve_annotation(option, globals(), locals(), cache)

option = Option._strip_none_type(option)
if self._is_typing_literal(option):
literal_values = get_args(option)
Expand Down
Loading