22 """Module defining connection classes for PipelineTask.
25 __all__ = [
"PipelineTaskConnections",
"InputQuantizedConnection",
"OutputQuantizedConnection",
26 "DeferredDatasetRef",
"iterConnections"]
28 from collections
import UserDict, namedtuple
29 from types
import SimpleNamespace
35 from .
import config
as configMod
36 from .connectionTypes
import (InitInput, InitOutput, Input, PrerequisiteInput,
37 Output, BaseConnection)
38 from lsst.daf.butler
import DatasetRef, DatasetType, NamedKeyDict, Quantum
40 if typing.TYPE_CHECKING:
41 from .config
import PipelineTaskConfig
45 """Exception raised when dataset type is configured as scalar
46 but there are multiple data IDs in a Quantum for that dataset.
51 """This is a special dict class used by PipelineTaskConnectionMetaclass
53 This dict is used in PipelineTaskConnection class creation, as the
54 dictionary that is initially used as __dict__. It exists to
55 intercept connection fields declared in a PipelineTaskConnection, and
56 what name is used to identify them. The names are then added to class
57 level list according to the connection type of the class attribute. The
58 names are also used as keys in a class level dictionary associated with
59 the corresponding class attribute. This information is a duplicate of
60 what exists in __dict__, but provides a simple place to lookup and
61 iterate on only these variables.
68 self.data[
'inputs'] = []
69 self.data[
'prerequisiteInputs'] = []
70 self.data[
'outputs'] = []
71 self.data[
'initInputs'] = []
72 self.data[
'initOutputs'] = []
73 self.data[
'allConnections'] = {}
76 if isinstance(value, Input):
77 self.data[
'inputs'].
append(name)
78 elif isinstance(value, PrerequisiteInput):
79 self.data[
'prerequisiteInputs'].
append(name)
80 elif isinstance(value, Output):
81 self.data[
'outputs'].
append(name)
82 elif isinstance(value, InitInput):
83 self.data[
'initInputs'].
append(name)
84 elif isinstance(value, InitOutput):
85 self.data[
'initOutputs'].
append(name)
88 if isinstance(value, BaseConnection):
89 object.__setattr__(value,
'varName', name)
90 self.data[
'allConnections'][name] = value
96 """Metaclass used in the declaration of PipelineTaskConnections classes
104 if isinstance(base, PipelineTaskConnectionsMetaclass):
105 for name, value
in base.allConnections.items():
110 dimensionsValueError = TypeError(
"PipelineTaskConnections class must be created with a dimensions "
111 "attribute which is an iterable of dimension names")
113 if name !=
'PipelineTaskConnections':
116 if 'dimensions' not in kwargs:
118 if hasattr(base,
'dimensions'):
119 kwargs[
'dimensions'] = base.dimensions
121 if 'dimensions' not in kwargs:
122 raise dimensionsValueError
124 dct[
'dimensions'] =
set(kwargs[
'dimensions'])
125 except TypeError
as exc:
126 raise dimensionsValueError
from exc
130 stringFormatter = string.Formatter()
132 for obj
in dct[
'allConnections'].values():
135 for param
in stringFormatter.parse(nameValue):
136 if param[1]
is not None:
137 allTemplates.add(param[1])
142 for base
in bases[::-1]:
143 if hasattr(base,
'defaultTemplates'):
144 mergeDict.update(base.defaultTemplates)
145 if 'defaultTemplates' in kwargs:
146 mergeDict.update(kwargs[
'defaultTemplates'])
148 if len(mergeDict) > 0:
149 kwargs[
'defaultTemplates'] = mergeDict
154 if len(allTemplates) > 0
and 'defaultTemplates' not in kwargs:
155 raise TypeError(
"PipelineTaskConnection class contains templated attribute names, but no "
156 "defaut templates were provided, add a dictionary attribute named "
157 "defaultTemplates which contains the mapping between template key and value")
158 if len(allTemplates) > 0:
160 defaultTemplateKeys =
set(kwargs[
'defaultTemplates'].
keys())
161 templateDifference = allTemplates.difference(defaultTemplateKeys)
162 if templateDifference:
163 raise TypeError(f
"Default template keys were not provided for {templateDifference}")
167 nameTemplateIntersection = allTemplates.intersection(
set(dct[
'allConnections'].
keys()))
168 if len(nameTemplateIntersection) > 0:
169 raise TypeError(f
"Template parameters cannot share names with Class attributes"
170 f
" (conflicts are {nameTemplateIntersection}).")
171 dct[
'defaultTemplates'] = kwargs.get(
'defaultTemplates', {})
175 for connectionName
in (
"inputs",
"prerequisiteInputs",
"outputs",
"initInputs",
"initOutputs"):
176 dct[connectionName] = frozenset(dct[connectionName])
179 return super().
__new__(cls, name, bases, dict(dct))
191 """A Namespace to map defined variable names of connections to their
192 `lsst.daf.buter.DatasetRef`s
194 This class maps the names used to define a connection on a
195 PipelineTaskConnectionsClass to the corresponding
196 `lsst.daf.butler.DatasetRef`s provided by a `lsst.daf.butler.Quantum`
197 instance. This will be a quantum of execution based on the graph created
198 by examining all the connections defined on the
199 `PipelineTaskConnectionsClass`.
204 object.__setattr__(self,
"_attributes",
set())
206 def __setattr__(self, name: str, value: typing.Union[DatasetRef, typing.List[DatasetRef]]):
208 self._attributes.add(name)
212 object.__delattr__(self, name)
213 self._attributes.remove(name)
215 def __iter__(self) -> typing.Generator[typing.Tuple[str, typing.Union[DatasetRef,
216 typing.List[DatasetRef]]], None, None]:
217 """Make an Iterator for this QuantizedConnection
219 Iterating over a QuantizedConnection will yield a tuple with the name
220 of an attribute and the value associated with that name. This is
221 similar to dict.items() but is on the namespace attributes rather than
224 yield from ((name, getattr(self, name))
for name
in self._attributes)
226 def keys(self) -> typing.Generator[str, None, None]:
227 """Returns an iterator over all the attributes added to a
228 QuantizedConnection class
230 yield from self._attributes
237 class OutputQuantizedConnection(QuantizedConnection):
242 """Class which denotes that a datasetRef should be treated as deferred when
243 interacting with the butler
247 datasetRef : `lsst.daf.butler.DatasetRef`
248 The `lsst.daf.butler.DatasetRef` that will be eventually used to
255 """PipelineTaskConnections is a class used to declare desired IO when a
256 PipelineTask is run by an activator
260 config : `PipelineTaskConfig`
261 A `PipelineTaskConfig` class instance whose class has been configured
262 to use this `PipelineTaskConnectionsClass`
266 ``PipelineTaskConnection`` classes are created by declaring class
267 attributes of types defined in `lsst.pipe.base.connectionTypes` and are
270 * ``InitInput`` - Defines connections in a quantum graph which are used as
271 inputs to the ``__init__`` function of the `PipelineTask` corresponding
273 * ``InitOuput`` - Defines connections in a quantum graph which are to be
274 persisted using a butler at the end of the ``__init__`` function of the
275 `PipelineTask` corresponding to this class. The variable name used to
276 define this connection should be the same as an attribute name on the
277 `PipelineTask` instance. E.g. if an ``InitOutput`` is declared with
278 the name ``outputSchema`` in a ``PipelineTaskConnections`` class, then
279 a `PipelineTask` instance should have an attribute
280 ``self.outputSchema`` defined. Its value is what will be saved by the
282 * ``PrerequisiteInput`` - An input connection type that defines a
283 `lsst.daf.butler.DatasetType` that must be present at execution time,
284 but that will not be used during the course of creating the quantum
285 graph to be executed. These most often are things produced outside the
286 processing pipeline, such as reference catalogs.
287 * ``Input`` - Input `lsst.daf.butler.DatasetType` objects that will be used
288 in the ``run`` method of a `PipelineTask`. The name used to declare
289 class attribute must match a function argument name in the ``run``
290 method of a `PipelineTask`. E.g. If the ``PipelineTaskConnections``
291 defines an ``Input`` with the name ``calexp``, then the corresponding
292 signature should be ``PipelineTask.run(calexp, ...)``
293 * ``Output`` - A `lsst.daf.butler.DatasetType` that will be produced by an
294 execution of a `PipelineTask`. The name used to declare the connection
295 must correspond to an attribute of a `Struct` that is returned by a
296 `PipelineTask` ``run`` method. E.g. if an output connection is
297 defined with the name ``measCat``, then the corresponding
298 ``PipelineTask.run`` method must return ``Struct(measCat=X,..)`` where
299 X matches the ``storageClass`` type defined on the output connection.
301 The process of declaring a ``PipelineTaskConnection`` class involves
302 parameters passed in the declaration statement.
304 The first parameter is ``dimensions`` which is an iterable of strings which
305 defines the unit of processing the run method of a corresponding
306 `PipelineTask` will operate on. These dimensions must match dimensions that
307 exist in the butler registry which will be used in executing the
308 corresponding `PipelineTask`.
310 The second parameter is labeled ``defaultTemplates`` and is conditionally
311 optional. The name attributes of connections can be specified as python
312 format strings, with named format arguments. If any of the name parameters
313 on connections defined in a `PipelineTaskConnections` class contain a
314 template, then a default template value must be specified in the
315 ``defaultTemplates`` argument. This is done by passing a dictionary with
316 keys corresponding to a template identifier, and values corresponding to
317 the value to use as a default when formatting the string. For example if
318 ``ConnectionClass.calexp.name = '{input}Coadd_calexp'`` then
319 ``defaultTemplates`` = {'input': 'deep'}.
321 Once a `PipelineTaskConnections` class is created, it is used in the
322 creation of a `PipelineTaskConfig`. This is further documented in the
323 documentation of `PipelineTaskConfig`. For the purposes of this
324 documentation, the relevant information is that the config class allows
325 configuration of connection names by users when running a pipeline.
327 Instances of a `PipelineTaskConnections` class are used by the pipeline
328 task execution framework to introspect what a corresponding `PipelineTask`
329 will require, and what it will produce.
333 >>> from lsst.pipe.base import connectionTypes as cT
334 >>> from lsst.pipe.base import PipelineTaskConnections
335 >>> from lsst.pipe.base import PipelineTaskConfig
336 >>> class ExampleConnections(PipelineTaskConnections,
337 ... dimensions=("A", "B"),
338 ... defaultTemplates={"foo": "Example"}):
339 ... inputConnection = cT.Input(doc="Example input",
340 ... dimensions=("A", "B"),
341 ... storageClass=Exposure,
342 ... name="{foo}Dataset")
343 ... outputConnection = cT.Output(doc="Example output",
344 ... dimensions=("A", "B"),
345 ... storageClass=Exposure,
346 ... name="{foo}output")
347 >>> class ExampleConfig(PipelineTaskConfig,
348 ... pipelineConnections=ExampleConnections):
350 >>> config = ExampleConfig()
351 >>> config.connections.foo = Modified
352 >>> config.connections.outputConnection = "TotallyDifferent"
353 >>> connections = ExampleConnections(config=config)
354 >>> assert(connections.inputConnection.name == "ModifiedDataset")
355 >>> assert(connections.outputConnection.name == "TotallyDifferent")
358 def __init__(self, *, config:
'PipelineTaskConfig' =
None):
365 if config
is None or not isinstance(config, configMod.PipelineTaskConfig):
366 raise ValueError(
"PipelineTaskConnections must be instantiated with"
367 " a PipelineTaskConfig instance")
372 templateValues = {name: getattr(config.connections, name)
for name
in getattr(self,
373 'defaultTemplates').
keys()}
378 for name
in self.allConnections.
keys()}
386 OutputQuantizedConnection]:
387 """Builds QuantizedConnections corresponding to input Quantum
391 quantum : `lsst.daf.butler.Quantum`
392 Quantum object which defines the inputs and outputs for a given
397 retVal : `tuple` of (`InputQuantizedConnection`,
398 `OutputQuantizedConnection`) Namespaces mapping attribute names
399 (identifiers of connections) to butler references defined in the
400 input `lsst.daf.butler.Quantum`
406 for refs, names
in zip((inputDatasetRefs, outputDatasetRefs),
409 for attributeName
in names:
411 attribute = getattr(self, attributeName)
413 if attribute.name
in quantum.predictedInputs:
415 quantumInputRefs = quantum.predictedInputs[attribute.name]
418 if attribute.deferLoad:
422 if not attribute.multiple:
423 if len(quantumInputRefs) > 1:
425 f
"Received multiple datasets "
426 f
"{', '.join(str(r.dataId) for r in quantumInputRefs)} "
427 f
"for scalar connection {attributeName} "
428 f
"({quantumInputRefs[0].datasetType.name}) "
429 f
"of quantum for {quantum.taskName} with data ID {quantum.dataId}."
431 if len(quantumInputRefs) == 0:
433 quantumInputRefs = quantumInputRefs[0]
435 setattr(refs, attributeName, quantumInputRefs)
437 elif attribute.name
in quantum.outputs:
438 value = quantum.outputs[attribute.name]
441 if not attribute.multiple:
444 setattr(refs, attributeName, value)
448 raise ValueError(f
"Attribute with name {attributeName} has no counterpoint "
450 return inputDatasetRefs, outputDatasetRefs
452 def adjustQuantum(self, datasetRefMap: NamedKeyDict[DatasetType, typing.Set[DatasetRef]]
453 ) -> NamedKeyDict[DatasetType, typing.Set[DatasetRef]]:
454 """Override to make adjustments to `lsst.daf.butler.DatasetRef` objects
455 in the `lsst.daf.butler.core.Quantum` during the graph generation stage
458 The base class implementation simply checks that input connections with
459 ``multiple`` set to `False` have no more than one dataset.
463 datasetRefMap : `NamedKeyDict`
464 Mapping from dataset type to a `set` of
465 `lsst.daf.butler.DatasetRef` objects
469 datasetRefMap : `NamedKeyDict`
470 Modified mapping of input with possibly adjusted
471 `lsst.daf.butler.DatasetRef` objects.
476 Raised if any `Input` or `PrerequisiteInput` connection has
477 ``multiple`` set to `False`, but multiple datasets.
479 Overrides of this function have the option of raising an Exception
480 if a field in the input does not satisfy a need for a corresponding
481 pipelineTask, i.e. no reference catalogs are found.
485 refs = datasetRefMap[connection.name]
486 if not connection.multiple
and len(refs) > 1:
488 f
"Found multiple datasets {', '.join(str(r.dataId) for r in refs)} "
489 f
"for scalar connection {connection.name} ({refs[0].datasetType.name})."
494 def iterConnections(connections: PipelineTaskConnections, connectionType: str) -> typing.Generator:
495 """Creates an iterator over the selected connections type which yields
496 all the defined connections of that type.
500 connections: `PipelineTaskConnections`
501 An instance of a `PipelineTaskConnections` object that will be iterated
503 connectionType: `str`
504 The type of connections to iterate over, valid values are inputs,
505 outputs, prerequisiteInputs, initInputs, initOutputs.
509 connection: `BaseConnection`
510 A connection defined on the input connections object of the type
511 supplied. The yielded value Will be an derived type of
514 for name
in getattr(connections, connectionType):
515 yield getattr(connections, name)