22 """Bunch of common classes and methods for use in unit tests.
25 __all__ = [
"AddTaskConfig",
"AddTask",
"AddTaskFactoryMock"]
31 from lsst.daf.butler
import Butler, Config, DatasetType
32 import lsst.daf.butler.tests
as butlerTests
34 from ...
import base
as pipeBase
35 from ..
import connectionTypes
as cT
37 _LOG = logging.getLogger(__name__)
47 return "SimpleInstrument"
53 class AddTaskConnections(pipeBase.PipelineTaskConnections,
54 dimensions=(
"instrument",
"detector"),
55 defaultTemplates={
"in_tmpl":
"_in",
"out_tmpl":
"_out"}):
56 """Connections for AddTask, has one input and two outputs,
59 input = cT.Input(name=
"add_dataset{in_tmpl}",
60 dimensions=[
"instrument",
"detector"],
61 storageClass=
"NumpyArray",
62 doc=
"Input dataset type for this task")
63 output = cT.Output(name=
"add_dataset{out_tmpl}",
64 dimensions=[
"instrument",
"detector"],
65 storageClass=
"NumpyArray",
66 doc=
"Output dataset type for this task")
67 output2 = cT.Output(name=
"add2_dataset{out_tmpl}",
68 dimensions=[
"instrument",
"detector"],
69 storageClass=
"NumpyArray",
70 doc=
"Output dataset type for this task")
71 initout = cT.InitOutput(name=
"add_init_output{out_tmpl}",
72 storageClass=
"NumpyArray",
73 doc=
"Init Output dataset type for this task")
76 class AddTaskConfig(pipeBase.PipelineTaskConfig,
77 pipelineConnections=AddTaskConnections):
78 """Config for AddTask.
80 addend = pexConfig.Field(doc=
"amount to add", dtype=int, default=3)
83 class AddTask(pipeBase.PipelineTask):
84 """Trivial PipelineTask for testing, has some extras useful for specific
88 ConfigClass = AddTaskConfig
89 _DefaultName =
"add_task"
91 initout = numpy.array([999])
92 """InitOutputs for this task"""
95 """Factory that makes instances"""
101 if self.taskFactory.stopAt == self.taskFactory.countExec:
102 raise RuntimeError(
"pretend something bad happened")
103 self.taskFactory.countExec += 1
105 self.metadata.add(
"add", self.config.addend)
106 output = input + self.config.addend
107 output2 = output + self.config.addend
108 _LOG.info(
"input = %s, output = %s, output2 = %s", input, output, output2)
109 return pipeBase.Struct(output=output, output2=output2)
112 class AddTaskFactoryMock(pipeBase.TaskFactory):
113 """Special task factory that instantiates AddTask.
115 It also defines some bookkeeping variables used by AddTask to report
116 progress to unit tests.
118 def __init__(self, stopAt=-1):
122 def loadTaskClass(self, taskName):
123 if taskName ==
"AddTask":
124 return AddTask,
"AddTask"
126 def makeTask(self, taskClass, config, overrides, butler):
128 config = taskClass.ConfigClass()
130 overrides.applyTo(config)
131 task = taskClass(config=config, initInputs=
None)
132 task.taskFactory = self
136 def registerDatasetTypes(registry, pipeline):
137 """Register all dataset types used by tasks in a registry.
139 Copied and modified from `PreExecInit.initializeDatasetTypes`.
143 registry : `~lsst.daf.butler.Registry`
145 pipeline : `typing.Iterable` of `TaskDef`
146 Iterable of TaskDef instances, likely the output of the method
147 toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object
149 for taskDef
in pipeline:
150 configDatasetType = DatasetType(taskDef.configDatasetName, {},
151 storageClass=
"Config",
152 universe=registry.dimensions)
153 packagesDatasetType = DatasetType(
"packages", {},
154 storageClass=
"Packages",
155 universe=registry.dimensions)
156 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
157 for datasetType
in itertools.chain(datasetTypes.initInputs, datasetTypes.initOutputs,
158 datasetTypes.inputs, datasetTypes.outputs,
159 datasetTypes.prerequisites,
160 [configDatasetType, packagesDatasetType]):
161 _LOG.info(
"Registering %s with registry", datasetType)
165 if not datasetType.isComponent():
166 registry.registerDatasetType(datasetType)
169 def makeSimplePipeline(nQuanta, instrument=None):
170 """Make a simple Pipeline for tests.
172 This is called by ``makeSimpleQGraph`` if no pipeline is passed to that
173 function. It can also be used to customize the pipeline used by
174 ``makeSimpleQGraph`` function by calling this first and passing the result
180 The number of quanta to add to the pipeline.
181 instrument : `str` or `None`, optional
182 The importable name of an instrument to be added to the pipeline or
183 if no instrument should be added then an empty string or `None`, by
188 pipeline : `~lsst.pipe.base.Pipeline`
189 The created pipeline object.
191 pipeline = pipeBase.Pipeline(
"test pipeline")
194 for lvl
in range(nQuanta):
195 pipeline.addTask(AddTask, f
"task{lvl}")
196 pipeline.addConfigOverride(f
"task{lvl}",
"connections.in_tmpl", f
"{lvl}")
197 pipeline.addConfigOverride(f
"task{lvl}",
"connections.out_tmpl", f
"{lvl+1}")
199 pipeline.addInstrument(instrument)
203 def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, root=None, skipExisting=False, inMemory=True,
205 """Make simple QuantumGraph for tests.
207 Makes simple one-task pipeline with AddTask, sets up in-memory
208 registry and butler, fills them with minimal data, and generates
209 QuantumGraph with all of that.
214 Number of quanta in a graph.
215 pipeline : `~lsst.pipe.base.Pipeline`
216 If `None` then one-task pipeline is made with `AddTask` and
217 default `AddTaskConfig`.
218 butler : `~lsst.daf.butler.Butler`, optional
219 Data butler instance, this should be an instance returned from a
220 previous call to this method.
222 Path or URI to the root location of the new repository. Only used if
224 skipExisting : `bool`, optional
225 If `True` (default), a Quantum is not created if all its outputs
227 inMemory : `bool`, optional
228 If true make in-memory repository.
229 userQuery : `str`, optional
230 The user query to pass to ``makeGraph``, by default an empty string.
234 butler : `~lsst.daf.butler.Butler`
236 qgraph : `~lsst.pipe.base.QuantumGraph`
237 Quantum graph instance
241 pipeline = makeSimplePipeline(nQuanta=nQuanta)
246 raise ValueError(
"Must provide `root` when `butler` is None")
250 config[
"registry",
"db"] = f
"sqlite:///{root}/gen3.sqlite"
251 config[
"datastore",
"cls"] =
"lsst.daf.butler.datastores.posixDatastore.PosixDatastore"
252 repo = butlerTests.makeTestRepo(root, {}, config=config)
254 butler = Butler(butler=repo, run=collection)
257 registerDatasetTypes(butler.registry, pipeline.toExpandedPipeline())
260 butler.registry.insertDimensionData(
"instrument", dict(name=
"INSTR"))
261 butler.registry.insertDimensionData(
"detector", dict(instrument=
"INSTR", id=0, full_name=
"det0"))
264 data = numpy.array([0., 1., 2., 5.])
265 butler.put(data,
"add_dataset0", instrument=
"INSTR", detector=0)
268 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExisting=skipExisting)
269 qgraph = builder.makeGraph(
271 collections=[butler.run],
276 return butler, qgraph