Skip to content

Commit 05ee084

Browse files
Merge pull request #42 from salesforce/fix_torch
Fix the torch and transformers import issue
2 parents a3b3ded + 933473e commit 05ee084

File tree

5 files changed

+136
-21
lines changed

5 files changed

+136
-21
lines changed

logai/algorithms/anomaly_detection_algo/__init__.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,30 @@
88
from .dbl import DBLDetector
99
from .distribution_divergence import DistributionDivergence
1010
from .ets import ETSDetector
11-
from .forecast_nn import ForecastBasedLSTM, ForecastBasedCNN, ForecastBasedTransformer
1211
from .isolation_forest import IsolationForestDetector
1312
from .local_outlier_factor import LOFDetector
14-
from .logbert import LogBERT
1513
from .one_class_svm import OneClassSVMDetector
14+
from logai.utils.misc import is_torch_available, \
15+
is_transformers_available
1616

17-
18-
__all__ = [
17+
_MODULES = [
1918
"DBLDetector",
2019
"DistributionDivergence",
2120
"ETSDetector",
22-
"ForecastBasedLSTM",
23-
"ForecastBasedCNN",
24-
"ForecastBasedTransformer",
2521
"IsolationForestDetector",
2622
"LOFDetector",
27-
"LogBERT",
28-
"OneClassSVMDetector",
23+
"OneClassSVMDetector"
2924
]
25+
26+
if is_torch_available() and is_transformers_available():
27+
from .forecast_nn import ForecastBasedLSTM, ForecastBasedCNN, ForecastBasedTransformer
28+
from .logbert import LogBERT
29+
30+
_MODULES += [
31+
"LogBERT",
32+
"ForecastBasedLSTM",
33+
"ForecastBasedCNN",
34+
"ForecastBasedTransformer"
35+
]
36+
37+
__all__ = _MODULES

logai/algorithms/anomaly_detection_algo/forecast_nn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,11 @@ def predict(self, test_data: ForecastNNVectorizedDataset):
7676
@factory.register("detection", "lstm", LSTMParams)
7777
class ForecastBasedLSTM(ForcastBasedNeuralAD):
7878
"""Forecasting based lstm model for log anomaly detection.
79+
7980
:param config: A config object containing parameters for LSTM based anomaly detection model.
8081
"""
8182

8283
def __init__(self, config: LSTMParams):
83-
8484
super().__init__(config)
8585
self.config = config
8686
self.model = LSTM(config=self.config)
@@ -89,6 +89,7 @@ def __init__(self, config: LSTMParams):
8989
@factory.register("detection", "cnn", CNNParams)
9090
class ForecastBasedCNN(ForcastBasedNeuralAD):
9191
"""Forecasting based cnn model for log anomaly detection.
92+
9293
:param config: A config object containing parameters for CNN based anomaly detection model.
9394
"""
9495

@@ -101,11 +102,11 @@ def __init__(self, config: CNNParams):
101102
@factory.register("detection", "transformer", TransformerParams)
102103
class ForecastBasedTransformer(ForcastBasedNeuralAD):
103104
"""Forecasting based transformer model for log anomaly detection.
105+
104106
:param config: A config object containing parameters for Transformer based anomaly detection model.
105107
"""
106108

107109
def __init__(self, config: TransformerParams):
108-
109110
super().__init__(config)
110111
self.config = config
111112
self.model = Transformer(config=self.config)

logai/algorithms/factory.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
66
#
77
#
8+
from logai.utils.misc import is_torch_available, \
9+
is_transformers_available
810

911

1012
class AlgorithmFactory:
@@ -18,6 +20,10 @@ class AlgorithmFactory:
1820
"clustering": {},
1921
"vectorization": {},
2022
}
23+
_algorithms_with_torch = {
24+
"lstm", "cnn", "transformer",
25+
"logbert", "forecast_nn"
26+
}
2127

2228
def __new__(cls):
2329
if not hasattr(cls, "instance"):
@@ -58,14 +64,22 @@ def unregister(cls, task, name):
5864
"""
5965
return cls._algorithms[task].pop(name, None)
6066

67+
def _check_algorithm(self, task, name):
68+
if name in self._algorithms_with_torch:
69+
if not is_torch_available() or not is_transformers_available():
70+
raise ImportError("Some deep learning packages are missing. "
71+
"Please install them via `pip install logai[deep-learning]`.")
72+
assert name in self._algorithms[task], \
73+
f"Unknown algorithm {name}, please choose from {self._algorithms[task].keys()}."
74+
6175
def get_config_class(self, task, name):
6276
"""
6377
Gets the corresponding configuration class given an algorithm name.
6478
6579
:param task: The task name.
6680
:param name: The algorithm name.
6781
"""
68-
assert name in self._algorithms[task], f"Unknown algorithm {name}."
82+
self._check_algorithm(task, name)
6983
return self._algorithms[task][name][0]
7084

7185
def get_algorithm_class(self, task, name):
@@ -75,7 +89,7 @@ def get_algorithm_class(self, task, name):
7589
:param task: The task name.
7690
:param name: The algorithm name.
7791
"""
78-
assert name in self._algorithms[task], f"Unknown algorithm {name}."
92+
self._check_algorithm(task, name)
7993
return self._algorithms[task][name][1]
8094

8195
def get_config(self, task, name, config_dict):
@@ -86,7 +100,7 @@ def get_config(self, task, name, config_dict):
86100
:param name: The algorithm name.
87101
:param config_dict: The config dictionary.
88102
"""
89-
assert name in self._algorithms[task], f"Unknown algorithm {name}."
103+
self._check_algorithm(task, name)
90104
return self._algorithms[task][name][0].from_dict(config_dict)
91105

92106
def get_algorithm(self, task, name, config):
@@ -97,7 +111,7 @@ def get_algorithm(self, task, name, config):
97111
:param name: The algorithm name.
98112
:param config: The config instance.
99113
"""
100-
assert name in self._algorithms[task], f"Unknown algorithm {name}."
114+
self._check_algorithm(task, name)
101115
config_class, algorithm_class = self._algorithms[task][name]
102116
if config and config.algo_params:
103117
assert isinstance(

logai/algorithms/vectorization_algo/__init__.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,28 @@
66
#
77
#
88
from .fasttext import FastText
9-
from .forecast_nn import ForecastNN
10-
from .logbert import LogBERT
119
from .semantic import Semantic
1210
from .sequential import Sequential
1311
from .tfidf import TfIdf
1412
from .word2vec import Word2Vec
13+
from logai.utils.misc import is_torch_available, \
14+
is_transformers_available
1515

16-
__all__ = [
16+
_MODULES = [
1717
"FastText",
18-
"ForecastNN",
19-
"LogBERT",
2018
"Semantic",
2119
"Sequential",
2220
"TfIdf",
23-
"Word2Vec",
21+
"Word2Vec"
2422
]
23+
24+
if is_torch_available() and is_transformers_available():
25+
from .forecast_nn import ForecastNN
26+
from .logbert import LogBERT
27+
28+
_MODULES += [
29+
"ForecastNN",
30+
"LogBERT"
31+
]
32+
33+
__all__ = _MODULES

logai/utils/misc.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#
2+
# Copyright (c) 2022 salesforce.com, inc.
3+
# All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6+
#
7+
import sys
8+
import importlib.util
9+
from packaging import version
10+
11+
if sys.version_info < (3, 8):
12+
import importlib_metadata
13+
else:
14+
import importlib.metadata as importlib_metadata
15+
16+
get_pkg_version = importlib_metadata.version
17+
18+
19+
def is_torch_available():
20+
"""
21+
Checks if pytorch is available.
22+
"""
23+
if importlib.util.find_spec("torch") is not None:
24+
_torch_version = importlib_metadata.version("torch")
25+
if version.parse(_torch_version) < version.parse("1.6"):
26+
raise EnvironmentError(f"Torch found but with version {_torch_version}. " f"The minimum version is 1.6")
27+
return True
28+
else:
29+
return False
30+
31+
32+
def is_tf_available():
33+
"""
34+
Checks if tensorflow 2.0 is available.
35+
"""
36+
candidates = (
37+
"tensorflow",
38+
"tensorflow-cpu",
39+
"tensorflow-gpu",
40+
"tf-nightly",
41+
"tf-nightly-cpu",
42+
"tf-nightly-gpu",
43+
"intel-tensorflow",
44+
"intel-tensorflow-avx512",
45+
"tensorflow-rocm",
46+
"tensorflow-macos",
47+
)
48+
_tf_version = None
49+
for pkg in candidates:
50+
try:
51+
_tf_version = importlib_metadata.version(pkg)
52+
break
53+
except importlib_metadata.PackageNotFoundError:
54+
pass
55+
if _tf_version is not None:
56+
if version.parse(_tf_version) < version.parse("2"):
57+
raise EnvironmentError(f"Tensorflow found but with version {_tf_version}. " f"The minimum version is 2.0")
58+
return True
59+
else:
60+
return False
61+
62+
63+
def is_transformers_available():
64+
"""
65+
Checks if the `transformers` library is installed.
66+
"""
67+
if importlib.util.find_spec("transformers") is not None:
68+
_version = importlib_metadata.version("transformers")
69+
if version.parse(_version) < version.parse("4.0"):
70+
raise EnvironmentError(f"Transformers found but with version {_version}. " f"The minimum version is 4.0")
71+
return True
72+
else:
73+
return False
74+
75+
76+
def is_nltk_available():
77+
"""
78+
Checks if the `nltk` library is installed.
79+
"""
80+
if importlib.util.find_spec("nltk") is not None:
81+
return True
82+
else:
83+
return False

0 commit comments

Comments
 (0)