Skip to content

Commit 0fbac6d

Browse files
Add search_public_hub_models function for jumpstart hub (#1683)
* first commit for search_public_hub_models to test * update SearchKeyWords style and HubContent initialization * update keywords for unquote logic, add pydoc, only expose search_public_hub_models function * update _list_all_hub_models to speed up, add unit and integ tests * update sagemaker_core import path, unit test mock_models data type * expand search_public_hub_models functionality for optional hub_name and sagemaker_session, update integ and unit test following the change
1 parent ce3099a commit 0fbac6d

File tree

4 files changed

+310
-1
lines changed

4 files changed

+310
-1
lines changed

sagemaker_utils/src/sagemaker/utils/jumpstart/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
"""This module contains JumpStart utilites for the SageMaker Python SDK."""
1414
from __future__ import absolute_import
1515

16-
from sagemaker.utils.jumpstart.configs import JumpStartConfig # noqa: F401
16+
from sagemaker.utils.jumpstart.configs import JumpStartConfig # noqa: F401
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import re
2+
import logging
3+
from typing import List, Iterator, Optional
4+
from sagemaker_core.helper.session_helper import Session
5+
from sagemaker_core.resources import HubContent
6+
7+
logger = logging.getLogger(__name__)
8+
9+
class _Filter:
10+
"""
11+
A filter that evaluates logical expressions against a list of keyword strings.
12+
13+
Supports logical operators (AND, OR, NOT), parentheses for grouping, and wildcard patterns
14+
(e.g., `text-*`, `*ai`, `@task:foo`).
15+
16+
Example:
17+
filt = _Filter("(@framework:huggingface OR text-*) AND NOT deprecated")
18+
filt.match(["@framework:huggingface", "text-generation"]) # Returns True
19+
"""
20+
21+
def __init__(self, expression: str) -> None:
22+
"""
23+
Initialize the filter with a string expression.
24+
25+
Args:
26+
expression (str): A logical expression to evaluate against keywords.
27+
Supports AND, OR, NOT, parentheses, and wildcard patterns (*).
28+
"""
29+
self.expression: str = expression
30+
31+
def match(self, keywords: List[str]) -> bool:
32+
"""
33+
Evaluate the filter expression against a list of keywords.
34+
35+
Args:
36+
keywords (List[str]): A list of keyword strings to test.
37+
38+
Returns:
39+
bool: True if the expression evaluates to True for the given keywords, else False.
40+
"""
41+
expr: str = self._convert_expression(self.expression)
42+
try:
43+
return eval(expr, {"__builtins__": {}}, {"keywords": keywords, "any": any})
44+
except Exception:
45+
return False
46+
47+
def _convert_expression(self, expr: str) -> str:
48+
"""
49+
Convert the logical filter expression into a Python-evaluable string.
50+
51+
Args:
52+
expr (str): The raw expression to convert.
53+
54+
Returns:
55+
str: A Python expression string using 'any' and logical operators.
56+
"""
57+
tokens: List[str] = re.findall(r'\bAND\b|\bOR\b|\bNOT\b|[^\s()]+|\(|\)', expr, flags=re.IGNORECASE)
58+
59+
def wildcard_condition(pattern: str) -> str:
60+
pattern = pattern.strip('"').strip("'")
61+
stripped = pattern.strip("*")
62+
63+
if pattern.startswith("*") and pattern.endswith("*"):
64+
return f"{repr(stripped)} in k"
65+
elif pattern.startswith("*"):
66+
return f"k.endswith({repr(stripped)})"
67+
elif pattern.endswith("*"):
68+
return f"k.startswith({repr(stripped)})"
69+
else:
70+
return f"k == {repr(pattern)}"
71+
72+
def convert_token(token: str) -> str:
73+
upper = token.upper()
74+
if upper == 'AND':
75+
return 'and'
76+
elif upper == 'OR':
77+
return 'or'
78+
elif upper == 'NOT':
79+
return 'not'
80+
elif token in ('(', ')'):
81+
return token
82+
else:
83+
return f"any({wildcard_condition(token)} for k in keywords)"
84+
85+
converted_tokens = [convert_token(tok) for tok in tokens]
86+
return ' '.join(converted_tokens)
87+
88+
89+
def _list_all_hub_models(hub_name: str, sm_client: Session) -> Iterator[HubContent]:
90+
"""
91+
Retrieve all model entries from the specified hub and yield them one by one.
92+
93+
This function paginates through the SageMaker Hub API to retrieve all published models of type "Model"
94+
and yields them as `HubContent` objects.
95+
96+
Args:
97+
hub_name (str): The name of the hub to query.
98+
sm_client (Session): The SageMaker session.
99+
100+
Yields:
101+
HubContent: A `HubContent` object representing a single model entry from the hub.
102+
"""
103+
next_token = None
104+
105+
while True:
106+
# Prepare the request parameters
107+
params = {
108+
"HubName": hub_name,
109+
"HubContentType": "Model",
110+
"MaxResults": 100
111+
}
112+
113+
# Add NextToken if it exists
114+
if next_token:
115+
params["NextToken"] = next_token
116+
117+
# Make the API call
118+
response = sm_client.list_hub_contents(**params)
119+
120+
# Yield each content summary
121+
for content in response["HubContentSummaries"]:
122+
yield HubContent(
123+
hub_name=hub_name,
124+
hub_content_arn=content["HubContentArn"],
125+
hub_content_type="Model",
126+
hub_content_name=content["HubContentName"],
127+
hub_content_version=content["HubContentVersion"],
128+
hub_content_description=content.get("HubContentDescription", ""),
129+
hub_content_search_keywords=content.get("HubContentSearchKeywords", []),
130+
)
131+
132+
# Check if there are more results
133+
next_token = response.get("NextToken", None)
134+
if not next_token or len(response["HubContentSummaries"]) == 0:
135+
break # Exit the loop if there are no more pages
136+
137+
138+
def search_public_hub_models(
139+
query: str,
140+
hub_name: Optional[str] = "SageMakerPublicHub",
141+
sagemaker_session: Optional[Session] = None,
142+
) -> List[HubContent]:
143+
"""
144+
Search and filter models from hub using a keyword expression.
145+
146+
Args:
147+
query (str): A logical expression used to filter models by keywords.
148+
Example: "@task:text-generation AND NOT @framework:legacy"
149+
hub_name (Optional[str]): The name of the hub to query. Defaults to "SageMakerPublicHub".
150+
sagemaker_session (Optional[Session]): An optional SageMaker `Session` object. If not provided,
151+
a default session will be created and a warning will be logged.
152+
153+
Returns:
154+
List[HubContent]: A list of filtered `HubContent` model objects that match the query.
155+
"""
156+
if sagemaker_session is None:
157+
sagemaker_session = Session()
158+
logger.warning("SageMaker session not provided. Using default Session.")
159+
sm_client = sagemaker_session.sagemaker_client
160+
161+
models = _list_all_hub_models(hub_name, sm_client)
162+
filt = _Filter(query)
163+
results: List[HubContent] = []
164+
165+
for model in models:
166+
keywords = model.hub_content_search_keywords
167+
normalized_keywords = [kw.replace(" ", "-") for kw in keywords]
168+
169+
if filt.match(normalized_keywords):
170+
results.append(model)
171+
172+
return results
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Test for JumpStart search_public_hub_models function."""
14+
from __future__ import absolute_import
15+
16+
import pytest
17+
from sagemaker.utils.jumpstart.search import search_public_hub_models
18+
from sagemaker_core.helper.session_helper import Session
19+
from sagemaker_core.resources import HubContent
20+
21+
22+
@pytest.mark.integ
23+
def test_search_public_hub_models_default_args():
24+
# Only query, uses default hub name and session
25+
query = "@task:text-generation OR @framework:huggingface"
26+
results = search_public_hub_models(query)
27+
28+
assert isinstance(results, list)
29+
assert all(isinstance(m, HubContent) for m in results)
30+
assert len(results) > 0, "Expected at least one matching model from the public hub"
31+
32+
33+
@pytest.mark.integ
34+
def test_search_public_hub_models_custom_session():
35+
# Provide a custom SageMaker session
36+
session = Session()
37+
query = "@task:text-generation"
38+
results = search_public_hub_models(query, sagemaker_session=session)
39+
40+
assert isinstance(results, list)
41+
assert all(isinstance(m, HubContent) for m in results)
42+
43+
44+
@pytest.mark.integ
45+
def test_search_public_hub_models_custom_hub_name():
46+
# Using the default public hub but provided explicitly
47+
query = "@framework:huggingface"
48+
results = search_public_hub_models(query, hub_name="SageMakerPublicHub")
49+
50+
assert isinstance(results, list)
51+
assert all(isinstance(m, HubContent) for m in results)
52+
53+
54+
@pytest.mark.integ
55+
def test_search_public_hub_models_all_args():
56+
# Provide both hub_name and session explicitly
57+
session = Session()
58+
query = "@task:natural-language-processing"
59+
results = search_public_hub_models(query, hub_name="SageMakerPublicHub", sagemaker_session=session)
60+
61+
assert isinstance(results, list)
62+
assert all(isinstance(m, HubContent) for m in results)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Test for JumpStart search_public_hub_models function."""
14+
from __future__ import absolute_import
15+
16+
import pytest
17+
from unittest.mock import patch
18+
from sagemaker.utils.jumpstart.search import _Filter, search_public_hub_models
19+
from sagemaker_core.resources import HubContent
20+
21+
22+
@pytest.mark.parametrize("query,keywords,expected", [
23+
("text-*", ["text-classification"], True),
24+
("@task:foo", ["@task:foo"], True),
25+
("@task:foo AND bar-*", ["@task:foo", "bar-baz"], True),
26+
("@task:foo AND bar-*", ["@task:foo"], False),
27+
("@task:foo OR bar-*", ["bar-qux"], True),
28+
("@task:foo OR bar-*", ["nothing"], False),
29+
("NOT @task:legacy", ["@task:modern"], True),
30+
("NOT @task:legacy", ["@task:legacy"], False),
31+
("(@framework:huggingface OR text-*) AND NOT @provider:qwen",
32+
["@framework:huggingface", "text-generator"], True),
33+
("(@framework:huggingface OR text-*) AND NOT @provider:qwen",
34+
["@framework:huggingface", "@provider:qwen"], False),
35+
])
36+
def test_filter_match(query, keywords, expected):
37+
f = _Filter(query)
38+
assert f.match(keywords) == expected
39+
40+
41+
def test_search_public_hub_models():
42+
mock_models = [
43+
HubContent(
44+
hub_content_type="Model",
45+
hub_content_name="textgen",
46+
hub_content_arn="arn:example:textgen",
47+
hub_content_version="1.0",
48+
document_schema_version="1.0",
49+
hub_content_display_name="Text Gen",
50+
hub_content_description="Generates text",
51+
hub_content_search_keywords=["@task:text-generation", "@framework:huggingface"],
52+
hub_content_status="Published",
53+
creation_time="2023-01-01T00:00:00Z",
54+
hub_name="SageMakerPublicHub"
55+
),
56+
HubContent(
57+
hub_content_type="Model",
58+
hub_content_name="qwen-model",
59+
hub_content_arn="arn:example:qwen",
60+
hub_content_version="1.0",
61+
document_schema_version="1.0",
62+
hub_content_display_name="Qwen",
63+
hub_content_description="Qwen LLM",
64+
hub_content_search_keywords=["@provider:qwen"],
65+
hub_content_status="Published",
66+
creation_time="2023-01-01T00:00:00Z",
67+
hub_name="SageMakerPublicHub"
68+
),
69+
]
70+
71+
with patch("sagemaker.utils.jumpstart.search._list_all_hub_models", return_value=mock_models):
72+
results = search_public_hub_models("(@task:text-generation OR huggingface) AND NOT @provider:qwen")
73+
assert len(results) == 1
74+
assert isinstance(results[0], HubContent)
75+
assert results[0].hub_content_name == "textgen"

0 commit comments

Comments
 (0)