diff --git a/aikido_zen/background_process/cloud_connection_manager/update_service_config.py b/aikido_zen/background_process/cloud_connection_manager/update_service_config.py index 8e881a29..e09f2c3d 100644 --- a/aikido_zen/background_process/cloud_connection_manager/update_service_config.py +++ b/aikido_zen/background_process/cloud_connection_manager/update_service_config.py @@ -30,4 +30,4 @@ def update_service_config(connection_manager, res): ) if "domains" in res: - connection_manager.conf.update_domains(res["domains"]) + connection_manager.conf.update_outbound_domains(res["domains"]) diff --git a/aikido_zen/background_process/cloud_connection_manager/update_service_config_test.py b/aikido_zen/background_process/cloud_connection_manager/update_service_config_test.py index be35babb..be6ddfed 100644 --- a/aikido_zen/background_process/cloud_connection_manager/update_service_config_test.py +++ b/aikido_zen/background_process/cloud_connection_manager/update_service_config_test.py @@ -34,7 +34,7 @@ def test_update_service_config_outbound_blocking(): # Verify that the outbound blocking configuration was set assert connection_manager.conf.block_new_outgoing_requests is True - assert connection_manager.conf.domains == { + assert connection_manager.conf.outbound_domains == { "example.com": "block", "allowed.com": "allow", } @@ -61,7 +61,7 @@ def test_update_service_config_outbound_blocking_false(): # Verify that the outbound blocking configuration was set assert connection_manager.conf.block_new_outgoing_requests is False - assert connection_manager.conf.domains == {} + assert connection_manager.conf.outbound_domains == {} def test_update_service_config_outbound_blocking_missing(): @@ -89,7 +89,7 @@ def test_update_service_config_outbound_blocking_missing(): # Verify that the outbound blocking configuration was not changed assert connection_manager.conf.block_new_outgoing_requests is False - assert connection_manager.conf.domains == {} + assert connection_manager.conf.outbound_domains == {} def test_update_service_config_failure(): @@ -108,7 +108,9 @@ def test_update_service_config_failure(): # Set initial values connection_manager.conf.set_block_new_outgoing_requests(True) - connection_manager.conf.update_domains([{"hostname": "test.com", "mode": "block"}]) + connection_manager.conf.update_outbound_domains( + [{"hostname": "test.com", "mode": "block"}] + ) # Test failed response res = {"success": False, "blockNewOutgoingRequests": False, "domains": []} @@ -117,7 +119,7 @@ def test_update_service_config_failure(): # Verify that nothing was changed due to failure assert connection_manager.conf.block_new_outgoing_requests is True - assert connection_manager.conf.domains == {"test.com": "block"} + assert connection_manager.conf.outbound_domains == {"test.com": "block"} def test_update_service_config_complete(): @@ -160,7 +162,7 @@ def test_update_service_config_complete(): assert connection_manager.conf.blocked_uids == {"user1", "user2"} assert connection_manager.conf.received_any_stats is True assert connection_manager.conf.block_new_outgoing_requests is True - assert connection_manager.conf.domains == { + assert connection_manager.conf.outbound_domains == { "blocked.com": "block", "allowed.com": "allow", "test.com": "block", @@ -194,7 +196,7 @@ def test_update_service_config_domains_only(): # Verify that only domains were updated assert connection_manager.conf.block_new_outgoing_requests is False # Not changed - assert connection_manager.conf.domains == { + assert connection_manager.conf.outbound_domains == { "api.example.com": "block", "cdn.example.com": "allow", } @@ -215,7 +217,7 @@ def test_update_service_config_block_new_outgoing_requests_only(): connection_manager.block = False # Set initial domains - connection_manager.conf.update_domains( + connection_manager.conf.update_outbound_domains( [{"hostname": "existing.com", "mode": "allow"}] ) @@ -229,4 +231,6 @@ def test_update_service_config_block_new_outgoing_requests_only(): # Verify that only blockNewOutgoingRequests was updated assert connection_manager.conf.block_new_outgoing_requests is True - assert connection_manager.conf.domains == {"existing.com": "allow"} # Not changed + assert connection_manager.conf.outbound_domains == { + "existing.com": "allow" + } # Not changed diff --git a/aikido_zen/background_process/service_config.py b/aikido_zen/background_process/service_config.py index 39222a34..283369d7 100644 --- a/aikido_zen/background_process/service_config.py +++ b/aikido_zen/background_process/service_config.py @@ -23,7 +23,7 @@ def __init__( endpoints, last_updated_at, blocked_uids, bypassed_ips, received_any_stats ) self.block_new_outgoing_requests = False - self.domains = {} + self.outbound_domains = {} def update( self, @@ -77,15 +77,17 @@ def is_bypassed_ip(self, ip): """Checks if the IP is on the bypass list""" return self.bypassed_ips.has(ip) - def update_domains(self, domains): - self.domains = {domain["hostname"]: domain["mode"] for domain in domains} + def update_outbound_domains(self, domains): + self.outbound_domains = { + domain["hostname"]: domain["mode"] for domain in domains + } def set_block_new_outgoing_requests(self, value: bool): """Set whether to block new outgoing requests""" self.block_new_outgoing_requests = bool(value) def should_block_outgoing_request(self, hostname: str) -> bool: - mode = self.domains.get(hostname) + mode = self.outbound_domains.get(hostname) if self.block_new_outgoing_requests: # Only allow outgoing requests if the mode is "allow" diff --git a/aikido_zen/background_process/service_config_test.py b/aikido_zen/background_process/service_config_test.py index 447fd53e..d3461169 100644 --- a/aikido_zen/background_process/service_config_test.py +++ b/aikido_zen/background_process/service_config_test.py @@ -15,9 +15,9 @@ def test_service_config_outbound_blocking_initialization(): # Test initial values assert hasattr(config, "block_new_outgoing_requests") - assert hasattr(config, "domains") + assert hasattr(config, "outbound_domains") assert config.block_new_outgoing_requests is False - assert config.domains == {} + assert config.outbound_domains == {} def test_service_config_set_block_new_outgoing_requests(): @@ -53,7 +53,7 @@ def test_service_config_set_block_new_outgoing_requests(): def test_service_config_update_domains(): - """Test the update_domains method""" + """Test the update_outbound_domains method""" config = ServiceConfig( endpoints=[], last_updated_at=0, @@ -63,7 +63,7 @@ def test_service_config_update_domains(): ) # Test initial state - assert config.domains == {} + assert config.outbound_domains == {} # Test updating with domains domains_data = [ @@ -71,20 +71,20 @@ def test_service_config_update_domains(): {"hostname": "allowed.com", "mode": "allow"}, {"hostname": "test.com", "mode": "block"}, ] - config.update_domains(domains_data) - assert config.domains == { + config.update_outbound_domains(domains_data) + assert config.outbound_domains == { "example.com": "block", "allowed.com": "allow", "test.com": "block", } # Test updating with empty list - config.update_domains([]) - assert config.domains == {} + config.update_outbound_domains([]) + assert config.outbound_domains == {} # Test updating with single domain - config.update_domains([{"hostname": "single.com", "mode": "allow"}]) - assert config.domains == {"single.com": "allow"} + config.update_outbound_domains([{"hostname": "single.com", "mode": "allow"}]) + assert config.outbound_domains == {"single.com": "allow"} def test_service_config_should_block_outgoing_request(): @@ -99,7 +99,7 @@ def test_service_config_should_block_outgoing_request(): # Test with block_new_outgoing_requests = False (default) # Only block if mode is "block" - config.update_domains( + config.update_outbound_domains( [ {"hostname": "blocked.com", "mode": "block"}, {"hostname": "allowed.com", "mode": "allow"}, @@ -125,13 +125,13 @@ def test_service_config_should_block_outgoing_request(): # Test edge cases config.set_block_new_outgoing_requests(False) - config.update_domains([]) # No domains configured + config.update_outbound_domains([]) # No domains configured assert ( config.should_block_outgoing_request("any.com") is False ) # No blocking when no domains config.set_block_new_outgoing_requests(True) - config.update_domains([]) # No domains configured + config.update_outbound_domains([]) # No domains configured assert ( config.should_block_outgoing_request("any.com") is True ) # Block all when block_new_outgoing_requests=True diff --git a/aikido_zen/sinks/socket/__init__.py b/aikido_zen/sinks/socket/__init__.py index 20c1f83a..574f2483 100644 --- a/aikido_zen/sinks/socket/__init__.py +++ b/aikido_zen/sinks/socket/__init__.py @@ -2,9 +2,13 @@ Sink module for `socket` """ +from aikido_zen.errors import AikidoSSRF from aikido_zen.helpers.get_argument import get_argument from aikido_zen.sinks import on_import, patch_function, after -from aikido_zen.sinks.socket.report_and_check_hostname import report_and_check_hostname +from aikido_zen.sinks.socket.normalize_hostname import normalize_hostname +from aikido_zen.sinks.socket.should_block_outbound_domain import ( + should_block_outbound_domain, +) from aikido_zen.vulnerabilities import run_vulnerability_scan @@ -14,7 +18,13 @@ def _getaddrinfo_after(func, instance, args, kwargs, return_value): host = get_argument(args, kwargs, 0, "host") port = get_argument(args, kwargs, 1, "port") - report_and_check_hostname(host, port) + # We want a normalized hostname for reporting & blocking outbound domains + # This function decodes the hostname if its written in punycode + hostname = normalize_hostname(host) + + # Store hostname and check if we should stop this request from happening + if should_block_outbound_domain(hostname, port): + raise AikidoSSRF(f"Zen has blocked an outbound connection to {hostname}") # Run vulnerability scan with the return value (DNS results) op = "socket.getaddrinfo" diff --git a/aikido_zen/sinks/socket/report_and_check_hostname.py b/aikido_zen/sinks/socket/report_and_check_hostname.py deleted file mode 100644 index 74099e91..00000000 --- a/aikido_zen/sinks/socket/report_and_check_hostname.py +++ /dev/null @@ -1,20 +0,0 @@ -from aikido_zen.context import get_current_context -from aikido_zen.errors import AikidoSSRF -from aikido_zen.sinks.socket.normalize_hostname import normalize_hostname -from aikido_zen.thread.thread_cache import get_cache - - -def report_and_check_hostname(hostname, port): - cache = get_cache() - if not cache: - return - - hostname = normalize_hostname(hostname) - cache.hostnames.add(hostname, port) - - context = get_current_context() - is_bypassed = context and cache.is_bypassed_ip(context.remote_address) - - if cache.config and not is_bypassed: - if cache.config.should_block_outgoing_request(hostname): - raise AikidoSSRF(f"Zen has blocked an outbound connection to {hostname}") diff --git a/aikido_zen/sinks/socket/should_block_outbound_domain.py b/aikido_zen/sinks/socket/should_block_outbound_domain.py new file mode 100644 index 00000000..aefdd14b --- /dev/null +++ b/aikido_zen/sinks/socket/should_block_outbound_domain.py @@ -0,0 +1,14 @@ +from aikido_zen.thread.thread_cache import get_cache + + +def should_block_outbound_domain(hostname, port): + process_cache = get_cache() + if not process_cache: + return False + + # We store the hostname before checking the blocking status + # This is because if we are in lockdown mode and blocking all new hostnames, it should still + # show up in the dashboard. This allows the user to allow traffic to newly detected hostnames. + process_cache.hostnames.add(hostname, port) + + return process_cache.config.should_block_outgoing_request(hostname) diff --git a/aikido_zen/sinks/tests/socket_test.py b/aikido_zen/sinks/tests/socket_test.py index 7d039276..6e8bfa45 100644 --- a/aikido_zen/sinks/tests/socket_test.py +++ b/aikido_zen/sinks/tests/socket_test.py @@ -44,7 +44,7 @@ def test_socket_getaddrinfo_block_specific_domain(): # Reset cache and set up blocking for specific domain cache = get_cache() cache.reset() - cache.config.update_domains( + cache.config.update_outbound_domains( [ {"hostname": "blocked.com", "mode": "block"}, {"hostname": "allowed.com", "mode": "allow"}, @@ -77,7 +77,7 @@ def test_socket_getaddrinfo_block_all_new_requests(): cache = get_cache() cache.reset() cache.config.set_block_new_outgoing_requests(True) - cache.config.update_domains([{"hostname": "allowed.com", "mode": "allow"}]) + cache.config.update_outbound_domains([{"hostname": "allowed.com", "mode": "allow"}]) # Test that unknown domain raises exception with pytest.raises(Exception) as exc_info: @@ -131,7 +131,7 @@ def test_service_config_should_block_outgoing_request(): assert not config.should_block_outgoing_request("example.com") # Test with specific domain blocked - config.update_domains([{"hostname": "blocked.com", "mode": "block"}]) + config.update_outbound_domains([{"hostname": "blocked.com", "mode": "block"}]) assert config.should_block_outgoing_request("blocked.com") assert not config.should_block_outgoing_request("allowed.com") @@ -141,13 +141,13 @@ def test_service_config_should_block_outgoing_request(): assert config.should_block_outgoing_request("blocked.com") # Still blocked # Test with explicitly allowed domain when block_new_outgoing_requests is True - config.update_domains([{"hostname": "allowed.com", "mode": "allow"}]) + config.update_outbound_domains([{"hostname": "allowed.com", "mode": "allow"}]) assert not config.should_block_outgoing_request("allowed.com") # Explicitly allowed assert config.should_block_outgoing_request("unknown.com") # Unknown still blocked def test_service_config_update_domains(): - """Test the update_domains method""" + """Test the update_outbound_domains method""" config = ServiceConfig( endpoints=[], last_updated_at=0, @@ -157,20 +157,20 @@ def test_service_config_update_domains(): ) # Test initial state - assert config.domains == {} + assert config.outbound_domains == {} # Test updating domains - config.update_domains( + config.update_outbound_domains( [ {"hostname": "example.com", "mode": "block"}, {"hostname": "allowed.com", "mode": "allow"}, ] ) - assert config.domains == {"example.com": "block", "allowed.com": "allow"} + assert config.outbound_domains == {"example.com": "block", "allowed.com": "allow"} # Test updating with empty list - config.update_domains([]) - assert config.domains == {} + config.update_outbound_domains([]) + assert config.outbound_domains == {} def test_service_config_set_block_new_outgoing_requests(): @@ -195,47 +195,6 @@ def test_service_config_set_block_new_outgoing_requests(): assert not config.block_new_outgoing_requests -def test_socket_getaddrinfo_bypassed_ip(): - """Test that getaddrinfo works when IP is in bypassed_ips list""" - # Reset cache and set up bypassed IPs - cache = get_cache() - cache.reset() - cache.config.set_bypassed_ips(["192.168.1.0/24"]) - cache.config.set_block_new_outgoing_requests(True) - cache.config.update_domains([{"hostname": "allowed.com", "mode": "allow"}]) - - # Bypassed IP not enforced : no context - with pytest.raises(Exception) as exc_info: - socket.getaddrinfo("unknown.com", 80) - assert "Zen has blocked an outbound connection to unknown.com" in str( - exc_info.value - ) - - generate_context(ip="1.1.1.1").set_as_current_context() - with pytest.raises(Exception) as exc_info: - socket.getaddrinfo("unknown.com", 80) - assert "Zen has blocked an outbound connection to unknown.com" in str( - exc_info.value - ) - - generate_context(ip="192.168.1.80").set_as_current_context() - try: - socket.getaddrinfo("unknown.com", 80) - except Exception: - pytest.fail("getaddrinfo should not throw an error if IP is bypassed") - - # Verify hostname was tracked even when bypassed - hostnames = get_cache().hostnames.as_array() - assert ( - len(hostnames) == 1 - ) # All attempts to same hostname:port are tracked together - assert hostnames[0]["hostname"] == "unknown.com" - assert hostnames[0]["port"] == 80 - assert ( - hostnames[0]["hits"] == 3 - ) # All 3 attempts were tracked (2 blocked, 1 bypassed) - - def test_socket_getaddrinfo_ip_address_as_hostname(): """Test that getaddrinfo works when hostname is an IP address""" # Reset cache to ensure clean state @@ -261,7 +220,7 @@ def test_punycode_normalization(): # Reset cache and set up blocking cache = get_cache() cache.reset() - cache.config.update_domains( + cache.config.update_outbound_domains( [ {"hostname": "ssrf-rédirects.testssandbox.com", "mode": "block"}, ]