diff --git a/README.md b/README.md index c53c020..6f59721 100644 --- a/README.md +++ b/README.md @@ -254,6 +254,7 @@ $ python export_db.py --help usage: export_db.py [-h] [--users] [--workspace] [--notebook-format {DBC,SOURCE,HTML}] [--download] [--libs] [--clusters] [--jobs] [--metastore] [--secrets] + [--scope-names SCOPE_NAMES [SCOPE_NAMES ...]] [--skip-scope-acl] [--metastore-unicode] [--cluster-name CLUSTER_NAME] [--database DATABASE] [--iam IAM] [--skip-failed] [--mounts] [--azure] [--profile PROFILE] @@ -286,6 +287,10 @@ optional arguments: --metastore log all the metastore table definitions --metastore-unicode log all the metastore table definitions including unicode characters + --secrets log all the secret scopes + --scope-names SCOPE_NAMES [SCOPE_NAMES ...] + log only the specified secret scope + --skip-scope-acl Skip logging the secret ACLs during export --table-acls log all table ACL grant and deny statements --cluster-name CLUSTER_NAME Cluster name to export the metastore to a specific @@ -335,6 +340,7 @@ usage: import_db.py [-h] [--users] [--workspace] [--workspace-top-level] [--archive-missing] [--libs] [--clusters] [--jobs] [--metastore] [--metastore-unicode] [--get-repair-log] [--cluster-name CLUSTER_NAME] [--skip-failed] [--azure] + [--secrets] [--scope-names SCOPE_NAMES [SCOPE_NAMES ...]] [--skip-scope-acl] [--profile PROFILE] [--single-user SINGLE_USER] [--no-ssl-verification] [--silent] [--debug] [--set-export-dir SET_EXPORT_DIR] [--pause-all-jobs] @@ -376,6 +382,10 @@ optional arguments: cluster. Cluster will be started. --skip-failed Skip missing users that do not exist when importing user notebooks + --secrets Import all secret scopes + --scope-names SCOPE_NAMES [SCOPE_NAMES ...] + import only the specified secret scope + --skip-scope-acl Skip importing the secret ACLs during export --azure Run on Azure. (Default is AWS) --profile PROFILE Profile to parse the credentials --no-ssl-verification diff --git a/dbclient/ClustersClient.py b/dbclient/ClustersClient.py index 631b105..c2591c8 100644 --- a/dbclient/ClustersClient.py +++ b/dbclient/ClustersClient.py @@ -12,7 +12,7 @@ def __init__(self, configs, checkpoint_service): super().__init__(configs) self._checkpoint_service = checkpoint_service self.groups_to_keep = configs.get("groups_to_keep", False) - self.skip_missing_users = configs['skip_missing_users'] + self.skip_missing_users = configs.get('skip_missing_users', False) create_configs = {'num_workers', 'autoscale', diff --git a/dbclient/SecretsClient.py b/dbclient/SecretsClient.py index 82bce96..ce38a71 100644 --- a/dbclient/SecretsClient.py +++ b/dbclient/SecretsClient.py @@ -31,9 +31,12 @@ def get_secret_value(self, scope_name, secret_key, cid, ec_id, error_logger): else: return results_get.get('data') - def log_all_secrets(self, cluster_name=None, log_dir='secret_scopes/'): + def secret_scope_map(self, scope_name: str): + return {'name': scope_name} + + def log_all_secrets(self, cluster_name=None, log_dir='secret_scopes/', secret_scopes: list = None): scopes_dir = self.get_export_dir() + log_dir - scopes_list = self.get_secret_scopes_list() + scopes_list = map(self.secret_scope_map, secret_scopes) if secret_scopes else self.get_secret_scopes_list() error_logger = logging_utils.get_error_logger( wmconstants.WM_EXPORT, wmconstants.SECRET_OBJECT, self.get_export_dir()) os.makedirs(scopes_dir, exist_ok=True) @@ -130,58 +133,67 @@ def get_all_other_permissions(scope_name, acl_dict): scope_perms.pop('MANAGE') return scope_perms - def import_all_secrets(self, log_dir='secret_scopes/'): + def import_all_secrets(self, log_dir='secret_scopes/', secret_scopes: list = [], skip_acl_updates: bool = False): scopes_dir = self.get_export_dir() + log_dir error_logger = logging_utils.get_error_logger( wmconstants.WM_IMPORT, wmconstants.SECRET_OBJECT, self.get_export_dir()) - scopes_acl_dict = self.load_acl_dict() + scopes_acl_dict = self.load_acl_dict() if not skip_acl_updates else {} + contains_elements = any(secret_scopes) for root, subdirs, files in self.walk(scopes_dir): for scope_name in files: - file_path = root + scope_name - # print('Log file: ', file_path) - # check if scopes acls are empty, then skip - if scopes_acl_dict.get(scope_name, None) is None: - print("Scope is empty with no manage permissions. Skipping...") - continue - # check if users has can manage perms then we can add during creation time - has_user_manage = self.has_users_can_manage_permission(scope_name, scopes_acl_dict) - create_scope_args = {'scope': scope_name} - if has_user_manage: - create_scope_args['initial_manage_principal'] = 'users' - other_permissions = self.get_all_other_permissions(scope_name, scopes_acl_dict) - create_resp = self.post('/secrets/scopes/create', create_scope_args) - logging_utils.log_response_error( - error_logger, create_resp, ignore_error_list=['RESOURCE_ALREADY_EXISTS']) - if other_permissions: - # use this dict minus the `users:MANAGE` permissions and apply the other permissions to the scope - for perm, principal_list in other_permissions.items(): - put_acl_args = {"scope": scope_name, - "permission": perm} - for x in principal_list: - put_acl_args["principal"] = x - logging.info(put_acl_args) - put_resp = self.post('/secrets/acls/put', put_acl_args) - logging_utils.log_response_error(error_logger, put_resp) - # loop through the scope and create the k/v pairs - with open(file_path, 'r', encoding="utf-8") as fp: - for s in fp: - s_dict = json.loads(s) - k = s_dict.get('name') - v = s_dict.get('value') - if 'WARNING: skipped' in v: - error_logger.error(f"Skipping scope {scope_name} as value is corrupted due to being too large \n") + if (contains_elements and scope_name in secret_scopes) or (not contains_elements): + file_path = root + scope_name + # print('Log file: ', file_path) + if not skip_acl_updates: + # check if scopes acls are empty, then skip + if scopes_acl_dict.get(scope_name, None) is None: + print("Scope is empty with no manage permissions. Skipping...") continue - try: - put_secret_args = {'scope': scope_name, - 'key': k, - 'string_value': base64.b64decode(v.encode('ascii')).decode('ascii')} - put_resp = self.post('/secrets/put', put_secret_args) - logging_utils.log_response_error(error_logger, put_resp) - except Exception as error: - if "Invalid base64-encoded string" in str(error) or 'decode' in str(error) or "padding" in str(error): - error_msg = f"secret_scope: {scope_name} has invalid invalid data characters: {str(error)} skipping.. and logging to error file." - logging.error(error_msg) - error_logger.error(error_msg) + # check if users has can manage perms then we can add during creation time + has_user_manage = self.has_users_can_manage_permission(scope_name, scopes_acl_dict) + create_scope_args = {'scope': scope_name} + if has_user_manage: + create_scope_args['initial_manage_principal'] = 'users' + other_permissions = self.get_all_other_permissions(scope_name, scopes_acl_dict) + create_resp = self.post('/secrets/scopes/create', create_scope_args) + logging_utils.log_response_error( + error_logger, create_resp, ignore_error_list=['RESOURCE_ALREADY_EXISTS']) + if other_permissions: + # use this dict minus the `users:MANAGE` permissions and apply the other permissions to the scope + for perm, principal_list in other_permissions.items(): + put_acl_args = {"scope": scope_name, + "permission": perm} + for x in principal_list: + put_acl_args["principal"] = x + logging.info(put_acl_args) + put_resp = self.post('/secrets/acls/put', put_acl_args) + logging_utils.log_response_error(error_logger, put_resp) + else: + logging.info("Skipping ACL Updates for {}".format(scope_name)) + # loop through the scope and create the k/v pairs + with open(file_path, 'r', encoding="utf-8") as fp: + for s in fp: + s_dict = json.loads(s) + k = s_dict.get('name') + v = s_dict.get('value') + if 'WARNING: skipped' in v: + error_logger.error( + f"Skipping scope {scope_name} as value is corrupted due to being too large \n") + continue + try: + put_secret_args = {'scope': scope_name, + 'key': k, + 'string_value': base64.b64decode(v.encode('ascii')).decode('ascii')} + put_resp = self.post('/secrets/put', put_secret_args) + logging_utils.log_response_error(error_logger, put_resp) + except Exception as error: + if "Invalid base64-encoded string" in str(error) or 'decode' in str( + error) or "padding" in str(error): + error_msg = f"secret_scope: {scope_name} has invalid invalid data characters: {str(error)} skipping.. and logging to error file." + logging.error(error_msg) + error_logger.error(error_msg) - else: - raise error + else: + raise error + else: + logging.info("Skipping import of {}".format(scope_name)) diff --git a/dbclient/parser.py b/dbclient/parser.py index c478e71..e8ece4f 100644 --- a/dbclient/parser.py +++ b/dbclient/parser.py @@ -109,6 +109,14 @@ def get_export_parser(): parser.add_argument('--secrets', action='store_true', help='log all the secret scopes') + # get specific secret scope + parser.add_argument('--scope-names', type=str, action='store', nargs="+", + help='log only the specified secret scope') + + # Export Secret ACLs + parser.add_argument('--skip-scope-acl', action='store_true', default= False, + help='Skip logging the secret ACLs during export') + # get all mlflow experiments parser.add_argument('--mlflow-experiments', action='store_true', help='log all the mlflow experiments') @@ -326,6 +334,14 @@ def get_import_parser(): parser.add_argument('--secrets', action='store_true', help='Import all secret scopes') + # import only specific secret scope + parser.add_argument('--scope-names', type=str, action='store', nargs="+", + help='import only the specified secret scope') + + # Skip Importing Secret ACLs + parser.add_argument('--skip-scope-acl', action='store_true', default= False, + help='Skip importing the secret ACLs during export') + # import all mlflow experiments parser.add_argument('--mlflow-experiments', action='store_true', help='Import all the mlflow experiments') @@ -426,8 +442,8 @@ def build_client_config(profile, url, token, args): 'skip_failed': args.skip_failed, 'debug': args.debug, 'file_format': str(args.notebook_format), - 'timeout':args.timeout, - 'skip_missing_users':args.skip_missing_users + 'timeout':args.timeout if 'timeout' in args else 86400, + 'skip_missing_users':args.skip_missing_users if 'skip_missing_users' in args else True } # this option only exists during imports so we check for existence if 'overwrite_notebooks' in args: @@ -450,6 +466,7 @@ def build_client_config(profile, url, token, args): config['num_parallel'] = args.num_parallel config['retry_total'] = args.retry_total config['retry_backoff'] = args.retry_backoff + config['timeout'] = args.timeout if 'timeout' in args else 86400 return config diff --git a/export_db.py b/export_db.py index 8a0d12e..91ff4f1 100644 --- a/export_db.py +++ b/export_db.py @@ -188,12 +188,16 @@ def main(): if not args.cluster_name: print("Please provide an existing cluster name w/ --cluster-name option\n") return - print("Export the secret scopes configs at {0}".format(now)) + scope_names = args.scope_names if args.scope_names else None + print("Export the secret scopes configs at {0} for scope: {1}".format(now, scope_names)) start = timer() sc = SecretsClient(client_config, checkpoint_service) # log job configs - sc.log_all_secrets(args.cluster_name) - sc.log_all_secrets_acls() + sc.log_all_secrets(args.cluster_name, secret_scopes=scope_names) + if not args.skip_scope_acl: + sc.log_all_secrets_acls() + else: + print("Skipping ACL Export") end = timer() print("Complete Secrets Export Time: " + str(timedelta(seconds=end - start))) diff --git a/import_db.py b/import_db.py index 49ac607..342f480 100644 --- a/import_db.py +++ b/import_db.py @@ -6,6 +6,7 @@ import logging_utils import os + # python 3.6 def main(): # define a parser to identify what component to import / export @@ -130,7 +131,7 @@ def main(): hive_c = HiveClient(client_config, checkpoint_service) # log job configs hive_c.import_hive_metastore(cluster_name=args.cluster_name, has_unicode=args.metastore_unicode, - should_repair_table=args.repair_metastore_tables) + should_repair_table=args.repair_metastore_tables) end = timer() print("Complete Metastore Import Time: " + str(timedelta(seconds=end - start))) @@ -149,13 +150,15 @@ def main(): # log table ACLS configs notebook_exit_value = table_acls_c.import_table_acls() end = timer() - print(f'Complete Table ACLs with exit value: {json.dumps(notebook_exit_value)}, Import Time: {timedelta(seconds=end - start)}') + print( + f'Complete Table ACLs with exit value: {json.dumps(notebook_exit_value)}, Import Time: {timedelta(seconds=end - start)}') if args.secrets: print("Import secret scopes configs at {0}".format(now)) start = timer() sc = SecretsClient(client_config, checkpoint_service) - sc.import_all_secrets() + scopes = args.scope_names if args.scope_names else [] + sc.import_all_secrets(secret_scopes=scopes, skip_acl_updates=args.skip_scope_acl) end = timer() print("Complete Secrets Import Time: " + str(timedelta(seconds=end - start))) @@ -176,7 +179,7 @@ def main(): jobs_c.pause_all_jobs(False) end = timer() print("Unpaused all jobs time: " + str(timedelta(seconds=end - start))) - + if args.import_pause_status: print("Importing pause status for migrated jobs {0}".format(now)) start = timer() @@ -186,7 +189,6 @@ def main(): end = timer() print("Import pause jobs time: " + str(timedelta(seconds=end - start))) - if args.delete_all_jobs: print("Delete all current jobs {0}".format(now)) start = timer() @@ -254,14 +256,17 @@ def main(): print("Importing MLflow experiments.") mlflow_c = MLFlowClient(client_config, checkpoint_service) mlflow_c.import_mlflow_experiments(num_parallel=args.num_parallel) - failed_task_log = logging_utils.get_error_log_file(wmconstants.WM_IMPORT, wmconstants.MLFLOW_EXPERIMENT_OBJECT, client_config['export_dir']) + failed_task_log = logging_utils.get_error_log_file(wmconstants.WM_IMPORT, wmconstants.MLFLOW_EXPERIMENT_OBJECT, + client_config['export_dir']) logging_utils.raise_if_failed_task_file_exists(failed_task_log, "MLflow Runs Import.") if args.mlflow_experiments_permissions: print("Importing MLflow experiment permissions.") mlflow_c = MLFlowClient(client_config, checkpoint_service) mlflow_c.import_mlflow_experiments_acls(num_parallel=args.num_parallel) - failed_task_log = logging_utils.get_error_log_file(wmconstants.WM_IMPORT, wmconstants.MLFLOW_EXPERIMENT_PERMISSION_OBJECT, client_config['export_dir']) + failed_task_log = logging_utils.get_error_log_file(wmconstants.WM_IMPORT, + wmconstants.MLFLOW_EXPERIMENT_PERMISSION_OBJECT, + client_config['export_dir']) logging_utils.raise_if_failed_task_file_exists(failed_task_log, "MLflow Experiments Permissions Import.") if args.mlflow_runs: @@ -269,12 +274,13 @@ def main(): mlflow_c = MLFlowClient(client_config, checkpoint_service) assert args.src_profile is not None, "Import MLflow runs requires --src-profile flag." src_login_args = get_login_credentials(profile=args.src_profile) - src_client_config = build_client_config(args.src_profile, src_login_args['host'], src_login_args.get('token', login_args.get('password')), args) + src_client_config = build_client_config(args.src_profile, src_login_args['host'], + src_login_args.get('token', login_args.get('password')), args) mlflow_c.import_mlflow_runs(src_client_config, num_parallel=args.num_parallel) - failed_task_log = logging_utils.get_error_log_file(wmconstants.WM_IMPORT, wmconstants.MLFLOW_RUN_OBJECT, client_config['export_dir']) + failed_task_log = logging_utils.get_error_log_file(wmconstants.WM_IMPORT, wmconstants.MLFLOW_RUN_OBJECT, + client_config['export_dir']) logging_utils.raise_if_failed_task_file_exists(failed_task_log, "MLflow Runs Import.") - if args.get_repair_log: print("Finding partitioned tables to repair at {0}".format(now)) start = timer()