diff --git a/lightflow/models/dag.py b/lightflow/models/dag.py index 26220d4..defed38 100644 --- a/lightflow/models/dag.py +++ b/lightflow/models/dag.py @@ -36,11 +36,13 @@ class Dag: schema (dict): A dictionary with the definition of the task graph. """ def __init__(self, name, *, autostart=True, queue=DefaultJobQueueName.Dag, - schema=None): + schema=None, + graph=None): self._name = name self._autostart = autostart self._queue = queue self._schema = schema + self._graph = graph self._copy_counter = 0 self._workflow_name = None @@ -114,7 +116,10 @@ def run(self, config, workflow_id, signal, *, data=None): DirectedAcyclicGraphInvalid: If the graph is not a dag (e.g. contains loops). ConfigNotDefinedError: If the configuration for the dag is empty. """ - graph = self.make_graph(self._schema) + if not self._graph: + graph = self.make_graph(self._schema) + else: + graph = self._graph # pre-checks self.validate(graph) @@ -252,8 +257,7 @@ def validate(self, graph): if not nx.is_directed_acyclic_graph(graph): raise DirectedAcyclicGraphInvalid(graph_name=self._name) - @staticmethod - def make_graph(schema): + def make_graph(self, schema): """ Construct the task graph (dag) from a given schema. Parses the graph schema definition and creates the task graph. Tasks are the @@ -315,7 +319,8 @@ def make_graph(schema): else: graph.add_node(parent) - return graph + self._graph = graph + return self._graph def __deepcopy__(self, memo): """ Create a copy of the dag object. diff --git a/requirements-dev.txt b/requirements-dev.txt index f604537..f88f319 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,3 @@ -pytest +pytest>=3.6 pytest-cov flake8