24 #include "pybind11/pybind11.h"
25 #include "pybind11/eigen.h"
26 #include "pybind11/stl.h"
28 #include "ndarray/pybind11.h"
33 using namespace pybind11::literals;
40 using Sampler = TruncatedGaussianSampler;
41 using Evaluator = TruncatedGaussianEvaluator;
42 using LogEvaluator = TruncatedGaussianLogEvaluator;
44 using PyTruncatedGaussian = py::class_<TruncatedGaussian, std::shared_ptr<TruncatedGaussian>>;
45 using PySampler = py::class_<Sampler, std::shared_ptr<Sampler>>;
46 using PyEvaluator = py::class_<Evaluator, std::shared_ptr<Evaluator>>;
47 using PyLogEvaluator = py::class_<LogEvaluator, std::shared_ptr<LogEvaluator>>;
52 template <
typename Class,
typename PyClass>
55 cls.def(py::init<TruncatedGaussian const &>(),
"parent"_a);
57 (
Scalar (Class::*)(ndarray::Array<Scalar const, 1, 1>
const &)
const) & Class::operator(),
59 cls.def(
"__call__", (
void (Class::*)(ndarray::Array<Scalar const, 2, 1>
const &,
60 ndarray::Array<Scalar, 1, 1>
const &)
const) &
62 "alpha"_a,
"output"_a);
69 py::module::import(
"lsst.afw.math");
71 PyTruncatedGaussian
cls(mod,
"TruncatedGaussian");
72 py::enum_<TruncatedGaussian::SampleStrategy>(
cls,
"SampleStrategy")
73 .value(
"DIRECT_WITH_REJECTION", TruncatedGaussian::DIRECT_WITH_REJECTION)
74 .value(
"ALIGN_AND_WEIGHT", TruncatedGaussian::ALIGN_AND_WEIGHT)
76 cls.def_static(
"fromSeriesParameters", &TruncatedGaussian::fromSeriesParameters,
"q0"_a,
"gradient"_a,
78 cls.def_static(
"fromStandardParameters", &TruncatedGaussian::fromStandardParameters,
"mean"_a,
80 cls.def(
"sample", (Sampler (TruncatedGaussian::*)(TruncatedGaussian::SampleStrategy)
const) &
81 TruncatedGaussian::sample,
83 cls.def(
"sample", (Sampler (TruncatedGaussian::*)(
Scalar)
const) & TruncatedGaussian::sample,
84 "minRejectionEfficiency"_a = 0.1);
85 cls.def(
"evaluateLog", &TruncatedGaussian::evaluateLog);
86 cls.def(
"evaluate", &TruncatedGaussian::evaluate);
87 cls.def(
"getDim", &TruncatedGaussian::getDim);
88 cls.def(
"maximize", &TruncatedGaussian::maximize);
89 cls.def(
"getUntruncatedFraction", &TruncatedGaussian::getUntruncatedFraction);
90 cls.def(
"getLogPeakAmplitude", &TruncatedGaussian::getLogPeakAmplitude);
91 cls.def(
"getLogIntegral", &TruncatedGaussian::getLogIntegral);
93 cls.attr(
"LogEvaluator") = declareEvaluator<LogEvaluator, PyLogEvaluator>(mod,
"LogEvaluator");
94 cls.attr(
"Evaluator") = declareEvaluator<Evaluator, PyEvaluator>(mod,
"Evaluator");
96 PySampler clsSampler(mod,
"TruncatedGaussianSampler");
97 clsSampler.def(py::init<TruncatedGaussian const &, TruncatedGaussian::SampleStrategy>(),
"parent"_a,
99 clsSampler.def(
"__call__",
100 (
Scalar (Sampler::*)(afw::math::Random &, ndarray::Array<Scalar, 1, 1>
const &)
const) &
103 clsSampler.def(
"__call__", (
void (Sampler::*)(afw::math::Random &, ndarray::Array<Scalar, 2, 1>
const &,
104 ndarray::Array<Scalar, 1, 1>
const &,
bool)
const) &
106 "rng"_a,
"alpha"_a,
"weights"_a,
"multiplyWeights"_a =
false);
108 cls.attr(
"Sampler") = clsSampler;