Skip to content

Commit 51974dc

Browse files
adithyasolaiAdithya Solai
andauthored
add multi-user rotation support for RDS Managed Master Password feature (#96)
Co-authored-by: Adithya Solai <adisolai@amazon.com>
1 parent b45024d commit 51974dc

File tree

5 files changed

+518
-35
lines changed

5 files changed

+518
-35
lines changed

SecretsManagerRDSMariaDBRotationMultiUser/lambda_function.py

Lines changed: 102 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
logger = logging.getLogger()
1111
logger.setLevel(logging.INFO)
1212

13+
MAX_RDS_DB_INSTANCE_ARN_LENGTH = 256
14+
1315

1416
def lambda_handler(event, context):
1517
"""Secrets Manager RDS MariaDB Handler
@@ -180,9 +182,10 @@ def set_secret(service_client, arn, token):
180182
raise ValueError("Unable to log into database using current credentials for secret %s" % arn)
181183
conn.close()
182184

183-
# Now get the master arn from the current secret
185+
# Use the master arn from the current secret to fetch master secret contents
184186
master_arn = current_dict['masterarn']
185-
master_dict = get_secret_dict(service_client, master_arn, "AWSCURRENT")
187+
master_dict = get_secret_dict(service_client, master_arn, "AWSCURRENT", None, True)
188+
186189
if current_dict['host'] != master_dict['host'] and not is_rds_replica_database(current_dict, master_dict):
187190
# If current dict is a replica of the master dict, can proceed
188191
logger.error("setSecret: Current database host %s is not the same host as/rds replica of master %s" % (current_dict['host'], master_dict['host']))
@@ -404,7 +407,7 @@ def connect_and_authenticate(secret_dict, port, dbname, use_ssl):
404407
return None
405408

406409

407-
def get_secret_dict(service_client, arn, stage, token=None):
410+
def get_secret_dict(service_client, arn, stage, token=None, master_secret=False):
408411
"""Gets the secret dictionary corresponding for the secret arn, stage, and token
409412
410413
This helper function gets credentials for the arn and stage passed in and returns the dictionary by parsing the JSON string
@@ -414,9 +417,11 @@ def get_secret_dict(service_client, arn, stage, token=None):
414417
415418
arn (string): The secret ARN or other identifier
416419
420+
stage (string): The stage identifying the secret version
421+
417422
token (string): The ClientRequestToken associated with the secret version, or None if no validation is desired
418423
419-
stage (string): The stage identifying the secret version
424+
master_secret (boolean): A flag that indicates if we are getting a master secret.
420425
421426
Returns:
422427
SecretDictionary: Secret dictionary
@@ -427,7 +432,7 @@ def get_secret_dict(service_client, arn, stage, token=None):
427432
ValueError: If the secret is not valid JSON
428433
429434
"""
430-
required_fields = ['host', 'username', 'password']
435+
required_fields = ['host', 'username', 'password', 'engine']
431436

432437
# Only do VersionId validation against the stage if a token is passed in
433438
if token:
@@ -438,12 +443,24 @@ def get_secret_dict(service_client, arn, stage, token=None):
438443
secret_dict = json.loads(plaintext)
439444

440445
# Run validations against the secret
441-
if 'engine' not in secret_dict or secret_dict['engine'] != 'mariadb':
442-
raise KeyError("Database engine must be set to 'mariadb' in order to use this rotation lambda")
446+
if master_secret and (set(secret_dict.keys()) == set(['username', 'password'])):
447+
# If this is an RDS-made Master Secret, we can fetch `host` and other connection params
448+
# from the DescribeDBInstances RDS API using the DB Instance ARN as a filter.
449+
# The DB Instance ARN is fetched from the RDS-made Master Secret's System Tags.
450+
db_instance_arn = fetch_instance_arn_from_system_tags(service_client, arn)
451+
if db_instance_arn is not None:
452+
secret_dict = get_connection_params_from_rds_api(secret_dict, db_instance_arn)
453+
logger.info("setSecret: Successfully fetched connection params for Master Secret %s from DescribeDBInstances API." % arn)
454+
455+
# For non-RDS-made Master Secrets that are missing `host`, this will error below when checking for required connection params.
456+
443457
for field in required_fields:
444458
if field not in secret_dict:
445459
raise KeyError("%s key is missing from secret JSON" % field)
446460

461+
if secret_dict['engine'] != 'mariadb':
462+
raise KeyError("Database engine must be set to 'mariadb' in order to use this rotation lambda")
463+
447464
# Parse and return the secret JSON string
448465
return secret_dict
449466

@@ -511,3 +528,81 @@ def is_rds_replica_database(replica_dict, master_dict):
511528
# DB Instance identifiers are unique - can only be one result
512529
current_instance = instances[0]
513530
return master_instance_id == current_instance.get('ReadReplicaSourceDBInstanceIdentifier')
531+
532+
533+
def fetch_instance_arn_from_system_tags(service_client, secret_arn):
534+
"""Fetches DB Instance ARN from the given secret's metadata.
535+
536+
Fetches DB Instance ARN from the given secret's metadata.
537+
538+
Args:
539+
service_client (client): The secrets manager service client
540+
541+
secret_arn (String): The secret ARN used in a DescribeSecrets API call to fetch the secret's metadata.
542+
543+
Returns:
544+
db_instance_arn (String): The DB Instance ARN of the Primary RDS Instance
545+
546+
"""
547+
548+
metadata = service_client.describe_secret(SecretId=secret_arn)
549+
tags = metadata['Tags']
550+
551+
# Check if DB Instance ARN is present in secret Tags
552+
db_instance_arn = None
553+
for tag in tags:
554+
if tag['Key'].lower() == 'aws:rds:primarydbinstancearn':
555+
db_instance_arn = tag['Value']
556+
557+
# DB Instance ARN must be present in secret System Tags to use this work-around
558+
if db_instance_arn is None:
559+
logger.warning("setSecret: DB Instance ARN not present in Metadata System Tags for secret %s" % secret_arn)
560+
elif len(db_instance_arn) > MAX_RDS_DB_INSTANCE_ARN_LENGTH:
561+
logger.error("setSecret: %s is not a valid DB Instance ARN. It exceeds the maximum length of %d." % (db_instance_arn, MAX_RDS_DB_INSTANCE_ARN_LENGTH))
562+
raise ValueError("%s is not a valid DB Instance ARN. It exceeds the maximum length of %d." % (db_instance_arn, MAX_RDS_DB_INSTANCE_ARN_LENGTH))
563+
564+
return db_instance_arn
565+
566+
567+
def get_connection_params_from_rds_api(master_dict, master_instance_arn):
568+
"""Fetches connection parameters (`host`, `port`, etc.) from the DescribeDBInstances RDS API using `master_instance_arn` in the master secret metadata as a filter.
569+
570+
This helper function fetches connection parameters from the DescribeDBInstances RDS API using `master_instance_arn` in the master secret metadata as a filter.
571+
572+
Args:
573+
master_dict (dictionary): The master secret dictionary that will be updated with connection parameters.
574+
575+
master_instance_arn (string): The DB Instance ARN from master secret System Tags that will be used as a filter in DescribeDBInstances RDS API calls.
576+
577+
Returns:
578+
master_dict (dictionary): An updated master secret dictionary that now contains connection parameters such as `host`, `port`, etc.
579+
580+
Raises:
581+
Exception: If there is some error/throttling when calling the DescribeDBInstances RDS API
582+
583+
ValueError: If the DescribeDBInstances RDS API Response contains no Instances or more than 1 Instance
584+
"""
585+
# Setup the client
586+
rds_client = boto3.client('rds')
587+
588+
# Call DescribeDBInstances RDS API
589+
try:
590+
describe_response = rds_client.describe_db_instances(DBInstanceIdentifier=master_instance_arn)
591+
except Exception as err:
592+
logger.error("setSecret: Encountered API error while fetching connection parameters from DescribeDBInstances RDS API: %s" % err)
593+
raise Exception("Encountered API error while fetching connection parameters from DescribeDBInstances RDS API: %s" % err)
594+
595+
# Verify the instance was found
596+
instances = describe_response['DBInstances']
597+
if len(instances) == 0:
598+
logger.error("setSecret: %s is not a valid DB Instance ARN. No Instances found when using DescribeDBInstances RDS API to get connection params." % master_instance_arn)
599+
raise ValueError("%s is not a valid DB Instance ARN. No Instances found when using DescribeDBInstances RDS API to get connection params." % master_instance_arn)
600+
601+
# put connection parameters in master secret dictionary
602+
primary_instance = instances[0]
603+
master_dict['host'] = primary_instance['Endpoint']['Address']
604+
master_dict['port'] = primary_instance['Endpoint']['Port']
605+
master_dict['dbname'] = primary_instance.get('DBName', None) # `DBName` doesn't have to be present.
606+
master_dict['engine'] = primary_instance['Engine']
607+
608+
return master_dict

SecretsManagerRDSMySQLRotationMultiUser/lambda_function.py

Lines changed: 102 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
logger = logging.getLogger()
1111
logger.setLevel(logging.INFO)
1212

13+
MAX_RDS_DB_INSTANCE_ARN_LENGTH = 256
14+
1315

1416
def lambda_handler(event, context):
1517
"""Secrets Manager RDS MySQL Handler
@@ -180,9 +182,10 @@ def set_secret(service_client, arn, token):
180182
raise ValueError("Unable to log into database using current credentials for secret %s" % arn)
181183
conn.close()
182184

183-
# Now get the master arn from the current secret
185+
# Use the master arn from the current secret to fetch master secret contents
184186
master_arn = current_dict['masterarn']
185-
master_dict = get_secret_dict(service_client, master_arn, "AWSCURRENT")
187+
master_dict = get_secret_dict(service_client, master_arn, "AWSCURRENT", None, True)
188+
186189
if current_dict['host'] != master_dict['host'] and not is_rds_replica_database(current_dict, master_dict):
187190
# If current dict is a replica of the master dict, can proceed
188191
logger.error("setSecret: Current database host %s is not the same host as/rds replica of master %s" % (current_dict['host'], master_dict['host']))
@@ -416,7 +419,7 @@ def connect_and_authenticate(secret_dict, port, dbname, use_ssl):
416419
return None
417420

418421

419-
def get_secret_dict(service_client, arn, stage, token=None):
422+
def get_secret_dict(service_client, arn, stage, token=None, master_secret=False):
420423
"""Gets the secret dictionary corresponding for the secret arn, stage, and token
421424
422425
This helper function gets credentials for the arn and stage passed in and returns the dictionary by parsing the JSON string
@@ -426,9 +429,11 @@ def get_secret_dict(service_client, arn, stage, token=None):
426429
427430
arn (string): The secret ARN or other identifier
428431
432+
stage (string): The stage identifying the secret version
433+
429434
token (string): The ClientRequestToken associated with the secret version, or None if no validation is desired
430435
431-
stage (string): The stage identifying the secret version
436+
master_secret (boolean): A flag that indicates if we are getting a master secret.
432437
433438
Returns:
434439
SecretDictionary: Secret dictionary
@@ -439,7 +444,7 @@ def get_secret_dict(service_client, arn, stage, token=None):
439444
ValueError: If the secret is not valid JSON
440445
441446
"""
442-
required_fields = ['host', 'username', 'password']
447+
required_fields = ['host', 'username', 'password', 'engine']
443448

444449
# Only do VersionId validation against the stage if a token is passed in
445450
if token:
@@ -450,12 +455,24 @@ def get_secret_dict(service_client, arn, stage, token=None):
450455
secret_dict = json.loads(plaintext)
451456

452457
# Run validations against the secret
453-
if 'engine' not in secret_dict or secret_dict['engine'] != 'mysql':
454-
raise KeyError("Database engine must be set to 'mysql' in order to use this rotation lambda")
458+
if master_secret and (set(secret_dict.keys()) == set(['username', 'password'])):
459+
# If this is an RDS-made Master Secret, we can fetch `host` and other connection params
460+
# from the DescribeDBInstances RDS API using the DB Instance ARN as a filter.
461+
# The DB Instance ARN is fetched from the RDS-made Master Secret's System Tags.
462+
db_instance_arn = fetch_instance_arn_from_system_tags(service_client, arn)
463+
if db_instance_arn is not None:
464+
secret_dict = get_connection_params_from_rds_api(secret_dict, db_instance_arn)
465+
logger.info("setSecret: Successfully fetched connection params for Master Secret %s from DescribeDBInstances API." % arn)
466+
467+
# For non-RDS-made Master Secrets that are missing `host`, this will error below when checking for required connection params.
468+
455469
for field in required_fields:
456470
if field not in secret_dict:
457471
raise KeyError("%s key is missing from secret JSON" % field)
458472

473+
if secret_dict['engine'] != 'mysql':
474+
raise KeyError("Database engine must be set to 'mysql' in order to use this rotation lambda")
475+
459476
# Parse and return the secret JSON string
460477
return secret_dict
461478

@@ -561,3 +578,81 @@ def is_rds_replica_database(replica_dict, master_dict):
561578
# DB Instance identifiers are unique - can only be one result
562579
current_instance = instances[0]
563580
return master_instance_id == current_instance.get('ReadReplicaSourceDBInstanceIdentifier')
581+
582+
583+
def fetch_instance_arn_from_system_tags(service_client, secret_arn):
584+
"""Fetches DB Instance ARN from the given secret's metadata.
585+
586+
Fetches DB Instance ARN from the given secret's metadata.
587+
588+
Args:
589+
service_client (client): The secrets manager service client
590+
591+
secret_arn (String): The secret ARN used in a DescribeSecrets API call to fetch the secret's metadata.
592+
593+
Returns:
594+
db_instance_arn (String): The DB Instance ARN of the Primary RDS Instance
595+
596+
"""
597+
598+
metadata = service_client.describe_secret(SecretId=secret_arn)
599+
tags = metadata['Tags']
600+
601+
# Check if DB Instance ARN is present in secret Tags
602+
db_instance_arn = None
603+
for tag in tags:
604+
if tag['Key'].lower() == 'aws:rds:primarydbinstancearn':
605+
db_instance_arn = tag['Value']
606+
607+
# DB Instance ARN must be present in secret System Tags to use this work-around
608+
if db_instance_arn is None:
609+
logger.warning("setSecret: DB Instance ARN not present in Metadata System Tags for secret %s" % secret_arn)
610+
elif len(db_instance_arn) > MAX_RDS_DB_INSTANCE_ARN_LENGTH:
611+
logger.error("setSecret: %s is not a valid DB Instance ARN. It exceeds the maximum length of %d." % (db_instance_arn, MAX_RDS_DB_INSTANCE_ARN_LENGTH))
612+
raise ValueError("%s is not a valid DB Instance ARN. It exceeds the maximum length of %d." % (db_instance_arn, MAX_RDS_DB_INSTANCE_ARN_LENGTH))
613+
614+
return db_instance_arn
615+
616+
617+
def get_connection_params_from_rds_api(master_dict, master_instance_arn):
618+
"""Fetches connection parameters (`host`, `port`, etc.) from the DescribeDBInstances RDS API using `master_instance_arn` in the master secret metadata as a filter.
619+
620+
This helper function fetches connection parameters from the DescribeDBInstances RDS API using `master_instance_arn` in the master secret metadata as a filter.
621+
622+
Args:
623+
master_dict (dictionary): The master secret dictionary that will be updated with connection parameters.
624+
625+
master_instance_arn (string): The DB Instance ARN from master secret System Tags that will be used as a filter in DescribeDBInstances RDS API calls.
626+
627+
Returns:
628+
master_dict (dictionary): An updated master secret dictionary that now contains connection parameters such as `host`, `port`, etc.
629+
630+
Raises:
631+
Exception: If there is some error/throttling when calling the DescribeDBInstances RDS API
632+
633+
ValueError: If the DescribeDBInstances RDS API Response contains no Instances or more than 1 Instance
634+
"""
635+
# Setup the client
636+
rds_client = boto3.client('rds')
637+
638+
# Call DescribeDBInstances RDS API
639+
try:
640+
describe_response = rds_client.describe_db_instances(DBInstanceIdentifier=master_instance_arn)
641+
except Exception as err:
642+
logger.error("setSecret: Encountered API error while fetching connection parameters from DescribeDBInstances RDS API: %s" % err)
643+
raise Exception("Encountered API error while fetching connection parameters from DescribeDBInstances RDS API: %s" % err)
644+
645+
# Verify the instance was found
646+
instances = describe_response['DBInstances']
647+
if len(instances) == 0:
648+
logger.error("setSecret: %s is not a valid DB Instance ARN. No Instances found when using DescribeDBInstances RDS API to get connection params." % master_instance_arn)
649+
raise ValueError("%s is not a valid DB Instance ARN. No Instances found when using DescribeDBInstances RDS API to get connection params." % master_instance_arn)
650+
651+
# put connection parameters in master secret dictionary
652+
primary_instance = instances[0]
653+
master_dict['host'] = primary_instance['Endpoint']['Address']
654+
master_dict['port'] = primary_instance['Endpoint']['Port']
655+
master_dict['dbname'] = primary_instance.get('DBName', None) # `DBName` doesn't have to be present.
656+
master_dict['engine'] = primary_instance['Engine']
657+
658+
return master_dict

0 commit comments

Comments
 (0)