f-stack/dpdk/dts/framework/remote_session/remote/ssh_session.py

134 lines
4.2 KiB
Python

# SPDX-License-Identifier: BSD-3-Clause
# Copyright(c) 2023 PANTHEON.tech s.r.o.
import socket
import traceback
from pathlib import PurePath
from fabric import Connection # type: ignore[import]
from invoke.exceptions import ( # type: ignore[import]
CommandTimedOut,
ThreadException,
UnexpectedExit,
)
from paramiko.ssh_exception import ( # type: ignore[import]
AuthenticationException,
BadHostKeyException,
NoValidConnectionsError,
SSHException,
)
from framework.config import NodeConfiguration
from framework.exception import SSHConnectionError, SSHSessionDeadError, SSHTimeoutError
from framework.logger import DTSLOG
from .remote_session import CommandResult, RemoteSession
class SSHSession(RemoteSession):
"""A persistent SSH connection to a remote Node.
The connection is implemented with the Fabric Python library.
Args:
node_config: The configuration of the Node to connect to.
session_name: The name of the session.
logger: The logger used for logging.
This should be passed from the parent OSSession.
Attributes:
session: The underlying Fabric SSH connection.
Raises:
SSHConnectionError: The connection cannot be established.
"""
session: Connection
def __init__(
self,
node_config: NodeConfiguration,
session_name: str,
logger: DTSLOG,
):
super(SSHSession, self).__init__(node_config, session_name, logger)
def _connect(self) -> None:
errors = []
retry_attempts = 10
login_timeout = 20 if self.port else 10
for retry_attempt in range(retry_attempts):
try:
self.session = Connection(
self.ip,
user=self.username,
port=self.port,
connect_kwargs={"password": self.password},
connect_timeout=login_timeout,
)
self.session.open()
except (ValueError, BadHostKeyException, AuthenticationException) as e:
self._logger.exception(e)
raise SSHConnectionError(self.hostname) from e
except (NoValidConnectionsError, socket.error, SSHException) as e:
self._logger.debug(traceback.format_exc())
self._logger.warning(e)
error = repr(e)
if error not in errors:
errors.append(error)
self._logger.info(f"Retrying connection: retry number {retry_attempt + 1}.")
else:
break
else:
raise SSHConnectionError(self.hostname, errors)
def is_alive(self) -> bool:
return self.session.is_connected
def _send_command(self, command: str, timeout: float, env: dict | None) -> CommandResult:
"""Send a command and return the result of the execution.
Args:
command: The command to execute.
timeout: Wait at most this many seconds for the execution to complete.
env: Extra environment variables that will be used in command execution.
Raises:
SSHSessionDeadError: The session died while executing the command.
SSHTimeoutError: The command execution timed out.
"""
try:
output = self.session.run(command, env=env, warn=True, hide=True, timeout=timeout)
except (UnexpectedExit, ThreadException) as e:
self._logger.exception(e)
raise SSHSessionDeadError(self.hostname) from e
except CommandTimedOut as e:
self._logger.exception(e)
raise SSHTimeoutError(command, e.result.stderr) from e
return CommandResult(self.name, command, output.stdout, output.stderr, output.return_code)
def copy_from(
self,
source_file: str | PurePath,
destination_file: str | PurePath,
) -> None:
self.session.get(str(destination_file), str(source_file))
def copy_to(
self,
source_file: str | PurePath,
destination_file: str | PurePath,
) -> None:
self.session.put(str(source_file), str(destination_file))
def _close(self, force: bool = False) -> None:
self.session.close()