Skip to content

Commit 53ce8d0

Browse files
committed
Refactor OS plugin to incorporate OCI OS UploadManager for enhanced functionality.
1 parent b0a33a5 commit 53ce8d0

File tree

4 files changed

+183
-54
lines changed

4 files changed

+183
-54
lines changed

container-image/environment.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ dependencies:
33
- main::pip
44
- pip:
55
- oracledb
6-
- mlflow==2.3.2
6+
- mlflow
77
- oracle-ads>=2.8.5
88
- mysql-connector-python
99
- oci-mlflow

docs/source/project.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ The ``{work_dir}`` can be used to point out that the YAML template located insid
5050
Data Science Job Template
5151
=========================
5252

53-
The template file contains the information about the infrastructure on which a Data Science job should be run, and also the runtime information. More details can be found in the `ADS documentation <https://accelerated-data-science.readthedocs.io/en/latest/user_guide/jobs/data_science_job.html>`__. The template file is divided into two main sections: ``infrastructure`` and ``runtime``.
53+
The template file contains the information about the infrastructure on which a Data Science job should be run, and also the runtime information. More details can be found in the `ADS documentation <https://accelerated-data-science.readthedocs.io/en/latest/user_guide/jobs/data_science_job.html>`__. The template file is divided into two main sections: ``infrastructure`` and ``runtime``. The template also can be generated using ``ads opctl init`` command. More details can be found in the `ADS documentation <https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/opctl/configure.html#generate-starter-yaml>`__.
5454

5555
Data Science Job Infrastructure
5656
###############################
@@ -507,7 +507,7 @@ This example demonstrates an MLflow project that trains a logistic regression mo
507507

508508
Copy the ``oci-datascience-config.json`` file to the ``pyspark_ml_autologging`` folder.
509509

510-
- Prepare a ``oci-datascience-template.yaml`` job configuration file.
510+
- Prepare a ``oci-datascience-template.yaml`` job configuration file. The template can be generated using ``ads opctl init`` command. More details can be found in the `ADS documentation <https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/opctl/configure.html#generate-starter-yaml>`__.
511511

512512
.. tabs::
513513

oci_mlflow/oci_object_storage.py

Lines changed: 141 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,40 +5,101 @@
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
import os
8+
from typing import List
89
from urllib.parse import urlparse
910

1011
import fsspec
12+
from ads.common.auth import AuthType, default_signer, set_auth
13+
from ads.common.oci_client import OCIClientFactory
1114
from mlflow.entities import FileInfo
1215
from mlflow.store.artifact.artifact_repo import ArtifactRepository
1316
from mlflow.utils.file_utils import relative_path_to_artifact_path
17+
from oci import object_storage
1418
from ocifs import OCIFileSystem
1519

1620
from oci_mlflow import logger
1721

18-
OCI_PREFIX = "oci://"
22+
OCI_SCHEME = "oci"
23+
OCI_PREFIX = f"{OCI_SCHEME}://"
1924

2025

21-
class OCIObjectStorageArtifactRepository(ArtifactRepository):
26+
def parse_os_uri(uri: str):
27+
"""
28+
Parse an OCI object storage URI, returning tuple (bucket, namespace, path).
29+
30+
Parameters
31+
----------
32+
uri: str
33+
The OCI Object Storage URI.
34+
35+
Returns
36+
-------
37+
Tuple
38+
The (bucket, ns, type)
39+
40+
Raise
41+
-----
42+
Exception
43+
If provided URI is not an OCI OS bucket URI.
2244
"""
23-
MLFlow Plugin implementation for storing artifacts to OCI Object Storage
45+
parsed = urlparse(uri)
46+
if parsed.scheme.lower() != OCI_SCHEME:
47+
raise Exception("Not an OCI object storage URI: %s" % uri)
48+
path = parsed.path
49+
50+
if path.startswith("/"):
51+
path = path[1:]
52+
53+
bucket, ns = parsed.netloc.split("@")
54+
55+
return bucket, ns, path
56+
57+
58+
class ArtifactUploader:
2459
"""
60+
The class helper to upload model artifacts.
61+
62+
Attributes
63+
----------
64+
upload_manager: UploadManager
65+
The uploadManager simplifies interaction with the Object Storage service.
66+
"""
67+
68+
def __init__(self):
69+
"""Initializes `ArtifactUploader` instance."""
70+
auth_type = os.environ.get(
71+
"OCI_IAM_TYPE", os.environ.get("OCIFS_IAM_TYPE", AuthType.API_KEY)
72+
)
73+
logger.debug(f"Using auth {auth_type=}")
74+
set_auth(auth_type)
75+
76+
self.upload_manager = object_storage.UploadManager(
77+
OCIClientFactory(**default_signer()).object_storage
78+
)
79+
80+
def upload(self, file_path: str, dst_path: str):
81+
"""Uploads model artifacts.
82+
83+
Parameters
84+
----------
85+
file_path: str
86+
The source file path.
87+
dst_path: str
88+
The destination path.
89+
"""
90+
bucket_name, namespace_name, object_name = parse_os_uri(dst_path)
91+
logger.debug(f"{bucket_name=}, {namespace_name=}, {object_name=}")
92+
response = self.upload_manager.upload_file(
93+
namespace_name=namespace_name,
94+
bucket_name=bucket_name,
95+
object_name=object_name,
96+
file_path=file_path,
97+
)
98+
logger.debug(response)
2599

26-
@staticmethod
27-
def parse_os_uri(uri):
28-
"""Parse an OCI object storage URI, returning (bucket, namespace, path)"""
29-
parsed = urlparse(uri)
30-
if parsed.scheme != "oci":
31-
raise Exception("Not an OCI object storage URI: %s" % uri)
32-
path = parsed.path
33-
if path.startswith("/"):
34-
path = path[1:]
35-
bucket, ns = parsed.netloc.split("@")
36-
return bucket, ns, path
37-
38-
def _upload_file(self, local_file, dest_path):
39-
with open(local_file, "rb") as data:
40-
with fsspec.open(dest_path, "wb") as outfile:
41-
outfile.write(data.read())
100+
101+
class OCIObjectStorageArtifactRepository(ArtifactRepository):
102+
"""MLFlow Plugin implementation for storing artifacts to OCI Object Storage."""
42103

43104
def _download_file(self, remote_file_path, local_path):
44105
if not remote_file_path.startswith(self.artifact_uri):
@@ -49,43 +110,84 @@ def _download_file(self, remote_file_path, local_path):
49110
logger.info(f"{full_path}, {remote_file_path}")
50111
fs.download(full_path, local_path)
51112

52-
def log_artifact(self, local_file, artifact_path=None):
113+
def log_artifact(self, local_file: str, artifact_path: str = None):
114+
"""
115+
Logs a local file as an artifact, optionally taking an ``artifact_path`` to place it in
116+
within the run's artifacts. Run artifacts can be organized into directories, so you can
117+
place the artifact in a directory this way.
118+
119+
Parameters
120+
----------
121+
local_file:str
122+
Path to artifact to log.
123+
artifact_path:str
124+
Directory within the run's artifact directory in which to log the artifact.
125+
"""
53126
if artifact_path:
54127
dest_path = os.path.join(self.artifact_uri, artifact_path)
55128
else:
56129
dest_path = self.artifact_uri
57130
dest_path = os.path.join(dest_path, os.path.basename(local_file))
58-
self._upload_file(local_file, dest_path)
59131

60-
def log_artifacts(self, local_dir, artifact_path=None):
132+
ArtifactUploader().upload(local_file, dest_path)
133+
134+
def log_artifacts(self, local_dir: str, artifact_path: str = None):
135+
"""
136+
Logs the files in the specified local directory as artifacts, optionally taking
137+
an ``artifact_path`` to place them in within the run's artifacts.
138+
139+
Parameters
140+
----------
141+
local_dir:str
142+
Directory of local artifacts to log.
143+
artifact_path:str
144+
Directory within the run's artifact directory in which to log the artifacts.
145+
"""
146+
artifact_uploader = ArtifactUploader()
147+
61148
if artifact_path:
62149
dest_path = os.path.join(self.artifact_uri, artifact_path)
63150
else:
64151
dest_path = artifact_path
65152
local_dir = os.path.abspath(local_dir)
66-
for (root, _, filenames) in os.walk(local_dir):
153+
154+
for root, _, filenames in os.walk(local_dir):
67155
upload_path = dest_path
68156
if root != local_dir:
69157
rel_path = os.path.relpath(root, local_dir)
70158
rel_path = relative_path_to_artifact_path(rel_path)
71159
upload_path = os.path.join(dest_path, rel_path)
72160
for f in filenames:
73-
self._upload_file(
74-
local_file=os.path.join(root, f),
75-
dest_path=os.path.join(upload_path, f),
161+
artifact_uploader.upload(
162+
file_path=os.path.join(root, f),
163+
dst_path=os.path.join(upload_path, f),
76164
)
77165

78166
def get_fs(self):
79167
"""
80-
Get fssepc filesystem based on the uri scheme
168+
Gets fssepc filesystem based on the uri scheme.
81169
"""
82170
self.fs = fsspec.filesystem(
83171
urlparse(self.artifact_uri).scheme
84172
) # FileSystem class corresponding to the URI scheme.
85173

86174
return self.fs
87175

88-
def list_artifacts(self, path: str = ""):
176+
def list_artifacts(self, path: str = "") -> List[FileInfo]:
177+
"""
178+
Return all the artifacts for this run_id directly under path. If path is a file, returns
179+
an empty list. Will error if path is neither a file nor directory.
180+
181+
Parameters
182+
----------
183+
path:str
184+
Relative source path that contains desired artifacts
185+
186+
Returns
187+
-------
188+
List[FileInfo]
189+
List of artifacts as FileInfo listed directly under path.
190+
"""
89191
result = []
90192
dest_path = self.artifact_uri
91193
if path:
@@ -111,7 +213,17 @@ def list_artifacts(self, path: str = ""):
111213
result.sort(key=lambda f: f.path)
112214
return result
113215

114-
def delete_artifacts(self, artifact_path=None):
216+
def delete_artifacts(self, artifact_path: str = None):
217+
"""
218+
Delete the artifacts at the specified location.
219+
Supports the deletion of a single file or of a directory. Deletion of a directory
220+
is recursive.
221+
222+
Parameters
223+
----------
224+
artifact_path: str
225+
Path of the artifact to delete.
226+
"""
115227
dest_path = self.artifact_uri
116228
if artifact_path:
117229
dest_path = os.path.join(self.artifact_uri, artifact_path)

tests/plugins/unitary/test_oci_object_storage.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
33

4+
import os
5+
import tempfile
6+
from unittest.mock import MagicMock, Mock, patch
7+
48
# Copyright (c) 2023 Oracle and/or its affiliates.
59
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
610
import pytest
7-
from unittest.mock import MagicMock, Mock, patch
8-
import tempfile
9-
import os
10-
1111
from mlflow.entities import FileInfo
1212

1313
from oci_mlflow import oci_object_storage
14-
from oci_mlflow.oci_object_storage import OCIObjectStorageArtifactRepository
14+
from oci_mlflow.oci_object_storage import (
15+
ArtifactUploader,
16+
OCIObjectStorageArtifactRepository,
17+
)
18+
from oci import object_storage
1519

1620

1721
class DataObject:
@@ -41,7 +45,7 @@ def mock_fsspec_open(self):
4145
yield mock_open
4246

4347
def test_parse_os_uri(self, oci_artifact_repo):
44-
bucket, namespace, path = oci_artifact_repo.parse_os_uri(
48+
bucket, namespace, path = oci_object_storage.parse_os_uri(
4549
"oci://my-bucket@my-namespace/my-artifact-path"
4650
)
4751
assert bucket == "my-bucket"
@@ -50,20 +54,7 @@ def test_parse_os_uri(self, oci_artifact_repo):
5054

5155
def test_parse_os_uri_with_invalid_scheme(self, oci_artifact_repo):
5256
with pytest.raises(Exception):
53-
oci_artifact_repo.parse_os_uri("s3://my-bucket/my-artifact-path")
54-
55-
def test_upload_file(self, mock_fsspec_open):
56-
local_file = os.path.join(self.curr_dir, "test_files/test.txt")
57-
dest_path = "oci://my-bucket@my-namespace/path/to/test.txt"
58-
repository = OCIObjectStorageArtifactRepository(artifact_uri=dest_path)
59-
mock_outfile = Mock()
60-
mock_fsspec_open.return_value.__enter__.return_value = mock_outfile
61-
62-
repository._upload_file(local_file, dest_path)
63-
64-
mock_fsspec_open.assert_called_once_with(dest_path, "wb")
65-
with open(local_file, "rb") as f:
66-
mock_outfile.write.assert_called_once_with(f.read())
57+
oci_object_storage.parse_os_uri("s3://my-bucket/my-artifact-path")
6758

6859
def test_download_file(self, oci_artifact_repo):
6960
mock_fs = MagicMock()
@@ -82,7 +73,7 @@ def test_download_file(self, oci_artifact_repo):
8273
local_path,
8374
)
8475

85-
@patch.object(OCIObjectStorageArtifactRepository, "_upload_file")
76+
@patch.object(ArtifactUploader, "upload")
8677
def test_log_artifact(self, mock_upload_file, oci_artifact_repo):
8778
local_file = "test_files/test.txt"
8879
artifact_path = "logs"
@@ -92,7 +83,7 @@ def test_log_artifact(self, mock_upload_file, oci_artifact_repo):
9283
)
9384
mock_upload_file.assert_called_once_with(local_file, expected_dest_path)
9485

95-
@patch.object(OCIObjectStorageArtifactRepository, "_upload_file")
86+
@patch.object(ArtifactUploader, "upload")
9687
def test_log_artifacts(self, mock_upload_file, oci_artifact_repo):
9788
local_dir = os.path.join(self.curr_dir, "test_files")
9889
dest_path = "path/to/dest"
@@ -128,3 +119,29 @@ def test_list_artifacts(self):
128119
FileInfo("sub_folder", True, 0),
129120
]
130121
assert artifacts == expected_artifacts
122+
123+
124+
class TestArtifactUploader:
125+
def test_init(self):
126+
"""Ensures the ArtifactUploader instance can be initialized."""
127+
artifact_uploader = ArtifactUploader()
128+
assert isinstance(
129+
artifact_uploader.upload_manager, object_storage.UploadManager
130+
)
131+
132+
@patch.object(object_storage.UploadManager, "upload_file")
133+
def test_upload(self, mock_upload_file):
134+
"""Tests uploading model artifacts."""
135+
artifact_uploader = ArtifactUploader()
136+
137+
local_file = "test_files/test.txt"
138+
dest_path = "oci://my-bucket@my-namespace/my-artifact-path/logs/test.txt"
139+
artifact_uploader.upload(local_file, dest_path)
140+
141+
mock_upload_file.assert_called_with(
142+
namespace_name="my-namespace",
143+
bucket_name="my-bucket",
144+
object_name="my-artifact-path/logs/test.txt",
145+
file_path=local_file,
146+
)
147+

0 commit comments

Comments
 (0)