21 from __future__
import annotations
24 __all__ = (
"QuantumGraph",
"IncompatibleGraphError")
26 from collections
import defaultdict, deque
28 from itertools
import chain, count
31 from networkx.drawing.nx_agraph
import write_dot
38 from typing
import (DefaultDict, Dict, FrozenSet, Iterable, List, Mapping, Set, Generator, Optional, Tuple,
41 from ..connections
import iterConnections
42 from ..pipeline
import TaskDef
43 from lsst.daf.butler
import Quantum, DatasetRef, ButlerURI, DimensionUniverse
45 from ._implDetails
import _DatasetTracker, DatasetTypeName
46 from .quantumNode
import QuantumNode, NodeId, BuildId
47 from ._loadHelpers
import LoadHelper
50 _T = TypeVar(
"_T", bound=
"QuantumGraph")
59 STRUCT_FMT_STRING =
'>HQQ'
63 MAGIC_BYTES = b
"qgraph4\xf6\xe8\xa9"
67 """Exception class to indicate that a lookup by NodeId is impossible due
74 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects
76 This data structure represents a concrete workflow generated from a
81 quanta : Mapping of `TaskDef` to sets of `Quantum`
82 This maps tasks (and their configs) to the sets of data they are to
85 def __init__(self, quanta: Mapping[TaskDef, Set[Quantum]]):
88 def _buildGraphs(self,
89 quanta: Mapping[TaskDef, Set[Quantum]],
91 _quantumToNodeId: Optional[Mapping[Quantum, NodeId]] =
None,
92 _buildId: Optional[BuildId] =
None):
93 """Builds the graph that is used to store the relation between tasks,
94 and the graph that holds the relations between quanta
97 self.
_buildId_buildId = _buildId
if _buildId
is not None else BuildId(f
"{time.time()}-{os.getpid()}")
101 self.
_datasetDict_datasetDict = _DatasetTracker[DatasetTypeName, TaskDef]()
102 self.
_datasetRefDict_datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]()
104 nodeNumberGenerator = count()
105 self._nodeIdMap: Dict[NodeId, QuantumNode] = {}
106 self._taskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set)
109 connections = taskDef.connections
114 for inpt
in iterConnections(connections, (
"inputs",
"prerequisiteInputs",
"initInputs")):
117 for output
in iterConnections(connections, (
"outputs",
"initOutputs")):
125 self.
_count_count += len(quantumSet)
126 for quantum
in quantumSet:
128 nodeId = _quantumToNodeId.get(quantum)
130 raise ValueError(
"If _quantuMToNodeNumber is not None, all quanta must have an "
131 "associated value in the mapping")
135 inits = quantum.initInputs.values()
136 inputs = quantum.inputs.values()
138 self._taskToQuantumNode[taskDef].add(value)
139 self._nodeIdMap[nodeId] = value
141 for dsRef
in chain(inits, inputs):
145 if isinstance(dsRef, Iterable):
150 for dsRef
in chain.from_iterable(quantum.outputs.values()):
161 """Return a graph representing the relations between the tasks inside
166 taskGraph : `networkx.Digraph`
167 Internal datastructure that holds relations of `TaskDef` objects
173 """Return a graph representing the relations between all the
174 `QuantumNode` objects. Largely it should be preferred to iterate
175 over, and use methods of this class, but sometimes direct access to
176 the networkx object may be helpful
180 graph : `networkx.Digraph`
181 Internal datastructure that holds relations of `QuantumNode`
188 """Make a `list` of all `QuantumNode` objects that are 'input' nodes
189 to the graph, meaning those nodes to not depend on any other nodes in
194 inputNodes : iterable of `QuantumNode`
195 A list of nodes that are inputs to the graph
197 return (q
for q, n
in self.
_connectedQuanta_connectedQuanta.in_degree
if n == 0)
201 """Make a `list` of all `QuantumNode` objects that are 'output' nodes
202 to the graph, meaning those nodes have no nodes that depend them in
207 outputNodes : iterable of `QuantumNode`
208 A list of nodes that are outputs of the graph
210 return [q
for q, n
in self.
_connectedQuanta_connectedQuanta.out_degree
if n == 0]
214 """Return all the `DatasetTypeName` objects that are contained inside
219 tuple of `DatasetTypeName`
220 All the data set type names that are present in the graph
226 """Return True if all of the nodes in the graph are connected, ignores
227 directionality of connections.
232 """Lookup a `QuantumNode` from an id associated with the node.
237 The number associated with a node
242 The node corresponding with input number
247 Raised if the requested nodeId is not in the graph.
248 IncompatibleGraphError
249 Raised if the nodeId was built with a different graph than is not
250 this instance (or a graph instance that produced this instance
251 through and operation such as subset)
253 if nodeId.buildId != self.
_buildId_buildId:
255 return self._nodeIdMap[nodeId]
258 """Return all the `Quantum` associated with a `TaskDef`.
263 The `TaskDef` for which `Quantum` are to be queried
267 frozenset of `Quantum`
268 The `set` of `Quantum` that is associated with the specified
271 return frozenset(self.
_quanta_quanta[taskDef])
274 """Return all the `QuantumNodes` associated with a `TaskDef`.
279 The `TaskDef` for which `Quantum` are to be queried
283 frozenset of `QuantumNodes`
284 The `frozenset` of `QuantumNodes` that is associated with the
287 return frozenset(self._taskToQuantumNode[taskDef])
290 """Find all tasks that have the specified dataset type name as an
295 datasetTypeName : `str`
296 A string representing the name of a dataset type to be queried,
297 can also accept a `DatasetTypeName` which is a `NewType` of str for
298 type safety in static type checking.
302 tasks : iterable of `TaskDef`
303 `TaskDef` objects that have the specified `DatasetTypeName` as an
304 input, list will be empty if no tasks use specified
305 `DatasetTypeName` as an input.
310 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
312 return (c
for c
in self.
_datasetDict_datasetDict.getInputs(datasetTypeName))
315 """Find all tasks that have the specified dataset type name as an
320 datasetTypeName : `str`
321 A string representing the name of a dataset type to be queried,
322 can also accept a `DatasetTypeName` which is a `NewType` of str for
323 type safety in static type checking.
328 `TaskDef` that outputs `DatasetTypeName` as an output or None if
329 none of the tasks produce this `DatasetTypeName`.
334 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
336 return self.
_datasetDict_datasetDict.getOutput(datasetTypeName)
339 """Find all tasks that are associated with the specified dataset type
344 datasetTypeName : `str`
345 A string representing the name of a dataset type to be queried,
346 can also accept a `DatasetTypeName` which is a `NewType` of str for
347 type safety in static type checking.
351 result : iterable of `TaskDef`
352 `TaskDef` objects that are associated with the specified
358 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
362 if output
is not None:
363 results = chain(results, (output,))
367 """Determine which `TaskDef` objects in this graph are associated
368 with a `str` representing a task name (looks at the taskName property
369 of `TaskDef` objects).
371 Returns a list of `TaskDef` objects as a `PipelineTask` may appear
372 multiple times in a graph with different labels.
377 Name of a task to search for
381 result : list of `TaskDef`
382 List of the `TaskDef` objects that have the name specified.
383 Multiple values are returned in the case that a task is used
384 multiple times with different labels.
388 split = task.taskName.split(
'.')
389 if split[-1] == taskName:
394 """Determine which `TaskDef` objects in this graph are associated
395 with a `str` representing a tasks label.
400 Name of a task to search for
405 `TaskDef` objects that has the specified label.
408 if label == task.label:
413 """Return all the `Quantum` that contain a specified `DatasetTypeName`.
417 datasetTypeName : `str`
418 The name of the dataset type to search for as a string,
419 can also accept a `DatasetTypeName` which is a `NewType` of str for
420 type safety in static type checking.
424 result : `set` of `QuantumNode` objects
425 A `set` of `QuantumNode`s that contain specified `DatasetTypeName`
430 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
433 tasks = self.
_datasetDict_datasetDict.getAll(datasetTypeName)
434 result: Set[Quantum] =
set()
435 result = result.union(*(self.
_quanta_quanta[task]
for task
in tasks))
439 """Check if specified quantum appears in the graph as part of a node.
444 The quantum to search for
449 The result of searching for the quantum
451 for qset
in self.
_quanta_quanta.values():
457 """Write out the graph as a dot graph.
461 output : str or `io.BufferedIOBase`
462 Either a filesystem path to write to, or a file handle object
466 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T:
467 """Create a new graph object that contains the subset of the nodes
468 specified as input. Node number is preserved.
472 nodes : `QuantumNode` or iterable of `QuantumNode`
476 graph : instance of graph type
477 An instance of the type from which the subset was created
479 if not isinstance(nodes, Iterable):
481 quantumSubgraph = self.
_connectedQuanta_connectedQuanta.subgraph(nodes).nodes
482 quantumMap = defaultdict(set)
485 for node
in quantumSubgraph:
486 quantumMap[node.taskDef].add(node.quantum)
488 newInst =
type(self)({})
489 newInst._buildGraphs(quantumMap, _quantumToNodeId={n.quantum: n.nodeId
for n
in nodes},
494 """Generate a list of subgraphs where each is connected.
498 result : list of `QuantumGraph`
499 A list of graphs that are each connected
501 return tuple(self.
subsetsubset(connectedSet)
502 for connectedSet
in nx.weakly_connected_components(self.
_connectedQuanta_connectedQuanta))
505 """Return a set of `QuantumNode` that are direct inputs to a specified
511 The node of the graph for which inputs are to be determined
516 All the nodes that are direct inputs to specified node
521 """Return a set of `QuantumNode` that are direct outputs of a specified
527 The node of the graph for which outputs are to be determined
532 All the nodes that are direct outputs to specified node
537 """Return a graph of `QuantumNode` that are direct inputs and outputs
543 The node of the graph for which connected nodes are to be
548 graph : graph of `QuantumNode`
549 All the nodes that are directly connected to specified node
553 return self.
subsetsubset(nodes)
556 """Return a graph of the specified node and all the ancestor nodes
557 directly reachable by walking edges.
562 The node for which all ansestors are to be determined
566 graph of `QuantumNode`
567 Graph of node and all of its ansestors
569 predecessorNodes = nx.ancestors(self.
_connectedQuanta_connectedQuanta, node)
570 predecessorNodes.add(node)
571 return self.
subsetsubset(predecessorNodes)
573 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]:
574 """Check a graph for the presense of cycles and returns the edges of
575 any cycles found, or an empty list if there is no cycle.
579 result : list of tuple of `QuantumNode`, `QuantumNode`
580 A list of any graph edges that form a cycle, or an empty list if
581 there is no cycle. Empty list to so support if graph.find_cycle()
582 syntax as an empty list is falsy.
586 except nx.NetworkXNoCycle:
590 """Save `QuantumGraph` to the specified URI.
594 uri : `ButlerURI` or `str`
595 URI to where the graph should be saved.
598 butlerUri = ButlerURI(uri)
599 if butlerUri.getExtension()
not in (
".qgraph"):
600 raise TypeError(f
"Can currently only save a graph in qgraph format not {uri}")
601 butlerUri.write(buffer)
604 def loadUri(cls, uri: Union[ButlerURI, str], universe: DimensionUniverse,
605 nodes: Optional[Iterable[int]] =
None,
606 graphID: Optional[BuildId] =
None
608 """Read `QuantumGraph` from a URI.
612 uri : `ButlerURI` or `str`
613 URI from where to load the graph.
614 universe: `~lsst.daf.butler.DimensionUniverse`
615 DimensionUniverse instance, not used by the method itself but
616 needed to ensure that registry data structures are initialized.
617 nodes: iterable of `int` or None
618 Numbers that correspond to nodes in the graph. If specified, only
619 these nodes will be loaded. Defaults to None, in which case all
620 nodes will be loaded.
621 graphID : `str` or `None`
622 If specified this ID is verified against the loaded graph prior to
623 loading any Nodes. This defaults to None in which case no
628 graph : `QuantumGraph`
629 Resulting QuantumGraph instance.
634 Raised if pickle contains instance of a type other than
637 Raised if one or more of the nodes requested is not in the
638 `QuantumGraph` or if graphID parameter does not match the graph
639 being loaded or if the supplied uri does not point at a valid
640 `QuantumGraph` save file.
645 Reading Quanta from pickle requires existence of singleton
646 DimensionUniverse which is usually instantiated during Registry
647 initialization. To make sure that DimensionUniverse exists this method
648 accepts dummy DimensionUniverse argument.
657 if uri.getExtension()
in (
".pickle",
".pkl"):
658 with uri.as_local()
as local, open(local.ospath,
"rb")
as fd:
659 warnings.warn(
"Pickle graphs are deprecated, please re-save your graph with the save method")
660 qgraph = pickle.load(fd)
661 elif uri.getExtension()
in (
'.qgraph'):
663 qgraph = loader.load(nodes, graphID)
665 raise ValueError(
"Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`")
666 if not isinstance(qgraph, QuantumGraph):
667 raise TypeError(f
"QuantumGraph save file contains unexpected object type: {type(qgraph)}")
670 def save(self, file: io.IO[bytes]):
671 """Save QuantumGraph to a file.
673 Presently we store QuantumGraph in pickle format, this could
674 potentially change in the future if better format is found.
678 file : `io.BufferedIOBase`
679 File to write pickle data open in binary mode.
684 def _buildSaveObject(self) -> bytearray:
698 dump = lzma.compress(pickle.dumps(taskDef, protocol=protocol))
699 taskDefMap[taskDef.label] = (count, count+len(dump))
701 pickleData.append(dump)
707 taskDefMap[
'__GraphBuildID'] = self.
graphIDgraphID
711 node = copy.copy(node)
712 taskDef = node.taskDef
723 object.__setattr__(node,
'taskDef', taskDef.label)
726 dump = lzma.compress(pickle.dumps(node, protocol=protocol))
727 pickleData.append(dump)
728 nodeMap[node.nodeId.number] = (count, count+len(dump))
732 taskDef_pickle = pickle.dumps(taskDefMap, protocol=protocol)
735 map_pickle = pickle.dumps(nodeMap, protocol=protocol)
739 map_lengths = struct.pack(STRUCT_FMT_STRING, SAVE_VERSION, len(taskDef_pickle), len(map_pickle))
747 buffer.extend(MAGIC_BYTES)
748 buffer.extend(map_lengths)
749 buffer.extend(taskDef_pickle)
750 buffer.extend(map_pickle)
760 buffer.extend(pickleData.popleft())
764 def load(cls, file: io.IO[bytes], universe: DimensionUniverse,
765 nodes: Optional[Iterable[int]] =
None,
766 graphID: Optional[BuildId] =
None
768 """Read QuantumGraph from a file that was made by `save`.
772 file : `io.IO` of bytes
773 File with pickle data open in binary mode.
774 universe: `~lsst.daf.butler.DimensionUniverse`
775 DimensionUniverse instance, not used by the method itself but
776 needed to ensure that registry data structures are initialized.
777 nodes: iterable of `int` or None
778 Numbers that correspond to nodes in the graph. If specified, only
779 these nodes will be loaded. Defaults to None, in which case all
780 nodes will be loaded.
781 graphID : `str` or `None`
782 If specified this ID is verified against the loaded graph prior to
783 loading any Nodes. This defaults to None in which case no
788 graph : `QuantumGraph`
789 Resulting QuantumGraph instance.
794 Raised if pickle contains instance of a type other than
797 Raised if one or more of the nodes requested is not in the
798 `QuantumGraph` or if graphID parameter does not match the graph
799 being loaded or if the supplied uri does not point at a valid
800 `QuantumGraph` save file.
804 Reading Quanta from pickle requires existence of singleton
805 DimensionUniverse which is usually instantiated during Registry
806 initialization. To make sure that DimensionUniverse exists this method
807 accepts dummy DimensionUniverse argument.
812 qgraph = pickle.load(file)
813 warnings.warn(
"Pickle graphs are deprecated, please re-save your graph with the save method")
814 except pickle.UnpicklingError:
816 qgraph = loader.load(nodes, graphID)
817 if not isinstance(qgraph, QuantumGraph):
818 raise TypeError(f
"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}")
822 """Iterate over the `taskGraph` attribute in topological order
827 `TaskDef` objects in topological order
829 yield from nx.topological_sort(self.
taskGraphtaskGraph)
833 """Returns the ID generated by the graph at construction time
837 def __iter__(self) -> Generator[QuantumNode, None, None]:
847 """Stores a compact form of the graph as a list of graph nodes, and a
848 tuple of task labels and task configs. The full graph can be
849 reconstructed with this information, and it preseves the ordering of
852 return {
"nodesList":
list(self)}
855 """Reconstructs the state of the graph from the information persisted
858 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set)
859 quantumToNodeId: Dict[Quantum, NodeId] = {}
860 quantumNode: QuantumNode
861 for quantumNode
in state[
'nodesList']:
862 quanta[quantumNode.taskDef].add(quantumNode.quantum)
863 quantumToNodeId[quantumNode.quantum] = quantumNode.nodeId
864 _buildId = quantumNode.nodeId.buildId
if state[
'nodesList']
else None
865 self.
_buildGraphs_buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=_buildId)
868 if not isinstance(other, QuantumGraph):
870 if len(self) != len(other):
873 if node
not in other:
std::vector< SchemaItem< Flag > > * items
_T determineConnectionsOfQuantumNode(_T self, QuantumNode node)
nx.DiGraph taskGraph(self)
Optional[TaskDef] findTaskDefByLabel(self, str label)
def __init__(self, Mapping[TaskDef, Set[Quantum]] quanta)
Set[QuantumNode] determineOutputsOfQuantumNode(self, QuantumNode node)
bool __contains__(self, QuantumNode node)
List[TaskDef] findTaskDefByName(self, str taskName)
Tuple[_T,...] subsetToConnected(_T self)
Optional[TaskDef] findTaskWithOutput(self, DatasetTypeName datasetTypeName)
Tuple[DatasetTypeName,...] allDatasetTypes(self)
Iterable[QuantumNode] inputQuanta(self)
QuantumGraph load(cls, io.IO[bytes] file, DimensionUniverse universe, Optional[Iterable[int]] nodes=None, Optional[BuildId] graphID=None)
Iterable[TaskDef] findTasksWithInput(self, DatasetTypeName datasetTypeName)
_T subset(_T self, Union[QuantumNode, Iterable[QuantumNode]] nodes)
def __setstate__(self, dict state)
QuantumNode getQuantumNodeByNodeId(self, NodeId nodeId)
bool checkQuantumInGraph(self, Quantum quantum)
List[Tuple[QuantumNode, QuantumNode]] findCycle(self)
def save(self, io.IO[bytes] file)
Set[Quantum] findQuantaWithDSType(self, DatasetTypeName datasetTypeName)
bytearray _buildSaveObject(self)
Generator[QuantumNode, None, None] __iter__(self)
Generator[TaskDef, None, None] iterTaskGraph(self)
Set[QuantumNode] determineInputsToQuantumNode(self, QuantumNode node)
Iterable[QuantumNode] outputQuanta(self)
QuantumGraph loadUri(cls, Union[ButlerURI, str] uri, DimensionUniverse universe, Optional[Iterable[int]] nodes=None, Optional[BuildId] graphID=None)
FrozenSet[Quantum] getQuantaForTask(self, TaskDef taskDef)
bool __eq__(self, object other)
def _buildGraphs(self, Mapping[TaskDef, Set[Quantum]] quanta, *Optional[Mapping[Quantum, NodeId]] _quantumToNodeId=None, Optional[BuildId] _buildId=None)
_T determineAncestorsOfQuantumNode(_T self, QuantumNode node)
Iterable[TaskDef] tasksWithDSType(self, DatasetTypeName datasetTypeName)
def writeDotGraph(self, Union[str, io.BufferedIOBase] output)
FrozenSet[QuantumNode] getNodesForTask(self, TaskDef taskDef)
daf::base::PropertyList * list
daf::base::PropertySet * set
typing.Generator[BaseConnection, None, None] iterConnections(PipelineTaskConnections connections, Union[str, Iterable[str]] connectionType)