23 __all__ = [
"makeQuantum",
"runTestQuantum",
"assertValidOutput"]
26 import collections.abc
30 from lsst.daf.butler
import DataCoordinate, DatasetRef, Quantum, StorageClassFactory
35 """Create a Quantum for a particular data ID(s).
39 task : `lsst.pipe.base.PipelineTask`
40 The task whose processing the quantum represents.
41 butler : `lsst.daf.butler.Butler`
42 The collection the quantum refers to.
43 dataId: any data ID type
44 The data ID of the quantum. Must have the same dimensions as
45 ``task``'s connections class.
46 ioDataIds : `collections.abc.Mapping` [`str`]
47 A mapping keyed by input/output names. Values must be data IDs for
48 single connections and sequences of data IDs for multiple connections.
52 quantum : `lsst.daf.butler.Quantum`
53 A quantum for ``task``, when called with ``dataIds``.
55 quantum = Quantum(taskClass=
type(task), dataId=dataId)
56 connections = task.config.ConnectionsClass(config=task.config)
59 for name
in itertools.chain(connections.inputs, connections.prerequisiteInputs):
60 connection = connections.__getattribute__(name)
61 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
62 ids = _normalizeDataIds(ioDataIds[name])
64 quantum.addPredictedInput(_refFromConnection(butler, connection, id))
65 for name
in connections.outputs:
66 connection = connections.__getattribute__(name)
67 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
68 ids = _normalizeDataIds(ioDataIds[name])
70 quantum.addOutput(_refFromConnection(butler, connection, id))
73 raise ValueError(
"Mismatch in input data.")
from e
76 def _checkDataIdMultiplicity(name, dataIds, multiple):
77 """Test whether data IDs are scalars for scalar connections and sequences
78 for multiple connections.
83 The name of the connection being tested.
84 dataIds : any data ID type or `~collections.abc.Sequence` [data ID]
85 The data ID(s) provided for the connection.
87 The ``multiple`` field of the connection.
92 Raised if ``dataIds`` and ``multiple`` do not match.
95 if not isinstance(dataIds, collections.abc.Sequence):
96 raise ValueError(f
"Expected multiple data IDs for {name}, got {dataIds}.")
99 if not isinstance(dataIds, collections.abc.Mapping):
100 raise ValueError(f
"Expected single data ID for {name}, got {dataIds}.")
103 def _normalizeDataIds(dataIds):
104 """Represent both single and multiple data IDs as a list.
108 dataIds : any data ID type or `~collections.abc.Sequence` thereof
109 The data ID(s) provided for a particular input or output connection.
113 normalizedIds : `~collections.abc.Sequence` [data ID]
114 A sequence equal to ``dataIds`` if it was already a sequence, or
115 ``[dataIds]`` if it was a single ID.
117 if isinstance(dataIds, collections.abc.Sequence):
123 def _refFromConnection(butler, connection, dataId, **kwargs):
124 """Create a DatasetRef for a connection in a collection.
128 butler : `lsst.daf.butler.Butler`
129 The collection to point to.
130 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection`
131 The connection defining the dataset type to point to.
133 The data ID for the dataset to point to.
135 Additional keyword arguments used to augment or construct
136 a `~lsst.daf.butler.DataCoordinate`.
140 ref : `lsst.daf.butler.DatasetRef`
141 A reference to a dataset compatible with ``connection``, with ID
142 ``dataId``, in the collection pointed to by ``butler``.
144 universe = butler.registry.dimensions
145 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe)
149 if "skypix" in connection.dimensions:
150 datasetType = butler.registry.getDatasetType(connection.name)
152 datasetType = connection.makeDatasetType(universe)
155 butler.registry.getDatasetType(datasetType.name)
157 raise ValueError(f
"Invalid dataset type {connection.name}.")
159 ref = DatasetRef(datasetType=datasetType, dataId=dataId)
161 except KeyError
as e:
162 raise ValueError(f
"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") \
167 """Run a PipelineTask on a Quantum.
171 task : `lsst.pipe.base.PipelineTask`
172 The task to run on the quantum.
173 butler : `lsst.daf.butler.Butler`
174 The collection to run on.
175 quantum : `lsst.daf.butler.Quantum`
178 Whether or not to replace ``task``'s ``run`` method. The default of
179 `True` is recommended unless ``run`` needs to do real work (e.g.,
180 because the test needs real output datasets).
184 run : `unittest.mock.Mock` or `None`
185 If ``mockRun`` is set, the mock that replaced ``run``. This object can
186 be queried for the arguments ``runQuantum`` passed to ``run``.
189 connections = task.config.ConnectionsClass(config=task.config)
190 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
192 with unittest.mock.patch.object(task,
"run")
as mock, \
193 unittest.mock.patch(
"lsst.pipe.base.ButlerQuantumContext.put"):
194 task.runQuantum(butlerQc, inputRefs, outputRefs)
197 task.runQuantum(butlerQc, inputRefs, outputRefs)
202 """Test that the output of a call to ``run`` conforms to its own connections.
206 task : `lsst.pipe.base.PipelineTask`
207 The task whose connections need validation. This is a fully-configured
208 task object to support features such as optional outputs.
209 result : `lsst.pipe.base.Struct`
210 A result object produced by calling ``task.run``.
215 Raised if ``result`` does not match what's expected from ``task's``
218 connections = task.config.ConnectionsClass(config=task.config)
219 recoveredOutputs = result.getDict()
221 for name
in connections.outputs:
222 connection = connections.__getattribute__(name)
225 output = recoveredOutputs[name]
227 raise AssertionError(f
"No such output: {name}")
229 if connection.multiple:
230 if not isinstance(output, collections.abc.Sequence):
231 raise AssertionError(f
"Expected {name} to be a sequence, got {output} instead.")
234 if isinstance(output, collections.abc.Sequence) \
236 StorageClassFactory().getStorageClass(connection.storageClass).pytype,
237 collections.abc.Sequence):
238 raise AssertionError(f
"Expected {name} to be a single value, got {output} instead.")