Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 78 additions & 12 deletions virl2_client/models/cl_pyats.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from __future__ import annotations

import io
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any

Expand All @@ -35,9 +36,17 @@
# Ensure markup processor never uses the command line arguments as that's broken
_PyatsProcessor.argv.clear()

try:
from unicon.core.errors import ConnectionError as _UConnectionError
from unicon.core.errors import SubCommandFailure as _USubCommandFailure
except Exception:
_UConnectionError = _USubCommandFailure = None


from ..exceptions import PyatsDeviceNotFound, PyatsNotInstalled

_LOGGER = logging.getLogger(__name__)

if TYPE_CHECKING:
from genie.libs.conf.device import Device
from genie.libs.conf.testbed import Testbed
Expand Down Expand Up @@ -156,6 +165,17 @@ def _prepare_params(
params["init_config_commands"] = init_config_commands
return params

def _reconnect(self, pyats_device: "Device", params: dict) -> None:
"""Helper method to reconnect a PyATS device with proper cleanup."""
if pyats_device in self._connections:
self._connections.remove(pyats_device)
try:
pyats_device.destroy()
except Exception:
pass
pyats_device.connect(log_stdout=False, learn_hostname=True, **params)
self._connections.add(pyats_device)

def _execute_command(
self,
node_label: str,
Expand Down Expand Up @@ -184,25 +204,51 @@ def _execute_command(
"""
self._check_pyats_installed()

if self._testbed is None:
raise RuntimeError("pyATS testbed is not initialized")

try:
pyats_device: Device = self._testbed.devices[node_label]
pyats_device: "Device" = self._testbed.devices[node_label]
except KeyError:
raise PyatsDeviceNotFound(node_label)

params = self._prepare_params(
init_exec_commands, init_config_commands, **pyats_params
)

if pyats_device not in self._connections or not pyats_device.is_connected():
if pyats_device in self._connections:
pyats_device.destroy()
self._reconnect(pyats_device, params)

pyats_device.connect(log_stdout=False, learn_hostname=True, **params)
self._connections.add(pyats_device)
if configure_mode:
return pyats_device.configure(command, log_stdout=False, **params)
else:
def _run():
if configure_mode:
return pyats_device.configure(command, log_stdout=False, **params)
return pyats_device.execute(command, log_stdout=False, **params)

try:
return _run()
except Exception as exc:
should_retry = False
retry_reason = None

if _UConnectionError and isinstance(exc, _UConnectionError):
should_retry = True
retry_reason = f"ConnectionError: {exc}"
elif _USubCommandFailure and isinstance(exc, _USubCommandFailure):
cause = getattr(exc, "__cause__", None)
if isinstance(cause, TimeoutError):
should_retry = True
retry_reason = f"SubCommandFailure with TimeoutError cause: {cause}"

if not should_retry:
raise

_LOGGER.info(
f"PyATS command failed on node {node_label}, retrying after reconnection. Reason: {retry_reason}"
)

self._reconnect(pyats_device, params)
return _run()

def run_command(
self,
node_label: str,
Expand Down Expand Up @@ -267,7 +313,27 @@ def run_config_command(
)

def cleanup(self) -> None:
"""Clean up the pyATS connections."""
for pyats_device in self._connections:
pyats_device.destroy()
self._connections.clear()
"""Clean up all pyATS connections."""
for pyats_device in tuple(self._connections):
try:
pyats_device.destroy()
finally:
self._connections.discard(pyats_device)

def cleanup_node_connection(self, node_label: str) -> None:
"""
Clean up the pyATS connection for a specific node.

:param node_label: The label/title of the node whose connection to cleanup
"""
if self._testbed is None:
return
try:
pyats_device: "Device" = self._testbed.devices[node_label]
except Exception:
return
if pyats_device in self._connections:
try:
pyats_device.destroy()
finally:
self._connections.discard(pyats_device)
1 change: 1 addition & 0 deletions virl2_client/models/lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,7 @@ def stop(self, wait: bool | None = None) -> None:
"""
url = self._url_for("stop")
self._session.put(url)
self.cleanup_pyats_connections()
if self.need_to_wait(wait):
self.wait_until_lab_converged()
_LOGGER.debug(f"Stopped lab: {self._id}")
Expand Down
2 changes: 1 addition & 1 deletion virl2_client/models/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import logging
import time
import warnings
from copy import deepcopy
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -673,6 +672,7 @@ def stop(self, wait=False) -> None:
"""
url = self._url_for("stop")
self._session.put(url)
self._lab.pyats.cleanup_node_connection(self.label)
if self._lab.need_to_wait(wait):
self.wait_until_converged()

Expand Down