22 """Bunch of common classes and methods for use in unit tests.
25 __all__ = [
"AddTaskConfig",
"AddTask",
"AddTaskFactoryMock"]
31 from lsst.daf.butler
import Butler, Config, DatasetType
32 import lsst.daf.butler.tests
as butlerTests
35 from ...
import base
as pipeBase
36 from ..
import connectionTypes
as cT
38 _LOG = logging.getLogger(__name__)
54 class AddTaskConnections(pipeBase.PipelineTaskConnections,
55 dimensions=(
"instrument",
"detector"),
56 defaultTemplates={
"in_tmpl":
"_in",
"out_tmpl":
"_out"}):
57 """Connections for AddTask, has one input and two outputs,
60 input = cT.Input(name=
"add_dataset{in_tmpl}",
61 dimensions=[
"instrument",
"detector"],
62 storageClass=
"NumpyArray",
63 doc=
"Input dataset type for this task")
64 output = cT.Output(name=
"add_dataset{out_tmpl}",
65 dimensions=[
"instrument",
"detector"],
66 storageClass=
"NumpyArray",
67 doc=
"Output dataset type for this task")
68 output2 = cT.Output(name=
"add2_dataset{out_tmpl}",
69 dimensions=[
"instrument",
"detector"],
70 storageClass=
"NumpyArray",
71 doc=
"Output dataset type for this task")
72 initout = cT.InitOutput(name=
"add_init_output{out_tmpl}",
73 storageClass=
"NumpyArray",
74 doc=
"Init Output dataset type for this task")
77 class AddTaskConfig(pipeBase.PipelineTaskConfig,
78 pipelineConnections=AddTaskConnections):
79 """Config for AddTask.
81 addend = pexConfig.Field(doc=
"amount to add", dtype=int, default=3)
84 class AddTask(pipeBase.PipelineTask):
85 """Trivial PipelineTask for testing, has some extras useful for specific
89 ConfigClass = AddTaskConfig
90 _DefaultName =
"add_task"
92 initout = numpy.array([999])
93 """InitOutputs for this task"""
96 """Factory that makes instances"""
102 if self.taskFactory.stopAt == self.taskFactory.countExec:
103 raise RuntimeError(
"pretend something bad happened")
104 self.taskFactory.countExec += 1
106 self.metadata.add(
"add", self.config.addend)
107 output = input + self.config.addend
108 output2 = output + self.config.addend
109 _LOG.info(
"input = %s, output = %s, output2 = %s", input, output, output2)
110 return pipeBase.Struct(output=output, output2=output2)
113 class AddTaskFactoryMock(pipeBase.TaskFactory):
114 """Special task factory that instantiates AddTask.
116 It also defines some bookkeeping variables used by AddTask to report
117 progress to unit tests.
119 def __init__(self, stopAt=-1):
123 def loadTaskClass(self, taskName):
124 if taskName ==
"AddTask":
125 return AddTask,
"AddTask"
127 def makeTask(self, taskClass, config, overrides, butler):
129 config = taskClass.ConfigClass()
131 overrides.applyTo(config)
132 task = taskClass(config=config, initInputs=
None)
133 task.taskFactory = self
137 def registerDatasetTypes(registry, pipeline):
138 """Register all dataset types used by tasks in a registry.
140 Copied and modified from `PreExecInit.initializeDatasetTypes`.
144 registry : `~lsst.daf.butler.Registry`
146 pipeline : `typing.Iterable` of `TaskDef`
147 Iterable of TaskDef instances, likely the output of the method
148 toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object
150 for taskDef
in pipeline:
151 configDatasetType = DatasetType(taskDef.configDatasetName, {},
152 storageClass=
"Config",
153 universe=registry.dimensions)
154 packagesDatasetType = DatasetType(
"packages", {},
155 storageClass=
"Packages",
156 universe=registry.dimensions)
157 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
158 for datasetType
in itertools.chain(datasetTypes.initInputs, datasetTypes.initOutputs,
159 datasetTypes.inputs, datasetTypes.outputs,
160 datasetTypes.prerequisites,
161 [configDatasetType, packagesDatasetType]):
162 _LOG.info(
"Registering %s with registry", datasetType)
166 if not datasetType.isComponent():
167 registry.registerDatasetType(datasetType)
170 def makeSimplePipeline(nQuanta, instrument=None):
171 """Make a simple Pipeline for tests.
173 This is called by ``makeSimpleQGraph`` if no pipeline is passed to that
174 function. It can also be used to customize the pipeline used by
175 ``makeSimpleQGraph`` function by calling this first and passing the result
181 The number of quanta to add to the pipeline.
182 instrument : `str` or `None`, optional
183 The importable name of an instrument to be added to the pipeline or
184 if no instrument should be added then an empty string or `None`, by
189 pipeline : `~lsst.pipe.base.Pipeline`
190 The created pipeline object.
192 pipeline = pipeBase.Pipeline(
"test pipeline")
195 for lvl
in range(nQuanta):
196 pipeline.addTask(AddTask, f
"task{lvl}")
197 pipeline.addConfigOverride(f
"task{lvl}",
"connections.in_tmpl", lvl)
198 pipeline.addConfigOverride(f
"task{lvl}",
"connections.out_tmpl", lvl+1)
200 pipeline.addInstrument(instrument)
204 def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, root=None, skipExisting=False, inMemory=True,
206 """Make simple QuantumGraph for tests.
208 Makes simple one-task pipeline with AddTask, sets up in-memory
209 registry and butler, fills them with minimal data, and generates
210 QuantumGraph with all of that.
215 Number of quanta in a graph.
216 pipeline : `~lsst.pipe.base.Pipeline`
217 If `None` then one-task pipeline is made with `AddTask` and
218 default `AddTaskConfig`.
219 butler : `~lsst.daf.butler.Butler`, optional
220 Data butler instance, this should be an instance returned from a
221 previous call to this method.
223 Path or URI to the root location of the new repository. Only used if
225 skipExisting : `bool`, optional
226 If `True` (default), a Quantum is not created if all its outputs
228 inMemory : `bool`, optional
229 If true make in-memory repository.
230 userQuery : `str`, optional
231 The user query to pass to ``makeGraph``, by default an empty string.
235 butler : `~lsst.daf.butler.Butler`
237 qgraph : `~lsst.pipe.base.QuantumGraph`
238 Quantum graph instance
242 pipeline = makeSimplePipeline(nQuanta=nQuanta)
247 raise ValueError(
"Must provide `root` when `butler` is None")
251 config[
"registry",
"db"] = f
"sqlite:///{root}/gen3.sqlite"
252 config[
"datastore",
"cls"] =
"lsst.daf.butler.datastores.fileDatastore.FileDatastore"
253 repo = butlerTests.makeTestRepo(root, {}, config=config)
255 butler = Butler(butler=repo, run=collection)
258 registerDatasetTypes(butler.registry, pipeline.toExpandedPipeline())
260 instrument = pipeline.getInstrument()
261 if instrument
is not None:
262 if isinstance(instrument, str):
264 instrumentName = instrument.getName()
266 instrumentName =
"INSTR"
269 butler.registry.insertDimensionData(
"instrument", dict(name=instrumentName))
270 butler.registry.insertDimensionData(
"detector", dict(instrument=instrumentName, id=0,
274 data = numpy.array([0., 1., 2., 5.])
275 butler.put(data,
"add_dataset0", instrument=instrumentName, detector=0)
278 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExisting=skipExisting)
279 qgraph = builder.makeGraph(
281 collections=[butler.run],
286 return butler, qgraph
def applyConfigOverrides(self, name, config)
def run(self, skyInfo, tempExpRefList, imageScalerList, weightList, altMaskList=None, mask=None, supplementaryData=None)