diff --git a/micro-benchmarks/nccl-tests/slurm/topology-aware-nccl-tests/hostfile_topologify.py b/micro-benchmarks/nccl-tests/slurm/topology-aware-nccl-tests/hostfile_topologify.py index cf1e1b27f..3c7a6db58 100644 --- a/micro-benchmarks/nccl-tests/slurm/topology-aware-nccl-tests/hostfile_topologify.py +++ b/micro-benchmarks/nccl-tests/slurm/topology-aware-nccl-tests/hostfile_topologify.py @@ -9,16 +9,17 @@ # topology. Default is to print to stdout, although an output file # can be specified. -import botocore import boto3 import argparse import sys import socket import time +from collections import defaultdict # To avoid overwhelming the EC2 APIs with large requests, process only # pagination_count entries through the search loops at a time. pagination_count = 64 +max_retries = 5 def generate_topology_csv(input_file, output_file, region): @@ -26,7 +27,7 @@ def generate_topology_csv(input_file, output_file, region): done = False - network_to_hostname = {} + network_to_hostname = defaultdict(lambda: defaultdict(list)) while not done: hostname_to_ip = {} @@ -35,7 +36,7 @@ def generate_topology_csv(input_file, output_file, region): # translate hostname to private ip, since PCluster uses custom # hostnames that the EC2 control plane doesn't see. - for i in range(pagination_count): + for _ in range(pagination_count): hostname = input_file.readline() if not hostname: done = True @@ -43,14 +44,16 @@ def generate_topology_csv(input_file, output_file, region): hostname = hostname.strip() ip = None - for i in range(5): + for _ in range(max_retries): try: ip = socket.gethostbyname(socket.getfqdn(hostname)) - except: + except Exception as e: + print("Error getting ip address for %s: %s" % (hostname, e)) time.sleep(1) else: break - if ip == None: + + if ip is None: print("Error getting ip address for %s" % (hostname)) sys.exit(1) @@ -107,14 +110,9 @@ def generate_topology_csv(input_file, output_file, region): for instance in response['Instances']: instanceid = instance['InstanceId'] - t2_node = instance['NetworkNodes'][1] t1_node = instance['NetworkNodes'][2] - if network_to_hostname.get(t2_node) == None: - network_to_hostname[t2_node] = {} - if network_to_hostname[t2_node].get(t1_node) == None: - network_to_hostname[t2_node][t1_node] = [] network_to_hostname[t2_node][t1_node].append( instanceid_to_hostname[instanceid]) @@ -147,15 +145,12 @@ def generate_topology_csv(input_file, output_file, region): args = parser.parse_args() - if args.output != None: - output_file_handle = open(args.output, "w") + if args.output is not None: + with ( + open(args.output, "w") as output_file_handle, + open(args.input, "r") as input_file_handle, + ): + generate_topology_csv(input_file_handle, output_file_handle, args.region) else: - output_file_handle = sys.stdout - - input_file_handle = open(args.input, "r") - - generate_topology_csv(input_file_handle, output_file_handle, args.region) - - input_file_handle.close() - if args.output != None: - output_file_handle.close() \ No newline at end of file + with open(args.input, "r") as input_file_handle: + generate_topology_csv(input_file_handle, sys.stdout, args.region)