|
15 | 15 |
|
16 | 16 | import importlib.util |
17 | 17 | import json |
18 | | -import uuid |
19 | | -from typing import Any, Type, List, Dict, Optional, Union |
20 | | -from dataclasses import dataclass, field |
21 | 18 | import logging |
22 | 19 | import os |
23 | 20 | import re |
24 | | - |
| 21 | +import uuid |
| 22 | +from dataclasses import dataclass, field |
25 | 23 | from pathlib import Path |
| 24 | +from typing import Any, Dict, List, Optional, Type, Union |
26 | 25 |
|
27 | 26 | from botocore.exceptions import ClientError |
28 | | -from sagemaker_core.main.resources import TrainingJob |
29 | | - |
30 | | -from sagemaker.transformer import Transformer |
| 27 | +from sagemaker import Session |
31 | 28 | from sagemaker.async_inference import AsyncInferenceConfig |
32 | | -from sagemaker.batch_inference.batch_transform_inference_config import BatchTransformInferenceConfig |
| 29 | +from sagemaker.base_predictor import PredictorBase |
| 30 | +from sagemaker.batch_inference.batch_transform_inference_config import \ |
| 31 | + BatchTransformInferenceConfig |
33 | 32 | from sagemaker.compute_resource_requirements import ResourceRequirements |
34 | | -from sagemaker.enums import Tag, EndpointType |
| 33 | +from sagemaker.deserializers import JSONDeserializer, TorchTensorDeserializer |
| 34 | +from sagemaker.enums import EndpointType, Tag |
35 | 35 | from sagemaker.estimator import Estimator |
| 36 | +from sagemaker.huggingface.llm_utils import ( |
| 37 | + download_huggingface_model_metadata, get_huggingface_model_metadata) |
36 | 38 | from sagemaker.jumpstart.accessors import JumpStartS3PayloadAccessor |
| 39 | +from sagemaker.jumpstart.model import JumpStartModel |
37 | 40 | from sagemaker.jumpstart.utils import get_jumpstart_content_bucket |
38 | | -from sagemaker.s3 import S3Downloader |
39 | | -from sagemaker import Session |
40 | 41 | from sagemaker.model import Model |
41 | | -from sagemaker.jumpstart.model import JumpStartModel |
42 | | -from sagemaker.base_predictor import PredictorBase |
| 42 | +from sagemaker.modules import logger |
| 43 | +from sagemaker.modules.train import ModelTrainer |
| 44 | +from sagemaker.predictor import Predictor |
| 45 | +from sagemaker.s3 import S3Downloader |
43 | 46 | from sagemaker.serializers import NumpySerializer, TorchTensorSerializer |
44 | | -from sagemaker.deserializers import JSONDeserializer, TorchTensorDeserializer |
| 47 | +from sagemaker.serve.builder.djl_builder import DJL |
| 48 | +from sagemaker.serve.builder.jumpstart_builder import JumpStart |
45 | 49 | from sagemaker.serve.builder.schema_builder import SchemaBuilder |
46 | | -from sagemaker.serve.builder.tf_serving_builder import TensorflowServing |
47 | | -from sagemaker.serve.mode.function_pointers import Mode |
48 | | -from sagemaker.serve.mode.sagemaker_endpoint_mode import SageMakerEndpointMode |
49 | | -from sagemaker.serve.mode.local_container_mode import LocalContainerMode |
50 | | -from sagemaker.serve.mode.in_process_mode import InProcessMode |
51 | | -from sagemaker.serve.detector.pickler import save_pkl, save_xgboost |
52 | 50 | from sagemaker.serve.builder.serve_settings import _ServeSettings |
53 | | -from sagemaker.serve.builder.djl_builder import DJL |
54 | 51 | from sagemaker.serve.builder.tei_builder import TEI |
| 52 | +from sagemaker.serve.builder.tf_serving_builder import TensorflowServing |
55 | 53 | from sagemaker.serve.builder.tgi_builder import TGI |
56 | | -from sagemaker.serve.builder.jumpstart_builder import JumpStart |
57 | 54 | from sagemaker.serve.builder.transformers_builder import Transformers |
58 | | -from sagemaker.predictor import Predictor |
| 55 | +from sagemaker.serve.detector.image_detector import ( |
| 56 | + _detect_framework_and_version, _get_model_base, auto_detect_container) |
| 57 | +from sagemaker.serve.detector.pickler import save_pkl, save_xgboost |
| 58 | +from sagemaker.serve.mode.function_pointers import Mode |
| 59 | +from sagemaker.serve.mode.in_process_mode import InProcessMode |
| 60 | +from sagemaker.serve.mode.local_container_mode import LocalContainerMode |
| 61 | +from sagemaker.serve.mode.sagemaker_endpoint_mode import SageMakerEndpointMode |
59 | 62 | from sagemaker.serve.model_format.mlflow.constants import ( |
60 | | - MLFLOW_MODEL_PATH, |
61 | | - MLFLOW_TRACKING_ARN, |
62 | | - MLFLOW_RUN_ID_REGEX, |
63 | | - MLFLOW_REGISTRY_PATH_REGEX, |
64 | | - MODEL_PACKAGE_ARN_REGEX, |
65 | | - MLFLOW_METADATA_FILE, |
66 | | - MLFLOW_PIP_DEPENDENCY_FILE, |
67 | | -) |
| 63 | + MLFLOW_METADATA_FILE, MLFLOW_MODEL_PATH, MLFLOW_PIP_DEPENDENCY_FILE, |
| 64 | + MLFLOW_REGISTRY_PATH_REGEX, MLFLOW_RUN_ID_REGEX, MLFLOW_TRACKING_ARN, |
| 65 | + MODEL_PACKAGE_ARN_REGEX) |
68 | 66 | from sagemaker.serve.model_format.mlflow.utils import ( |
69 | | - _get_default_model_server_for_mlflow, |
70 | | - _download_s3_artifacts, |
71 | | - _select_container_for_mlflow_model, |
72 | | - _generate_mlflow_artifact_path, |
73 | | - _get_all_flavor_metadata, |
74 | | - _get_deployment_flavor, |
75 | | - _validate_input_for_mlflow, |
76 | | - _copy_directory_contents, |
77 | | -) |
78 | | -from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata |
| 67 | + _copy_directory_contents, _download_s3_artifacts, |
| 68 | + _generate_mlflow_artifact_path, _get_all_flavor_metadata, |
| 69 | + _get_default_model_server_for_mlflow, _get_deployment_flavor, |
| 70 | + _select_container_for_mlflow_model, _validate_input_for_mlflow) |
| 71 | +from sagemaker.serve.model_server.smd.prepare import prepare_for_smd |
| 72 | +from sagemaker.serve.model_server.torchserve.prepare import \ |
| 73 | + prepare_for_torchserve |
| 74 | +from sagemaker.serve.model_server.triton.triton_builder import Triton |
| 75 | +from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import ( |
| 76 | + Metadata, get_metadata) |
| 77 | +from sagemaker.serve.save_retrive.version_1_0_0.save.save_handler import \ |
| 78 | + SaveHandler |
| 79 | +from sagemaker.serve.spec.inference_base import (AsyncCustomOrchestrator, |
| 80 | + CustomOrchestrator) |
79 | 81 | from sagemaker.serve.spec.inference_spec import InferenceSpec |
80 | | -from sagemaker.serve.spec.inference_base import CustomOrchestrator, AsyncCustomOrchestrator |
81 | 82 | from sagemaker.serve.utils import task |
82 | 83 | from sagemaker.serve.utils.exceptions import TaskNotFoundException |
83 | | -from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model |
84 | | -from sagemaker.serve.utils.optimize_utils import ( |
85 | | - _generate_optimized_model, |
86 | | - _generate_model_source, |
87 | | - _extract_optimization_config_and_env, |
88 | | - _is_s3_uri, |
89 | | - _custom_speculative_decoding, |
90 | | - _extract_speculative_draft_model_provider, |
91 | | - _jumpstart_speculative_decoding, |
92 | | -) |
93 | | -from sagemaker.serve.utils.predictors import ( |
94 | | - _get_local_mode_predictor, |
95 | | - _get_in_process_mode_predictor, |
96 | | -) |
97 | 84 | from sagemaker.serve.utils.hardware_detector import ( |
98 | | - _get_gpu_info, |
99 | | - _get_gpu_info_fallback, |
100 | | - _total_inference_model_size_mib, |
101 | | -) |
102 | | -from sagemaker.serve.detector.image_detector import ( |
103 | | - auto_detect_container, |
104 | | - _detect_framework_and_version, |
105 | | - _get_model_base, |
106 | | -) |
107 | | -from sagemaker.serve.model_server.torchserve.prepare import prepare_for_torchserve |
108 | | -from sagemaker.serve.model_server.smd.prepare import prepare_for_smd |
109 | | -from sagemaker.serve.model_server.triton.triton_builder import Triton |
| 85 | + _get_gpu_info, _get_gpu_info_fallback, _total_inference_model_size_mib) |
| 86 | +from sagemaker.serve.utils.lineage_utils import \ |
| 87 | + _maintain_lineage_tracking_for_mlflow_model |
| 88 | +from sagemaker.serve.utils.optimize_utils import ( |
| 89 | + _custom_speculative_decoding, _extract_optimization_config_and_env, |
| 90 | + _extract_speculative_draft_model_provider, _generate_model_source, |
| 91 | + _generate_optimized_model, _is_s3_uri, _jumpstart_speculative_decoding) |
| 92 | +from sagemaker.serve.utils.predictors import (_get_in_process_mode_predictor, |
| 93 | + _get_local_mode_predictor) |
110 | 94 | from sagemaker.serve.utils.telemetry_logger import _capture_telemetry |
111 | | -from sagemaker.serve.utils.types import ModelServer, ModelHub |
| 95 | +from sagemaker.serve.utils.types import ModelHub, ModelServer |
| 96 | +from sagemaker.serve.validations.check_image_and_hardware_type import \ |
| 97 | + validate_image_uri_and_hardware |
112 | 98 | from sagemaker.serve.validations.check_image_uri import is_1p_image_uri |
113 | | -from sagemaker.serve.save_retrive.version_1_0_0.save.save_handler import SaveHandler |
114 | | -from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import get_metadata |
115 | | -from sagemaker.serve.validations.check_image_and_hardware_type import ( |
116 | | - validate_image_uri_and_hardware, |
117 | | -) |
| 99 | +from sagemaker.serve.validations.optimization import \ |
| 100 | + _validate_optimization_configuration |
118 | 101 | from sagemaker.serverless import ServerlessInferenceConfig |
| 102 | +from sagemaker.transformer import Transformer |
119 | 103 | from sagemaker.utils import Tags, unique_name_from_base |
120 | 104 | from sagemaker.workflow.entities import PipelineVariable |
121 | | -from sagemaker.huggingface.llm_utils import ( |
122 | | - get_huggingface_model_metadata, |
123 | | - download_huggingface_model_metadata, |
124 | | -) |
125 | | -from sagemaker.serve.validations.optimization import _validate_optimization_configuration |
126 | | -from sagemaker.modules.train import ModelTrainer |
127 | | -from sagemaker.modules import logger |
| 105 | +from sagemaker_core.main.resources import TrainingJob |
128 | 106 |
|
129 | 107 | # Any new server type should be added here |
130 | 108 | supported_model_servers = { |
@@ -1176,9 +1154,10 @@ def _get_smd_image_uri(self, processing_unit: str = None) -> str: |
1176 | 1154 | Returns: |
1177 | 1155 | str: SMD Inference Image URI. |
1178 | 1156 | """ |
1179 | | - from sagemaker import image_uris |
1180 | 1157 | import sys |
1181 | 1158 |
|
| 1159 | + from sagemaker import image_uris |
| 1160 | + |
1182 | 1161 | self.sagemaker_session = self.sagemaker_session or Session() |
1183 | 1162 | from packaging.version import Version |
1184 | 1163 |
|
@@ -1469,7 +1448,8 @@ def _hf_schema_builder_init(self, model_task: str): |
1469 | 1448 | sample_inputs, sample_outputs = task.retrieve_local_schemas(model_task) |
1470 | 1449 | except ValueError: |
1471 | 1450 | # samples could not be loaded locally, try to fetch remote hf schema |
1472 | | - from sagemaker_schema_inference_artifacts.huggingface import remote_schema_retriever |
| 1451 | + from sagemaker_schema_inference_artifacts.huggingface import \ |
| 1452 | + remote_schema_retriever |
1473 | 1453 |
|
1474 | 1454 | if model_task in ("text-to-image", "automatic-speech-recognition"): |
1475 | 1455 | logger.warning( |
|
0 commit comments