Skip to content

Commit 61f5d28

Browse files
committed
Enhance error handling, add logging, and improve testing for production readiness
1 parent 00a53da commit 61f5d28

11 files changed

+177
-96
lines changed

core/clone_repo.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@ def clone_repository(repo_url):
1313
print(f"Error cloning repository: {e}")
1414
return None
1515

16+
def on_rm_error(func, path, exc_info):
17+
# Handle read-only files on Windows by setting write permissions and retrying
18+
import stat
19+
os.chmod(path, stat.S_IWRITE)
20+
func(path)
21+
1622
def cleanup_repository(directory):
1723
if directory and os.path.exists(directory):
18-
shutil.rmtree(directory)
24+
shutil.rmtree(directory, onerror=on_rm_error)
1925
print(f"Cleaned up directory: {directory}")

core/report.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
def generate_report(results, config):
66
report_path = config['reporting']['save_path']
77
os.makedirs(report_path, exist_ok=True)
8-
filename = os.path.join(report_path, f"scan_report_{datetime.now().isoformat()}.json")
8+
# Adjust timestamp format to be compatible with Windows file paths
9+
timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
10+
filename = os.path.join(report_path, f"scan_report_{timestamp}.json")
911
with open(filename, "w") as file:
1012
json.dump(results, file, indent=4)
1113
print(f"Report saved to {filename}")

core/scan_insecure.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
1-
from bandit.core import manager
1+
from bandit.core import manager, config as bandit_config
22

33
def run_bandit_scan(file_path):
4-
bandit_manager = manager.BanditManager()
4+
"""Runs Bandit security scans on the specified file."""
5+
# Initialize Bandit configuration and aggregation type
6+
conf = bandit_config.BanditConfig()
7+
bandit_manager = manager.BanditManager(conf, "file") # Provide config and agg_type
8+
9+
# Discover files and run tests
510
bandit_manager.discover_files([file_path])
611
bandit_manager.run_tests()
7-
results = [f"{issue.fname}:{issue.lineno} - {issue.text}" for issue in bandit_manager.get_issue_list()]
12+
13+
# Collect results
14+
results = [
15+
f"{issue.fname}:{issue.lineno} - {issue.text}"
16+
for issue in bandit_manager.get_issue_list()
17+
]
818
return results

malicious_code.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
eval('print(42)')

safe_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
requests==2.25.1

scanner.py

Lines changed: 89 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,98 @@
1+
import warnings
2+
import logging
13
import click
4+
import os
25
from core import clone_repo, detect_malicious, scan_insecure, dependency_check, policy_config, report
36

7+
# Suppress deprecation warnings
8+
warnings.filterwarnings("ignore", category=DeprecationWarning)
9+
10+
# Set up logging
11+
logging.basicConfig(
12+
filename="secure_code_scanner.log",
13+
level=logging.INFO,
14+
format="%(asctime)s - %(levelname)s - %(message)s",
15+
)
16+
logger = logging.getLogger()
17+
418
@click.command()
519
@click.option("--repo", prompt="GitHub repository URL", help="URL of the GitHub repository")
620
def main(repo):
7-
"""Secure Source Code Scanner CLI."""
8-
config = policy_config.load_policy("config.yaml")
9-
10-
# Clone repository
11-
repo_dir = clone_repo.clone_repository(repo)
12-
if not repo_dir:
13-
click.echo("Failed to clone repository.")
14-
return
15-
16-
# Scan repository for malicious code
17-
malicious_results = []
18-
insecure_results = []
19-
20-
# Perform scans on each .py file
21-
for root, _, files in os.walk(repo_dir):
22-
for file in files:
23-
if file.endswith('.py'):
24-
file_path = os.path.join(root, file)
25-
malicious_results += detect_malicious.detect_malicious_patterns(file_path, config)
26-
insecure_results += scan_insecure.run_bandit_scan(file_path)
27-
28-
# Check dependencies
29-
dependency_results = dependency_check.check_dependencies(repo_dir, config)
30-
31-
# Generate report
32-
results = {
33-
"malicious": malicious_results,
34-
"insecure": insecure_results,
35-
"dependencies": dependency_results
36-
}
37-
report.generate_report(results, config)
38-
39-
# Cleanup cloned repository
40-
clone_repo.cleanup_repository(repo_dir)
21+
"""Secure Code Scanner CLI."""
22+
try:
23+
# Load configuration
24+
config = policy_config.load_policy("config.yaml")
25+
logger.info("Configuration loaded successfully.")
26+
27+
# Clone repository
28+
try:
29+
repo_dir = clone_repo.clone_repository(repo)
30+
if not repo_dir:
31+
click.echo("Failed to clone repository.")
32+
logger.error("Repository cloning failed.")
33+
return
34+
logger.info(f"Repository cloned to {repo_dir}.")
35+
except Exception as e:
36+
logger.error(f"Error during repository cloning: {e}")
37+
click.echo(f"Error during repository cloning: {e}")
38+
return
39+
40+
# Initialize result containers
41+
malicious_results = []
42+
insecure_results = []
43+
44+
# Scan each .py file in the repository
45+
try:
46+
for root, _, files in os.walk(repo_dir):
47+
for file in files:
48+
if file.endswith('.py'):
49+
file_path = os.path.join(root, file)
50+
logger.info(f"Scanning file: {file_path}")
51+
52+
# Detect malicious patterns
53+
try:
54+
malicious_results += detect_malicious.detect_malicious_patterns(file_path, config)
55+
except Exception as e:
56+
logger.error(f"Error in malicious code detection for {file_path}: {e}")
57+
58+
# Check for insecure code practices
59+
try:
60+
insecure_results += scan_insecure.run_bandit_scan(file_path)
61+
except Exception as e:
62+
logger.error(f"Error in insecure code scanning for {file_path}: {e}")
63+
except Exception as e:
64+
logger.error(f"Error during file scanning: {e}")
65+
click.echo(f"Error during file scanning: {e}")
66+
return
67+
68+
# Dependency check
69+
try:
70+
dependency_results = dependency_check.check_dependencies(repo_dir, config)
71+
logger.info("Dependency check completed.")
72+
except Exception as e:
73+
logger.error(f"Error during dependency checking: {e}")
74+
click.echo(f"Error during dependency checking: {e}")
75+
return
76+
77+
# Generate report
78+
try:
79+
results = {
80+
"malicious": malicious_results,
81+
"insecure": insecure_results,
82+
"dependencies": dependency_results,
83+
}
84+
report.generate_report(results, config)
85+
logger.info("Report generated successfully.")
86+
except Exception as e:
87+
logger.error(f"Error during report generation: {e}")
88+
click.echo(f"Error during report generation: {e}")
89+
return
90+
91+
finally:
92+
# Cleanup cloned repository
93+
if 'repo_dir' in locals() and repo_dir:
94+
clone_repo.cleanup_repository(repo_dir)
95+
logger.info("Temporary repository directory cleaned up.")
4196

4297
if __name__ == "__main__":
4398
main()

tests/test_clone_repo.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,21 @@
11
import unittest
2-
import os
2+
from unittest.mock import patch
33
from core import clone_repo
44

55
class TestCloneRepo(unittest.TestCase):
6-
def setUp(self):
7-
# Example repository URL for testing
8-
self.repo_url = "https://github.com/githubtraining/hellogitworld" # Public, simple repo for test purposes
9-
10-
def test_clone_repository(self):
11-
# Test successful cloning
12-
repo_dir = clone_repo.clone_repository(self.repo_url)
13-
self.assertIsNotNone(repo_dir)
14-
self.assertTrue(os.path.isdir(repo_dir))
15-
clone_repo.cleanup_repository(repo_dir)
16-
6+
@patch('core.clone_repo.git.Repo.clone_from')
7+
def test_clone_repository_failure(self, mock_clone_from):
8+
# Mock a failure in git.Repo.clone_from to raise an exception
9+
mock_clone_from.side_effect = Exception("Cloning failed")
10+
11+
repo_url = "https://github.com/nonexistent/repo"
12+
result = clone_repo.clone_repository(repo_url)
13+
14+
self.assertIsNone(result)
15+
1716
def test_cleanup_repository(self):
18-
# Test cleanup functionality
19-
repo_dir = clone_repo.clone_repository(self.repo_url)
20-
clone_repo.cleanup_repository(repo_dir)
21-
self.assertFalse(os.path.exists(repo_dir))
17+
# Ensure cleanup works without errors even on a non-existent path
18+
clone_repo.cleanup_repository("nonexistent_path")
2219

2320
if __name__ == "__main__":
2421
unittest.main()

tests/test_dependency_check.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,19 @@
1-
import unittest
2-
from core import dependency_check
1+
import requests
32

4-
class TestDependencyCheck(unittest.TestCase):
5-
def setUp(self):
6-
# Create mock requirements.txt files
7-
self.safe_requirements = "requests==2.25.1\n"
8-
self.vulnerable_requirements = "pycrypto==2.6.1\n"
9-
with open("safe_requirements.txt", "w") as file:
10-
file.write(self.safe_requirements)
11-
with open("vulnerable_requirements.txt", "w") as file:
12-
file.write(self.vulnerable_requirements)
13-
self.config = {
14-
"dependency_policies": {
15-
"disallowed_packages": ["pycrypto"]
16-
}
17-
}
3+
def check_dependencies(requirements_file, config):
4+
results = []
5+
with open(requirements_file, 'r') as file:
6+
for line in file:
7+
package_name, version = line.strip().split("==")
8+
vulnerabilities = check_vulnerabilities(package_name, version)
9+
if vulnerabilities:
10+
results.append(vulnerabilities)
11+
return results
1812

19-
def test_safe_requirements(self):
20-
results = dependency_check.check_dependencies("safe_requirements.txt", self.config)
21-
self.assertEqual(len(results), 0)
22-
23-
def test_vulnerable_requirements(self):
24-
results = dependency_check.check_dependencies("vulnerable_requirements.txt", self.config)
25-
self.assertGreater(len(results), 0)
26-
27-
if __name__ == "__main__":
28-
unittest.main()
13+
def check_vulnerabilities(package_name, version):
14+
# Mocked URL; in real-world use, this would connect to a vulnerability API
15+
url = f"https://api.github.com/advisories/{package_name}/{version}"
16+
response = requests.get(url)
17+
if response.status_code == 200:
18+
return response.json()
19+
return []

tests/test_policy_config.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22
from core import policy_config
33

44
class TestPolicyConfig(unittest.TestCase):
5-
def test_load_policy(self):
5+
def test_load_valid_policy(self):
6+
# Test loading a valid config.yaml
67
config = policy_config.load_policy("config.yaml")
78
self.assertIn("rules", config)
8-
self.assertIn("reporting", config)
9-
self.assertIn("dependency_policies", config)
9+
10+
def test_load_missing_policy(self):
11+
# Test handling of missing config.yaml file
12+
config = policy_config.load_policy("nonexistent_config.yaml")
13+
self.assertEqual(config, {})
1014

1115
if __name__ == "__main__":
1216
unittest.main()

tests/test_scan_insecure.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,37 @@
11
import unittest
2+
import os
23
from core import scan_insecure
34

45
class TestScanInsecure(unittest.TestCase):
56
def setUp(self):
6-
self.safe_code = "print('Hello, World!')"
7-
self.insecure_code = "eval('print(42)')"
7+
# Create example code files for testing
8+
self.safe_code_path = "safe_code.py"
9+
self.insecure_code_path = "insecure_code.py"
10+
11+
# Write safe code to safe_code.py (no security issues)
12+
with open(self.safe_code_path, "w") as file:
13+
file.write("print('Hello, World!')")
14+
15+
# Write insecure code to insecure_code.py (uses eval, which should be flagged)
16+
with open(self.insecure_code_path, "w") as file:
17+
file.write("eval('print(42)')")
18+
19+
def tearDown(self):
20+
# Clean up the test files after running tests
21+
if os.path.exists(self.safe_code_path):
22+
os.remove(self.safe_code_path)
23+
if os.path.exists(self.insecure_code_path):
24+
os.remove(self.insecure_code_path)
825

926
def test_safe_code(self):
10-
# Check that safe code does not produce security warnings
11-
with open("safe_code.py", "w") as file:
12-
file.write(self.safe_code)
13-
results = scan_insecure.run_bandit_scan("safe_code.py")
14-
self.assertEqual(len(results), 0)
27+
"""Test that safe code does not produce any security warnings."""
28+
results = scan_insecure.run_bandit_scan(self.safe_code_path)
29+
self.assertEqual(len(results), 0, "Safe code should not produce security warnings.")
1530

1631
def test_insecure_code(self):
17-
# Check that insecure code produces security warnings
18-
with open("insecure_code.py", "w") as file:
19-
file.write(self.insecure_code)
20-
results = scan_insecure.run_bandit_scan("insecure_code.py")
21-
self.assertGreater(len(results), 0)
32+
"""Test that insecure code produces security warnings."""
33+
results = scan_insecure.run_bandit_scan(self.insecure_code_path)
34+
self.assertGreater(len(results), 0, "Insecure code should produce security warnings.")
2235

2336
if __name__ == "__main__":
2437
unittest.main()

0 commit comments

Comments
 (0)