From 9016eba7cb5a2af59300359c8a105d76ed58420a Mon Sep 17 00:00:00 2001 From: TAG-Epic Date: Thu, 5 Jan 2023 09:20:09 +0100 Subject: [PATCH 01/16] test: minor improvements to test - Lint tests (this will cause a few typing issues which will need to be fixed) - Improve coverage Not sure if trying to get a high coverage is worth it, especially with some pretty bad tests. --- pyproject.toml | 2 +- tests/__init__.py | 0 tests/common/__init__.py | 0 tests/gateway/__init__.py | 0 tests/http/__init__.py | 0 tests/http/global_rate_limiter/__init__.py | 0 .../global_rate_limiter/test_unlimited.py | 13 ++++++++-- tests/http/test_file.py | 4 ++++ tests/http/test_ratelimit_storage.py | 24 ++++++++++++------- 9 files changed, 32 insertions(+), 11 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/common/__init__.py create mode 100644 tests/gateway/__init__.py create mode 100644 tests/http/__init__.py create mode 100644 tests/http/global_rate_limiter/__init__.py create mode 100644 tests/http/test_file.py diff --git a/pyproject.toml b/pyproject.toml index 3043400b8..7ccba06a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,7 @@ testpaths = ["tests"] pythonPlatform = "All" typeCheckingMode = "strict" pythonVersion = "3.8" -exclude = ["tests/"] +#exclude = ["tests/"] [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/common/__init__.py b/tests/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/gateway/__init__.py b/tests/gateway/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/http/__init__.py b/tests/http/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/http/global_rate_limiter/__init__.py b/tests/http/global_rate_limiter/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/http/global_rate_limiter/test_unlimited.py b/tests/http/global_rate_limiter/test_unlimited.py index eb96f3264..6b82fdac9 100644 --- a/tests/http/global_rate_limiter/test_unlimited.py +++ b/tests/http/global_rate_limiter/test_unlimited.py @@ -1,5 +1,4 @@ -from asyncio import TimeoutError as AsyncioTimeoutError -from asyncio import sleep, wait_for +from asyncio import CancelledError, TimeoutError as AsyncioTimeoutError, sleep, wait_for from pytest import mark, raises @@ -48,3 +47,13 @@ async def test_no_wait() -> None: with raises(RateLimitedError): async with rate_limiter.acquire(wait=False): ... + +@mark.asyncio +async def test_cancel() -> None: + rate_limiter = UnlimitedGlobalRateLimiter() + + with raises(CancelledError): + async with rate_limiter.acquire(): + raise CancelledError() + assert len(rate_limiter._pending_requests) == 0, "Pending request was not cleared" # type: ignore [reportPrivateUsage] + diff --git a/tests/http/test_file.py b/tests/http/test_file.py new file mode 100644 index 000000000..f72576bab --- /dev/null +++ b/tests/http/test_file.py @@ -0,0 +1,4 @@ +from nextcore.http import File + +def test_file_creation(): + _file = File("hello.txt", "hi!") diff --git a/tests/http/test_ratelimit_storage.py b/tests/http/test_ratelimit_storage.py index 721a5ef7b..be38bb01c 100644 --- a/tests/http/test_ratelimit_storage.py +++ b/tests/http/test_ratelimit_storage.py @@ -1,5 +1,4 @@ import gc -import sys from pytest import mark @@ -15,10 +14,10 @@ async def test_does_gc_collect_unused_buckets() -> None: metadata = BucketMetadata() bucket = Bucket(metadata) - await storage.store_bucket_by_nextcore_id(1, bucket) + await storage.store_bucket_by_nextcore_id("1", bucket) gc.collect() - assert await storage.get_bucket_by_nextcore_id(1) is None, "Bucket was not collected" + assert await storage.get_bucket_by_nextcore_id("1") is None, "Bucket was not collected" @mark.asyncio @@ -30,10 +29,10 @@ async def test_does_not_collect_dirty_buckets() -> None: await bucket.update(0, 1) - await storage.store_bucket_by_nextcore_id(1, bucket) + await storage.store_bucket_by_nextcore_id("1", bucket) gc.collect() - assert await storage.get_bucket_by_nextcore_id(1) is not None, "Bucket should not be collected" + assert await storage.get_bucket_by_nextcore_id("1") is not None, "Bucket should not be collected" @mark.asyncio @@ -58,10 +57,10 @@ async def test_stores_and_get_nextcore_id() -> None: metadata = BucketMetadata() bucket = Bucket(metadata) - assert await storage.get_bucket_by_nextcore_id(1) is None, "Bucket should not exist as it is not added yet" + assert await storage.get_bucket_by_nextcore_id("1") is None, "Bucket should not exist as it is not added yet" - await storage.store_bucket_by_nextcore_id(1, bucket) - assert await storage.get_bucket_by_nextcore_id(1) is bucket, "Bucket was not stored" + await storage.store_bucket_by_nextcore_id("1", bucket) + assert await storage.get_bucket_by_nextcore_id("1") is bucket, "Bucket was not stored" @mark.asyncio @@ -75,3 +74,12 @@ async def test_stores_and_get_discord_id() -> None: await storage.store_bucket_by_discord_id("1", bucket) assert await storage.get_bucket_by_discord_id("1") is bucket, "Bucket was not stored" + +@mark.asyncio +async def test_bucket_metadata_stored() -> None: + storage = RateLimitStorage() + metadata = BucketMetadata() + + await storage.store_metadata("1", metadata) + + assert await storage.get_bucket_metadata("1") is not None From 2f68a94583761ec170c9f8bac53f7a195863c575 Mon Sep 17 00:00:00 2001 From: TAG-Epic Date: Fri, 6 Jan 2023 10:02:40 +0100 Subject: [PATCH 02/16] test: more coverage --- tests/common/test_times_per.py | 18 ++++++++++++++ tests/http/test_bucket.py | 43 +++++++++++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/tests/common/test_times_per.py b/tests/common/test_times_per.py index 4e4b443ee..47047118a 100644 --- a/tests/common/test_times_per.py +++ b/tests/common/test_times_per.py @@ -3,6 +3,7 @@ from nextcore.common.errors import RateLimitedError from nextcore.common.times_per import TimesPer from tests.utils import match_time +from asyncio import Future, sleep, create_task @mark.asyncio @@ -41,6 +42,23 @@ async def test_exception_undos(): except: pass +@mark.asyncio +@match_time(.1, .01) +async def test_exception_undos_with_pending(): + rate_limiter = TimesPer(1, 1) + waiting_future: Future[None] = Future() + + async def wait_for_a_second(): + async with rate_limiter.acquire(): + waiting_future.set_result(None) + await sleep(.1) + raise + + create_task(wait_for_a_second()) + await waiting_future + + async with rate_limiter.acquire(): + ... @mark.asyncio async def test_no_wait(): diff --git a/tests/http/test_bucket.py b/tests/http/test_bucket.py index 41f2ddd30..d448bb27a 100644 --- a/tests/http/test_bucket.py +++ b/tests/http/test_bucket.py @@ -1,6 +1,7 @@ import asyncio -from pytest import mark +from pytest import mark, raises +from nextcore.common.errors import RateLimitedError from nextcore.http.bucket import Bucket from nextcore.http.bucket_metadata import BucketMetadata @@ -74,6 +75,46 @@ async def test_unlimited() -> None: async with bucket.acquire(): ... +@mark.asyncio +async def test_out_no_wait() -> None: + metadata = BucketMetadata(limit=1) + bucket = Bucket(metadata) + + await bucket.update(0, 1) + + with raises(RateLimitedError): + async with bucket.acquire(wait=False): + ... + +@mark.asyncio +@mark.skipif(True, reason="Currently broken") +@match_time(0, 0.1) +async def test_re_release() -> None: + metadata = BucketMetadata(limit=1) + bucket = Bucket(metadata) + + started: asyncio.Future[None] = asyncio.Future() + can_raise: asyncio.Future[None] = asyncio.Future() + + await bucket.update(1, 1) + + async def use(): + try: + async with bucket.acquire(): + started.set_result(None) + await can_raise + raise + except: + pass + + asyncio.create_task(use()) + + await started + can_raise.set_result(None) + async with bucket.acquire(): + ... + + # Dirty tests def test_clean_bucket_is_not_dirty() -> None: From fac435be376e770a7316fbc3ab91cfa49e317601 Mon Sep 17 00:00:00 2001 From: TAG-Epic Date: Sun, 15 Jan 2023 13:20:14 +0100 Subject: [PATCH 03/16] test: decompressor test --- tests/gateway/test_decompressor.py | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 tests/gateway/test_decompressor.py diff --git a/tests/gateway/test_decompressor.py b/tests/gateway/test_decompressor.py new file mode 100644 index 000000000..e7584ba99 --- /dev/null +++ b/tests/gateway/test_decompressor.py @@ -0,0 +1,11 @@ +from nextcore.gateway import Decompressor +import zlib + +def test_decompress(): + decompressor = Decompressor() + + content = b"Hello, world!" + compressed = zlib.compress(content) + Decompressor.ZLIB_SUFFIX + + assert decompressor.decompress(compressed) == content + From bb4c814e105ec2e0448225ce98c99aacebe632b0 Mon Sep 17 00:00:00 2001 From: TAG-Epic Date: Sun, 15 Jan 2023 13:41:22 +0100 Subject: [PATCH 04/16] style: lint --- nextcore/http/client/client.py | 6 +- nextcore/http/client/wrappers/channel.py | 100 +++++++++++++++--- tests/common/test_times_per.py | 9 +- tests/gateway/test_decompressor.py | 9 +- .../global_rate_limiter/test_unlimited.py | 10 +- tests/http/test_bucket.py | 7 +- tests/http/test_file.py | 1 + tests/http/test_ratelimit_storage.py | 1 + 8 files changed, 113 insertions(+), 30 deletions(-) diff --git a/nextcore/http/client/client.py b/nextcore/http/client/client.py index a1cf71ed5..74882aa64 100644 --- a/nextcore/http/client/client.py +++ b/nextcore/http/client/client.py @@ -21,10 +21,10 @@ from __future__ import annotations +from collections import defaultdict from logging import getLogger from time import time from typing import TYPE_CHECKING -from collections import defaultdict from aiohttp import ClientSession @@ -173,7 +173,9 @@ def __init__( "User-Agent": f"DiscordBot (https://github.com/nextsnake/nextcore, {nextcore_version})" } self.max_retries: int = max_rate_limit_retries - self.rate_limit_storages: defaultdict[str | None, RateLimitStorage] = defaultdict(RateLimitStorage) # User ID -> RateLimitStorage + self.rate_limit_storages: defaultdict[str | None, RateLimitStorage] = defaultdict( + RateLimitStorage + ) # User ID -> RateLimitStorage self.dispatcher: Dispatcher[Literal["request_response"]] = Dispatcher() # Internals diff --git a/nextcore/http/client/wrappers/channel.py b/nextcore/http/client/wrappers/channel.py index 00680d7f8..75a98300f 100644 --- a/nextcore/http/client/wrappers/channel.py +++ b/nextcore/http/client/wrappers/channel.py @@ -566,7 +566,12 @@ async def delete_channel( headers["X-Audit-Log-Reason"] = reason await self._request( - route, headers=headers, rate_limit_key=authentication.rate_limit_key, bucket_priority=bucket_priority, global_priority=global_priority, wait=wait, + route, + headers=headers, + rate_limit_key=authentication.rate_limit_key, + bucket_priority=bucket_priority, + global_priority=global_priority, + wait=wait, ) @overload @@ -1020,7 +1025,12 @@ async def create_reaction( headers = {"Authorization": str(authentication)} await self._request( - route, rate_limit_key=authentication.rate_limit_key, headers=headers, bucket_priority=bucket_priority, global_priority=global_priority, wait=wait + route, + rate_limit_key=authentication.rate_limit_key, + headers=headers, + bucket_priority=bucket_priority, + global_priority=global_priority, + wait=wait, ) async def delete_own_reaction( @@ -1084,7 +1094,12 @@ async def delete_own_reaction( headers = {"Authorization": str(authentication)} await self._request( - route, rate_limit_key=authentication.rate_limit_key, headers=headers, bucket_priority=bucket_priority, global_priority=global_priority, wait=wait + route, + rate_limit_key=authentication.rate_limit_key, + headers=headers, + bucket_priority=bucket_priority, + global_priority=global_priority, + wait=wait, ) async def delete_user_reaction( @@ -1159,7 +1174,12 @@ async def delete_user_reaction( headers = {"Authorization": str(authentication)} await self._request( - route, rate_limit_key=authentication.rate_limit_key, headers=headers, bucket_priority=bucket_priority, global_priority=global_priority, wait=wait + route, + rate_limit_key=authentication.rate_limit_key, + headers=headers, + bucket_priority=bucket_priority, + global_priority=global_priority, + wait=wait, ) async def get_reactions( @@ -1332,7 +1352,12 @@ async def delete_all_reactions( headers = {"Authorization": str(authentication)} await self._request( - route, rate_limit_key=authentication.rate_limit_key, headers=headers, bucket_priority=bucket_priority, global_priority=global_priority, wait=wait + route, + rate_limit_key=authentication.rate_limit_key, + headers=headers, + bucket_priority=bucket_priority, + global_priority=global_priority, + wait=wait, ) async def delete_all_reactions_for_emoji( @@ -1392,7 +1417,12 @@ async def delete_all_reactions_for_emoji( headers = {"Authorization": str(authentication)} await self._request( - route, rate_limit_key=authentication.rate_limit_key, headers=headers, bucket_priority=bucket_priority, global_priority=global_priority, wait=wait + route, + rate_limit_key=authentication.rate_limit_key, + headers=headers, + bucket_priority=bucket_priority, + global_priority=global_priority, + wait=wait, ) async def edit_message( @@ -1581,7 +1611,12 @@ async def delete_message( headers["X-Audit-Log-Reason"] = reason await self._request( - route, rate_limit_key=authentication.rate_limit_key, headers=headers, bucket_priority=bucket_priority, global_priority=global_priority, wait=wait + route, + rate_limit_key=authentication.rate_limit_key, + headers=headers, + bucket_priority=bucket_priority, + global_priority=global_priority, + wait=wait, ) async def bulk_delete_messages( @@ -1802,7 +1837,12 @@ async def get_channel_invites( headers = {"Authorization": str(authentication)} r = await self._request( - route, rate_limit_key=authentication.rate_limit_key, headers=headers, bucket_priority=bucket_priority, global_priority=global_priority, wait=wait + route, + rate_limit_key=authentication.rate_limit_key, + headers=headers, + bucket_priority=bucket_priority, + global_priority=global_priority, + wait=wait, ) # TODO: Make this verify the data from Discord @@ -2021,7 +2061,12 @@ async def delete_channel_permission( headers["X-Audit-Log-Reason"] = reason await self._request( - route, rate_limit_key=authentication.rate_limit_key, headers=headers, bucket_priority=bucket_priority, global_priority=global_priority, wait=wait + route, + rate_limit_key=authentication.rate_limit_key, + headers=headers, + bucket_priority=bucket_priority, + global_priority=global_priority, + wait=wait, ) async def follow_news_channel( @@ -2122,7 +2167,12 @@ async def trigger_typing_indicator( headers = {"Authorization": str(authentication)} await self._request( - route, rate_limit_key=authentication.rate_limit_key, headers=headers, bucket_priority=bucket_priority, global_priority=global_priority, wait=wait + route, + rate_limit_key=authentication.rate_limit_key, + headers=headers, + bucket_priority=bucket_priority, + global_priority=global_priority, + wait=wait, ) async def get_pinned_messages( @@ -2170,7 +2220,12 @@ async def get_pinned_messages( headers = {"Authorization": str(authentication)} r = await self._request( - route, rate_limit_key=authentication.rate_limit_key, headers=headers, bucket_priority=bucket_priority, global_priority=global_priority, wait=wait + route, + rate_limit_key=authentication.rate_limit_key, + headers=headers, + bucket_priority=bucket_priority, + global_priority=global_priority, + wait=wait, ) # TODO: Make this verify the data from Discord @@ -2232,7 +2287,14 @@ async def pin_message( if reason is not UNDEFINED: headers["X-Audit-Log-Reason"] = reason - await self._request(route, rate_limit_key=authentication.rate_limit_key, headers=headers, bucket_priority=bucket_priority, global_priority=global_priority, wait=wait) + await self._request( + route, + rate_limit_key=authentication.rate_limit_key, + headers=headers, + bucket_priority=bucket_priority, + global_priority=global_priority, + wait=wait, + ) async def unpin_message( self, @@ -2279,7 +2341,12 @@ async def unpin_message( headers = {"Authorization": str(authentication)} await self._request( - route, rate_limit_key=authentication.rate_limit_key, headers=headers, bucket_priority=bucket_priority, global_priority=global_priority, wait=wait + route, + rate_limit_key=authentication.rate_limit_key, + headers=headers, + bucket_priority=bucket_priority, + global_priority=global_priority, + wait=wait, ) async def group_dm_add_recipient( @@ -2322,7 +2389,12 @@ async def group_dm_add_recipient( headers = {"Authorization": str(authentication)} await self._request( - route, rate_limit_key=authentication.rate_limit_key, headers=headers, bucket_priority=bucket_priority, global_priority=global_priority, wait=wait + route, + rate_limit_key=authentication.rate_limit_key, + headers=headers, + bucket_priority=bucket_priority, + global_priority=global_priority, + wait=wait, ) async def group_dm_remove_recipient( diff --git a/tests/common/test_times_per.py b/tests/common/test_times_per.py index 47047118a..88e7868aa 100644 --- a/tests/common/test_times_per.py +++ b/tests/common/test_times_per.py @@ -1,9 +1,10 @@ +from asyncio import Future, create_task, sleep + from pytest import mark, raises from nextcore.common.errors import RateLimitedError from nextcore.common.times_per import TimesPer from tests.utils import match_time -from asyncio import Future, sleep, create_task @mark.asyncio @@ -42,8 +43,9 @@ async def test_exception_undos(): except: pass + @mark.asyncio -@match_time(.1, .01) +@match_time(0.1, 0.01) async def test_exception_undos_with_pending(): rate_limiter = TimesPer(1, 1) waiting_future: Future[None] = Future() @@ -51,7 +53,7 @@ async def test_exception_undos_with_pending(): async def wait_for_a_second(): async with rate_limiter.acquire(): waiting_future.set_result(None) - await sleep(.1) + await sleep(0.1) raise create_task(wait_for_a_second()) @@ -60,6 +62,7 @@ async def wait_for_a_second(): async with rate_limiter.acquire(): ... + @mark.asyncio async def test_no_wait(): rate_limiter = TimesPer(1, 1) diff --git a/tests/gateway/test_decompressor.py b/tests/gateway/test_decompressor.py index e7584ba99..ea354cbe8 100644 --- a/tests/gateway/test_decompressor.py +++ b/tests/gateway/test_decompressor.py @@ -1,11 +1,12 @@ -from nextcore.gateway import Decompressor import zlib +from nextcore.gateway import Decompressor + + def test_decompress(): decompressor = Decompressor() - + content = b"Hello, world!" compressed = zlib.compress(content) + Decompressor.ZLIB_SUFFIX - assert decompressor.decompress(compressed) == content - + assert decompressor.decompress(compressed) == content diff --git a/tests/http/global_rate_limiter/test_unlimited.py b/tests/http/global_rate_limiter/test_unlimited.py index 6b82fdac9..8287d0658 100644 --- a/tests/http/global_rate_limiter/test_unlimited.py +++ b/tests/http/global_rate_limiter/test_unlimited.py @@ -1,4 +1,6 @@ -from asyncio import CancelledError, TimeoutError as AsyncioTimeoutError, sleep, wait_for +from asyncio import CancelledError +from asyncio import TimeoutError as AsyncioTimeoutError +from asyncio import sleep, wait_for from pytest import mark, raises @@ -48,12 +50,12 @@ async def test_no_wait() -> None: async with rate_limiter.acquire(wait=False): ... + @mark.asyncio async def test_cancel() -> None: rate_limiter = UnlimitedGlobalRateLimiter() - + with raises(CancelledError): async with rate_limiter.acquire(): raise CancelledError() - assert len(rate_limiter._pending_requests) == 0, "Pending request was not cleared" # type: ignore [reportPrivateUsage] - + assert len(rate_limiter._pending_requests) == 0, "Pending request was not cleared" # type: ignore [reportPrivateUsage] diff --git a/tests/http/test_bucket.py b/tests/http/test_bucket.py index d448bb27a..d381eac24 100644 --- a/tests/http/test_bucket.py +++ b/tests/http/test_bucket.py @@ -1,8 +1,8 @@ import asyncio from pytest import mark, raises -from nextcore.common.errors import RateLimitedError +from nextcore.common.errors import RateLimitedError from nextcore.http.bucket import Bucket from nextcore.http.bucket_metadata import BucketMetadata from tests.utils import match_time @@ -75,6 +75,7 @@ async def test_unlimited() -> None: async with bucket.acquire(): ... + @mark.asyncio async def test_out_no_wait() -> None: metadata = BucketMetadata(limit=1) @@ -86,6 +87,7 @@ async def test_out_no_wait() -> None: async with bucket.acquire(wait=False): ... + @mark.asyncio @mark.skipif(True, reason="Currently broken") @match_time(0, 0.1) @@ -106,7 +108,7 @@ async def use(): raise except: pass - + asyncio.create_task(use()) await started @@ -115,7 +117,6 @@ async def use(): ... - # Dirty tests def test_clean_bucket_is_not_dirty() -> None: metadata = BucketMetadata() diff --git a/tests/http/test_file.py b/tests/http/test_file.py index f72576bab..a320903b9 100644 --- a/tests/http/test_file.py +++ b/tests/http/test_file.py @@ -1,4 +1,5 @@ from nextcore.http import File + def test_file_creation(): _file = File("hello.txt", "hi!") diff --git a/tests/http/test_ratelimit_storage.py b/tests/http/test_ratelimit_storage.py index be38bb01c..f1cdca421 100644 --- a/tests/http/test_ratelimit_storage.py +++ b/tests/http/test_ratelimit_storage.py @@ -75,6 +75,7 @@ async def test_stores_and_get_discord_id() -> None: await storage.store_bucket_by_discord_id("1", bucket) assert await storage.get_bucket_by_discord_id("1") is bucket, "Bucket was not stored" + @mark.asyncio async def test_bucket_metadata_stored() -> None: storage = RateLimitStorage() From 259e9d2de0f673d6f744e35def0d2645ac63234b Mon Sep 17 00:00:00 2001 From: TAG-Epic Date: Wed, 18 Jan 2023 00:01:38 +0100 Subject: [PATCH 05/16] test(integration): basic integration tests --- pyproject.toml | 1 + tests/integration/test_discord_api.py | 114 ++++++++++++++++++++++++++ 2 files changed, 115 insertions(+) create mode 100644 tests/integration/test_discord_api.py diff --git a/pyproject.toml b/pyproject.toml index 7ccba06a0..f761e583b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ typing-extensions = "^4.1.1" # Same as above orjson = {version = "^3.6.8", optional = true} types-orjson = {version = "^3.6.2", optional = true} discord-typings = "^0.5.0" +pytest-harmony = {path = "../pytest-harmony"} [tool.poetry.dev-dependencies] Sphinx = "^4.4.0" diff --git a/tests/integration/test_discord_api.py b/tests/integration/test_discord_api.py new file mode 100644 index 000000000..350ec84a0 --- /dev/null +++ b/tests/integration/test_discord_api.py @@ -0,0 +1,114 @@ +from __future__ import annotations +from pytest_harmony import TreeTests +import pytest +import typing +import os +from nextcore.http import BotAuthentication, HTTPClient +from nextcore.gateway import ShardManager, GatewayOpcode +from discord_typings import GuildData + +tree = TreeTests() + +@pytest.mark.asyncio +async def test_discord_api(): + await tree.run_tests() +# Get token +@tree.append() +async def get_token(state: dict[str, typing.Any]): + token = os.environ.get("TOKEN") + + if token is None: + pytest.skip("No TOKEN env var") + + state["token"] = token + state["authentication"] = BotAuthentication(token) + + http_client = HTTPClient() + await http_client.setup() + state["http_client"] = http_client + +@get_token.cleanup() +async def cleanup_get_token(state: dict[str, typing.Any]): + del state["token"] + del state["authentication"] + + http_client: HTTPClient = state["http_client"] + + await http_client.close() + +# Get token / create guild +@get_token.append() +async def create_guild(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + + guild = await http_client.create_guild(authentication, name="Test guild") + + state["guild"] = guild + +@create_guild.cleanup() +async def cleanup_create_guild(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + guild: GuildData = state["guild"] + + await http_client.delete_guild(authentication, guild["id"]) + + +# Get token / create guild / get audit logs +@create_guild.append() +async def get_audit_logs(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + guild: GuildData = state["guild"] + + logs = await http_client.get_guild_audit_log(authentication, guild["id"], limit=10) + + + assert logs["audit_log_entries"] == [] + +# Get token / connect to gateway +@get_token.append() +async def connect_to_gateway(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + + intents = 3276799 # Everything. TODO: Do this using a intents helper? + + gateway = ShardManager(authentication, intents, http_client) + + state["gateway"] = gateway + + await gateway.connect() + + await gateway.event_dispatcher.wait_for(lambda _: True, "READY") + +@connect_to_gateway.cleanup() +async def cleanup_connect_to_gateway(state: dict[str, typing.Any]): + gateway: ShardManager = state["gateway"] + + await gateway.close() + +# Get token / connect to gateway / get latency +@connect_to_gateway.append() +async def get_latency(state: dict[str, typing.Any]): + gateway: ShardManager = state["gateway"] + + for shard in gateway.active_shards: + # Heartbeats are calculated after heartbeating. This may be VERY slow on bots with many shards, so use a test bot. + await shard.raw_dispatcher.wait_for(lambda _: True, GatewayOpcode.HEARTBEAT_ACK) + print(shard.latency) + +# Get token / connect to gateway +@connect_to_gateway.append() +async def rescale_shards(state: dict[str, typing.Any]): + gateway: ShardManager = state["gateway"] + + await gateway.rescale_shards(5) + + assert len(gateway.active_shards) == 5 + +@rescale_shards.cleanup() +async def cleanup_rescale_shards(state: dict[str, typing.Any]): + gateway: ShardManager = state["gateway"] + await gateway.rescale_shards(1) From ac4ab405bf552676d7e04808ba74adb6b70b17d3 Mon Sep 17 00:00:00 2001 From: TAG-Epic Date: Thu, 19 Jan 2023 08:52:32 +0100 Subject: [PATCH 06/16] test(integration): more tests --- tests/integration/test_discord_api.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/tests/integration/test_discord_api.py b/tests/integration/test_discord_api.py index 350ec84a0..3294a75e1 100644 --- a/tests/integration/test_discord_api.py +++ b/tests/integration/test_discord_api.py @@ -5,13 +5,15 @@ import os from nextcore.http import BotAuthentication, HTTPClient from nextcore.gateway import ShardManager, GatewayOpcode -from discord_typings import GuildData +from discord_typings import GuildData, ReadyData tree = TreeTests() @pytest.mark.asyncio async def test_discord_api(): await tree.run_tests() + + # Get token @tree.append() async def get_token(state: dict[str, typing.Any]): @@ -35,6 +37,16 @@ async def cleanup_get_token(state: dict[str, typing.Any]): http_client: HTTPClient = state["http_client"] await http_client.close() + del state["http_client"] + + +# Get gateway +@get_token.append() +async def get_gateway(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + + await http_client.get_gateway() + # Get token / create guild @get_token.append() @@ -54,6 +66,8 @@ async def cleanup_create_guild(state: dict[str, typing.Any]): await http_client.delete_guild(authentication, guild["id"]) + del state["guild"] + # Get token / create guild / get audit logs @create_guild.append() @@ -81,7 +95,8 @@ async def connect_to_gateway(state: dict[str, typing.Any]): await gateway.connect() - await gateway.event_dispatcher.wait_for(lambda _: True, "READY") + ready_data: ReadyData = (await gateway.event_dispatcher.wait_for(lambda _: True, "READY"))[0] + state["bot_user"] = ready_data["user"] @connect_to_gateway.cleanup() async def cleanup_connect_to_gateway(state: dict[str, typing.Any]): @@ -89,6 +104,9 @@ async def cleanup_connect_to_gateway(state: dict[str, typing.Any]): await gateway.close() + del state["gateway"] + del state["bot_user"] + # Get token / connect to gateway / get latency @connect_to_gateway.append() async def get_latency(state: dict[str, typing.Any]): From 905e71f76198609d174e40e9d2c720e5be277fed Mon Sep 17 00:00:00 2001 From: TAG-Epic Date: Thu, 19 Jan 2023 09:47:52 +0100 Subject: [PATCH 07/16] feat: add create channel --- nextcore/http/client/wrappers/guild.py | 175 +++++++++++++++++++++++++ 1 file changed, 175 insertions(+) diff --git a/nextcore/http/client/wrappers/guild.py b/nextcore/http/client/wrappers/guild.py index 537847e50..7012dd373 100644 --- a/nextcore/http/client/wrappers/guild.py +++ b/nextcore/http/client/wrappers/guild.py @@ -48,6 +48,7 @@ RoleData, RolePositionData, Snowflake, + VideoQualityModes, VoiceRegionData, WelcomeChannelData, WelcomeScreenData, @@ -404,6 +405,180 @@ async def get_guild_channels( # TODO: Make this verify the data from Discord return await r.json() # type: ignore [no-any-return] + async def create_guild_channel( + self, + authentication: BotAuthentication, + guild_id: Snowflake, + name: str, + *, + type: int | None | UndefinedType, + topic: str | None | UndefinedType = UNDEFINED, + bitrate: int | None | UndefinedType = UNDEFINED, + user_limit: int | None | UndefinedType = UNDEFINED, + rate_limit_per_user: int | None | UndefinedType = UNDEFINED, + position: int | None | UndefinedType = UNDEFINED, + permission_overwrites: list[dict[str, Any]] | None | UndefinedType = UNDEFINED, + parent_id: Snowflake | None | UndefinedType = UNDEFINED, + nsfw: bool | None | UndefinedType = UNDEFINED, + rtc_region: str | None | UndefinedType = UNDEFINED, + video_quality_mode: VideoQualityModes | None | UndefinedType = UNDEFINED, + default_auto_archive_duration: int | None | UndefinedType = UNDEFINED, + default_reaction_emoji: Any | None | UndefinedType = UNDEFINED, + available_tags: list[Any] | None | UndefinedType = UNDEFINED, + default_sort_order: int | None | UndefinedType = UNDEFINED, + reason: str | UndefinedType = UNDEFINED, + bucket_priority: int = 0, + global_priority: int = 0, + wait: bool = True, + ) -> ChannelData: + """Creates a channel + + Read the `documentation `__ + + Parameters + ---------- + authentication: + The auth info. + guild_id: + The guild to create a channel in. + name: + The name of the channel. + + .. note:: + This has to be between 1-100 characters. + type: + The type of the channel. + topic: + The channel topic or forum guidelines if creating a forum channel. + + .. note:: + This has to be between 0-1024 characters + bitrate: + The voice bitrate. + + .. note:: + This has to be more than 8000. + + - If the guild is boost level 3 or it has the ``VIP_REGIONS`` feature, the max is 384000 + - If the guild is boost level 2 this is 256000 + - If the guild is boost level 1 this is 128000 + - Else this is 96000 + user_limit: + The most amount of people that can be in a voice channel at once + + .. note:: + Stage channels are not affected by this + rate_limit_per_user: + amount of seconds a user has to wait before sending another message or create another thread. + + .. note:: + This has to be between 0-21600 + .. note:: + Bots and members with ``MANAGE_MESSAGES`` or ``MANAGE_CHANNEL`` are immune. + position: + The sorting position of the channel inside its category. + permission_overwrites: + The channels permissions overwrites. + + .. note:: + The allow or deny keys default to 0. + .. note:: + Only permissions your bot has can be allowed/denied. + + ``MANAGE_ROLES`` can also only be allowed/denied by members with the ``ADMINISTRATOR`` permission + parent_id: + The category to put this channel under. + nsfw: + If the channel is age restricted. + rtc_region: + The voice region id to use. If this is :data:`None` this will automatically decide a voice region when needed. + video_quality_mode: + The quality mode for camera. + + Only affects voice channels and stage channels. + default_auto_archive_duration: + The default auto archive duration used by the Discord client in minutes. + default_reaction_emoji: + The default reaction emoji that will be shown in forum channels on posts + available_tags: + The tags that can be added to forum posts. + default_sort_order: + The default sort order for the forum posts. + reason: + The reason to put in the audit log + global_priority: + The priority of the request for the global rate-limiter. + bucket_priority: + The priority of the request for the bucket rate-limiter. + wait: + Wait when rate limited. + + This will raise :exc:`RateLimitedError` if set to :data:`False` and you are rate limited. + + Raises + ------ + RateLimitedError + You are rate limited, and ``wait`` was set to :data:`False` + """ + route = Route("POST", "/guilds/{guild_id}/channels", guild_id=guild_id) + + headers = {} + + # These have different behaviour when not provided and set to None. + # This only adds them if they are provided (not Undefined) + if reason is not UNDEFINED: + headers["X-Audit-Log-Reason"] = reason + + payload: dict[str, Any] = {"name": name} + + # These have different behaviour when not provided and set to None. + # This only adds them if they are provided (not Undefined) + if type is not UNDEFINED: + payload["type"] = type + if topic is not UNDEFINED: + payload["topic"] = topic + if bitrate is not UNDEFINED: + payload["bitrate"] = bitrate + if user_limit is not UNDEFINED: + payload["user_limit"] = user_limit + if rate_limit_per_user is not UNDEFINED: + payload["rate_limit_per_user"] = rate_limit_per_user + if position is not UNDEFINED: + payload["position"] = position + if permission_overwrites is not UNDEFINED: + payload["permission_overwrites"] = permission_overwrites + if parent_id is not UNDEFINED: + payload["parent_id"] = parent_id + if nsfw is not UNDEFINED: + payload["nsfw"] = nsfw + if rtc_region is not UNDEFINED: + payload["rtc_region"] = rtc_region + if video_quality_mode is not UNDEFINED: + payload["video_quality_mode"] = video_quality_mode + if default_auto_archive_duration is not UNDEFINED: + payload["default_auto_archive_duration"] = default_auto_archive_duration + if default_reaction_emoji is not UNDEFINED: + payload["default_reaction_emoji"] = default_reaction_emoji + if available_tags is not UNDEFINED: + payload["available_tags"] = available_tags + if default_sort_order is not UNDEFINED: + payload["default_sort_order"] = default_sort_order + + r = await self._request( + route, + json=payload, + rate_limit_key=authentication.rate_limit_key, + headers={"Authorization": str(authentication)}, + bucket_priority=bucket_priority, + global_priority=global_priority, + wait=wait, + ) + + # TODO: Make this verify the data from Discord + return await r.json() # type: ignore [no-any-return] + + + # TODO: Implement create guild channel async def modify_guild_channel_positions( From 04d15066351473ac85e69cbfbae5657175f09984 Mon Sep 17 00:00:00 2001 From: TAG-Epic Date: Thu, 19 Jan 2023 10:14:05 +0100 Subject: [PATCH 08/16] fix(http): use BotAuthentication instead of BearerAuthentication in modify_guild_channel --- nextcore/http/client/wrappers/channel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nextcore/http/client/wrappers/channel.py b/nextcore/http/client/wrappers/channel.py index 75a98300f..c957a1621 100644 --- a/nextcore/http/client/wrappers/channel.py +++ b/nextcore/http/client/wrappers/channel.py @@ -219,7 +219,7 @@ async def modify_group_dm( async def modify_guild_channel( self, - authentication: BearerAuthentication, + authentication: BotAuthentication, channel_id: Snowflake, *, name: str | UndefinedType = UNDEFINED, From 9567b539f006d510ef50a54302645265dddaea6f Mon Sep 17 00:00:00 2001 From: TAG-Epic Date: Thu, 19 Jan 2023 10:55:44 +0100 Subject: [PATCH 09/16] fix(http): change get_channel_messages around, before, after to be Snowflakes instead of ints --- nextcore/http/client/wrappers/channel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nextcore/http/client/wrappers/channel.py b/nextcore/http/client/wrappers/channel.py index c957a1621..ffa796192 100644 --- a/nextcore/http/client/wrappers/channel.py +++ b/nextcore/http/client/wrappers/channel.py @@ -633,9 +633,9 @@ async def get_channel_messages( authentication: BotAuthentication, channel_id: Snowflake, *, - around: int | UndefinedType = UNDEFINED, - before: int | UndefinedType = UNDEFINED, - after: int | UndefinedType = UNDEFINED, + around: Snowflake | UndefinedType = UNDEFINED, + before: Snowflake | UndefinedType = UNDEFINED, + after: Snowflake | UndefinedType = UNDEFINED, limit: int | UndefinedType = UNDEFINED, bucket_priority: int = 0, global_priority: int = 0, From 3cc7953c5f8e58c87b2be9ce7b7a96efc860385f Mon Sep 17 00:00:00 2001 From: TAG-Epic Date: Thu, 19 Jan 2023 10:57:06 +0100 Subject: [PATCH 10/16] fix(http): update overwrites with Snowflakes --- nextcore/http/client/wrappers/channel.py | 6 +- tests/integration/test_discord_api.py | 173 ++++++++++++++++++++++- 2 files changed, 169 insertions(+), 10 deletions(-) diff --git a/nextcore/http/client/wrappers/channel.py b/nextcore/http/client/wrappers/channel.py index ffa796192..71aa0c41e 100644 --- a/nextcore/http/client/wrappers/channel.py +++ b/nextcore/http/client/wrappers/channel.py @@ -580,7 +580,7 @@ async def get_channel_messages( authentication: BotAuthentication, channel_id: Snowflake, *, - around: int, + around: Snowflake, limit: int | UndefinedType, bucket_priority: int = 0, global_priority: int = 0, @@ -594,7 +594,7 @@ async def get_channel_messages( authentication: BotAuthentication, channel_id: Snowflake, *, - before: int, + before: Snowflake, limit: int | UndefinedType, bucket_priority: int = 0, global_priority: int = 0, @@ -608,7 +608,7 @@ async def get_channel_messages( authentication: BotAuthentication, channel_id: Snowflake, *, - after: int, + after: Snowflake, limit: int | UndefinedType, bucket_priority: int = 0, global_priority: int = 0, diff --git a/tests/integration/test_discord_api.py b/tests/integration/test_discord_api.py index 3294a75e1..ebaa36407 100644 --- a/tests/integration/test_discord_api.py +++ b/tests/integration/test_discord_api.py @@ -1,14 +1,19 @@ from __future__ import annotations -from pytest_harmony import TreeTests -import pytest -import typing + import os +import typing + +import pytest +from discord_typings import GuildData, ReadyData, ChannelData, MessageData, ThreadChannelData +from pytest_harmony import TreeTests + +from nextcore.gateway import GatewayOpcode, ShardManager from nextcore.http import BotAuthentication, HTTPClient -from nextcore.gateway import ShardManager, GatewayOpcode -from discord_typings import GuildData, ReadyData +from nextcore.http.errors import BadRequestError tree = TreeTests() + @pytest.mark.asyncio async def test_discord_api(): await tree.run_tests() @@ -29,6 +34,7 @@ async def get_token(state: dict[str, typing.Any]): await http_client.setup() state["http_client"] = http_client + @get_token.cleanup() async def cleanup_get_token(state: dict[str, typing.Any]): del state["token"] @@ -58,6 +64,7 @@ async def create_guild(state: dict[str, typing.Any]): state["guild"] = guild + @create_guild.cleanup() async def cleanup_create_guild(state: dict[str, typing.Any]): http_client: HTTPClient = state["http_client"] @@ -78,16 +85,16 @@ async def get_audit_logs(state: dict[str, typing.Any]): logs = await http_client.get_guild_audit_log(authentication, guild["id"], limit=10) - assert logs["audit_log_entries"] == [] + # Get token / connect to gateway @get_token.append() async def connect_to_gateway(state: dict[str, typing.Any]): http_client: HTTPClient = state["http_client"] authentication: BotAuthentication = state["authentication"] - intents = 3276799 # Everything. TODO: Do this using a intents helper? + intents = 3276799 # Everything. TODO: Do this using a intents helper? gateway = ShardManager(authentication, intents, http_client) @@ -98,6 +105,7 @@ async def connect_to_gateway(state: dict[str, typing.Any]): ready_data: ReadyData = (await gateway.event_dispatcher.wait_for(lambda _: True, "READY"))[0] state["bot_user"] = ready_data["user"] + @connect_to_gateway.cleanup() async def cleanup_connect_to_gateway(state: dict[str, typing.Any]): gateway: ShardManager = state["gateway"] @@ -107,6 +115,7 @@ async def cleanup_connect_to_gateway(state: dict[str, typing.Any]): del state["gateway"] del state["bot_user"] + # Get token / connect to gateway / get latency @connect_to_gateway.append() async def get_latency(state: dict[str, typing.Any]): @@ -117,6 +126,7 @@ async def get_latency(state: dict[str, typing.Any]): await shard.raw_dispatcher.wait_for(lambda _: True, GatewayOpcode.HEARTBEAT_ACK) print(shard.latency) + # Get token / connect to gateway @connect_to_gateway.append() async def rescale_shards(state: dict[str, typing.Any]): @@ -126,7 +136,156 @@ async def rescale_shards(state: dict[str, typing.Any]): assert len(gateway.active_shards) == 5 + @rescale_shards.cleanup() async def cleanup_rescale_shards(state: dict[str, typing.Any]): gateway: ShardManager = state["gateway"] await gateway.rescale_shards(1) + + +# Get token / create guild / create text channel +@create_guild.append() +async def create_text_channel(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + guild: GuildData = state["guild"] + + channel = await http_client.create_guild_channel(authentication, guild["id"], "test-text", type=0) + + state["channel"] = channel + +@create_text_channel.cleanup() +async def cleanup_create_text_channel(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + channel: ChannelData = state["channel"] + + await http_client.delete_channel(authentication, channel["id"]) + + del state["channel"] + + +# Get token / create guild / create text channel / get channel +@create_text_channel.append() +async def get_channel(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + channel: ChannelData = state["channel"] + + fetched_channel = await http_client.get_channel(authentication, channel["id"]) + + assert fetched_channel["id"] == channel["id"] + +# Get token / create guild / create text channel / modify text channel +@create_text_channel.append() +async def modify_text_channel(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + channel: ChannelData = state["channel"] + + modified_channel = await http_client.modify_guild_channel(authentication, channel["id"], name="cool-name", topic="This is a cool channel topic", nsfw=True, rate_limit_per_user=50, default_auto_archive_duration=1440) + + assert modified_channel["name"] == "cool-name" + assert modified_channel["topic"] == "This is a cool channel topic" + assert modified_channel["nsfw"] == True + assert modified_channel["rate_limit_per_user"] == 50 + assert modified_channel["default_auto_archive_duration"] == 1440 + + +# Get token / create guild / create text channel / create message +@create_text_channel.append() +async def create_message(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + channel: ChannelData = state["channel"] + + message = await http_client.create_message(authentication, channel["id"], content="Hello!") + + state["message"] = message + +@create_message.cleanup() +async def cleanup_create_message(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + channel: ChannelData = state["channel"] + message: MessageData = state["message"] + + await http_client.delete_message(authentication, channel["id"], message["id"]) + + del state["message"] + + +# Get token / create guild / create text channel / create message / create reaction +@create_message.append() +async def create_reaction(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + channel: ChannelData = state["channel"] + message: ChannelData = state["message"] + + await http_client.create_reaction(authentication, channel["id"], message["id"], "👋") + + +@create_message.cleanup() +async def cleanup_create_reaction(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + channel: ChannelData = state["channel"] + message: ChannelData = state["message"] + + await http_client.delete_own_reaction(authentication, channel["id"], message["id"], "👋") + +# Get token / create guild / create text channel / create message / create thread +@create_message.append() +async def create_thread(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + channel: ChannelData = state["channel"] + message: ChannelData = state["message"] + + thread = await http_client.start_thread_from_message(authentication, channel["id"], message["id"], "Test thread stuff") + await http_client.join_thread(authentication, thread["id"]) + + state["thread"] = thread + +@create_thread.cleanup() +async def cleanup_create_thread(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + thread: ThreadChannelData = state["thread"] + + await http_client.leave_thread(authentication, thread["id"]) + await http_client.delete_channel(authentication, thread["id"], reason="Bad thread") + +# Get token / create guild / create text channel / create message / create thread / modify thread +@create_thread.append() +async def modify_thread(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + thread: ThreadChannelData = state["thread"] + + updated_thread = await http_client.modify_thread(authentication, thread["id"], archived=True, locked=True) + updated_metadata = updated_thread["thread_metadata"] + + assert updated_metadata["archived"] + assert updated_metadata["locked"] + +@modify_thread.cleanup() +async def cleanup_modify_thread(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + thread: ThreadChannelData = state["thread"] + + await http_client.modify_thread(authentication, thread["id"], archived=False, locked=False) + +# Get token / create guild / create text channel / create message / create thread / modify thread +@create_message.append() +async def get_channel_history(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + channel: ChannelData = state["channel"] + message: MessageData = state["message"] + + messages = await http_client.get_channel_messages(authentication, channel["id"], around=message["id"], limit=1) + + assert len(messages) == 1 From 31f59153794fdb9df0066606540b8b8fdddcc083 Mon Sep 17 00:00:00 2001 From: TAG-Epic Date: Thu, 19 Jan 2023 12:05:01 +0100 Subject: [PATCH 11/16] test(integration): more tests --- tests/integration/test_discord_api.py | 74 +++++++++++++++++++++++++-- 1 file changed, 70 insertions(+), 4 deletions(-) diff --git a/tests/integration/test_discord_api.py b/tests/integration/test_discord_api.py index ebaa36407..9135cedd8 100644 --- a/tests/integration/test_discord_api.py +++ b/tests/integration/test_discord_api.py @@ -4,12 +4,13 @@ import typing import pytest -from discord_typings import GuildData, ReadyData, ChannelData, MessageData, ThreadChannelData +from discord_typings import EmbedData, GuildData, ReadyData, ChannelData, MessageData, ThreadChannelData from pytest_harmony import TreeTests from nextcore.gateway import GatewayOpcode, ShardManager from nextcore.http import BotAuthentication, HTTPClient from nextcore.http.errors import BadRequestError +from nextcore.http.file import File tree = TreeTests() @@ -235,6 +236,19 @@ async def cleanup_create_reaction(state: dict[str, typing.Any]): await http_client.delete_own_reaction(authentication, channel["id"], message["id"], "👋") + +# Get token / create guild / create text channel / create message / create reaction / get reactions +@create_reaction.append() +async def get_reactions(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + channel: ChannelData = state["channel"] + message: ChannelData = state["message"] + + reactions = await http_client.get_reactions(authentication, channel["id"], message["id"], "👋", limit=1, after=0) + assert len(reactions) == 1 + + # Get token / create guild / create text channel / create message / create thread @create_message.append() async def create_thread(state: dict[str, typing.Any]): @@ -278,14 +292,66 @@ async def cleanup_modify_thread(state: dict[str, typing.Any]): await http_client.modify_thread(authentication, thread["id"], archived=False, locked=False) -# Get token / create guild / create text channel / create message / create thread / modify thread +# Get token / create guild / create text channel / create message / get channel messages @create_message.append() -async def get_channel_history(state: dict[str, typing.Any]): +async def get_channel_messages(state: dict[str, typing.Any]): http_client: HTTPClient = state["http_client"] authentication: BotAuthentication = state["authentication"] channel: ChannelData = state["channel"] message: MessageData = state["message"] messages = await http_client.get_channel_messages(authentication, channel["id"], around=message["id"], limit=1) - assert len(messages) == 1 + + before_messages = await http_client.get_channel_messages(authentication, channel["id"], before=message["id"], limit=1) + assert len(before_messages) == 0 + + after_messages = await http_client.get_channel_messages(authentication, channel["id"], after=message["id"], limit=1) + assert len(after_messages) == 0 + +# Get token / create guild / create text channel / create message / edit message +@create_message.append() +async def edit_message(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + channel: ChannelData = state["channel"] + message: MessageData = state["message"] + + new_message = await http_client.edit_message(authentication, channel["id"], message["id"], content="foobar") + + assert new_message["content"] == "foobar" + +# Get token / create guild / create text channel / create message +@create_text_channel.append() +async def create_message_advanced(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + channel: ChannelData = state["channel"] + + embeds: list[EmbedData] = [ + { + "title": "Hi", + "description": "Hello" + } + ] + file = File("test.txt", "Test contents") + + message = await http_client.create_message(authentication, channel["id"], embeds=embeds, files=[file]) + + state["message"] = message + + assert len(message["embeds"]) == 1 + assert len(message["attachments"]) == 1 + +@create_message.cleanup() +async def cleanup_create_message(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + channel: ChannelData = state["channel"] + message: MessageData = state["message"] + + await http_client.delete_message(authentication, channel["id"], message["id"]) + + del state["message"] + + From 40b53e0d616ba74d4d8f11d13809e335142f07d6 Mon Sep 17 00:00:00 2001 From: TAG-Epic Date: Thu, 19 Jan 2023 14:24:23 +0100 Subject: [PATCH 12/16] test: add remove event listener test --- tests/common/test_dispatcher.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/common/test_dispatcher.py b/tests/common/test_dispatcher.py index b09327698..1f66a4cd5 100644 --- a/tests/common/test_dispatcher.py +++ b/tests/common/test_dispatcher.py @@ -180,3 +180,21 @@ def false_callback(event: str | None = None) -> bool: # Check for logging errors. error_count = len([record for record in caplog.records if record.levelname == "ERROR"]) assert error_count == 0, "Logged errors where present" + +@mark.asyncio +@mark.parametrize("event_name", [None, "test"]) +async def test_remove_listener(event_name): + failed: Future[None] = Future() + + async def handler(): + failed.set_result(None) + + dispatcher = Dispatcher() + + dispatcher.add_listener(handler, event_name) + dispatcher.remove_listener(handler, event_name) + + await dispatcher.dispatch(event_name) + + with raises(AsyncioTimeoutError): + await wait_for(dispatcher.wait_for(lambda: True, event_name), timeout=1) From e722dcc1f6d0ed7de201335bb6d04ea66dc40eb2 Mon Sep 17 00:00:00 2001 From: TAG-Epic Date: Thu, 19 Jan 2023 14:46:57 +0100 Subject: [PATCH 13/16] test(http): Mark reraise as stable again --- tests/http/test_bucket.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/http/test_bucket.py b/tests/http/test_bucket.py index d381eac24..ec487cfa1 100644 --- a/tests/http/test_bucket.py +++ b/tests/http/test_bucket.py @@ -89,7 +89,6 @@ async def test_out_no_wait() -> None: @mark.asyncio -@mark.skipif(True, reason="Currently broken") @match_time(0, 0.1) async def test_re_release() -> None: metadata = BucketMetadata(limit=1) @@ -105,7 +104,7 @@ async def use(): async with bucket.acquire(): started.set_result(None) await can_raise - raise + raise RuntimeError("Raising so bucket gets un-acquired") except: pass From afd3e1aaf84c31c345211169f37f91cbef38ae59 Mon Sep 17 00:00:00 2001 From: TAG-Epic Date: Thu, 19 Jan 2023 14:47:21 +0100 Subject: [PATCH 14/16] test(http): test that unlimited releases requests --- tests/http/global_rate_limiter/test_unlimited.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/http/global_rate_limiter/test_unlimited.py b/tests/http/global_rate_limiter/test_unlimited.py index 8287d0658..3d0956ad6 100644 --- a/tests/http/global_rate_limiter/test_unlimited.py +++ b/tests/http/global_rate_limiter/test_unlimited.py @@ -59,3 +59,15 @@ async def test_cancel() -> None: async with rate_limiter.acquire(): raise CancelledError() assert len(rate_limiter._pending_requests) == 0, "Pending request was not cleared" # type: ignore [reportPrivateUsage] + +@mark.asyncio +@match_time(1, .1) +async def test_reset() -> None: + rate_limiter = UnlimitedGlobalRateLimiter() + + rate_limiter.update(1) + + await sleep(.5) # Ensure it registers + + async with rate_limiter.acquire(): + ... From 745773b868f952e8ddea134141ea5626e071fdeb Mon Sep 17 00:00:00 2001 From: TAG-Epic Date: Thu, 19 Jan 2023 15:00:56 +0100 Subject: [PATCH 15/16] fix: delete webhook responds with no content --- nextcore/http/client/wrappers/webhook.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/nextcore/http/client/wrappers/webhook.py b/nextcore/http/client/wrappers/webhook.py index dcaee087d..3dc87374a 100644 --- a/nextcore/http/client/wrappers/webhook.py +++ b/nextcore/http/client/wrappers/webhook.py @@ -139,7 +139,6 @@ async def create_webhook( global_priority=global_priority, wait=wait, ) - # TODO: Make this verify the payload from discord? return await r.json() # type: ignore [no-any-return] @@ -572,7 +571,7 @@ async def delete_webhook( if reason is not UNDEFINED: headers["X-Audit-Log-Reason"] = reason - r = await self._request( + await self._request( route, rate_limit_key=authentication.rate_limit_key, headers=headers, @@ -581,9 +580,6 @@ async def delete_webhook( wait=wait, ) - # TODO: Make this verify the payload from discord? - return await r.json() # type: ignore [no-any-return] - async def delete_webhook_with_token( self, webhook_id: Snowflake, From 3fb3e9527ca0d4b4bfdfadc2fdf895faa1f79fa2 Mon Sep 17 00:00:00 2001 From: TAG-Epic Date: Thu, 19 Jan 2023 15:01:14 +0100 Subject: [PATCH 16/16] test: create webhook --- tests/integration/test_discord_api.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/tests/integration/test_discord_api.py b/tests/integration/test_discord_api.py index 9135cedd8..224cabcd1 100644 --- a/tests/integration/test_discord_api.py +++ b/tests/integration/test_discord_api.py @@ -2,9 +2,10 @@ import os import typing +from aiohttp import ContentTypeError import pytest -from discord_typings import EmbedData, GuildData, ReadyData, ChannelData, MessageData, ThreadChannelData +from discord_typings import EmbedData, GuildData, ReadyData, ChannelData, MessageData, ThreadChannelData, WebhookData from pytest_harmony import TreeTests from nextcore.gateway import GatewayOpcode, ShardManager @@ -193,6 +194,25 @@ async def modify_text_channel(state: dict[str, typing.Any]): assert modified_channel["default_auto_archive_duration"] == 1440 +# Get token / create guild / create text channel / create webhook +@create_text_channel.append() +async def create_webhook(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + channel: ChannelData = state["channel"] + + webhook = await http_client.create_webhook(authentication, channel["id"], "test") + + state["webhook"] = webhook + +@create_webhook.cleanup() +async def cleanup_create_webhook(state: dict[str, typing.Any]): + http_client: HTTPClient = state["http_client"] + authentication: BotAuthentication = state["authentication"] + webhook: WebhookData = state["webhook"] + + await http_client.delete_webhook(authentication, webhook["id"]) + # Get token / create guild / create text channel / create message @create_text_channel.append() async def create_message(state: dict[str, typing.Any]): @@ -343,8 +363,8 @@ async def create_message_advanced(state: dict[str, typing.Any]): assert len(message["embeds"]) == 1 assert len(message["attachments"]) == 1 -@create_message.cleanup() -async def cleanup_create_message(state: dict[str, typing.Any]): +@create_message_advanced.cleanup() +async def cleanup_create_message_advanced(state: dict[str, typing.Any]): http_client: HTTPClient = state["http_client"] authentication: BotAuthentication = state["authentication"] channel: ChannelData = state["channel"]