LSST Applications  22.0.1,22.0.1+01bcf6a671,22.0.1+046ee49490,22.0.1+05c7de27da,22.0.1+0c6914dbf6,22.0.1+1220d50b50,22.0.1+12fd109e95,22.0.1+1a1dd69893,22.0.1+1c910dc348,22.0.1+1ef34551f5,22.0.1+30170c3d08,22.0.1+39153823fd,22.0.1+611137eacc,22.0.1+771eb1e3e8,22.0.1+94e66cc9ed,22.0.1+9a075d06e2,22.0.1+a5ff6e246e,22.0.1+a7db719c1a,22.0.1+ba0d97e778,22.0.1+bfe1ee9056,22.0.1+c4e1e0358a,22.0.1+cc34b8281e,22.0.1+d640e2c0fa,22.0.1+d72a2e677a,22.0.1+d9a6b571bd,22.0.1+e485e9761b,22.0.1+ebe8d3385e
LSST Data Management Base Package
pipeTools.py
Go to the documentation of this file.
1 # This file is part of pipe_base.
2 #
3 # Developed for the LSST Data Management System.
4 # This product includes software developed by the LSST Project
5 # (http://www.lsst.org).
6 # See the COPYRIGHT file at the top-level directory of this distribution
7 # for details of code ownership.
8 #
9 # This program is free software: you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation, either version 3 of the License, or
12 # (at your option) any later version.
13 #
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 # GNU General Public License for more details.
18 #
19 # You should have received a copy of the GNU General Public License
20 # along with this program. If not, see <http://www.gnu.org/licenses/>.
21 
22 """Module defining few methods to manipulate or query pipelines.
23 """
24 
25 # No one should do import * from this module
26 __all__ = ["isPipelineOrdered", "orderPipeline"]
27 
28 # -------------------------------
29 # Imports of standard modules --
30 # -------------------------------
31 import itertools
32 
33 # -----------------------------
34 # Imports for other modules --
35 # -----------------------------
36 from .connections import iterConnections
37 
38 # ----------------------------------
39 # Local non-exported definitions --
40 # ----------------------------------
41 
42 
43 def _loadTaskClass(taskDef, taskFactory):
44  """Import task class if necessary.
45 
46  Raises
47  ------
48  `ImportError` is raised when task class cannot be imported.
49  `MissingTaskFactoryError` is raised when TaskFactory is needed but not
50  provided.
51  """
52  taskClass = taskDef.taskClass
53  if not taskClass:
54  if not taskFactory:
55  raise MissingTaskFactoryError("Task class is not defined but task "
56  "factory instance is not provided")
57  taskClass = taskFactory.loadTaskClass(taskDef.taskName)
58  return taskClass
59 
60 # ------------------------
61 # Exported definitions --
62 # ------------------------
63 
64 
65 class MissingTaskFactoryError(Exception):
66  """Exception raised when client fails to provide TaskFactory instance.
67  """
68  pass
69 
70 
71 class DuplicateOutputError(Exception):
72  """Exception raised when Pipeline has more than one task for the same
73  output.
74  """
75  pass
76 
77 
78 class PipelineDataCycleError(Exception):
79  """Exception raised when Pipeline has data dependency cycle.
80  """
81  pass
82 
83 
84 def isPipelineOrdered(pipeline, taskFactory=None):
85  """Checks whether tasks in pipeline are correctly ordered.
86 
87  Pipeline is correctly ordered if for any DatasetType produced by a task
88  in a pipeline all its consumer tasks are located after producer.
89 
90  Parameters
91  ----------
92  pipeline : `pipe.base.Pipeline`
93  Pipeline description.
94  taskFactory: `pipe.base.TaskFactory`, optional
95  Instance of an object which knows how to import task classes. It is
96  only used if pipeline task definitions do not define task classes.
97 
98  Returns
99  -------
100  True for correctly ordered pipeline, False otherwise.
101 
102  Raises
103  ------
104  `ImportError` is raised when task class cannot be imported.
105  `DuplicateOutputError` is raised when there is more than one producer for a
106  dataset type.
107  `MissingTaskFactoryError` is raised when TaskFactory is needed but not
108  provided.
109  """
110  # Build a map of DatasetType name to producer's index in a pipeline
111  producerIndex = {}
112  for idx, taskDef in enumerate(pipeline):
113 
114  for attr in iterConnections(taskDef.connections, 'outputs'):
115  if attr.name in producerIndex:
116  raise DuplicateOutputError("DatasetType `{}' appears more than "
117  "once as output".format(attr.name))
118  producerIndex[attr.name] = idx
119 
120  # check all inputs that are also someone's outputs
121  for idx, taskDef in enumerate(pipeline):
122 
123  # get task input DatasetTypes, this can only be done via class method
124  inputs = {name: getattr(taskDef.connections, name) for name in taskDef.connections.inputs}
125  for dsTypeDescr in inputs.values():
126  # all pre-existing datasets have effective index -1
127  prodIdx = producerIndex.get(dsTypeDescr.name, -1)
128  if prodIdx >= idx:
129  # not good, producer is downstream
130  return False
131 
132  return True
133 
134 
135 def orderPipeline(pipeline):
136  """Re-order tasks in pipeline to satisfy data dependencies.
137 
138  When possible new ordering keeps original relative order of the tasks.
139 
140  Parameters
141  ----------
142  pipeline : `list` of `pipe.base.TaskDef`
143  Pipeline description.
144 
145  Returns
146  -------
147  Correctly ordered pipeline (`list` of `pipe.base.TaskDef` objects).
148 
149  Raises
150  ------
151  `DuplicateOutputError` is raised when there is more than one producer for a
152  dataset type.
153  `PipelineDataCycleError` is also raised when pipeline has dependency
154  cycles. `MissingTaskFactoryError` is raised when TaskFactory is needed but
155  not provided.
156  """
157 
158  # This is a modified version of Kahn's algorithm that preserves order
159 
160  # build mapping of the tasks to their inputs and outputs
161  inputs = {} # maps task index to its input DatasetType names
162  outputs = {} # maps task index to its output DatasetType names
163  allInputs = set() # all inputs of all tasks
164  allOutputs = set() # all outputs of all tasks
165  for idx, taskDef in enumerate(pipeline):
166  # task outputs
167  dsMap = {name: getattr(taskDef.connections, name) for name in taskDef.connections.outputs}
168  for dsTypeDescr in dsMap.values():
169  if dsTypeDescr.name in allOutputs:
170  raise DuplicateOutputError("DatasetType `{}' appears more than "
171  "once as output".format(dsTypeDescr.name))
172  outputs[idx] = set(dsTypeDescr.name for dsTypeDescr in dsMap.values())
173  allOutputs.update(outputs[idx])
174 
175  # task inputs
176  connectionInputs = itertools.chain(taskDef.connections.inputs, taskDef.connections.prerequisiteInputs)
177  dsMap = [getattr(taskDef.connections, name).name for name in connectionInputs]
178  inputs[idx] = set(dsMap)
179  allInputs.update(inputs[idx])
180 
181  # for simplicity add pseudo-node which is a producer for all pre-existing
182  # inputs, its index is -1
183  preExisting = allInputs - allOutputs
184  outputs[-1] = preExisting
185 
186  # Set of nodes with no incoming edges, initially set to pseudo-node
187  queue = [-1]
188  result = []
189  while queue:
190 
191  # move to final list, drop -1
192  idx = queue.pop(0)
193  if idx >= 0:
194  result.append(idx)
195 
196  # remove task outputs from other tasks inputs
197  thisTaskOutputs = outputs.get(idx, set())
198  for taskInputs in inputs.values():
199  taskInputs -= thisTaskOutputs
200 
201  # find all nodes with no incoming edges and move them to the queue
202  topNodes = [key for key, value in inputs.items() if not value]
203  queue += topNodes
204  for key in topNodes:
205  del inputs[key]
206 
207  # keep queue ordered
208  queue.sort()
209 
210  # if there is something left it means cycles
211  if inputs:
212  # format it in usable way
213  loops = []
214  for idx, inputNames in inputs.items():
215  taskName = pipeline[idx].label
216  outputNames = outputs[idx]
217  edge = " {} -> {} -> {}".format(inputNames, taskName, outputNames)
218  loops.append(edge)
219  raise PipelineDataCycleError("Pipeline has data cycles:\n" + "\n".join(loops))
220 
221  return [pipeline[idx] for idx in result]
daf::base::PropertySet * set
Definition: fits.cc:912
def format(config, name=None, writeSourceLine=True, prefix="", verbose=False)
Definition: history.py:174
typing.Generator[BaseConnection, None, None] iterConnections(PipelineTaskConnections connections, Union[str, Iterable[str]] connectionType)
Definition: connections.py:503
def isPipelineOrdered(pipeline, taskFactory=None)
Definition: pipeTools.py:84
def orderPipeline(pipeline)
Definition: pipeTools.py:135