22 """Code to load multi-Gaussian approximations to profiles from "The Tractor"
23 into a lsst.shapelet.MultiShapeletBasis.
25 Please see the README file in the data directory of the lsst.shapelet
26 package for more information.
36 from .radialProfile
import RadialProfile
37 from .multiShapeletBasis
import MultiShapeletBasis
38 from .shapeletFunction
import ShapeletFunction
42 """Register the pickled profiles in the data directory with the RadialProfile
45 This should only be called at import time by this module; it's only a function to
46 avoid polluting the module namespace with all the local variables used here.
48 dataDir = os.path.join(os.environ[
"SHAPELET_DIR"],
"data")
49 regex = re.compile(
r"([a-z]+\d?)_K(\d+)_MR(\d+)\.pickle")
50 for filename
in os.listdir(dataDir):
51 match = regex.match(filename)
55 nComponents = int(match.group(2))
56 maxRadius = int(match.group(3))
58 profile = RadialProfile.get(name)
60 warnings.warn(
"No C++ profile for multi-Gaussian pickle file '%s'" % filename)
62 with open(os.path.join(dataDir, filename),
'rb')
as stream:
63 if sys.version_info[0] >= 3:
64 array = pickle.load(stream, encoding=
'latin1')
66 array = pickle.load(stream)
67 amplitudes = array[:nComponents]
68 amplitudes /= amplitudes.sum()
69 variances = array[nComponents:]
70 if amplitudes.shape != (nComponents,)
or variances.shape != (nComponents,):
71 warnings.warn(
"Unknown format for multi-Gaussian pickle file '%s'" % filename)
73 basis = MultiShapeletBasis(1)
74 for amplitude, variance
in zip(amplitudes, variances):
75 radius = variance**0.5
76 matrix = numpy.array([[amplitude / ShapeletFunction.FLUX_FACTOR]], dtype=float)
77 basis.addComponent(radius, 0, matrix)
78 profile.registerBasis(basis, nComponents, maxRadius)
87 """Plot a single-element MultiShapeletBasis as a radial profile.
94 `True` to evaluate components.
97 coefficients = numpy.ones(1, dtype=float)
98 msf = basis.makeFunction(ellipse, coefficients)
102 n += len(msf.getComponents())
103 z = numpy.zeros((n,) + r.shape, dtype=float)
104 for j, x
in enumerate(r):
107 for i, sf
in enumerate(msf.getComponents()):
109 for j, x
in enumerate(r):
110 z[i+1, j] = evc(x, 0.0)
117 """Integrate the profiles to compare relative fluxes between the true profiles
118 and their approximations.
120 After normalizing by surface brightness at r=1 r_e, integrate the profiles to compare
121 relative fluxes between the true profiles and their approximations.
125 maxRadius : `float`, optional
126 Maximum radius to integrate the profile, in units of r_e.
127 nSteps : `int`, optional
128 Number of concrete points at which to evaluate the profile to
129 do the integration (we just use the trapezoidal rule).
133 fluxes : `dict` of `float` values
134 Dictionary of fluxes (``exp``, ``lux``, ``dev``, ``luv``, ``ser2``, ``ser3``,
135 ``ser5``, ``gexp``, ``glux``, ``gdev``, ``gluv``, ``gser2``, ``gser3``, ``gser5``)
137 radii = numpy.linspace(0.0, maxRadius, nSteps)
138 profiles = {name: RadialProfile.get(name)
for name
in (
"exp",
"lux",
"dev",
"luv",
139 "ser2",
"ser3",
"ser5")}
141 for name, profile
in profiles.items():
142 evaluated[name] = profile.evaluate(radii)
143 basis = profile.getBasis(8)
144 evaluated[
"g" + name] =
evaluateRadial(basis, radii, sbNormalize=
True, doComponents=
False)[0, :]
145 fluxes = {name: numpy.trapz(z*radii, radii)
for name, z
in evaluated.items()}
150 """Plot all the profiles defined in this module together.
152 Plot all the profiles defined in this module together: true exp and dev,
153 the SDSS softened/truncated lux and luv, and the multi-Gaussian approximations
158 doComponents : `bool`, optional
159 True, to plot the individual Gaussians that form the multi-Gaussian approximations.
163 figure : `matplotlib.figure.Figure`
164 Figure that contains the plot.
165 axes : `numpy.ndarray` of `matplotlib.axes.Axes`
166 A 2x4 NumPy array of matplotlib axes objects.
168 from matplotlib
import pyplot
169 fig = pyplot.figure(figsize=(9, 4.7))
170 axes = numpy.zeros((2, 4), dtype=object)
171 r1 = numpy.logspace(-3, 0, 1000, base=10)
172 r2 = numpy.linspace(1, 10, 1000)
176 axes[i, j] = fig.add_subplot(2, 4, i*4+j+1)
177 profiles = {name: RadialProfile.get(name)
for name
in (
"exp",
"lux",
"dev",
"luv")}
178 basis = {name: profiles[name].getBasis(8)
for name
in profiles}
179 z = numpy.zeros((2, 4), dtype=object)
180 colors = (
"k",
"g",
"b",
"r")
181 fig.subplots_adjust(wspace=0.025, hspace=0.025, bottom=0.15, left=0.1, right=0.98, top=0.92)
182 centers = [
None,
None]
184 for j
in range(0, 4, 2):
185 bbox0 = axes[i, j].get_position()
186 bbox1 = axes[i, j+1].get_position()
187 bbox1.x0 = bbox0.x1 - 0.06
189 centers[j//2] = 0.5*(bbox0.x0 + bbox1.x1)
190 axes[i, j].set_position(bbox0)
191 axes[i, j+1].set_position(bbox1)
192 for j
in range(0, 2):
193 z[0, j] = [
evaluateRadial(basis[k], r[j], sbNormalize=
True, doComponents=doComponents)
194 for k
in (
"exp",
"lux")]
195 z[0, j][0:0] = [profiles[k].evaluate(r[j])[numpy.newaxis, :]
for k
in (
"exp",
"lux")]
196 z[0, j+2] = [
evaluateRadial(basis[k], r[j], sbNormalize=
True, doComponents=doComponents)
197 for k
in (
"dev",
"luv")]
198 z[0, j+2][0:0] = [profiles[k].evaluate(r[j])[numpy.newaxis, :]
for k
in (
"dev",
"luv")]
199 methodNames = [[
"loglog",
"semilogy"], [
"semilogx",
"plot"]]
200 for j
in range(0, 4):
201 z[1, j] = [(z[0, j][0][0, :] - z[0, j][i][0, :])/z[0, j][0][0, :]
for i
in range(0, 4)]
203 method0 = getattr(axes[0, j], methodNames[0][j%2])
204 method1 = getattr(axes[1, j], methodNames[1][j%2])
207 handles.append(method0(r[j%2], y0[0, :], color=colors[k])[0])
209 for l
in range(1, y0.shape[0]):
210 method0(r[j%2], y0[l, :], color=colors[k], alpha=0.25)
211 method1(r[j%2], z[1, j][k], color=colors[k])
212 axes[0, j].set_xticklabels([])
213 axes[0, j].set_ylim(1E-6, 1E3)
214 axes[1, j].set_ylim(-0.2, 1.0)
215 for i, label
in enumerate((
"profile",
"relative error")):
216 axes[i, 0].set_ylabel(label)
217 for t
in axes[i, 0].get_yticklabels():
219 for j
in range(1, 4):
220 axes[0, j].set_yticklabels([])
221 axes[1, j].set_yticklabels([])
222 xticks = [[
'$\\mathdefault{10^{%d}}$' % i
for i
in range(-3, 1)],
223 [str(i)
for i
in range(1, 11)]]
226 for j
in range(0, 4):
227 axes[1, j].set_xticklabels(xticks[j%2])
228 for t
in axes[1, j].get_xticklabels():
230 fig.legend(handles, [
"exp/dev",
"lux/luv",
"approx exp/dev",
"approx lux/luv"],
231 loc=
'lower center', ncol=4)
232 fig.text(centers[0], 0.95,
"exponential", ha=
'center', weight=
'bold')
233 fig.text(centers[1], 0.95,
"de Vaucouleur", ha=
'center', weight=
'bold')