24#include "pybind11/pybind11.h"
26#include "pybind11/eigen.h"
27#include "pybind11/stl.h"
31#include "ndarray/pybind11.h"
32#include "ndarray/eigen.h"
38using namespace pybind11::literals;
45using PyMixtureComponent = py::class_<MixtureComponent>;
46using PyMixtureUpdateRestriction =
47 py::class_<MixtureUpdateRestriction, std::shared_ptr<MixtureUpdateRestriction>>;
48using PyMixture = py::class_<Mixture, std::shared_ptr<Mixture>>;
50PyMixtureComponent declareMixtureComponent(lsst::cpputils::python::WrapperCollection &wrappers) {
52 PyMixtureComponent(wrappers.
module,
"MixtureComponent"), [](
auto &mod,
auto &cls) {
53 cls.def(
"getDimension", &MixtureComponent::getDimension);
54 cls.def_readwrite(
"weight", &MixtureComponent::weight);
55 cls.def(
"getMu", &MixtureComponent::getMu);
56 cls.def(
"setMu", &MixtureComponent::setMu);
57 cls.def(
"getSigma", &MixtureComponent::getSigma);
58 cls.def(
"setSigma", &MixtureComponent::setSigma);
59 cls.def(
"project", (MixtureComponent (MixtureComponent::*)(int) const) &MixtureComponent::project,
61 cls.def(
"project", (MixtureComponent (MixtureComponent::*)(int, int) const) &MixtureComponent::project,
63 cls.def(py::init<int>(),
"dim"_a);
64 cls.def(py::init<Scalar, Vector const &, Matrix const &>(),
"weight"_a,
"mu"_a,
"sigma"_a);
65 cpputils::python::addOutputOp(cls,
"__str__");
66 cpputils::python::addOutputOp(cls,
"__repr__");
70PyMixtureUpdateRestriction declareMixtureUpdateRestriction(lsst::cpputils::python::WrapperCollection &wrappers) {
71 return wrappers.
wrapType(PyMixtureUpdateRestriction(
72 wrappers.
module,
"MixtureUpdateRestriction"), [](
auto &mod,
auto &cls) {
73 cls.def(
"getDimension", &MixtureUpdateRestriction::getDimension);
74 cls.def(py::init<int>(),
"dim"_a);
79PyMixture declareMixture(lsst::cpputils::python::WrapperCollection &wrappers) {
80 return wrappers.
wrapType(PyMixture(wrappers.
module,
"Mixture"), [](
auto &mod,
auto &cls) {
81 afw::table::io::python::addPersistableMethods<Mixture>(cls);
82 cls.def(
"__iter__", [](Mixture &self) { return py::make_iterator(self.begin(), self.end()); },
83 py::keep_alive<0, 1>());
84 cls.def(
"__getitem__",
86 py::return_value_policy::reference_internal);
101 ndarray::Array<Scalar, 1, 0>
const &array) ->
Scalar {
102 return self.evaluate(component, ndarray::asEigenMatrix(array));
104 "component"_a,
"x"_a);
106 [](
Mixture const &self, ndarray::Array<Scalar, 1, 0>
const &array) ->
Scalar {
107 return self.evaluate(ndarray::asEigenMatrix(array));
110 cls.def(
"evaluate", (
void (
Mixture::*)(ndarray::Array<Scalar const, 2, 1>
const &,
111 ndarray::Array<Scalar, 1, 0>
const &)
const) &
115 cls.def(
"evaluateDerivatives",
116 py::overload_cast<ndarray::Array<Scalar const, 1, 1>
const &,
117 ndarray::Array<Scalar, 1, 1>
const &,
119 "x"_a,
"gradient"_a,
"hessian"_a);
121 cls.def(
"updateEM", (
void (
Mixture::*)(ndarray::Array<Scalar const, 2, 1>
const &,
122 ndarray::Array<Scalar const, 1, 0>
const &,
Scalar,
Scalar)) &
124 "x"_a,
"w"_a,
"tau1"_a = 0.0,
"tau2"_a = 0.5);
125 cls.def(
"updateEM", (
void (
Mixture::*)(ndarray::Array<Scalar const, 2, 1>
const &,
126 ndarray::Array<Scalar const, 1, 0>
const &,
129 "x"_a,
"w"_a,
"restriction"_a,
"tau1"_a = 0.0,
"tau2"_a = 0.5);
130 cls.def(
"updateEM", (
void (
Mixture::*)(ndarray::Array<Scalar const, 2, 1>
const &,
133 "x"_a,
"restriction"_a,
"tau1"_a = 0.0,
"tau2"_a = 0.5);
135 cls.def(py::init<int, Mixture::ComponentList &, Scalar>(),
"dim"_a,
"components"_a,
144 auto clsMixtureComponent = declareMixtureComponent(wrappers);
145 auto clsMixtureUpdateRestriction = declareMixtureUpdateRestriction(wrappers);
146 auto clsMixture = declareMixture(wrappers);
147 clsMixture.attr(
"Component") = clsMixtureComponent;
148 clsMixture.attr(
"UpdateRestriction") = clsMixtureUpdateRestriction;
A helper class for subdividing pybind11 module across multiple translation units (i....
PyType wrapType(PyType cls, ClassWrapperCallback function, bool setModuleName=true)
Add a type (class or enum) wrapper, deferring method and other attribute definitions until finish() i...
pybind11::module module
The module object passed to the PYBIND11_MODULE block that contains this WrapperCollection.
A weighted Student's T or Gaussian distribution used as a component in a Mixture.
void updateEM(ndarray::Array< Scalar const, 2, 1 > const &x, ndarray::Array< Scalar const, 1, 0 > const &w, Scalar tau1=0.0, Scalar tau2=0.5)
Perform an Expectation-Maximization step, updating the component parameters to match the given weight...
Scalar getDegreesOfFreedom() const
Get the number of degrees of freedom in the component Student's T distributions (inf=Gaussian)
std::size_t size() const
Return the number of components.
std::shared_ptr< Mixture > project(int dim) const
Project the distribution onto the given dimensions (marginalize over all others)
Scalar evaluate(Component const &component, Eigen::MatrixBase< Derived > const &x) const
Evaluate the probability density at the given point for the given component distribution.
void normalize()
Iterate over all components, rescaling their weights so they sum to one.
void evaluateComponents(ndarray::Array< Scalar const, 2, 1 > const &x, ndarray::Array< Scalar, 2, 1 > const &p) const
Evaluate the contributions of each component to the full probability at the given points.
void evaluateDerivatives(ndarray::Array< Scalar const, 1, 1 > const &x, ndarray::Array< Scalar, 1, 1 > const &gradient, ndarray::Array< Scalar, 2, 1 > const &hessian) const
Evaluate the derivative of the distribution at the given point.
void shift(int dim, Scalar offset)
Shift the mixture in the given dimension, adding the given offset to all mu vectors.
virtual std::shared_ptr< Mixture > clone() const
Polymorphic deep copy.
virtual int getComponentCount() const
Return the number of components.
void draw(afw::math::Random &rng, ndarray::Array< Scalar, 2, 1 > const &x) const
Draw random variates from the distribution.
int getDimension() const
Return the number of dimensions.
void setDegreesOfFreedom(Scalar df=std::numeric_limits< Scalar >::infinity())
Set the number of degrees of freedom in the component Student's T distributions (inf=Gaussian)
std::size_t clip(Scalar threshold=0.0)
Iterate over all components, removing those with weight less than or equal to threshold.
Helper class used to define restrictions to the form of the component parameters in Mixture::updateEM...
void addOutputOp(PyClass &cls, std::string const &method)
Add __str__ or __repr__ method implemented by operator<<.
std::size_t cppIndex(std::ptrdiff_t size, std::ptrdiff_t i)
Compute a C++ index from a Python index (negative values count from the end) and range-check.
void wrapMixture(lsst::cpputils::python::WrapperCollection &wrappers)
double Scalar
Typedefs to be used for probability and parameter values.