1+ import re
2+ import logging
3+ from typing import List , Iterator , Optional
4+ from sagemaker_core .helper .session_helper import Session
5+ from sagemaker_core .resources import HubContent
6+
7+ logger = logging .getLogger (__name__ )
8+
9+ class _Filter :
10+ """
11+ A filter that evaluates logical expressions against a list of keyword strings.
12+
13+ Supports logical operators (AND, OR, NOT), parentheses for grouping, and wildcard patterns
14+ (e.g., `text-*`, `*ai`, `@task:foo`).
15+
16+ Example:
17+ filt = _Filter("(@framework:huggingface OR text-*) AND NOT deprecated")
18+ filt.match(["@framework:huggingface", "text-generation"]) # Returns True
19+ """
20+
21+ def __init__ (self , expression : str ) -> None :
22+ """
23+ Initialize the filter with a string expression.
24+
25+ Args:
26+ expression (str): A logical expression to evaluate against keywords.
27+ Supports AND, OR, NOT, parentheses, and wildcard patterns (*).
28+ """
29+ self .expression : str = expression
30+
31+ def match (self , keywords : List [str ]) -> bool :
32+ """
33+ Evaluate the filter expression against a list of keywords.
34+
35+ Args:
36+ keywords (List[str]): A list of keyword strings to test.
37+
38+ Returns:
39+ bool: True if the expression evaluates to True for the given keywords, else False.
40+ """
41+ expr : str = self ._convert_expression (self .expression )
42+ try :
43+ return eval (expr , {"__builtins__" : {}}, {"keywords" : keywords , "any" : any })
44+ except Exception :
45+ return False
46+
47+ def _convert_expression (self , expr : str ) -> str :
48+ """
49+ Convert the logical filter expression into a Python-evaluable string.
50+
51+ Args:
52+ expr (str): The raw expression to convert.
53+
54+ Returns:
55+ str: A Python expression string using 'any' and logical operators.
56+ """
57+ tokens : List [str ] = re .findall (r'\bAND\b|\bOR\b|\bNOT\b|[^\s()]+|\(|\)' , expr , flags = re .IGNORECASE )
58+
59+ def wildcard_condition (pattern : str ) -> str :
60+ pattern = pattern .strip ('"' ).strip ("'" )
61+ stripped = pattern .strip ("*" )
62+
63+ if pattern .startswith ("*" ) and pattern .endswith ("*" ):
64+ return f"{ repr (stripped )} in k"
65+ elif pattern .startswith ("*" ):
66+ return f"k.endswith({ repr (stripped )} )"
67+ elif pattern .endswith ("*" ):
68+ return f"k.startswith({ repr (stripped )} )"
69+ else :
70+ return f"k == { repr (pattern )} "
71+
72+ def convert_token (token : str ) -> str :
73+ upper = token .upper ()
74+ if upper == 'AND' :
75+ return 'and'
76+ elif upper == 'OR' :
77+ return 'or'
78+ elif upper == 'NOT' :
79+ return 'not'
80+ elif token in ('(' , ')' ):
81+ return token
82+ else :
83+ return f"any({ wildcard_condition (token )} for k in keywords)"
84+
85+ converted_tokens = [convert_token (tok ) for tok in tokens ]
86+ return ' ' .join (converted_tokens )
87+
88+
89+ def _list_all_hub_models (hub_name : str , sm_client : Session ) -> Iterator [HubContent ]:
90+ """
91+ Retrieve all model entries from the specified hub and yield them one by one.
92+
93+ This function paginates through the SageMaker Hub API to retrieve all published models of type "Model"
94+ and yields them as `HubContent` objects.
95+
96+ Args:
97+ hub_name (str): The name of the hub to query.
98+ sm_client (Session): The SageMaker session.
99+
100+ Yields:
101+ HubContent: A `HubContent` object representing a single model entry from the hub.
102+ """
103+ next_token = None
104+
105+ while True :
106+ # Prepare the request parameters
107+ params = {
108+ "HubName" : hub_name ,
109+ "HubContentType" : "Model" ,
110+ "MaxResults" : 100
111+ }
112+
113+ # Add NextToken if it exists
114+ if next_token :
115+ params ["NextToken" ] = next_token
116+
117+ # Make the API call
118+ response = sm_client .list_hub_contents (** params )
119+
120+ # Yield each content summary
121+ for content in response ["HubContentSummaries" ]:
122+ yield HubContent (
123+ hub_name = hub_name ,
124+ hub_content_arn = content ["HubContentArn" ],
125+ hub_content_type = "Model" ,
126+ hub_content_name = content ["HubContentName" ],
127+ hub_content_version = content ["HubContentVersion" ],
128+ hub_content_description = content .get ("HubContentDescription" , "" ),
129+ hub_content_search_keywords = content .get ("HubContentSearchKeywords" , []),
130+ )
131+
132+ # Check if there are more results
133+ next_token = response .get ("NextToken" , None )
134+ if not next_token or len (response ["HubContentSummaries" ]) == 0 :
135+ break # Exit the loop if there are no more pages
136+
137+
138+ def search_public_hub_models (
139+ query : str ,
140+ hub_name : Optional [str ] = "SageMakerPublicHub" ,
141+ sagemaker_session : Optional [Session ] = None ,
142+ ) -> List [HubContent ]:
143+ """
144+ Search and filter models from hub using a keyword expression.
145+
146+ Args:
147+ query (str): A logical expression used to filter models by keywords.
148+ Example: "@task:text-generation AND NOT @framework:legacy"
149+ hub_name (Optional[str]): The name of the hub to query. Defaults to "SageMakerPublicHub".
150+ sagemaker_session (Optional[Session]): An optional SageMaker `Session` object. If not provided,
151+ a default session will be created and a warning will be logged.
152+
153+ Returns:
154+ List[HubContent]: A list of filtered `HubContent` model objects that match the query.
155+ """
156+ if sagemaker_session is None :
157+ sagemaker_session = Session ()
158+ logger .warning ("SageMaker session not provided. Using default Session." )
159+ sm_client = sagemaker_session .sagemaker_client
160+
161+ models = _list_all_hub_models (hub_name , sm_client )
162+ filt = _Filter (query )
163+ results : List [HubContent ] = []
164+
165+ for model in models :
166+ keywords = model .hub_content_search_keywords
167+ normalized_keywords = [kw .replace (" " , "-" ) for kw in keywords ]
168+
169+ if filt .match (normalized_keywords ):
170+ results .append (model )
171+
172+ return results
0 commit comments