11import asyncio
22import logging
33import signal
4+ import socket
45import typing as t
5- from typing import Any
66
7+ import aiohttp
78import arrow
89import discord
9- from aiohttp import ClientSession
1010from discord import Activity , AllowedMentions , Intents
1111from discord .client import _cleanup_loop
1212from discord .ext import commands
@@ -41,9 +41,12 @@ class ModmailBot(commands.Bot):
4141 def __init__ (self , ** kwargs ):
4242 self .config = CONFIG
4343 self .start_time : t .Optional [arrow .Arrow ] = None # arrow.utcnow()
44- self .http_session : t .Optional [ClientSession ] = None
44+ self .http_session : t .Optional [aiohttp . ClientSession ] = None
4545 self .dispatcher = Dispatcher ()
4646
47+ self ._connector = None
48+ self ._resolver = None
49+
4750 status = discord .Status .online
4851 activity = Activity (type = discord .ActivityType .listening , name = "users dming me!" )
4952 # listen to messages mentioning the bot or matching the prefix
@@ -65,6 +68,24 @@ def __init__(self, **kwargs):
6568 ** kwargs ,
6669 )
6770
71+ async def create_connectors (self , * args , ** kwargs ) -> None :
72+ """Re-create the connector and set up sessions before logging into Discord."""
73+ # Use asyncio for DNS resolution instead of threads so threads aren't spammed.
74+ self ._resolver = aiohttp .AsyncResolver ()
75+
76+ # Use AF_INET as its socket family to prevent HTTPS related problems both locally
77+ # and in production.
78+ self ._connector = aiohttp .TCPConnector (
79+ resolver = self ._resolver ,
80+ family = socket .AF_INET ,
81+ )
82+
83+ # Client.login() will call HTTPClient.static_login() which will create a session using
84+ # this connector attribute.
85+ self .http .connector = self ._connector
86+
87+ self .http_session = aiohttp .ClientSession (connector = self ._connector )
88+
6889 async def start (self , token : str , reconnect : bool = True ) -> None :
6990 """
7091 Start the bot.
@@ -74,8 +95,8 @@ async def start(self, token: str, reconnect: bool = True) -> None:
7495 """
7596 try :
7697 # create the aiohttp session
77- self . http_session = ClientSession ( loop = self .loop )
78- self .logger .trace ("Created ClientSession." )
98+ await self .create_connectors ( )
99+ self .logger .trace ("Created aiohttp. ClientSession." )
79100 # set start time to when we started the bot.
80101 # This is now, since we're about to connect to the gateway.
81102 # This should also be before we load any extensions, since if they have a load time, it should
@@ -122,7 +143,7 @@ def run(self, *args, **kwargs) -> None:
122143 except NotImplementedError :
123144 pass
124145
125- def stop_loop_on_completion (f : Any ) -> None :
146+ def stop_loop_on_completion (f : t . Any ) -> None :
126147 loop .stop ()
127148
128149 future = asyncio .ensure_future (self .start (* args , ** kwargs ), loop = loop )
@@ -164,10 +185,16 @@ async def close(self) -> None:
164185 except Exception :
165186 self .logger .error (f"Exception occured while removing cog { cog .name } " , exc_info = True )
166187
188+ await super ().close ()
189+
167190 if self .http_session :
168191 await self .http_session .close ()
169192
170- await super ().close ()
193+ if self ._connector :
194+ await self ._connector .close ()
195+
196+ if self ._resolver :
197+ await self ._resolver .close ()
171198
172199 def load_extensions (self ) -> None :
173200 """Load all enabled extensions."""
0 commit comments