#!/usr/bin/python3 # -*- coding: utf-8 -*- # ----------------------------------------------------------------------- # This file is part of TISBackup # # TISBackup is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # TISBackup is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with TISBackup. If not, see . # # ----------------------------------------------------------------------- """ Test suite for libtisbackup.ssh module. Tests SSH key loading and remote command execution functionality. """ import os import tempfile from unittest.mock import Mock, patch import paramiko import pytest from libtisbackup.ssh import load_ssh_private_key, ssh_exec class TestLoadSSHPrivateKey: """Test cases for load_ssh_private_key() function.""" def test_load_ed25519_key_success(self): """Test loading a valid Ed25519 key.""" with patch.object(paramiko.Ed25519Key, "from_private_key_file") as mock_ed25519: mock_key = Mock() mock_ed25519.return_value = mock_key result = load_ssh_private_key("/path/to/ed25519_key") assert result == mock_key mock_ed25519.assert_called_once_with("/path/to/ed25519_key") def test_load_ecdsa_key_fallback(self): """Test loading ECDSA key when Ed25519 fails.""" with patch.object(paramiko.Ed25519Key, "from_private_key_file") as mock_ed25519, patch.object( paramiko.ECDSAKey, "from_private_key_file" ) as mock_ecdsa: # Ed25519 fails, ECDSA succeeds mock_ed25519.side_effect = paramiko.SSHException("Not Ed25519") mock_key = Mock() mock_ecdsa.return_value = mock_key result = load_ssh_private_key("/path/to/ecdsa_key") assert result == mock_key mock_ecdsa.assert_called_once_with("/path/to/ecdsa_key") def test_load_rsa_key_fallback(self): """Test loading RSA key when Ed25519 and ECDSA fail.""" with patch.object(paramiko.Ed25519Key, "from_private_key_file") as mock_ed25519, patch.object( paramiko.ECDSAKey, "from_private_key_file" ) as mock_ecdsa, patch.object(paramiko.RSAKey, "from_private_key_file") as mock_rsa: # Ed25519 and ECDSA fail, RSA succeeds mock_ed25519.side_effect = paramiko.SSHException("Not Ed25519") mock_ecdsa.side_effect = paramiko.SSHException("Not ECDSA") mock_key = Mock() mock_rsa.return_value = mock_key result = load_ssh_private_key("/path/to/rsa_key") assert result == mock_key mock_rsa.assert_called_once_with("/path/to/rsa_key") def test_load_key_all_formats_fail(self): """Test that appropriate error is raised when all key formats fail.""" with patch.object(paramiko.Ed25519Key, "from_private_key_file") as mock_ed25519, patch.object( paramiko.ECDSAKey, "from_private_key_file" ) as mock_ecdsa, patch.object(paramiko.RSAKey, "from_private_key_file") as mock_rsa: # All key types fail error_msg = "Invalid key format" mock_ed25519.side_effect = paramiko.SSHException(error_msg) mock_ecdsa.side_effect = paramiko.SSHException(error_msg) mock_rsa.side_effect = paramiko.SSHException(error_msg) with pytest.raises(paramiko.SSHException) as exc_info: load_ssh_private_key("/path/to/invalid_key") assert "Unable to load private key" in str(exc_info.value) assert "Ed25519 (recommended), ECDSA, RSA" in str(exc_info.value) assert "DSA keys are no longer supported" in str(exc_info.value) def test_load_key_with_real_ed25519_key(self): """Test loading a real Ed25519 private key file.""" from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import ed25519 # Create a temporary Ed25519 key for testing with tempfile.TemporaryDirectory() as tmpdir: key_path = os.path.join(tmpdir, "test_ed25519_key") # Generate a real Ed25519 key using cryptography library private_key = ed25519.Ed25519PrivateKey.generate() # Write the key in OpenSSH format (required for paramiko) pem = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.OpenSSH, encryption_algorithm=serialization.NoEncryption() ) with open(key_path, 'wb') as f: f.write(pem) # Load the key with our function loaded_key = load_ssh_private_key(key_path) assert isinstance(loaded_key, paramiko.Ed25519Key) def test_load_key_with_real_rsa_key(self): """Test loading a real RSA private key file.""" with tempfile.TemporaryDirectory() as tmpdir: key_path = os.path.join(tmpdir, "test_rsa_key") # Generate a real RSA key key = paramiko.RSAKey.generate(2048) key.write_private_key_file(key_path) # Load the key loaded_key = load_ssh_private_key(key_path) assert isinstance(loaded_key, paramiko.RSAKey) class TestSSHExec: """Test cases for ssh_exec() function.""" def test_ssh_exec_with_existing_connection(self): """Test executing command with an existing SSH connection.""" # Mock SSH client and channel mock_ssh = Mock(spec=paramiko.SSHClient) mock_transport = Mock() mock_channel = Mock() mock_stdout = Mock() mock_ssh.get_transport.return_value = mock_transport mock_transport.open_session.return_value = mock_channel mock_channel.makefile.return_value = mock_stdout mock_stdout.read.return_value = b"command output\n" mock_channel.recv_exit_status.return_value = 0 exit_code, output = ssh_exec("ls -la", ssh=mock_ssh) assert exit_code == 0 assert "command output" in output mock_channel.exec_command.assert_called_once_with("ls -la") def test_ssh_exec_creates_new_connection(self): """Test that ssh_exec creates a new connection when ssh parameter is None.""" with patch("libtisbackup.ssh.load_ssh_private_key") as mock_load_key, patch( "libtisbackup.ssh.paramiko.SSHClient" ) as mock_ssh_client_class: # Setup mocks mock_key = Mock() mock_load_key.return_value = mock_key mock_ssh = Mock() mock_ssh_client_class.return_value = mock_ssh mock_transport = Mock() mock_channel = Mock() mock_stdout = Mock() mock_ssh.get_transport.return_value = mock_transport mock_transport.open_session.return_value = mock_channel mock_channel.makefile.return_value = mock_stdout mock_stdout.read.return_value = b"test output" mock_channel.recv_exit_status.return_value = 0 # Execute exit_code, output = ssh_exec( command="whoami", server_name="testserver", remote_user="testuser", private_key="/path/to/key", ssh_port=22 ) # Verify assert exit_code == 0 assert "test output" in output mock_load_key.assert_called_once_with("/path/to/key") mock_ssh.set_missing_host_key_policy.assert_called_once() mock_ssh.connect.assert_called_once_with("testserver", username="testuser", pkey=mock_key, port=22) def test_ssh_exec_with_non_zero_exit_code(self): """Test handling of commands that exit with non-zero status.""" mock_ssh = Mock(spec=paramiko.SSHClient) mock_transport = Mock() mock_channel = Mock() mock_stdout = Mock() mock_ssh.get_transport.return_value = mock_transport mock_transport.open_session.return_value = mock_channel mock_channel.makefile.return_value = mock_stdout mock_stdout.read.return_value = b"error: command failed\n" mock_channel.recv_exit_status.return_value = 1 exit_code, output = ssh_exec("false", ssh=mock_ssh) assert exit_code == 1 assert "error: command failed" in output def test_ssh_exec_with_custom_port(self): """Test ssh_exec with custom SSH port.""" with patch("libtisbackup.ssh.load_ssh_private_key") as mock_load_key, patch( "libtisbackup.ssh.paramiko.SSHClient" ) as mock_ssh_client_class: mock_key = Mock() mock_load_key.return_value = mock_key mock_ssh = Mock() mock_ssh_client_class.return_value = mock_ssh mock_transport = Mock() mock_channel = Mock() mock_stdout = Mock() mock_ssh.get_transport.return_value = mock_transport mock_transport.open_session.return_value = mock_channel mock_channel.makefile.return_value = mock_stdout mock_stdout.read.return_value = b"output" mock_channel.recv_exit_status.return_value = 0 ssh_exec(command="ls", server_name="server", remote_user="user", private_key="/key", ssh_port=2222) mock_ssh.connect.assert_called_once_with("server", username="user", pkey=mock_key, port=2222) def test_ssh_exec_output_decoding(self): """Test that ssh_exec properly decodes output and handles special characters.""" mock_ssh = Mock(spec=paramiko.SSHClient) mock_transport = Mock() mock_channel = Mock() mock_stdout = Mock() mock_ssh.get_transport.return_value = mock_transport mock_transport.open_session.return_value = mock_channel mock_channel.makefile.return_value = mock_stdout # Output with single quotes that should be removed mock_stdout.read.return_value = b"output with 'quotes' included" mock_channel.recv_exit_status.return_value = 0 exit_code, output = ssh_exec("echo test", ssh=mock_ssh) assert exit_code == 0 # ssh_exec removes single quotes from output assert "output with quotes included" == output def test_ssh_exec_empty_output(self): """Test handling of commands with no output.""" mock_ssh = Mock(spec=paramiko.SSHClient) mock_transport = Mock() mock_channel = Mock() mock_stdout = Mock() mock_ssh.get_transport.return_value = mock_transport mock_transport.open_session.return_value = mock_channel mock_channel.makefile.return_value = mock_stdout mock_stdout.read.return_value = b"" mock_channel.recv_exit_status.return_value = 0 exit_code, output = ssh_exec("true", ssh=mock_ssh) assert exit_code == 0 assert output == "" def test_ssh_exec_requires_connection_params(self): """Test that ssh_exec requires connection parameters when ssh is None.""" # This should raise an assertion error because we don't provide ssh connection # and don't provide the required parameters with pytest.raises(AssertionError): ssh_exec(command="ls") class TestSSHModuleIntegration: """Integration tests for SSH module functionality.""" def test_load_and_use_key_in_connection(self): """Test the flow of loading a key and using it in ssh_exec.""" with tempfile.TemporaryDirectory() as tmpdir: key_path = os.path.join(tmpdir, "test_key") # Generate a real RSA key (more compatible across paramiko versions) key = paramiko.RSAKey.generate(2048) key.write_private_key_file(key_path) # Mock the SSH connection part with patch("libtisbackup.ssh.paramiko.SSHClient") as mock_ssh_client_class: mock_ssh = Mock() mock_ssh_client_class.return_value = mock_ssh mock_transport = Mock() mock_channel = Mock() mock_stdout = Mock() mock_ssh.get_transport.return_value = mock_transport mock_transport.open_session.return_value = mock_channel mock_channel.makefile.return_value = mock_stdout mock_stdout.read.return_value = b"success" mock_channel.recv_exit_status.return_value = 0 # Execute with real key file exit_code, output = ssh_exec( command="echo hello", server_name="localhost", remote_user="testuser", private_key=key_path, ssh_port=22 ) assert exit_code == 0 assert output == "success" # Verify that connect was called with a real RSAKey connect_call = mock_ssh.connect.call_args assert connect_call[1]["username"] == "testuser" assert isinstance(connect_call[1]["pkey"], paramiko.RSAKey)