Skip to content

Commit b627c82

Browse files
authored
[Key Vault] Revise KeyClient.get_random_bytes (Azure#20097)
1 parent f1d0d20 commit b627c82

File tree

7 files changed

+81
-13
lines changed

7 files changed

+81
-13
lines changed

sdk/keyvault/azure-keyvault-keys/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
### Features Added
66

77
### Breaking Changes
8+
> These changes do not impact the API of stable versions such as 4.4.0.
9+
> Only code written against a beta version such as 4.5.0b1 may be affected.
10+
- `KeyClient.get_random_bytes` now returns a `RandomBytes` model with bytes in a `value`
11+
property, rather than returning the bytes directly
812

913
### Bugs Fixed
1014

sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# -------------------------------------
55
from ._enums import KeyCurveName, KeyOperation, KeyType
66
from ._shared.client_base import ApiVersion
7-
from ._models import DeletedKey, JsonWebKey, KeyProperties, KeyVaultKey, KeyVaultKeyIdentifier
7+
from ._models import DeletedKey, JsonWebKey, KeyProperties, KeyVaultKey, KeyVaultKeyIdentifier, RandomBytes
88
from ._client import KeyClient
99

1010
__all__ = [
@@ -18,6 +18,7 @@
1818
"KeyType",
1919
"DeletedKey",
2020
"KeyProperties",
21+
"RandomBytes",
2122
]
2223

2324
from ._version import VERSION

sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_client.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ._shared import KeyVaultClientBase
99
from ._shared.exceptions import error_map as _error_map
1010
from ._shared._polling import DeleteRecoverPollingMethod, KeyVaultOperationPoller
11-
from ._models import DeletedKey, KeyVaultKey, KeyProperties
11+
from ._models import DeletedKey, KeyVaultKey, KeyProperties, RandomBytes
1212

1313
try:
1414
from typing import TYPE_CHECKING
@@ -615,13 +615,26 @@ def import_key(self, name, key, **kwargs):
615615

616616
@distributed_trace
617617
def get_random_bytes(self, count, **kwargs):
618-
# type: (int, **Any) -> bytes
618+
# type: (int, **Any) -> RandomBytes
619619
"""Get the requested number of random bytes from a managed HSM.
620620
621621
:param int count: The requested number of random bytes.
622622
:return: The random bytes.
623-
:rtype: bytes
623+
:rtype: ~azure.keyvault.keys.RandomBytes
624+
:raises:
625+
:class:`ValueError` if less than one random byte is requested,
626+
:class:`~azure.core.exceptions.HttpResponseError` for other errors
627+
628+
Example:
629+
.. literalinclude:: ../tests/test_key_client.py
630+
:start-after: [START get_random_bytes]
631+
:end-before: [END get_random_bytes]
632+
:language: python
633+
:caption: Get random bytes
634+
:dedent: 12
624635
"""
636+
if count < 1:
637+
raise ValueError("At least one random byte must be requested")
625638
parameters = self._models.GetRandomBytesRequest(count=count)
626639
result = self._client.get_random_bytes(vault_base_url=self._vault_url, parameters=parameters, **kwargs)
627-
return result.value
640+
return RandomBytes(value=result.value)

sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,3 +438,13 @@ def scheduled_purge_date(self):
438438
:rtype: ~datetime.datetime
439439
"""
440440
return self._scheduled_purge_date
441+
442+
443+
class RandomBytes(object):
444+
"""Contains random bytes returned from a managed HSM.
445+
446+
:param bytes value: the random bytes
447+
"""
448+
449+
def __init__(self, value):
450+
self.value = value

sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/aio/_client.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .._shared._polling_async import AsyncDeleteRecoverPollingMethod
1212
from .._shared import AsyncKeyVaultClientBase
1313
from .._shared.exceptions import error_map as _error_map
14-
from .. import DeletedKey, JsonWebKey, KeyVaultKey, KeyProperties
14+
from .. import DeletedKey, JsonWebKey, KeyVaultKey, KeyProperties, RandomBytes
1515

1616
if TYPE_CHECKING:
1717
# pylint:disable=ungrouped-imports
@@ -590,13 +590,26 @@ async def import_key(self, name: str, key: JsonWebKey, **kwargs: "Any") -> KeyVa
590590
return KeyVaultKey._from_key_bundle(bundle)
591591

592592
@distributed_trace_async
593-
async def get_random_bytes(self, count: int, **kwargs: "Any") -> bytes:
593+
async def get_random_bytes(self, count: int, **kwargs: "Any") -> RandomBytes:
594594
"""Get the requested number of random bytes from a managed HSM.
595595
596596
:param int count: The requested number of random bytes.
597597
:return: The random bytes.
598-
:rtype: bytes
598+
:rtype: ~azure.keyvault.keys.RandomBytes
599+
:raises:
600+
:class:`ValueError` if less than one random byte is requested,
601+
:class:`~azure.core.exceptions.HttpResponseError` for other errors
602+
603+
Example:
604+
.. literalinclude:: ../tests/test_keys_async.py
605+
:start-after: [START get_random_bytes]
606+
:end-before: [END get_random_bytes]
607+
:language: python
608+
:caption: Get random bytes
609+
:dedent: 12
599610
"""
611+
if count < 1:
612+
raise ValueError("At least one random byte must be requested")
600613
parameters = self._models.GetRandomBytesRequest(count=count)
601614
result = await self._client.get_random_bytes(vault_base_url=self._vault_url, parameters=parameters, **kwargs)
602-
return result.value
615+
return RandomBytes(value=result.value)

sdk/keyvault/azure-keyvault-keys/tests/test_key_client.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
1313
from azure.core.pipeline.policies import SansIOHTTPPolicy
1414
from azure.keyvault.keys import ApiVersion, JsonWebKey, KeyClient
15+
import pytest
1516
from six import byte2int
1617

1718
from _shared.test_case import KeyVaultTestCase
@@ -422,12 +423,24 @@ def test_get_random_bytes(self, client, **kwargs):
422423

423424
generated_random_bytes = []
424425
for i in range(5):
425-
random_bytes = client.get_random_bytes(count=8)
426+
# [START get_random_bytes]
427+
# get eight random bytes from a managed HSM
428+
result = client.get_random_bytes(count=8)
429+
random_bytes = result.value
430+
# [END get_random_bytes]
426431
assert len(random_bytes) == 8
427-
assert all([random_bytes != rb] for rb in generated_random_bytes)
432+
assert all(random_bytes != rb for rb in generated_random_bytes)
428433
generated_random_bytes.append(random_bytes)
429434

430435

436+
def test_positive_bytes_count_required():
437+
client = KeyClient("...", object())
438+
with pytest.raises(ValueError):
439+
client.get_random_bytes(count=0)
440+
with pytest.raises(ValueError):
441+
client.get_random_bytes(count=-1)
442+
443+
431444
def test_service_headers_allowed_in_logs():
432445
service_headers = {"x-ms-keyvault-network-info", "x-ms-keyvault-region", "x-ms-keyvault-service-version"}
433446
client = KeyClient("...", object())

sdk/keyvault/azure-keyvault-keys/tests/test_keys_async.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from azure.core.pipeline.policies import SansIOHTTPPolicy
1414
from azure.keyvault.keys import ApiVersion, JsonWebKey
1515
from azure.keyvault.keys.aio import KeyClient
16+
import pytest
1617
from six import byte2int
1718

1819
from _shared.test_case_async import KeyVaultTestCase
@@ -452,12 +453,25 @@ async def test_get_random_bytes(self, client, **kwargs):
452453

453454
generated_random_bytes = []
454455
for i in range(5):
455-
random_bytes = await client.get_random_bytes(count=8)
456+
# [START get_random_bytes]
457+
# get eight random bytes from a managed HSM
458+
result = await client.get_random_bytes(count=8)
459+
random_bytes = result.value
460+
# [END get_random_bytes]
456461
assert len(random_bytes) == 8
457-
assert all([random_bytes != rb] for rb in generated_random_bytes)
462+
assert all(random_bytes != rb for rb in generated_random_bytes)
458463
generated_random_bytes.append(random_bytes)
459464

460465

466+
@pytest.mark.asyncio
467+
async def test_positive_bytes_count_required():
468+
client = KeyClient("...", object())
469+
with pytest.raises(ValueError):
470+
await client.get_random_bytes(count=0)
471+
with pytest.raises(ValueError):
472+
await client.get_random_bytes(count=-1)
473+
474+
461475
def test_service_headers_allowed_in_logs():
462476
service_headers = {"x-ms-keyvault-network-info", "x-ms-keyvault-region", "x-ms-keyvault-service-version"}
463477
client = KeyClient("...", object())

0 commit comments

Comments
 (0)