Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ flake8.report
junit*.xml
doc/build
.cache
.idea/
33 changes: 33 additions & 0 deletions test/inventory.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
all:
hosts:
s1:
ansible_host: 192.168.1.1
s2:
ansible_host: 192.168.1.2
s3:
ansible_host: 192.168.1.3
s4:
ansible_host: 192.168.1.4
s5:
ansible_host: 192.168.1.5
s6:
ansible_host: 192.168.1.6
s7:
ansible_host: 192.168.1.7
s8:
ansible_host: 192.168.1.8
s9:
ansible_host: 192.168.1.9

children:
servers:
hosts:
s1:
s2:
s3:
s4:
s5:
s6:
s7:
s8:
s9:
30 changes: 30 additions & 0 deletions test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,14 @@ def test_docker_encoding(host):
def test_parse_hostspec(hostspec, expected):
assert BaseBackend.parse_hostspec(hostspec) == expected

@pytest.mark.parametrize(
"hostspec,expected",
[
("ansible://host1", ('host1', {'connection': 'ansible'})),
],
)
def test_init_parse_hostspec(hostspec, expected):
assert testinfra.backend.parse_hostspec(hostspec) == expected

@pytest.mark.parametrize(
"hostspec,pod,container,namespace,kubeconfig,context",
Expand Down Expand Up @@ -642,6 +650,28 @@ def test_get_hosts():
]


def test_get_hosts_ansible_limit():
# Hosts returned by get_host must be deduplicated (by name & kwargs) and in
# same order as asked
hosts = testinfra.backend.get_backends(
[
"ansible://s%5B1-4%5D%2A?ansible_inventory=inventory.yml" # s%5B1-4%5D%2A == s[1-4]*
]
)
assert [h.hostname for h in hosts] == ["s1", "s2", "s3", "s4"]

def test_get_hosts_ansible_limit_from_kwargs():
# Hosts returned by get_host must be deduplicated (by name & kwargs) and in
# same order as asked
hosts = testinfra.backend.get_backends(
[
"ansible://all"
],
ansible_inventory="inventory.yml",
ansible_limit="s[5-8]*"
)
assert [h.hostname for h in hosts] == ["s5", "s6", "s7", "s8"]

@pytest.mark.testinfra_hosts(*HOSTS)
def test_command_deadlock(host):
# Test for deadlock when exceeding Paramiko transport buffer (2MB)
Expand Down
1 change: 1 addition & 0 deletions testinfra/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def get_backends(
backends = {}
for hostspec in hosts:
host, kw = parse_hostspec(hostspec)
host = urllib.parse.unquote(host)
for k, v in kwargs.items():
kw.setdefault(k, v)
connection = kw.get("connection")
Expand Down
13 changes: 12 additions & 1 deletion testinfra/backend/ansible.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,15 @@ def get_variables(self) -> dict[str, Any]:
@classmethod
def get_hosts(cls, host: str, **kwargs: Any) -> list[str]:
inventory = kwargs.get("ansible_inventory")
return AnsibleRunner.get_runner(inventory).get_hosts(host or "all")
hosts = AnsibleRunner.get_runner(inventory).get_hosts(host or "all")
limit = kwargs.get("ansible_limit")
if limit:
# Filter hosts based on the limit expression
from ansible.parsing.dataloader import DataLoader
from ansible.inventory.manager import InventoryManager
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will be clearer to bring the imports to the top of the file

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@osk-8 done :)


loader = DataLoader()
inventory_manager = InventoryManager(loader=loader, sources=inventory)
return list(map(lambda h: h.address, inventory_manager.get_hosts(pattern=limit)))
else:
return hosts
7 changes: 7 additions & 0 deletions testinfra/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ def pytest_addoption(parser: pytest.Parser) -> None:
dest="nagios",
help="Nagios plugin",
)
group.addoption(
"--ansible-limit",
action="store_true",
dest="ansible_limit",
help="Limit to specific hosts using the same syntax as Ansible's --limit option.",
)


def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
Expand All @@ -126,6 +132,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
sudo_user=metafunc.config.option.sudo_user,
ansible_inventory=metafunc.config.option.ansible_inventory,
force_ansible=metafunc.config.option.force_ansible,
ansible_limit=metafunc.config.option.ansible_limit,
)
params = sorted(params, key=lambda x: x.backend.get_pytest_id())
ids = [e.backend.get_pytest_id() for e in params]
Expand Down