44import json
55import os
66import sys
7- from typing import List
7+ from typing import Any , Dict , List
88
99# currently we don't support python 3.13t due to tensorrt does not support 3.13t
1010disabled_python_versions : List [str ] = ["3.13t" ]
1717sbsa_container_image : str = "quay.io/pypa/manylinux_2_34_aarch64"
1818
1919
20+ def validate_matrix (matrix_dict : Dict [str , Any ]) -> None :
21+ """Validate the structure of the input matrix."""
22+ if not isinstance (matrix_dict , dict ):
23+ raise ValueError ("Matrix must be a dictionary" )
24+ if "include" not in matrix_dict :
25+ raise ValueError ("Matrix must contain 'include' key" )
26+ if not isinstance (matrix_dict ["include" ], list ):
27+ raise ValueError ("Matrix 'include' must be a list" )
28+
29+
30+ def filter_matrix_item (
31+ item : Dict [str , Any ], is_jetpack : bool , limit_pr_builds : bool , is_nightly : bool
32+ ) -> bool :
33+ """Filter a single matrix item based on the build type and requirements."""
34+ if item ["python_version" ] in disabled_python_versions :
35+ # Skipping disabled Python version
36+ return False
37+
38+ if is_jetpack :
39+ if limit_pr_builds :
40+ # pr build,matrix passed from test-infra is cu128, python 3.9, change to cu126, python 3.10
41+ item ["desired_cuda" ] = "cu126"
42+ item ["python_version" ] = "3.10"
43+ item ["container_image" ] = jetpack_container_image
44+ return True
45+ elif is_nightly :
46+ # nightly build, matrix passed from test-infra is cu128, all python versions, change to cu126, python 3.10
47+ if item ["python_version" ] in jetpack_python_versions :
48+ item ["desired_cuda" ] = "cu126"
49+ item ["container_image" ] = jetpack_container_image
50+ return True
51+ return False
52+ else :
53+ if (
54+ item ["python_version" ] in jetpack_python_versions
55+ and item ["desired_cuda" ] in jetpack_cuda_versions
56+ ):
57+ item ["container_image" ] = jetpack_container_image
58+ return True
59+ return False
60+ else :
61+ if item ["gpu_arch_type" ] == "cuda-aarch64" :
62+ # pytorch image:pytorch/manylinuxaarch64-builder:cuda12.8 comes with glibc2.28
63+ # however, TensorRT requires glibc2.31 on aarch64 platform
64+ # TODO: in future, if pytorch supports aarch64 with glibc2.31, we should switch to use the pytorch image
65+ item ["container_image" ] = sbsa_container_image
66+ return True
67+ return True
68+
69+
2070def main (args : list [str ]) -> None :
2171 parser = argparse .ArgumentParser ()
2272 parser .add_argument (
@@ -42,41 +92,39 @@ def main(args: list[str]) -> None:
4292 default = os .getenv ("LIMIT_PR_BUILDS" , "false" ),
4393 )
4494
95+ parser .add_argument (
96+ "--is-nightly" ,
97+ help = "If it is a nightly build" ,
98+ type = str ,
99+ choices = ["true" , "false" ],
100+ default = os .getenv ("LIMIT_PR_BUILDS" , "false" ),
101+ )
102+
45103 options = parser .parse_args (args )
46104 if options .matrix == "" :
47- raise Exception ("--matrix needs to be provided" )
105+ raise ValueError ("--matrix needs to be provided" )
106+
107+ try :
108+ matrix_dict = json .loads (options .matrix )
109+ validate_matrix (matrix_dict )
110+ except json .JSONDecodeError as e :
111+ raise ValueError (f"Invalid JSON in matrix: { e } " )
112+ except ValueError as e :
113+ raise ValueError (f"Invalid matrix structure: { e } " )
48114
49- matrix_dict = json .loads (options .matrix )
50115 includes = matrix_dict ["include" ]
51116 filtered_includes = []
117+
52118 for item in includes :
53- if item ["python_version" ] in disabled_python_versions :
54- continue
55- if options .jetpack == "true" :
56- if options .limit_pr_builds == "true" :
57- # limit pr build, matrix passed in from test-infra is cu128, python 3.9, change to cu126, python 3.10
58- item ["desired_cuda" ] = "cu126"
59- item ["python_version" ] = "3.10"
60- item ["container_image" ] = jetpack_container_image
61- filtered_includes .append (item )
62- else :
63- if (
64- item ["python_version" ] in jetpack_python_versions
65- and item ["desired_cuda" ] in jetpack_cuda_versions
66- ):
67- item ["container_image" ] = jetpack_container_image
68- filtered_includes .append (item )
69- else :
70- if item ["gpu_arch_type" ] == "cuda-aarch64" :
71- # pytorch image:pytorch/manylinuxaarch64-builder:cuda12.8 comes with glibc2.28
72- # however, TensorRT requires glibc2.31 on aarch64 platform
73- # TODO: in future, if pytorch supports aarch64 with glibc2.31, we should switch to use the pytorch image
74- item ["container_image" ] = sbsa_container_image
75- filtered_includes .append (item )
76- else :
77- filtered_includes .append (item )
78- filtered_matrix_dict = {}
79- filtered_matrix_dict ["include" ] = filtered_includes
119+ if filter_matrix_item (
120+ item ,
121+ options .jetpack == "true" ,
122+ options .limit_pr_builds == "true" ,
123+ options .is_nightly == "true" ,
124+ ):
125+ filtered_includes .append (item )
126+
127+ filtered_matrix_dict = {"include" : filtered_includes }
80128 print (json .dumps (filtered_matrix_dict ))
81129
82130
0 commit comments