24 #include "pybind11/pybind11.h"
25 #include "pybind11/eigen.h"
26 #include "pybind11/stl.h"
30 #include "ndarray/pybind11.h"
31 #include "ndarray/eigen.h"
37 using namespace pybind11::literals;
44 using PyMixtureComponent = py::class_<MixtureComponent>;
45 using PyMixtureUpdateRestriction =
46 py::class_<MixtureUpdateRestriction, std::shared_ptr<MixtureUpdateRestriction>>;
47 using PyMixture = py::class_<Mixture, std::shared_ptr<Mixture>>;
49 static PyMixtureComponent declareMixtureComponent(
py::module &mod) {
50 PyMixtureComponent
cls(mod,
"MixtureComponent");
51 cls.def(
"getDimension", &MixtureComponent::getDimension);
53 cls.def(
"getMu", &MixtureComponent::getMu);
54 cls.def(
"setMu", &MixtureComponent::setMu);
55 cls.def(
"getSigma", &MixtureComponent::getSigma);
56 cls.def(
"setSigma", &MixtureComponent::setSigma);
57 cls.def(
"project", (MixtureComponent (MixtureComponent::*)(
int)
const) & MixtureComponent::project,
59 cls.def(
"project", (MixtureComponent (MixtureComponent::*)(
int,
int)
const) & MixtureComponent::project,
61 cls.def(py::init<int>(),
"dim"_a);
62 cls.def(py::init<Scalar, Vector const &, Matrix const &>(),
"weight"_a,
"mu"_a,
"sigma"_a);
68 static PyMixtureUpdateRestriction declareMixtureUpdateRestriction(
py::module &mod) {
69 PyMixtureUpdateRestriction
cls(mod,
"MixtureUpdateRestriction");
70 cls.def(
"getDimension", &MixtureUpdateRestriction::getDimension);
71 cls.def(py::init<int>(),
"dim"_a);
76 static PyMixture declareMixture(
py::module &mod) {
77 PyMixture
cls(mod,
"Mixture");
78 afw::table::io::python::addPersistableMethods<Mixture>(
cls);
79 cls.def(
"__iter__", [](Mixture &
self) {
return py::make_iterator(
self.
begin(),
self.
end()); },
80 py::keep_alive<0, 1>());
81 cls.def(
"__getitem__",
83 py::return_value_policy::reference_internal);
84 cls.def(
"__len__", &Mixture::size);
85 cls.def(
"getComponentCount", &Mixture::getComponentCount);
89 cls.def(
"getDimension", &Mixture::getDimension);
90 cls.def(
"normalize", &Mixture::normalize);
91 cls.def(
"shift", &Mixture::shift,
"dim"_a,
"offset"_a);
92 cls.def(
"clip", &Mixture::clip,
"threshold"_a = 0.0);
93 cls.def(
"getDegreesOfFreedom", &Mixture::getDegreesOfFreedom);
94 cls.def(
"setDegreesOfFreedom", &Mixture::setDegreesOfFreedom,
97 [](Mixture
const &
self, MixtureComponent
const &component,
98 ndarray::Array<Scalar, 1, 0>
const &array) ->
Scalar {
99 return self.evaluate(component, ndarray::asEigenMatrix(array));
101 "component"_a,
"x"_a);
103 [](Mixture
const &
self, ndarray::Array<Scalar, 1, 0>
const &array) ->
Scalar {
104 return self.evaluate(ndarray::asEigenMatrix(array));
107 cls.def(
"evaluate", (
void (Mixture::*)(ndarray::Array<Scalar const, 2, 1>
const &,
108 ndarray::Array<Scalar, 1, 0>
const &)
const) &
111 cls.def(
"evaluateComponents", &Mixture::evaluateComponents,
"x"_a,
"p"_a);
112 cls.def(
"evaluateDerivatives",
113 py::overload_cast<ndarray::Array<Scalar const, 1, 1>
const &,
114 ndarray::Array<Scalar,1,1>
const &,
115 ndarray::Array<Scalar,2,1>
const &>(&Mixture::evaluateDerivatives, py::const_),
116 "x"_a,
"gradient"_a,
"hessian"_a);
117 cls.def(
"draw", &Mixture::draw,
"rng"_a,
"x"_a);
118 cls.def(
"updateEM", (
void (Mixture::*)(ndarray::Array<Scalar const, 2, 1>
const &,
119 ndarray::Array<Scalar const, 1, 0>
const &,
Scalar,
Scalar)) &
121 "x"_a,
"w"_a,
"tau1"_a = 0.0,
"tau2"_a = 0.5);
122 cls.def(
"updateEM", (
void (Mixture::*)(ndarray::Array<Scalar const, 2, 1>
const &,
123 ndarray::Array<Scalar const, 1, 0>
const &,
124 MixtureUpdateRestriction
const &restriction,
Scalar,
Scalar)) &
126 "x"_a,
"w"_a,
"restriction"_a,
"tau1"_a = 0.0,
"tau2"_a = 0.5);
127 cls.def(
"updateEM", (
void (Mixture::*)(ndarray::Array<Scalar const, 2, 1>
const &,
128 MixtureUpdateRestriction
const &restriction,
Scalar,
Scalar)) &
130 "x"_a,
"restriction"_a,
"tau1"_a = 0.0,
"tau2"_a = 0.5);
132 cls.def(py::init<int, Mixture::ComponentList &, Scalar>(),
"dim"_a,
"components"_a,
140 py::module::import(
"lsst.afw.math");
142 auto clsMixtureComponent = declareMixtureComponent(mod);
143 auto clsMixtureUpdateRestriction = declareMixtureUpdateRestriction(mod);
144 auto clsMixture = declareMixture(mod);
145 clsMixture.attr(
"Component") = clsMixtureComponent;
146 clsMixture.attr(
"UpdateRestriction") = clsMixtureUpdateRestriction;