diff --git a/CMakeLists.txt b/CMakeLists.txt index ffa2e807a..413636440 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,7 +36,7 @@ function(py_test TARGET_NAME) add_test(NAME ${TARGET_NAME}"_with_abs_path" COMMAND python -u ${py_test_SRCS} ${py_test_ARGS} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) - set_tests_properties(${TARGET_NAME}"_with_abs_path" PROPERTIES TIMEOUT 500) + set_tests_properties(${TARGET_NAME}"_with_abs_path" PROPERTIES TIMEOUT 7200) else() get_filename_component(WORKING_DIR ${py_test_SRCS} DIRECTORY) get_filename_component(FILE_NAME ${py_test_SRCS} NAME) @@ -44,7 +44,7 @@ function(py_test TARGET_NAME) add_test(NAME ${TARGET_NAME} COMMAND python -u ${FILE_NAME} ${py_test_ARGS} WORKING_DIRECTORY ${COMBINED_PATH}) - set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 500) + set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 7200) endif() endfunction() diff --git a/parl/remote/client.py b/parl/remote/client.py index 6d63c10ce..981f26fdc 100644 --- a/parl/remote/client.py +++ b/parl/remote/client.py @@ -59,6 +59,7 @@ def __init__(self, master_address, process_id, distributed_files=[]): remote instances(e,g. the configuration file for initialization) . """ + self.client_is_alive = mp.Value('i', True) self._create_heartbeat_server() self.master_address = master_address self.process_id = process_id @@ -85,6 +86,8 @@ def destroy(self): for th in self.threads: th.join() self.ctx.destroy() + self.client_is_alive.value = False + self.job_heartbeat_process.join() def get_executable_path(self): """Return current executable path.""" @@ -296,7 +299,7 @@ def _create_heartbeat_server(self): """ job_heartbeat_port = mp.Value('i', 0) self.actor_num = mp.Value('i', 0) - self.job_heartbeat_process = HeartbeatServerProcess(job_heartbeat_port, self.actor_num) + self.job_heartbeat_process = HeartbeatServerProcess(job_heartbeat_port, self.actor_num, self.client_is_alive) self.job_heartbeat_process.daemon = True self.job_heartbeat_process.start() assert job_heartbeat_port.value != 0, "fail to initialize heartbeat server for jobs." diff --git a/parl/remote/grpc_heartbeat/heartbeat_server.py b/parl/remote/grpc_heartbeat/heartbeat_server.py index 1dd514e16..0d39d774c 100644 --- a/parl/remote/grpc_heartbeat/heartbeat_server.py +++ b/parl/remote/grpc_heartbeat/heartbeat_server.py @@ -25,11 +25,12 @@ class GrpcHeartbeatServer(heartbeat_pb2_grpc.GrpcHeartbeatServicer): - def __init__(self, client_count=None): + def __init__(self, client_count=None, host_is_alive=True): self.last_heartbeat_time = time.time() self.last_heartbeat_table = dict() self.exit_flag = False self.client_count = client_count + self.host_is_alive = host_is_alive def Send(self, request, context): client_id = request.client_id @@ -54,6 +55,8 @@ def timeout_timer(self): break def _parent_process_is_running(self): + if not self.host_is_alive.value: + return False ppid = os.getppid() return ppid != 1 @@ -133,7 +136,7 @@ def exit(self): self.heartbeat_server.exit() class HeartbeatServerProcess(mp.Process): - def __init__(self, port, client_count): + def __init__(self, port, client_count, host_is_alive): """Create a process to run the heartbeat server. Args: port(mp.Value): notify the main prcoess of the severt port. @@ -144,7 +147,7 @@ def __init__(self, port, client_count): futures.ThreadPoolExecutor(max_workers=500), options=[('grpc.max_receive_message_length', -1), ('grpc.max_send_message_length', -1)]) - self.heartbeat_server = GrpcHeartbeatServer(client_count) + self.heartbeat_server = GrpcHeartbeatServer(client_count, host_is_alive) heartbeat_pb2_grpc.add_GrpcHeartbeatServicer_to_server( self.heartbeat_server, self.grpc_server)