24 #include "pybind11/pybind11.h" 
   25 #include "pybind11/eigen.h" 
   26 #include "pybind11/stl.h" 
   30 #include "ndarray/pybind11.h" 
   31 #include "ndarray/eigen.h" 
   37 using namespace pybind11::literals;
 
   44 using PyMixtureComponent = py::class_<MixtureComponent>;
 
   45 using PyMixtureUpdateRestriction =
 
   46         py::class_<MixtureUpdateRestriction, std::shared_ptr<MixtureUpdateRestriction>>;
 
   47 using PyMixture = py::class_<Mixture, std::shared_ptr<Mixture>, afw::table::io::PersistableFacade<Mixture>,
 
   48                              afw::table::io::Persistable>;
 
   50 static PyMixtureComponent declareMixtureComponent(
py::module &mod) {
 
   51     PyMixtureComponent 
cls(mod, 
"MixtureComponent");
 
   52     cls.def(
"getDimension", &MixtureComponent::getDimension);
 
   54     cls.def(
"getMu", &MixtureComponent::getMu);
 
   55     cls.def(
"setMu", &MixtureComponent::setMu);
 
   56     cls.def(
"getSigma", &MixtureComponent::getSigma);
 
   57     cls.def(
"setSigma", &MixtureComponent::setSigma);
 
   58     cls.def(
"project", (MixtureComponent (MixtureComponent::*)(
int) 
const) & MixtureComponent::project,
 
   60     cls.def(
"project", (MixtureComponent (MixtureComponent::*)(
int, 
int) 
const) & MixtureComponent::project,
 
   62     cls.def(py::init<int>(), 
"dim"_a);
 
   63     cls.def(py::init<Scalar, Vector const &, Matrix const &>(), 
"weight"_a, 
"mu"_a, 
"sigma"_a);
 
   69 static PyMixtureUpdateRestriction declareMixtureUpdateRestriction(
py::module &mod) {
 
   70     PyMixtureUpdateRestriction 
cls(mod, 
"MixtureUpdateRestriction");
 
   71     cls.def(
"getDimension", &MixtureUpdateRestriction::getDimension);
 
   72     cls.def(py::init<int>(), 
"dim"_a);
 
   77 static PyMixture declareMixture(
py::module &mod) {
 
   78     afw::table::io::python::declarePersistableFacade<Mixture>(mod, 
"Mixture");
 
   79     PyMixture 
cls(mod, 
"Mixture");
 
   80     cls.def(
"__iter__", [](Mixture &
self) { 
return py::make_iterator(
self.
begin(), 
self.
end()); },
 
   81             py::keep_alive<0, 1>());
 
   82     cls.def(
"__getitem__",
 
   84             py::return_value_policy::reference_internal);
 
   85     cls.def(
"__len__", &Mixture::size);
 
   86     cls.def(
"getComponentCount", &Mixture::getComponentCount);
 
   90     cls.def(
"getDimension", &Mixture::getDimension);
 
   91     cls.def(
"normalize", &Mixture::normalize);
 
   92     cls.def(
"shift", &Mixture::shift, 
"dim"_a, 
"offset"_a);
 
   93     cls.def(
"clip", &Mixture::clip, 
"threshold"_a = 0.0);
 
   94     cls.def(
"getDegreesOfFreedom", &Mixture::getDegreesOfFreedom);
 
   95     cls.def(
"setDegreesOfFreedom", &Mixture::setDegreesOfFreedom,
 
   98             [](Mixture 
const &
self, MixtureComponent 
const &component,
 
   99                ndarray::Array<Scalar, 1, 0> 
const &array) -> 
Scalar {
 
  100                 return self.evaluate(component, ndarray::asEigenMatrix(array));
 
  102             "component"_a, 
"x"_a);
 
  104             [](Mixture 
const &
self, ndarray::Array<Scalar, 1, 0> 
const &array) -> 
Scalar {
 
  105                 return self.evaluate(ndarray::asEigenMatrix(array));
 
  108     cls.def(
"evaluate", (
void (Mixture::*)(ndarray::Array<Scalar const, 2, 1> 
const &,
 
  109                                            ndarray::Array<Scalar, 1, 0> 
const &) 
const) &
 
  112     cls.def(
"evaluateComponents", &Mixture::evaluateComponents, 
"x"_a, 
"p"_a);
 
  113     cls.def(
"evaluateDerivatives",
 
  114             py::overload_cast<ndarray::Array<Scalar const, 1, 1> 
const &,
 
  115                                ndarray::Array<Scalar,1,1> 
const &,
 
  116                                ndarray::Array<Scalar,2,1> 
const &>(&Mixture::evaluateDerivatives, py::const_),
 
  117             "x"_a, 
"gradient"_a, 
"hessian"_a);
 
  118     cls.def(
"draw", &Mixture::draw, 
"rng"_a, 
"x"_a);
 
  119     cls.def(
"updateEM", (
void (Mixture::*)(ndarray::Array<Scalar const, 2, 1> 
const &,
 
  120                                            ndarray::Array<Scalar const, 1, 0> 
const &, 
Scalar, 
Scalar)) &
 
  122             "x"_a, 
"w"_a, 
"tau1"_a = 0.0, 
"tau2"_a = 0.5);
 
  123     cls.def(
"updateEM", (
void (Mixture::*)(ndarray::Array<Scalar const, 2, 1> 
const &,
 
  124                                            ndarray::Array<Scalar const, 1, 0> 
const &,
 
  125                                            MixtureUpdateRestriction 
const &restriction, 
Scalar, 
Scalar)) &
 
  127             "x"_a, 
"w"_a, 
"restriction"_a, 
"tau1"_a = 0.0, 
"tau2"_a = 0.5);
 
  128     cls.def(
"updateEM", (
void (Mixture::*)(ndarray::Array<Scalar const, 2, 1> 
const &,
 
  129                                            MixtureUpdateRestriction 
const &restriction, 
Scalar, 
Scalar)) &
 
  131             "x"_a, 
"restriction"_a, 
"tau1"_a = 0.0, 
"tau2"_a = 0.5);
 
  133     cls.def(py::init<int, Mixture::ComponentList &, Scalar>(), 
"dim"_a, 
"components"_a,
 
  141     py::module::import(
"lsst.afw.math");
 
  143     auto clsMixtureComponent = declareMixtureComponent(mod);
 
  144     auto clsMixtureUpdateRestriction = declareMixtureUpdateRestriction(mod);
 
  145     auto clsMixture = declareMixture(mod);
 
  146     clsMixture.attr(
"Component") = clsMixtureComponent;
 
  147     clsMixture.attr(
"UpdateRestriction") = clsMixtureUpdateRestriction;