LSST Applications 26.0.0,g0265f82a02+6660c170cc,g07994bdeae+30b05a742e,g0a0026dc87+17526d298f,g0a60f58ba1+17526d298f,g0e4bf8285c+96dd2c2ea9,g0ecae5effc+c266a536c8,g1e7d6db67d+6f7cb1f4bb,g26482f50c6+6346c0633c,g2bbee38e9b+6660c170cc,g2cc88a2952+0a4e78cd49,g3273194fdb+f6908454ef,g337abbeb29+6660c170cc,g337c41fc51+9a8f8f0815,g37c6e7c3d5+7bbafe9d37,g44018dc512+6660c170cc,g4a941329ef+4f7594a38e,g4c90b7bd52+5145c320d2,g58be5f913a+bea990ba40,g635b316a6c+8d6b3a3e56,g67924a670a+bfead8c487,g6ae5381d9b+81bc2a20b4,g93c4d6e787+26b17396bd,g98cecbdb62+ed2cb6d659,g98ffbb4407+81bc2a20b4,g9ddcbc5298+7f7571301f,ga1e77700b3+99e9273977,gae46bcf261+6660c170cc,gb2715bf1a1+17526d298f,gc86a011abf+17526d298f,gcf0d15dbbd+96dd2c2ea9,gdaeeff99f8+0d8dbea60f,gdb4ec4c597+6660c170cc,ge23793e450+96dd2c2ea9,gf041782ebf+171108ac67
LSST Data Management Base Package
Loading...
Searching...
No Matches
apdbSql.py
Go to the documentation of this file.
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21
22"""Module defining Apdb class and related methods.
23"""
24
25from __future__ import annotations
26
27__all__ = ["ApdbSqlConfig", "ApdbSql"]
28
29import logging
30from collections.abc import Callable, Iterable, Mapping, MutableMapping
31from typing import Any, Dict, List, Optional, Tuple, cast
32
33import lsst.daf.base as dafBase
34import numpy as np
35import pandas
36import sqlalchemy
37from felis.simple import Table
38from lsst.pex.config import ChoiceField, Field, ListField
39from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d
40from lsst.utils.iteration import chunk_iterable
41from sqlalchemy import func, inspection, sql
42from sqlalchemy.engine import Inspector
43from sqlalchemy.pool import NullPool
44
45from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData
46from .apdbSchema import ApdbTables
47from .apdbSqlSchema import ApdbSqlSchema, ExtraTables
48from .timer import Timer
49
50_LOG = logging.getLogger(__name__)
51
52
53if pandas.__version__.partition(".")[0] == "1":
54
55 class _ConnectionHackSA2(sqlalchemy.engine.Connectable):
56 """Terrible hack to workaround Pandas 1 incomplete support for
57 sqlalchemy 2.
58
59 We need to pass a Connection instance to pandas method, but in SA 2 the
60 Connection class lost ``connect`` method which is used by Pandas.
61 """
62
63 def __init__(self, connection: sqlalchemy.engine.Connection):
64 self._connection = connection
65
66 def connect(self, **kwargs: Any) -> Any:
67 return self
68
69 @property
70 def execute(self) -> Callable:
71 return self._connection.execute
72
73 @property
74 def execution_options(self) -> Callable:
75 return self._connection.execution_options
76
77 @property
78 def connection(self) -> Any:
79 return self._connection.connection
80
81 def __enter__(self) -> sqlalchemy.engine.Connection:
82 return self._connection
83
84 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
85 # Do not close connection here
86 pass
87
88 @inspection._inspects(_ConnectionHackSA2)
89 def _connection_insp(conn: _ConnectionHackSA2) -> Inspector:
90 return Inspector._construct(Inspector._init_connection, conn._connection)
91
92else:
93 # Pandas 2.0 supports SQLAlchemy 2 correctly.
94 def _ConnectionHackSA2( # type: ignore[no-redef]
95 conn: sqlalchemy.engine.Connectable,
96 ) -> sqlalchemy.engine.Connectable:
97 return conn
98
99
100def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame:
101 """Change type of the uint64 columns to int64, return copy of data frame."""
102 names = [c[0] for c in df.dtypes.items() if c[1] == np.uint64]
103 return df.astype({name: np.int64 for name in names})
104
105
106def _make_midpointMjdTai_start(visit_time: dafBase.DateTime, months: int) -> float:
107 """Calculate starting point for time-based source search.
108
109 Parameters
110 ----------
111 visit_time : `lsst.daf.base.DateTime`
112 Time of current visit.
113 months : `int`
114 Number of months in the sources history.
115
116 Returns
117 -------
118 time : `float`
119 A ``midpointMjdTai`` starting point, MJD time.
120 """
121 # TODO: `system` must be consistent with the code in ap_association
122 # (see DM-31996)
123 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30
124
125
127 """APDB configuration class for SQL implementation (ApdbSql)."""
128
129 db_url = Field[str](doc="SQLAlchemy database connection URI")
130 isolation_level = ChoiceField[str](
131 doc=(
132 "Transaction isolation level, if unset then backend-default value "
133 "is used, except for SQLite backend where we use READ_UNCOMMITTED. "
134 "Some backends may not support every allowed value."
135 ),
136 allowed={
137 "READ_COMMITTED": "Read committed",
138 "READ_UNCOMMITTED": "Read uncommitted",
139 "REPEATABLE_READ": "Repeatable read",
140 "SERIALIZABLE": "Serializable",
141 },
142 default=None,
143 optional=True,
144 )
145 connection_pool = Field[bool](
146 doc="If False then disable SQLAlchemy connection pool. Do not use connection pool when forking.",
147 default=True,
148 )
149 connection_timeout = Field[float](
150 doc=(
151 "Maximum time to wait time for database lock to be released before exiting. "
152 "Defaults to sqlalchemy defaults if not set."
153 ),
154 default=None,
155 optional=True,
156 )
157 sql_echo = Field[bool](doc="If True then pass SQLAlchemy echo option.", default=False)
158 dia_object_index = ChoiceField[str](
159 doc="Indexing mode for DiaObject table",
160 allowed={
161 "baseline": "Index defined in baseline schema",
162 "pix_id_iov": "(pixelId, objectId, iovStart) PK",
163 "last_object_table": "Separate DiaObjectLast table",
164 },
165 default="baseline",
166 )
167 htm_level = Field[int](doc="HTM indexing level", default=20)
168 htm_max_ranges = Field[int](doc="Max number of ranges in HTM envelope", default=64)
169 htm_index_column = Field[str](
170 default="pixelId", doc="Name of a HTM index column for DiaObject and DiaSource tables"
171 )
172 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table")
173 dia_object_columns = ListField[str](
174 doc="List of columns to read from DiaObject, by default read all columns", default=[]
175 )
176 prefix = Field[str](doc="Prefix to add to table names and index names", default="")
177 namespace = Field[str](
178 doc=(
179 "Namespace or schema name for all tables in APDB database. "
180 "Presently only works for PostgreSQL backend. "
181 "If schema with this name does not exist it will be created when "
182 "APDB tables are created."
183 ),
184 default=None,
185 optional=True,
186 )
187 timer = Field[bool](doc="If True then print/log timing information", default=False)
188
189 def validate(self) -> None:
190 super().validate()
191 if len(self.ra_dec_columns) != 2:
192 raise ValueError("ra_dec_columns must have exactly two column names")
193
194
196 """Implementation of ApdbTableData that wraps sqlalchemy Result."""
197
198 def __init__(self, result: sqlalchemy.engine.Result):
199 self._keys = list(result.keys())
200 self._rows: list[tuple] = cast(list[tuple], list(result.fetchall()))
201
202 def column_names(self) -> list[str]:
203 return self._keys
204
205 def rows(self) -> Iterable[tuple]:
206 return self._rows
207
208
210 """Implementation of APDB interface based on SQL database.
211
212 The implementation is configured via standard ``pex_config`` mechanism
213 using `ApdbSqlConfig` configuration class. For an example of different
214 configurations check ``config/`` folder.
215
216 Parameters
217 ----------
218 config : `ApdbSqlConfig`
219 Configuration object.
220 """
221
222 ConfigClass = ApdbSqlConfig
223
224 def __init__(self, config: ApdbSqlConfig):
225 config.validate()
226 self.config = config
227
228 _LOG.debug("APDB Configuration:")
229 _LOG.debug(" dia_object_index: %s", self.config.dia_object_index)
230 _LOG.debug(" read_sources_months: %s", self.config.read_sources_months)
231 _LOG.debug(" read_forced_sources_months: %s", self.config.read_forced_sources_months)
232 _LOG.debug(" dia_object_columns: %s", self.config.dia_object_columns)
233 _LOG.debug(" schema_file: %s", self.config.schema_file)
234 _LOG.debug(" extra_schema_file: %s", self.config.extra_schema_file)
235 _LOG.debug(" schema prefix: %s", self.config.prefix)
236
237 # engine is reused between multiple processes, make sure that we don't
238 # share connections by disabling pool (by using NullPool class)
239 kw: MutableMapping[str, Any] = dict(echo=self.config.sql_echo)
240 conn_args: Dict[str, Any] = dict()
241 if not self.config.connection_pool:
242 kw.update(poolclass=NullPool)
243 if self.config.isolation_level is not None:
244 kw.update(isolation_level=self.config.isolation_level)
245 elif self.config.db_url.startswith("sqlite"): # type: ignore
246 # Use READ_UNCOMMITTED as default value for sqlite.
247 kw.update(isolation_level="READ_UNCOMMITTED")
248 if self.config.connection_timeout is not None:
249 if self.config.db_url.startswith("sqlite"):
250 conn_args.update(timeout=self.config.connection_timeout)
251 elif self.config.db_url.startswith(("postgresql", "mysql")):
252 conn_args.update(connect_timeout=self.config.connection_timeout)
253 kw.update(connect_args=conn_args)
254 self._engine = sqlalchemy.create_engine(self.config.db_url, **kw)
255
257 engine=self._engine,
258 dia_object_index=self.config.dia_object_index,
259 schema_file=self.config.schema_file,
260 schema_name=self.config.schema_name,
261 prefix=self.config.prefix,
262 namespace=self.config.namespace,
263 htm_index_column=self.config.htm_index_column,
264 use_insert_id=config.use_insert_id,
265 )
266
267 self.pixelator = HtmPixelization(self.config.htm_level)
268 self.use_insert_id = self._schema.has_insert_id
269
270 def tableRowCount(self) -> Dict[str, int]:
271 """Returns dictionary with the table names and row counts.
272
273 Used by ``ap_proto`` to keep track of the size of the database tables.
274 Depending on database technology this could be expensive operation.
275
276 Returns
277 -------
278 row_counts : `dict`
279 Dict where key is a table name and value is a row count.
280 """
281 res = {}
282 tables = [ApdbTables.DiaObject, ApdbTables.DiaSource, ApdbTables.DiaForcedSource]
283 if self.config.dia_object_index == "last_object_table":
284 tables.append(ApdbTables.DiaObjectLast)
285 with self._engine.begin() as conn:
286 for table in tables:
287 sa_table = self._schema.get_table(table)
288 stmt = sql.select(func.count()).select_from(sa_table)
289 count: int = conn.execute(stmt).scalar_one()
290 res[table.name] = count
291
292 return res
293
294 def tableDef(self, table: ApdbTables) -> Optional[Table]:
295 # docstring is inherited from a base class
296 return self._schema.tableSchemas.get(table)
297
298 def makeSchema(self, drop: bool = False) -> None:
299 # docstring is inherited from a base class
300 self._schema.makeSchema(drop=drop)
301
302 def getDiaObjects(self, region: Region) -> pandas.DataFrame:
303 # docstring is inherited from a base class
304
305 # decide what columns we need
306 if self.config.dia_object_index == "last_object_table":
307 table_enum = ApdbTables.DiaObjectLast
308 else:
309 table_enum = ApdbTables.DiaObject
310 table = self._schema.get_table(table_enum)
311 if not self.config.dia_object_columns:
312 columns = self._schema.get_apdb_columns(table_enum)
313 else:
314 columns = [table.c[col] for col in self.config.dia_object_columns]
315 query = sql.select(*columns)
316
317 # build selection
318 query = query.where(self._filterRegion(table, region))
319
320 # select latest version of objects
321 if self.config.dia_object_index != "last_object_table":
322 query = query.where(table.c.validityEnd == None) # noqa: E711
323
324 # _LOG.debug("query: %s", query)
325
326 # execute select
327 with Timer("DiaObject select", self.config.timer):
328 with self._engine.begin() as conn:
329 objects = pandas.read_sql_query(query, _ConnectionHackSA2(conn))
330 _LOG.debug("found %s DiaObjects", len(objects))
331 return objects
332
334 self, region: Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime
335 ) -> Optional[pandas.DataFrame]:
336 # docstring is inherited from a base class
337 if self.config.read_sources_months == 0:
338 _LOG.debug("Skip DiaSources fetching")
339 return None
340
341 if object_ids is None:
342 # region-based select
343 return self._getDiaSourcesInRegion(region, visit_time)
344 else:
345 return self._getDiaSourcesByIDs(list(object_ids), visit_time)
346
348 self, region: Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime
349 ) -> Optional[pandas.DataFrame]:
350 """Return catalog of DiaForcedSource instances from a given region.
351
352 Parameters
353 ----------
354 region : `lsst.sphgeom.Region`
355 Region to search for DIASources.
356 object_ids : iterable [ `int` ], optional
357 List of DiaObject IDs to further constrain the set of returned
358 sources. If list is empty then empty catalog is returned with a
359 correct schema.
360 visit_time : `lsst.daf.base.DateTime`
361 Time of the current visit.
362
363 Returns
364 -------
365 catalog : `pandas.DataFrame`, or `None`
366 Catalog containing DiaSource records. `None` is returned if
367 ``read_sources_months`` configuration parameter is set to 0.
368
369 Raises
370 ------
371 NotImplementedError
372 Raised if ``object_ids`` is `None`.
373
374 Notes
375 -----
376 Even though base class allows `None` to be passed for ``object_ids``,
377 this class requires ``object_ids`` to be not-`None`.
378 `NotImplementedError` is raised if `None` is passed.
379
380 This method returns DiaForcedSource catalog for a region with additional
381 filtering based on DiaObject IDs. Only a subset of DiaSource history
382 is returned limited by ``read_forced_sources_months`` config parameter,
383 w.r.t. ``visit_time``. If ``object_ids`` is empty then an empty catalog
384 is always returned with a correct schema (columns/types).
385 """
386
387 if self.config.read_forced_sources_months == 0:
388 _LOG.debug("Skip DiaForceSources fetching")
389 return None
390
391 if object_ids is None:
392 # This implementation does not support region-based selection.
393 raise NotImplementedError("Region-based selection is not supported")
394
395 # TODO: DateTime.MJD must be consistent with code in ap_association,
396 # alternatively we can fill midpointMjdTai ourselves in store()
397 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_forced_sources_months)
398 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
399
400 with Timer("DiaForcedSource select", self.config.timer):
401 sources = self._getSourcesByIDs(
402 ApdbTables.DiaForcedSource, list(object_ids), midpointMjdTai_start
403 )
404
405 _LOG.debug("found %s DiaForcedSources", len(sources))
406 return sources
407
408 def getInsertIds(self) -> list[ApdbInsertId] | None:
409 # docstring is inherited from a base class
410 if not self._schema.has_insert_id:
411 return None
412
413 table = self._schema.get_table(ExtraTables.DiaInsertId)
414 assert table is not None, "has_insert_id=True means it must be defined"
415 query = sql.select(table.columns["insert_id"]).order_by(table.columns["insert_time"])
416 with Timer("DiaObject insert id select", self.config.timer):
417 with self._engine.connect() as conn:
418 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query)
419 return [ApdbInsertId(row) for row in result.scalars()]
420
421 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None:
422 # docstring is inherited from a base class
423 if not self._schema.has_insert_id:
424 raise ValueError("APDB is not configured for history storage")
425
426 table = self._schema.get_table(ExtraTables.DiaInsertId)
427
428 insert_ids = [id.id for id in ids]
429 where_clause = table.columns["insert_id"].in_(insert_ids)
430 stmt = table.delete().where(where_clause)
431 with self._engine.begin() as conn:
432 conn.execute(stmt)
433
434 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
435 # docstring is inherited from a base class
436 return self._get_history(ids, ApdbTables.DiaObject, ExtraTables.DiaObjectInsertId)
437
438 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
439 # docstring is inherited from a base class
440 return self._get_history(ids, ApdbTables.DiaSource, ExtraTables.DiaSourceInsertId)
441
442 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
443 # docstring is inherited from a base class
444 return self._get_history(ids, ApdbTables.DiaForcedSource, ExtraTables.DiaForcedSourceInsertId)
445
447 self,
448 ids: Iterable[ApdbInsertId],
449 table_enum: ApdbTables,
450 history_table_enum: ExtraTables,
451 ) -> ApdbTableData:
452 """Common implementation of the history methods."""
453 if not self._schema.has_insert_id:
454 raise ValueError("APDB is not configured for history retrieval")
455
456 table = self._schema.get_table(table_enum)
457 history_table = self._schema.get_table(history_table_enum)
458
459 join = table.join(history_table)
460 insert_ids = [id.id for id in ids]
461 history_id_column = history_table.columns["insert_id"]
462 apdb_columns = self._schema.get_apdb_columns(table_enum)
463 where_clause = history_id_column.in_(insert_ids)
464 query = sql.select(history_id_column, *apdb_columns).select_from(join).where(where_clause)
465
466 # execute select
467 with Timer(f"{table.name} history select", self.config.timer):
468 with self._engine.begin() as conn:
469 result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(query)
470 return ApdbSqlTableData(result)
471
472 def getSSObjects(self) -> pandas.DataFrame:
473 # docstring is inherited from a base class
474
475 columns = self._schema.get_apdb_columns(ApdbTables.SSObject)
476 query = sql.select(*columns)
477
478 # execute select
479 with Timer("DiaObject select", self.config.timer):
480 with self._engine.begin() as conn:
481 objects = pandas.read_sql_query(query, conn)
482 _LOG.debug("found %s SSObjects", len(objects))
483 return objects
484
485 def store(
486 self,
487 visit_time: dafBase.DateTime,
488 objects: pandas.DataFrame,
489 sources: Optional[pandas.DataFrame] = None,
490 forced_sources: Optional[pandas.DataFrame] = None,
491 ) -> None:
492 # docstring is inherited from a base class
493
494 # We want to run all inserts in one transaction.
495 with self._engine.begin() as connection:
496 insert_id: ApdbInsertId | None = None
497 if self._schema.has_insert_id:
498 insert_id = ApdbInsertId.new_insert_id()
499 self._storeInsertId(insert_id, visit_time, connection)
500
501 # fill pixelId column for DiaObjects
502 objects = self._add_obj_htm_index(objects)
503 self._storeDiaObjects(objects, visit_time, insert_id, connection)
504
505 if sources is not None:
506 # copy pixelId column from DiaObjects to DiaSources
507 sources = self._add_src_htm_index(sources, objects)
508 self._storeDiaSources(sources, insert_id, connection)
509
510 if forced_sources is not None:
511 self._storeDiaForcedSources(forced_sources, insert_id, connection)
512
513 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
514 # docstring is inherited from a base class
515
516 idColumn = "ssObjectId"
517 table = self._schema.get_table(ApdbTables.SSObject)
518
519 # everything to be done in single transaction
520 with self._engine.begin() as conn:
521 # Find record IDs that already exist. Some types like np.int64 can
522 # cause issues with sqlalchemy, convert them to int.
523 ids = sorted(int(oid) for oid in objects[idColumn])
524
525 query = sql.select(table.columns[idColumn], table.columns[idColumn].in_(ids))
526 result = conn.execute(query)
527 knownIds = set(row.ssObjectId for row in result)
528
529 filter = objects[idColumn].isin(knownIds)
530 toUpdate = cast(pandas.DataFrame, objects[filter])
531 toInsert = cast(pandas.DataFrame, objects[~filter])
532
533 # insert new records
534 if len(toInsert) > 0:
535 toInsert.to_sql(
536 table.name, _ConnectionHackSA2(conn), if_exists="append", index=False, schema=table.schema
537 )
538
539 # update existing records
540 if len(toUpdate) > 0:
541 whereKey = f"{idColumn}_param"
542 update = table.update().where(table.columns[idColumn] == sql.bindparam(whereKey))
543 toUpdate = toUpdate.rename({idColumn: whereKey}, axis="columns")
544 values = toUpdate.to_dict("records")
545 result = conn.execute(update, values)
546
547 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
548 # docstring is inherited from a base class
549
550 table = self._schema.get_table(ApdbTables.DiaSource)
551 query = table.update().where(table.columns["diaSourceId"] == sql.bindparam("srcId"))
552
553 with self._engine.begin() as conn:
554 # Need to make sure that every ID exists in the database, but
555 # executemany may not support rowcount, so iterate and check what is
556 # missing.
557 missing_ids: List[int] = []
558 for key, value in idMap.items():
559 params = dict(srcId=key, diaObjectId=0, ssObjectId=value)
560 result = conn.execute(query, params)
561 if result.rowcount == 0:
562 missing_ids.append(key)
563 if missing_ids:
564 missing = ",".join(str(item) for item in missing_ids)
565 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
566
567 def dailyJob(self) -> None:
568 # docstring is inherited from a base class
569 pass
570
571 def countUnassociatedObjects(self) -> int:
572 # docstring is inherited from a base class
573
574 # Retrieve the DiaObject table.
575 table: sqlalchemy.schema.Table = self._schema.get_table(ApdbTables.DiaObject)
576
577 # Construct the sql statement.
578 stmt = sql.select(func.count()).select_from(table).where(table.c.nDiaSources == 1)
579 stmt = stmt.where(table.c.validityEnd == None) # noqa: E711
580
581 # Return the count.
582 with self._engine.begin() as conn:
583 count = conn.execute(stmt).scalar_one()
584
585 return count
586
587 def _getDiaSourcesInRegion(self, region: Region, visit_time: dafBase.DateTime) -> pandas.DataFrame:
588 """Returns catalog of DiaSource instances from given region.
589
590 Parameters
591 ----------
592 region : `lsst.sphgeom.Region`
593 Region to search for DIASources.
594 visit_time : `lsst.daf.base.DateTime`
595 Time of the current visit.
596
597 Returns
598 -------
599 catalog : `pandas.DataFrame`
600 Catalog containing DiaSource records.
601 """
602 # TODO: DateTime.MJD must be consistent with code in ap_association,
603 # alternatively we can fill midpointMjdTai ourselves in store()
604 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months)
605 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
606
607 table = self._schema.get_table(ApdbTables.DiaSource)
608 columns = self._schema.get_apdb_columns(ApdbTables.DiaSource)
609 query = sql.select(*columns)
610
611 # build selection
612 time_filter = table.columns["midpointMjdTai"] > midpointMjdTai_start
613 where = sql.expression.and_(self._filterRegion(table, region), time_filter)
614 query = query.where(where)
615
616 # execute select
617 with Timer("DiaSource select", self.config.timer):
618 with self._engine.begin() as conn:
619 sources = pandas.read_sql_query(query, conn)
620 _LOG.debug("found %s DiaSources", len(sources))
621 return sources
622
623 def _getDiaSourcesByIDs(self, object_ids: List[int], visit_time: dafBase.DateTime) -> pandas.DataFrame:
624 """Returns catalog of DiaSource instances given set of DiaObject IDs.
625
626 Parameters
627 ----------
628 object_ids :
629 Collection of DiaObject IDs
630 visit_time : `lsst.daf.base.DateTime`
631 Time of the current visit.
632
633 Returns
634 -------
635 catalog : `pandas.DataFrame`
636 Catalog contaning DiaSource records.
637 """
638 # TODO: DateTime.MJD must be consistent with code in ap_association,
639 # alternatively we can fill midpointMjdTai ourselves in store()
640 midpointMjdTai_start = _make_midpointMjdTai_start(visit_time, self.config.read_sources_months)
641 _LOG.debug("midpointMjdTai_start = %.6f", midpointMjdTai_start)
642
643 with Timer("DiaSource select", self.config.timer):
644 sources = self._getSourcesByIDs(ApdbTables.DiaSource, object_ids, midpointMjdTai_start)
645
646 _LOG.debug("found %s DiaSources", len(sources))
647 return sources
648
650 self, table_enum: ApdbTables, object_ids: List[int], midpointMjdTai_start: float
651 ) -> pandas.DataFrame:
652 """Returns catalog of DiaSource or DiaForcedSource instances given set
653 of DiaObject IDs.
654
655 Parameters
656 ----------
657 table : `sqlalchemy.schema.Table`
658 Database table.
659 object_ids :
660 Collection of DiaObject IDs
661 midpointMjdTai_start : `float`
662 Earliest midpointMjdTai to retrieve.
663
664 Returns
665 -------
666 catalog : `pandas.DataFrame`
667 Catalog contaning DiaSource records. `None` is returned if
668 ``read_sources_months`` configuration parameter is set to 0 or
669 when ``object_ids`` is empty.
670 """
671 table = self._schema.get_table(table_enum)
672 columns = self._schema.get_apdb_columns(table_enum)
673
674 sources: Optional[pandas.DataFrame] = None
675 if len(object_ids) <= 0:
676 _LOG.debug("ID list is empty, just fetch empty result")
677 query = sql.select(*columns).where(sql.literal(False))
678 with self._engine.begin() as conn:
679 sources = pandas.read_sql_query(query, conn)
680 else:
681 data_frames: list[pandas.DataFrame] = []
682 for ids in chunk_iterable(sorted(object_ids), 1000):
683 query = sql.select(*columns)
684
685 # Some types like np.int64 can cause issues with
686 # sqlalchemy, convert them to int.
687 int_ids = [int(oid) for oid in ids]
688
689 # select by object id
690 query = query.where(
691 sql.expression.and_(
692 table.columns["diaObjectId"].in_(int_ids),
693 table.columns["midpointMjdTai"] > midpointMjdTai_start,
694 )
695 )
696
697 # execute select
698 with self._engine.begin() as conn:
699 data_frames.append(pandas.read_sql_query(query, conn))
700
701 if len(data_frames) == 1:
702 sources = data_frames[0]
703 else:
704 sources = pandas.concat(data_frames)
705 assert sources is not None, "Catalog cannot be None"
706 return sources
707
709 self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime, connection: sqlalchemy.engine.Connection
710 ) -> None:
711 dt = visit_time.toPython()
712
713 table = self._schema.get_table(ExtraTables.DiaInsertId)
714
715 stmt = table.insert().values(insert_id=insert_id.id, insert_time=dt)
716 connection.execute(stmt)
717
719 self,
720 objs: pandas.DataFrame,
721 visit_time: dafBase.DateTime,
722 insert_id: ApdbInsertId | None,
723 connection: sqlalchemy.engine.Connection,
724 ) -> None:
725 """Store catalog of DiaObjects from current visit.
726
727 Parameters
728 ----------
729 objs : `pandas.DataFrame`
730 Catalog with DiaObject records.
731 visit_time : `lsst.daf.base.DateTime`
732 Time of the visit.
733 insert_id : `ApdbInsertId`
734 Insert identifier.
735 """
736
737 # Some types like np.int64 can cause issues with sqlalchemy, convert
738 # them to int.
739 ids = sorted(int(oid) for oid in objs["diaObjectId"])
740 _LOG.debug("first object ID: %d", ids[0])
741
742 # TODO: Need to verify that we are using correct scale here for
743 # DATETIME representation (see DM-31996).
744 dt = visit_time.toPython()
745
746 # everything to be done in single transaction
747 if self.config.dia_object_index == "last_object_table":
748 # insert and replace all records in LAST table, mysql and postgres have
749 # non-standard features
750 table = self._schema.get_table(ApdbTables.DiaObjectLast)
751
752 # Drop the previous objects (pandas cannot upsert).
753 query = table.delete().where(table.columns["diaObjectId"].in_(ids))
754
755 with Timer(table.name + " delete", self.config.timer):
756 res = connection.execute(query)
757 _LOG.debug("deleted %s objects", res.rowcount)
758
759 # DiaObjectLast is a subset of DiaObject, strip missing columns
760 last_column_names = [column.name for column in table.columns]
761 last_objs = objs[last_column_names]
762 last_objs = _coerce_uint64(last_objs)
763
764 if "lastNonForcedSource" in last_objs.columns:
765 # lastNonForcedSource is defined NOT NULL, fill it with visit time
766 # just in case.
767 last_objs["lastNonForcedSource"].fillna(dt, inplace=True)
768 else:
769 extra_column = pandas.Series([dt] * len(objs), name="lastNonForcedSource")
770 last_objs.set_index(extra_column.index, inplace=True)
771 last_objs = pandas.concat([last_objs, extra_column], axis="columns")
772
773 with Timer("DiaObjectLast insert", self.config.timer):
774 last_objs.to_sql(
775 table.name,
776 _ConnectionHackSA2(connection),
777 if_exists="append",
778 index=False,
779 schema=table.schema,
780 )
781 else:
782 # truncate existing validity intervals
783 table = self._schema.get_table(ApdbTables.DiaObject)
784
785 update = (
786 table.update()
787 .values(validityEnd=dt)
788 .where(
789 sql.expression.and_(
790 table.columns["diaObjectId"].in_(ids),
791 table.columns["validityEnd"].is_(None),
792 )
793 )
794 )
795
796 # _LOG.debug("query: %s", query)
797
798 with Timer(table.name + " truncate", self.config.timer):
799 res = connection.execute(update)
800 _LOG.debug("truncated %s intervals", res.rowcount)
801
802 objs = _coerce_uint64(objs)
803
804 # Fill additional columns
805 extra_columns: List[pandas.Series] = []
806 if "validityStart" in objs.columns:
807 objs["validityStart"] = dt
808 else:
809 extra_columns.append(pandas.Series([dt] * len(objs), name="validityStart"))
810 if "validityEnd" in objs.columns:
811 objs["validityEnd"] = None
812 else:
813 extra_columns.append(pandas.Series([None] * len(objs), name="validityEnd"))
814 if "lastNonForcedSource" in objs.columns:
815 # lastNonForcedSource is defined NOT NULL, fill it with visit time
816 # just in case.
817 objs["lastNonForcedSource"].fillna(dt, inplace=True)
818 else:
819 extra_columns.append(pandas.Series([dt] * len(objs), name="lastNonForcedSource"))
820 if extra_columns:
821 objs.set_index(extra_columns[0].index, inplace=True)
822 objs = pandas.concat([objs] + extra_columns, axis="columns")
823
824 # Insert history data
825 table = self._schema.get_table(ApdbTables.DiaObject)
826 history_data: list[dict] = []
827 history_stmt: Any = None
828 if insert_id is not None:
829 pk_names = [column.name for column in table.primary_key]
830 history_data = objs[pk_names].to_dict("records")
831 for row in history_data:
832 row["insert_id"] = insert_id.id
833 history_table = self._schema.get_table(ExtraTables.DiaObjectInsertId)
834 history_stmt = history_table.insert()
835
836 # insert new versions
837 with Timer("DiaObject insert", self.config.timer):
838 objs.to_sql(
839 table.name,
840 _ConnectionHackSA2(connection),
841 if_exists="append",
842 index=False,
843 schema=table.schema,
844 )
845 if history_stmt is not None:
846 connection.execute(history_stmt, history_data)
847
849 self,
850 sources: pandas.DataFrame,
851 insert_id: ApdbInsertId | None,
852 connection: sqlalchemy.engine.Connection,
853 ) -> None:
854 """Store catalog of DiaSources from current visit.
855
856 Parameters
857 ----------
858 sources : `pandas.DataFrame`
859 Catalog containing DiaSource records
860 """
861 table = self._schema.get_table(ApdbTables.DiaSource)
862
863 # Insert history data
864 history: list[dict] = []
865 history_stmt: Any = None
866 if insert_id is not None:
867 pk_names = [column.name for column in table.primary_key]
868 history = sources[pk_names].to_dict("records")
869 for row in history:
870 row["insert_id"] = insert_id.id
871 history_table = self._schema.get_table(ExtraTables.DiaSourceInsertId)
872 history_stmt = history_table.insert()
873
874 # everything to be done in single transaction
875 with Timer("DiaSource insert", self.config.timer):
876 sources = _coerce_uint64(sources)
877 sources.to_sql(
878 table.name,
879 _ConnectionHackSA2(connection),
880 if_exists="append",
881 index=False,
882 schema=table.schema,
883 )
884 if history_stmt is not None:
885 connection.execute(history_stmt, history)
886
888 self,
889 sources: pandas.DataFrame,
890 insert_id: ApdbInsertId | None,
891 connection: sqlalchemy.engine.Connection,
892 ) -> None:
893 """Store a set of DiaForcedSources from current visit.
894
895 Parameters
896 ----------
897 sources : `pandas.DataFrame`
898 Catalog containing DiaForcedSource records
899 """
900 table = self._schema.get_table(ApdbTables.DiaForcedSource)
901
902 # Insert history data
903 history: list[dict] = []
904 history_stmt: Any = None
905 if insert_id is not None:
906 pk_names = [column.name for column in table.primary_key]
907 history = sources[pk_names].to_dict("records")
908 for row in history:
909 row["insert_id"] = insert_id.id
910 history_table = self._schema.get_table(ExtraTables.DiaForcedSourceInsertId)
911 history_stmt = history_table.insert()
912
913 # everything to be done in single transaction
914 with Timer("DiaForcedSource insert", self.config.timer):
915 sources = _coerce_uint64(sources)
916 sources.to_sql(
917 table.name,
918 _ConnectionHackSA2(connection),
919 if_exists="append",
920 index=False,
921 schema=table.schema,
922 )
923 if history_stmt is not None:
924 connection.execute(history_stmt, history)
925
926 def _htm_indices(self, region: Region) -> List[Tuple[int, int]]:
927 """Generate a set of HTM indices covering specified region.
928
929 Parameters
930 ----------
931 region: `sphgeom.Region`
932 Region that needs to be indexed.
933
934 Returns
935 -------
936 Sequence of ranges, range is a tuple (minHtmID, maxHtmID).
937 """
938 _LOG.debug("region: %s", region)
939 indices = self.pixelator.envelope(region, self.config.htm_max_ranges)
940
941 return indices.ranges()
942
943 def _filterRegion(self, table: sqlalchemy.schema.Table, region: Region) -> sql.ColumnElement:
944 """Make SQLAlchemy expression for selecting records in a region."""
945 htm_index_column = table.columns[self.config.htm_index_column]
946 exprlist = []
947 pixel_ranges = self._htm_indices(region)
948 for low, upper in pixel_ranges:
949 upper -= 1
950 if low == upper:
951 exprlist.append(htm_index_column == low)
952 else:
953 exprlist.append(sql.expression.between(htm_index_column, low, upper))
954
955 return sql.expression.or_(*exprlist)
956
957 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame:
958 """Calculate HTM index for each record and add it to a DataFrame.
959
960 Notes
961 -----
962 This overrides any existing column in a DataFrame with the same name
963 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
964 returned.
965 """
966 # calculate HTM index for every DiaObject
967 htm_index = np.zeros(df.shape[0], dtype=np.int64)
968 ra_col, dec_col = self.config.ra_dec_columns
969 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
970 uv3d = UnitVector3d(LonLat.fromDegrees(ra, dec))
971 idx = self.pixelator.index(uv3d)
972 htm_index[i] = idx
973 df = df.copy()
974 df[self.config.htm_index_column] = htm_index
975 return df
976
977 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
978 """Add pixelId column to DiaSource catalog.
979
980 Notes
981 -----
982 This method copies pixelId value from a matching DiaObject record.
983 DiaObject catalog needs to have a pixelId column filled by
984 ``_add_obj_htm_index`` method and DiaSource records need to be
985 associated to DiaObjects via ``diaObjectId`` column.
986
987 This overrides any existing column in a DataFrame with the same name
988 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
989 returned.
990 """
991 pixel_id_map: Dict[int, int] = {
992 diaObjectId: pixelId
993 for diaObjectId, pixelId in zip(objs["diaObjectId"], objs[self.config.htm_index_column])
994 }
995 # DiaSources associated with SolarSystemObjects do not have an
996 # associated DiaObject hence we skip them and set their htmIndex
997 # value to 0.
998 pixel_id_map[0] = 0
999 htm_index = np.zeros(sources.shape[0], dtype=np.int64)
1000 for i, diaObjId in enumerate(sources["diaObjectId"]):
1001 htm_index[i] = pixel_id_map[diaObjId]
1002 sources = sources.copy()
1003 sources[self.config.htm_index_column] = htm_index
1004 return sources
Class for handling dates/times, including MJD, UTC, and TAI.
Definition DateTime.h:64
None __exit__(self, Any type_, Any value, Any traceback)
Definition apdbSql.py:84
Any connect(self, **Any kwargs)
Definition apdbSql.py:66
sqlalchemy.engine.Connection __enter__(self)
Definition apdbSql.py:81
__init__(self, sqlalchemy.engine.Connection connection)
Definition apdbSql.py:63
Optional[pandas.DataFrame] getDiaForcedSources(self, Region region, Optional[Iterable[int]] object_ids, dafBase.DateTime visit_time)
Definition apdbSql.py:349
None _storeInsertId(self, ApdbInsertId insert_id, dafBase.DateTime visit_time, sqlalchemy.engine.Connection connection)
Definition apdbSql.py:710
sql.ColumnElement _filterRegion(self, sqlalchemy.schema.Table table, Region region)
Definition apdbSql.py:943
ApdbTableData getDiaSourcesHistory(self, Iterable[ApdbInsertId] ids)
Definition apdbSql.py:438
pandas.DataFrame _getSourcesByIDs(self, ApdbTables table_enum, List[int] object_ids, float midpointMjdTai_start)
Definition apdbSql.py:651
pandas.DataFrame _getDiaSourcesByIDs(self, List[int] object_ids, dafBase.DateTime visit_time)
Definition apdbSql.py:623
ApdbTableData _get_history(self, Iterable[ApdbInsertId] ids, ApdbTables table_enum, ExtraTables history_table_enum)
Definition apdbSql.py:451
None deleteInsertIds(self, Iterable[ApdbInsertId] ids)
Definition apdbSql.py:421
None _storeDiaForcedSources(self, pandas.DataFrame sources, ApdbInsertId|None insert_id, sqlalchemy.engine.Connection connection)
Definition apdbSql.py:892
Optional[Table] tableDef(self, ApdbTables table)
Definition apdbSql.py:294
pandas.DataFrame _add_src_htm_index(self, pandas.DataFrame sources, pandas.DataFrame objs)
Definition apdbSql.py:977
None store(self, dafBase.DateTime visit_time, pandas.DataFrame objects, Optional[pandas.DataFrame] sources=None, Optional[pandas.DataFrame] forced_sources=None)
Definition apdbSql.py:491
pandas.DataFrame getDiaObjects(self, Region region)
Definition apdbSql.py:302
pandas.DataFrame getSSObjects(self)
Definition apdbSql.py:472
List[Tuple[int, int]] _htm_indices(self, Region region)
Definition apdbSql.py:926
None makeSchema(self, bool drop=False)
Definition apdbSql.py:298
ApdbTableData getDiaObjectsHistory(self, Iterable[ApdbInsertId] ids)
Definition apdbSql.py:434
pandas.DataFrame _getDiaSourcesInRegion(self, Region region, dafBase.DateTime visit_time)
Definition apdbSql.py:587
Optional[pandas.DataFrame] getDiaSources(self, Region region, Optional[Iterable[int]] object_ids, dafBase.DateTime visit_time)
Definition apdbSql.py:335
None storeSSObjects(self, pandas.DataFrame objects)
Definition apdbSql.py:513
list[ApdbInsertId]|None getInsertIds(self)
Definition apdbSql.py:408
Dict[str, int] tableRowCount(self)
Definition apdbSql.py:270
ApdbTableData getDiaForcedSourcesHistory(self, Iterable[ApdbInsertId] ids)
Definition apdbSql.py:442
None _storeDiaSources(self, pandas.DataFrame sources, ApdbInsertId|None insert_id, sqlalchemy.engine.Connection connection)
Definition apdbSql.py:853
None _storeDiaObjects(self, pandas.DataFrame objs, dafBase.DateTime visit_time, ApdbInsertId|None insert_id, sqlalchemy.engine.Connection connection)
Definition apdbSql.py:724
pandas.DataFrame _add_obj_htm_index(self, pandas.DataFrame df)
Definition apdbSql.py:957
__init__(self, ApdbSqlConfig config)
Definition apdbSql.py:224
None reassignDiaSources(self, Mapping[int, int] idMap)
Definition apdbSql.py:547
__init__(self, sqlalchemy.engine.Result result)
Definition apdbSql.py:198
HtmPixelization provides HTM indexing of points and regions.
Region is a minimal interface for 2-dimensional regions on the unit sphere.
Definition Region.h:79
UnitVector3d is a unit vector in ℝ³ with components stored in double precision.
daf::base::PropertyList * list
Definition fits.cc:928
daf::base::PropertySet * set
Definition fits.cc:927
Inspector _connection_insp(_ConnectionHackSA2 conn)
Definition apdbSql.py:89
float _make_midpointMjdTai_start(dafBase.DateTime visit_time, int months)
Definition apdbSql.py:106
pandas.DataFrame _coerce_uint64(pandas.DataFrame df)
Definition apdbSql.py:100