diff --git a/mssql/base.py b/mssql/base.py index c69046ac..8c69e380 100644 --- a/mssql/base.py +++ b/mssql/base.py @@ -46,6 +46,26 @@ EDITION_AZURE_SQL_DB = 5 EDITION_AZURE_SQL_MANAGED_INSTANCE = 8 + +class BaseTokenManager: + invalidate_seconds = 600 # 10 minutes + + _token = None + last_token_creation = datetime.datetime.now() + + @property + def token(self): + td = datetime.timedelta(seconds=self.invalidate_seconds) + now = datetime.datetime.now() + if self._token is None or self.last_token_creation < now - td: + self._token = self.get_token() + self.last_token_creation = datetime.datetime.now() + return self._token + + def get_token(self): + raise NotImplementedError("This method should be implemented!") + + def encode_connection_string(fields): """Encode dictionary of keys and values as an ODBC connection String. @@ -360,8 +380,13 @@ def get_new_connection(self, conn_params): 'timeout': timeout, } if 'TOKEN' in conn_params: + token = conn_params['TOKEN'] + if isinstance(token, BaseTokenManager): + token_obj = token + token = token_obj.token + args['attrs_before'] = { - 1256: prepare_token_for_odbc(conn_params['TOKEN']) + 1256: prepare_token_for_odbc(token) } while conn is None: try: