Skip to content
This repository was archived by the owner on Oct 2, 2023. It is now read-only.

Commit 145d1e8

Browse files
committed
Added verification reactions (fixes #116)
1 parent 935a467 commit 145d1e8

File tree

2 files changed

+102
-13
lines changed

2 files changed

+102
-13
lines changed

information/user_info/cog.py

Lines changed: 99 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,38 @@
1+
import asyncio
12
import time
3+
from asyncio import Event
4+
from collections import defaultdict
25
from datetime import datetime, timedelta
36
from typing import Optional, Union
47

5-
from dateutil.relativedelta import relativedelta
6-
from discord import User, NotFound, Embed, Guild, Forbidden, HTTPException, Member, Role
7-
from discord.ext import commands
8-
from discord.ext.commands import Context, UserInputError, CommandError, max_concurrency, guild_only
9-
from discord.utils import snowflake_time
10-
118
from PyDrocsid.async_thread import semaphore_gather
129
from PyDrocsid.cog import Cog
1310
from PyDrocsid.command import reply, optional_permissions
14-
from PyDrocsid.config import Contributor, Config
15-
from PyDrocsid.database import db, filter_by, db_wrapper
11+
from PyDrocsid.config import Contributor
12+
from PyDrocsid.database import db, filter_by, db_wrapper, db_context
1613
from PyDrocsid.embeds import send_long_embed
1714
from PyDrocsid.emojis import name_to_emoji
15+
from PyDrocsid.logger import get_logger
1816
from PyDrocsid.settings import RoleSettings
1917
from PyDrocsid.translations import t
18+
from dateutil.relativedelta import relativedelta
19+
from discord import (
20+
User,
21+
NotFound,
22+
Embed,
23+
Guild,
24+
Forbidden,
25+
HTTPException,
26+
Member,
27+
Role,
28+
Message,
29+
MessageType,
30+
TextChannel,
31+
)
32+
from discord.ext import commands
33+
from discord.ext.commands import Context, UserInputError, CommandError, max_concurrency, guild_only
34+
from discord.utils import snowflake_time
35+
2036
from .colors import Colors
2137
from .models import Join, Leave, UsernameUpdate, Verification
2238
from .permissions import UserInfoPermission
@@ -25,8 +41,11 @@
2541
get_user_info_entries,
2642
get_user_status_entries,
2743
revoke_verification,
44+
send_alert,
2845
)
2946

47+
logger = get_logger(__name__)
48+
3049
tg = t.g
3150
t = t.user_info
3251

@@ -72,12 +91,36 @@ async def get_user(
7291
class UserInfoCog(Cog, name="User Information"):
7392
CONTRIBUTORS = [Contributor.Defelo]
7493

75-
async def on_member_join(self, member: Member):
76-
await Join.create(member.id, str(member))
94+
def __init__(self):
95+
self.join_events: dict[int, Event] = defaultdict(Event)
96+
self.join_id: dict[int, int] = {}
7797

78-
if "verified" not in Config.ROLES or not Config.ROLES["verified"][1]:
98+
async def on_message(self, message: Message):
99+
if message.type != MessageType.new_member:
79100
return
80101

102+
member_id: int = message.author.id
103+
await self.join_events[member_id].wait()
104+
self.join_events[member_id].clear()
105+
self.join_events.pop(member_id)
106+
107+
async with db_context():
108+
join: Join = await db.get(Join, id=self.join_id.pop(member_id))
109+
join.join_msg_channel_id = message.channel.id
110+
join.join_msg_id = message.id
111+
112+
async def on_member_join(self, member: Member):
113+
self.join_events[member.id].clear()
114+
115+
join: Join = await Join.create(member.id, str(member), member.joined_at.replace(microsecond=0))
116+
117+
async def trigger_join_event():
118+
await db.wait_for_close_event()
119+
self.join_id[member.id] = join.id
120+
self.join_events[member.id].set()
121+
122+
asyncio.create_task(trigger_join_event())
123+
81124
last_verification: Optional[Verification] = await db.first(
82125
filter_by(Verification, member=member.id).order_by(Verification.timestamp.desc()),
83126
)
@@ -89,6 +132,8 @@ async def on_member_join(self, member: Member):
89132
await member.add_roles(role)
90133

91134
async def on_member_remove(self, member: Member):
135+
self.join_events.pop(member.id, None)
136+
self.join_id.pop(member.id, None)
92137
await Leave.create(member.id, str(member))
93138

94139
async def on_member_nick_update(self, before: Member, after: Member):
@@ -100,10 +145,47 @@ async def on_user_update(self, before: User, after: User):
100145

101146
await UsernameUpdate.create(before.id, str(before), str(after), False)
102147

148+
async def update_verification_reaction(self, member: Member, add: bool):
149+
guild: Guild = member.guild
150+
151+
for _ in range(10):
152+
async with db_context():
153+
join: Optional[Join] = await db.get(
154+
Join,
155+
member=member.id,
156+
timestamp=member.joined_at.replace(microsecond=0),
157+
)
158+
if not join or not join.join_msg_id or not join.join_msg_channel_id:
159+
await asyncio.sleep(2)
160+
continue
161+
162+
channel_id: int = join.join_msg_channel_id
163+
message_id: int = join.join_msg_id
164+
break
165+
else:
166+
return
167+
168+
channel: Optional[TextChannel] = self.bot.get_channel(channel_id)
169+
if not channel:
170+
return
171+
172+
try:
173+
message: Message = await channel.fetch_message(message_id)
174+
except (NotFound, Forbidden, HTTPException):
175+
return
176+
177+
try:
178+
await message.remove_reaction(name_to_emoji["x" if add else "white_check_mark"], guild.me)
179+
await message.add_reaction(name_to_emoji["white_check_mark" if add else "x"])
180+
except Forbidden:
181+
await send_alert(guild, tg.could_not_add_reaction(message.channel.mention))
182+
103183
async def on_member_role_add(self, member: Member, role: Role):
104184
if role.id != await RoleSettings.get("verified"):
105185
return
106186

187+
asyncio.create_task(self.update_verification_reaction(member, add=True))
188+
107189
last_verification: Optional[Verification] = await db.first(
108190
filter_by(Verification, member=member.id).order_by(Verification.timestamp.desc()),
109191
)
@@ -113,8 +195,12 @@ async def on_member_role_add(self, member: Member, role: Role):
113195
await Verification.create(member.id, str(member), True)
114196

115197
async def on_member_role_remove(self, member: Member, role: Role):
116-
if role.id == await RoleSettings.get("verified"):
117-
await Verification.create(member.id, str(member), False)
198+
if role.id != await RoleSettings.get("verified"):
199+
return
200+
201+
asyncio.create_task(self.update_verification_reaction(member, add=False))
202+
203+
await Verification.create(member.id, str(member), False)
118204

119205
@revoke_verification.subscribe
120206
async def handle_revoke_verification(self, member: Member):

information/user_info/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@ class Join(db.Base):
1414
member: Union[Column, int] = Column(BigInteger)
1515
member_name: Union[Column, str] = Column(Text(collation="utf8mb4_bin"))
1616
timestamp: Union[Column, datetime] = Column(DateTime)
17+
join_msg_channel_id: Union[Column, int] = Column(BigInteger, nullable=True)
18+
join_msg_id: Union[Column, int] = Column(BigInteger, nullable=True)
1719

1820
@staticmethod
1921
async def create(member: int, member_name: str, timestamp: Optional[datetime] = None) -> Join:
2022
row = Join(member=member, member_name=member_name, timestamp=timestamp or datetime.utcnow())
2123
await db.add(row)
24+
await db.session.flush()
2225
return row
2326

2427
@staticmethod

0 commit comments

Comments
 (0)