24 #include "pybind11/pybind11.h"
25 #include "pybind11/eigen.h"
26 #include "pybind11/stl.h"
30 #include "ndarray/pybind11.h"
31 #include "ndarray/eigen.h"
37 using namespace pybind11::literals;
44 using PyMixtureComponent = py::class_<MixtureComponent>;
45 using PyMixtureUpdateRestriction =
46 py::class_<MixtureUpdateRestriction, std::shared_ptr<MixtureUpdateRestriction>>;
47 using PyMixture = py::class_<Mixture, std::shared_ptr<Mixture>, afw::table::io::PersistableFacade<Mixture>,
48 afw::table::io::Persistable>;
50 static PyMixtureComponent declareMixtureComponent(
py::module &mod) {
51 PyMixtureComponent
cls(mod,
"MixtureComponent");
52 cls.def(
"getDimension", &MixtureComponent::getDimension);
54 cls.def(
"getMu", &MixtureComponent::getMu);
55 cls.def(
"setMu", &MixtureComponent::setMu);
56 cls.def(
"getSigma", &MixtureComponent::getSigma);
57 cls.def(
"setSigma", &MixtureComponent::setSigma);
58 cls.def(
"project", (MixtureComponent (MixtureComponent::*)(
int)
const) & MixtureComponent::project,
60 cls.def(
"project", (MixtureComponent (MixtureComponent::*)(
int,
int)
const) & MixtureComponent::project,
62 cls.def(py::init<int>(),
"dim"_a);
63 cls.def(py::init<Scalar, Vector const &, Matrix const &>(),
"weight"_a,
"mu"_a,
"sigma"_a);
69 static PyMixtureUpdateRestriction declareMixtureUpdateRestriction(
py::module &mod) {
70 PyMixtureUpdateRestriction
cls(mod,
"MixtureUpdateRestriction");
71 cls.def(
"getDimension", &MixtureUpdateRestriction::getDimension);
72 cls.def(py::init<int>(),
"dim"_a);
77 static PyMixture declareMixture(
py::module &mod) {
78 afw::table::io::python::declarePersistableFacade<Mixture>(mod,
"Mixture");
79 PyMixture
cls(mod,
"Mixture");
80 cls.def(
"__iter__", [](Mixture &
self) {
return py::make_iterator(
self.
begin(),
self.
end()); },
81 py::keep_alive<0, 1>());
82 cls.def(
"__getitem__",
84 py::return_value_policy::reference_internal);
85 cls.def(
"__len__", &Mixture::size);
86 cls.def(
"getComponentCount", &Mixture::getComponentCount);
90 cls.def(
"getDimension", &Mixture::getDimension);
91 cls.def(
"normalize", &Mixture::normalize);
92 cls.def(
"shift", &Mixture::shift,
"dim"_a,
"offset"_a);
93 cls.def(
"clip", &Mixture::clip,
"threshold"_a = 0.0);
94 cls.def(
"getDegreesOfFreedom", &Mixture::getDegreesOfFreedom);
95 cls.def(
"setDegreesOfFreedom", &Mixture::setDegreesOfFreedom,
98 [](Mixture
const &
self, MixtureComponent
const &component,
99 ndarray::Array<Scalar, 1, 0>
const &array) ->
Scalar {
100 return self.evaluate(component, ndarray::asEigenMatrix(array));
102 "component"_a,
"x"_a);
104 [](Mixture
const &
self, ndarray::Array<Scalar, 1, 0>
const &array) ->
Scalar {
105 return self.evaluate(ndarray::asEigenMatrix(array));
108 cls.def(
"evaluate", (
void (Mixture::*)(ndarray::Array<Scalar const, 2, 1>
const &,
109 ndarray::Array<Scalar, 1, 0>
const &)
const) &
112 cls.def(
"evaluateComponents", &Mixture::evaluateComponents,
"x"_a,
"p"_a);
113 cls.def(
"evaluateDerivatives",
114 py::overload_cast<ndarray::Array<Scalar const, 1, 1>
const &,
115 ndarray::Array<Scalar,1,1>
const &,
116 ndarray::Array<Scalar,2,1>
const &>(&Mixture::evaluateDerivatives, py::const_),
117 "x"_a,
"gradient"_a,
"hessian"_a);
118 cls.def(
"draw", &Mixture::draw,
"rng"_a,
"x"_a);
119 cls.def(
"updateEM", (
void (Mixture::*)(ndarray::Array<Scalar const, 2, 1>
const &,
120 ndarray::Array<Scalar const, 1, 0>
const &,
Scalar,
Scalar)) &
122 "x"_a,
"w"_a,
"tau1"_a = 0.0,
"tau2"_a = 0.5);
123 cls.def(
"updateEM", (
void (Mixture::*)(ndarray::Array<Scalar const, 2, 1>
const &,
124 ndarray::Array<Scalar const, 1, 0>
const &,
125 MixtureUpdateRestriction
const &restriction,
Scalar,
Scalar)) &
127 "x"_a,
"w"_a,
"restriction"_a,
"tau1"_a = 0.0,
"tau2"_a = 0.5);
128 cls.def(
"updateEM", (
void (Mixture::*)(ndarray::Array<Scalar const, 2, 1>
const &,
129 MixtureUpdateRestriction
const &restriction,
Scalar,
Scalar)) &
131 "x"_a,
"restriction"_a,
"tau1"_a = 0.0,
"tau2"_a = 0.5);
133 cls.def(py::init<int, Mixture::ComponentList &, Scalar>(),
"dim"_a,
"components"_a,
141 py::module::import(
"lsst.afw.math");
143 auto clsMixtureComponent = declareMixtureComponent(mod);
144 auto clsMixtureUpdateRestriction = declareMixtureUpdateRestriction(mod);
145 auto clsMixture = declareMixture(mod);
146 clsMixture.attr(
"Component") = clsMixtureComponent;
147 clsMixture.attr(
"UpdateRestriction") = clsMixtureUpdateRestriction;