23 __all__ = [
"makeQuantum",
"runTestQuantum",
"assertValidOutput"]
26 from collections
import defaultdict
27 import collections.abc
31 from lsst.daf.butler
import DataCoordinate, DatasetRef, Quantum, StorageClassFactory
36 """Create a Quantum for a particular data ID(s).
40 task : `lsst.pipe.base.PipelineTask`
41 The task whose processing the quantum represents.
42 butler : `lsst.daf.butler.Butler`
43 The collection the quantum refers to.
44 dataId: any data ID type
45 The data ID of the quantum. Must have the same dimensions as
46 ``task``'s connections class.
47 ioDataIds : `collections.abc.Mapping` [`str`]
48 A mapping keyed by input/output names. Values must be data IDs for
49 single connections and sequences of data IDs for multiple connections.
53 quantum : `lsst.daf.butler.Quantum`
54 A quantum for ``task``, when called with ``dataIds``.
56 connections = task.config.ConnectionsClass(config=task.config)
59 inputs = defaultdict(list)
60 outputs = defaultdict(list)
61 for name
in itertools.chain(connections.inputs, connections.prerequisiteInputs):
62 connection = connections.__getattribute__(name)
63 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
64 ids = _normalizeDataIds(ioDataIds[name])
66 ref = _refFromConnection(butler, connection, id)
67 inputs[ref.datasetType].
append(ref)
68 for name
in connections.outputs:
69 connection = connections.__getattribute__(name)
70 _checkDataIdMultiplicity(name, ioDataIds[name], connection.multiple)
71 ids = _normalizeDataIds(ioDataIds[name])
73 ref = _refFromConnection(butler, connection, id)
74 outputs[ref.datasetType].
append(ref)
75 quantum = Quantum(taskClass=
type(task),
81 raise ValueError(
"Mismatch in input data.")
from e
84 def _checkDataIdMultiplicity(name, dataIds, multiple):
85 """Test whether data IDs are scalars for scalar connections and sequences
86 for multiple connections.
91 The name of the connection being tested.
92 dataIds : any data ID type or `~collections.abc.Sequence` [data ID]
93 The data ID(s) provided for the connection.
95 The ``multiple`` field of the connection.
100 Raised if ``dataIds`` and ``multiple`` do not match.
103 if not isinstance(dataIds, collections.abc.Sequence):
104 raise ValueError(f
"Expected multiple data IDs for {name}, got {dataIds}.")
107 if not isinstance(dataIds, collections.abc.Mapping):
108 raise ValueError(f
"Expected single data ID for {name}, got {dataIds}.")
111 def _normalizeDataIds(dataIds):
112 """Represent both single and multiple data IDs as a list.
116 dataIds : any data ID type or `~collections.abc.Sequence` thereof
117 The data ID(s) provided for a particular input or output connection.
121 normalizedIds : `~collections.abc.Sequence` [data ID]
122 A sequence equal to ``dataIds`` if it was already a sequence, or
123 ``[dataIds]`` if it was a single ID.
125 if isinstance(dataIds, collections.abc.Sequence):
131 def _refFromConnection(butler, connection, dataId, **kwargs):
132 """Create a DatasetRef for a connection in a collection.
136 butler : `lsst.daf.butler.Butler`
137 The collection to point to.
138 connection : `lsst.pipe.base.connectionTypes.DimensionedConnection`
139 The connection defining the dataset type to point to.
141 The data ID for the dataset to point to.
143 Additional keyword arguments used to augment or construct
144 a `~lsst.daf.butler.DataCoordinate`.
148 ref : `lsst.daf.butler.DatasetRef`
149 A reference to a dataset compatible with ``connection``, with ID
150 ``dataId``, in the collection pointed to by ``butler``.
152 universe = butler.registry.dimensions
153 dataId = DataCoordinate.standardize(dataId, **kwargs, universe=universe)
157 if "skypix" in connection.dimensions:
158 datasetType = butler.registry.getDatasetType(connection.name)
160 datasetType = connection.makeDatasetType(universe)
163 butler.registry.getDatasetType(datasetType.name)
165 raise ValueError(f
"Invalid dataset type {connection.name}.")
167 ref = DatasetRef(datasetType=datasetType, dataId=dataId)
169 except KeyError
as e:
170 raise ValueError(f
"Dataset type ({connection.name}) and ID {dataId.byName()} not compatible.") \
174 def _resolveTestQuantumInputs(butler, quantum):
175 """Look up all input datasets a test quantum in the `Registry` to resolve
176 all `DatasetRef` objects (i.e. ensure they have not-`None` ``id`` and
181 quantum : `~lsst.daf.butler.Quantum`
182 Single Quantum instance.
183 butler : `~lsst.daf.butler.Butler`
192 for refsForDatasetType
in quantum.inputs.values():
193 newRefsForDatasetType = []
194 for ref
in refsForDatasetType:
196 resolvedRef = butler.registry.findDataset(ref.datasetType, ref.dataId,
197 collections=butler.collections)
198 if resolvedRef
is None:
200 f
"Cannot find {ref.datasetType.name} with id {ref.dataId} "
201 f
"in collections {butler.collections}."
203 newRefsForDatasetType.append(resolvedRef)
205 newRefsForDatasetType.append(ref)
206 refsForDatasetType[:] = newRefsForDatasetType
210 """Run a PipelineTask on a Quantum.
214 task : `lsst.pipe.base.PipelineTask`
215 The task to run on the quantum.
216 butler : `lsst.daf.butler.Butler`
217 The collection to run on.
218 quantum : `lsst.daf.butler.Quantum`
221 Whether or not to replace ``task``'s ``run`` method. The default of
222 `True` is recommended unless ``run`` needs to do real work (e.g.,
223 because the test needs real output datasets).
227 run : `unittest.mock.Mock` or `None`
228 If ``mockRun`` is set, the mock that replaced ``run``. This object can
229 be queried for the arguments ``runQuantum`` passed to ``run``.
231 _resolveTestQuantumInputs(butler, quantum)
233 connections = task.config.ConnectionsClass(config=task.config)
234 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
236 with unittest.mock.patch.object(task,
"run")
as mock, \
237 unittest.mock.patch(
"lsst.pipe.base.ButlerQuantumContext.put"):
238 task.runQuantum(butlerQc, inputRefs, outputRefs)
241 task.runQuantum(butlerQc, inputRefs, outputRefs)
246 """Test that the output of a call to ``run`` conforms to its own
251 task : `lsst.pipe.base.PipelineTask`
252 The task whose connections need validation. This is a fully-configured
253 task object to support features such as optional outputs.
254 result : `lsst.pipe.base.Struct`
255 A result object produced by calling ``task.run``.
260 Raised if ``result`` does not match what's expected from ``task's``
263 connections = task.config.ConnectionsClass(config=task.config)
264 recoveredOutputs = result.getDict()
266 for name
in connections.outputs:
267 connection = connections.__getattribute__(name)
270 output = recoveredOutputs[name]
272 raise AssertionError(f
"No such output: {name}")
274 if connection.multiple:
275 if not isinstance(output, collections.abc.Sequence):
276 raise AssertionError(f
"Expected {name} to be a sequence, got {output} instead.")
280 if isinstance(output, collections.abc.Sequence) \
282 StorageClassFactory().getStorageClass(connection.storageClass).pytype,
283 collections.abc.Sequence):
284 raise AssertionError(f
"Expected {name} to be a single value, got {output} instead.")