22 """Module defining connection classes for PipelineTask.
25 __all__ = [
"PipelineTaskConnections",
"InputQuantizedConnection",
"OutputQuantizedConnection",
26 "DeferredDatasetRef",
"iterConnections"]
28 from collections
import UserDict, namedtuple
29 from types
import SimpleNamespace
31 from typing
import Union, Iterable
36 from .
import config
as configMod
37 from .connectionTypes
import (InitInput, InitOutput, Input, PrerequisiteInput,
38 Output, BaseConnection)
39 from lsst.daf.butler
import DatasetRef, DatasetType, NamedKeyDict, Quantum
41 if typing.TYPE_CHECKING:
42 from .config
import PipelineTaskConfig
46 """Exception raised when dataset type is configured as scalar
47 but there are multiple data IDs in a Quantum for that dataset.
52 """This is a special dict class used by PipelineTaskConnectionMetaclass
54 This dict is used in PipelineTaskConnection class creation, as the
55 dictionary that is initially used as __dict__. It exists to
56 intercept connection fields declared in a PipelineTaskConnection, and
57 what name is used to identify them. The names are then added to class
58 level list according to the connection type of the class attribute. The
59 names are also used as keys in a class level dictionary associated with
60 the corresponding class attribute. This information is a duplicate of
61 what exists in __dict__, but provides a simple place to lookup and
62 iterate on only these variables.
69 self.data[
'inputs'] = []
70 self.data[
'prerequisiteInputs'] = []
71 self.data[
'outputs'] = []
72 self.data[
'initInputs'] = []
73 self.data[
'initOutputs'] = []
74 self.data[
'allConnections'] = {}
77 if isinstance(value, Input):
78 self.data[
'inputs'].
append(name)
79 elif isinstance(value, PrerequisiteInput):
80 self.data[
'prerequisiteInputs'].
append(name)
81 elif isinstance(value, Output):
82 self.data[
'outputs'].
append(name)
83 elif isinstance(value, InitInput):
84 self.data[
'initInputs'].
append(name)
85 elif isinstance(value, InitOutput):
86 self.data[
'initOutputs'].
append(name)
89 if isinstance(value, BaseConnection):
90 object.__setattr__(value,
'varName', name)
91 self.data[
'allConnections'][name] = value
97 """Metaclass used in the declaration of PipelineTaskConnections classes
105 if isinstance(base, PipelineTaskConnectionsMetaclass):
106 for name, value
in base.allConnections.items():
111 dimensionsValueError = TypeError(
"PipelineTaskConnections class must be created with a dimensions "
112 "attribute which is an iterable of dimension names")
114 if name !=
'PipelineTaskConnections':
117 if 'dimensions' not in kwargs:
119 if hasattr(base,
'dimensions'):
120 kwargs[
'dimensions'] = base.dimensions
122 if 'dimensions' not in kwargs:
123 raise dimensionsValueError
125 if isinstance(kwargs[
'dimensions'], str):
126 raise TypeError(
"Dimensions must be iterable of dimensions, got str,"
127 "possibly omitted trailing comma")
128 if not isinstance(kwargs[
'dimensions'], typing.Iterable):
129 raise TypeError(
"Dimensions must be iterable of dimensions")
130 dct[
'dimensions'] =
set(kwargs[
'dimensions'])
131 except TypeError
as exc:
132 raise dimensionsValueError
from exc
136 stringFormatter = string.Formatter()
138 for obj
in dct[
'allConnections'].values():
141 for param
in stringFormatter.parse(nameValue):
142 if param[1]
is not None:
143 allTemplates.add(param[1])
148 for base
in bases[::-1]:
149 if hasattr(base,
'defaultTemplates'):
150 mergeDict.update(base.defaultTemplates)
151 if 'defaultTemplates' in kwargs:
152 mergeDict.update(kwargs[
'defaultTemplates'])
154 if len(mergeDict) > 0:
155 kwargs[
'defaultTemplates'] = mergeDict
160 if len(allTemplates) > 0
and 'defaultTemplates' not in kwargs:
161 raise TypeError(
"PipelineTaskConnection class contains templated attribute names, but no "
162 "defaut templates were provided, add a dictionary attribute named "
163 "defaultTemplates which contains the mapping between template key and value")
164 if len(allTemplates) > 0:
166 defaultTemplateKeys =
set(kwargs[
'defaultTemplates'].
keys())
167 templateDifference = allTemplates.difference(defaultTemplateKeys)
168 if templateDifference:
169 raise TypeError(f
"Default template keys were not provided for {templateDifference}")
173 nameTemplateIntersection = allTemplates.intersection(
set(dct[
'allConnections'].
keys()))
174 if len(nameTemplateIntersection) > 0:
175 raise TypeError(f
"Template parameters cannot share names with Class attributes"
176 f
" (conflicts are {nameTemplateIntersection}).")
177 dct[
'defaultTemplates'] = kwargs.get(
'defaultTemplates', {})
181 for connectionName
in (
"inputs",
"prerequisiteInputs",
"outputs",
"initInputs",
"initOutputs"):
182 dct[connectionName] = frozenset(dct[connectionName])
185 return super().
__new__(cls, name, bases, dict(dct))
197 """A Namespace to map defined variable names of connections to their
198 `lsst.daf.buter.DatasetRef`s
200 This class maps the names used to define a connection on a
201 PipelineTaskConnectionsClass to the corresponding
202 `lsst.daf.butler.DatasetRef`s provided by a `lsst.daf.butler.Quantum`
203 instance. This will be a quantum of execution based on the graph created
204 by examining all the connections defined on the
205 `PipelineTaskConnectionsClass`.
210 object.__setattr__(self,
"_attributes",
set())
212 def __setattr__(self, name: str, value: typing.Union[DatasetRef, typing.List[DatasetRef]]):
214 self._attributes.add(name)
218 object.__delattr__(self, name)
219 self._attributes.remove(name)
221 def __iter__(self) -> typing.Generator[typing.Tuple[str, typing.Union[DatasetRef,
222 typing.List[DatasetRef]]], None, None]:
223 """Make an Iterator for this QuantizedConnection
225 Iterating over a QuantizedConnection will yield a tuple with the name
226 of an attribute and the value associated with that name. This is
227 similar to dict.items() but is on the namespace attributes rather than
230 yield from ((name, getattr(self, name))
for name
in self._attributes)
232 def keys(self) -> typing.Generator[str, None, None]:
233 """Returns an iterator over all the attributes added to a
234 QuantizedConnection class
236 yield from self._attributes
243 class OutputQuantizedConnection(QuantizedConnection):
248 """Class which denotes that a datasetRef should be treated as deferred when
249 interacting with the butler
253 datasetRef : `lsst.daf.butler.DatasetRef`
254 The `lsst.daf.butler.DatasetRef` that will be eventually used to
261 """PipelineTaskConnections is a class used to declare desired IO when a
262 PipelineTask is run by an activator
266 config : `PipelineTaskConfig`
267 A `PipelineTaskConfig` class instance whose class has been configured
268 to use this `PipelineTaskConnectionsClass`
272 ``PipelineTaskConnection`` classes are created by declaring class
273 attributes of types defined in `lsst.pipe.base.connectionTypes` and are
276 * ``InitInput`` - Defines connections in a quantum graph which are used as
277 inputs to the ``__init__`` function of the `PipelineTask` corresponding
279 * ``InitOuput`` - Defines connections in a quantum graph which are to be
280 persisted using a butler at the end of the ``__init__`` function of the
281 `PipelineTask` corresponding to this class. The variable name used to
282 define this connection should be the same as an attribute name on the
283 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with
284 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then
285 a `PipelineTask` instance should have an attribute
286 ``self.outputSchema`` defined. Its value is what will be saved by the
288 * ``PrerequisiteInput`` - An input connection type that defines a
289 `lsst.daf.butler.DatasetType` that must be present at execution time,
290 but that will not be used during the course of creating the quantum
291 graph to be executed. These most often are things produced outside the
292 processing pipeline, such as reference catalogs.
293 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used
294 in the ``run`` method of a `PipelineTask`. The name used to declare
295 class attribute must match a function argument name in the ``run``
296 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections``
297 defines an ``Input`` with the name ``calexp``, then the corresponding
298 signature should be ``PipelineTask.run(calexp, ...)``
299 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an
300 execution of a `PipelineTask`. The name used to declare the connection
301 must correspond to an attribute of a `Struct` that is returned by a
302 `PipelineTask` ``run`` method. E.g. if an output connection is
303 defined with the name ``measCat``, then the corresponding
304 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where
305 X matches the ``storageClass`` type defined on the output connection.
307 The process of declaring a ``PipelineTaskConnection`` class involves
308 parameters passed in the declaration statement.
310 The first parameter is ``dimensions`` which is an iterable of strings which
311 defines the unit of processing the run method of a corresponding
312 `PipelineTask` will operate on. These dimensions must match dimensions that
313 exist in the butler registry which will be used in executing the
314 corresponding `PipelineTask`.
316 The second parameter is labeled ``defaultTemplates`` and is conditionally
317 optional. The name attributes of connections can be specified as python
318 format strings, with named format arguments. If any of the name parameters
319 on connections defined in a `PipelineTaskConnections` class contain a
320 template, then a default template value must be specified in the
321 ``defaultTemplates`` argument. This is done by passing a dictionary with
322 keys corresponding to a template identifier, and values corresponding to
323 the value to use as a default when formatting the string. For example if
324 ``ConnectionClass.calexp.name = '{input}Coadd_calexp'`` then
325 ``defaultTemplates`` = {'input': 'deep'}.
327 Once a `PipelineTaskConnections` class is created, it is used in the
328 creation of a `PipelineTaskConfig`. This is further documented in the
329 documentation of `PipelineTaskConfig`. For the purposes of this
330 documentation, the relevant information is that the config class allows
331 configuration of connection names by users when running a pipeline.
333 Instances of a `PipelineTaskConnections` class are used by the pipeline
334 task execution framework to introspect what a corresponding `PipelineTask`
335 will require, and what it will produce.
339 >>> from lsst.pipe.base import connectionTypes as cT
340 >>> from lsst.pipe.base import PipelineTaskConnections
341 >>> from lsst.pipe.base import PipelineTaskConfig
342 >>> class ExampleConnections(PipelineTaskConnections,
343 ... dimensions=("A", "B"),
344 ... defaultTemplates={"foo": "Example"}):
345 ... inputConnection = cT.Input(doc="Example input",
346 ... dimensions=("A", "B"),
347 ... storageClass=Exposure,
348 ... name="{foo}Dataset")
349 ... outputConnection = cT.Output(doc="Example output",
350 ... dimensions=("A", "B"),
351 ... storageClass=Exposure,
352 ... name="{foo}output")
353 >>> class ExampleConfig(PipelineTaskConfig,
354 ... pipelineConnections=ExampleConnections):
356 >>> config = ExampleConfig()
357 >>> config.connections.foo = Modified
358 >>> config.connections.outputConnection = "TotallyDifferent"
359 >>> connections = ExampleConnections(config=config)
360 >>> assert(connections.inputConnection.name == "ModifiedDataset")
361 >>> assert(connections.outputConnection.name == "TotallyDifferent")
364 def __init__(self, *, config:
'PipelineTaskConfig' =
None):
372 if config
is None or not isinstance(config, configMod.PipelineTaskConfig):
373 raise ValueError(
"PipelineTaskConnections must be instantiated with"
374 " a PipelineTaskConfig instance")
379 templateValues = {name: getattr(config.connections, name)
for name
in getattr(self,
380 'defaultTemplates').
keys()}
393 OutputQuantizedConnection]:
394 """Builds QuantizedConnections corresponding to input Quantum
398 quantum : `lsst.daf.butler.Quantum`
399 Quantum object which defines the inputs and outputs for a given
404 retVal : `tuple` of (`InputQuantizedConnection`,
405 `OutputQuantizedConnection`) Namespaces mapping attribute names
406 (identifiers of connections) to butler references defined in the
407 input `lsst.daf.butler.Quantum`
413 for refs, names
in zip((inputDatasetRefs, outputDatasetRefs),
416 for attributeName
in names:
418 attribute = getattr(self, attributeName)
420 if attribute.name
in quantum.inputs:
422 quantumInputRefs = quantum.inputs[attribute.name]
425 if attribute.deferLoad:
429 if not attribute.multiple:
430 if len(quantumInputRefs) > 1:
432 f
"Received multiple datasets "
433 f
"{', '.join(str(r.dataId) for r in quantumInputRefs)} "
434 f
"for scalar connection {attributeName} "
435 f
"({quantumInputRefs[0].datasetType.name}) "
436 f
"of quantum for {quantum.taskName} with data ID {quantum.dataId}."
438 if len(quantumInputRefs) == 0:
440 quantumInputRefs = quantumInputRefs[0]
442 setattr(refs, attributeName, quantumInputRefs)
444 elif attribute.name
in quantum.outputs:
445 value = quantum.outputs[attribute.name]
448 if not attribute.multiple:
451 setattr(refs, attributeName, value)
455 raise ValueError(f
"Attribute with name {attributeName} has no counterpoint "
457 return inputDatasetRefs, outputDatasetRefs
459 def adjustQuantum(self, datasetRefMap: NamedKeyDict[DatasetType, typing.Set[DatasetRef]]
460 ) -> NamedKeyDict[DatasetType, typing.Set[DatasetRef]]:
461 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects
462 in the `lsst.daf.butler.core.Quantum` during the graph generation stage
465 The base class implementation simply checks that input connections with
466 ``multiple`` set to `False` have no more than one dataset.
470 datasetRefMap : `NamedKeyDict`
471 Mapping from dataset type to a `set` of
472 `lsst.daf.butler.DatasetRef` objects
476 datasetRefMap : `NamedKeyDict`
477 Modified mapping of input with possibly adjusted
478 `lsst.daf.butler.DatasetRef` objects.
483 Raised if any `Input` or `PrerequisiteInput` connection has
484 ``multiple`` set to `False`, but multiple datasets.
486 Overrides of this function have the option of raising an Exception
487 if a field in the input does not satisfy a need for a corresponding
488 pipelineTask, i.e. no reference catalogs are found.
492 refs = datasetRefMap[connection.name]
493 if not connection.multiple
and len(refs) > 1:
495 f
"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
496 f
"for scalar connection {connection.name} ({refs[0].datasetType.name})."
502 connectionType: Union[str, Iterable[str]]
503 ) -> typing.Generator[BaseConnection,
None,
None]:
504 """Creates an iterator over the selected connections type which yields
505 all the defined connections of that type.
509 connections: `PipelineTaskConnections`
510 An instance of a `PipelineTaskConnections` object that will be iterated
512 connectionType: `str`
513 The type of connections to iterate over, valid values are inputs,
514 outputs, prerequisiteInputs, initInputs, initOutputs.
518 connection: `BaseConnection`
519 A connection defined on the input connections object of the type
520 supplied. The yielded value Will be an derived type of
523 if isinstance(connectionType, str):
524 connectionType = (connectionType,)
525 for name
in itertools.chain.from_iterable(getattr(connections, ct)
for ct
in connectionType):
526 yield getattr(connections, name)