diff --git a/src/ahttpx/_request.py b/src/ahttpx/_request.py index 78b82282d0..a247b8b458 100644 --- a/src/ahttpx/_request.py +++ b/src/ahttpx/_request.py @@ -29,6 +29,11 @@ def __init__( if "Host" not in self.headers: self.headers = self.headers.copy_set("Host", self.url.netloc) + if self.url.scheme not in ('http', 'https'): + raise ValueError(f'Invalid scheme for URL {str(self.url)!r}.') + if not self.url.netloc: + raise ValueError(f'Missing host for URL {str(self.url)!r}.') + if content is not None: if isinstance(content, bytes): self.stream = ByteStream(content) diff --git a/src/httpx/_request.py b/src/httpx/_request.py index 1b739b1872..03c78dfea2 100644 --- a/src/httpx/_request.py +++ b/src/httpx/_request.py @@ -28,6 +28,11 @@ def __init__( # A client MUST include a Host header field in all HTTP/1.1 request messages. if "Host" not in self.headers: self.headers = self.headers.copy_set("Host", self.url.netloc) + + if self.url.scheme not in ('http', 'https'): + raise ValueError(f'Invalid scheme for URL {str(self.url)!r}.') + if not self.url.netloc: + raise ValueError(f'Missing host for URL {str(self.url)!r}.') if content is not None: if isinstance(content, bytes): diff --git a/tests/test_request.py b/tests/test_request.py index a69e1d1358..cab0a5e109 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -1,4 +1,5 @@ import httpx +import pytest class ByteIterator: @@ -77,3 +78,13 @@ def test_request_empty_post(): "Content-Length": "0", } assert r.read() == b'' + + +def test_request_invalid_scheme(): + with pytest.raises(ValueError): + httpx.Request("GET", "ws://example.com") + + +def test_request_missing_host(): + with pytest.raises(ValueError): + r = httpx.Request("GET", "https:/example.com")