22 """Module defining Apdb class and related methods.
25 from __future__
import annotations
27 __all__ = [
"ApdbSqlConfig",
"ApdbSql"]
29 from contextlib
import contextmanager
33 from typing
import Any, Dict, Iterable, Iterator, List, Optional, Tuple
37 from lsst.sphgeom import HtmPixelization, LonLat, Region, UnitVector3d
39 from sqlalchemy
import (func, sql)
40 from sqlalchemy.pool
import NullPool
41 from .apdb
import Apdb, ApdbConfig
42 from .apdbSchema
import ApdbTables, TableDef
43 from .apdbSqlSchema
import ApdbSqlSchema
44 from .timer
import Timer
47 _LOG = logging.getLogger(__name__)
50 def _split(seq: Iterable, nItems: int) -> Iterator[List]:
51 """Split a sequence into smaller sequences"""
58 def _coerce_uint64(df: pandas.DataFrame) -> pandas.DataFrame:
59 """Change type of the uint64 columns to int64, return copy of data frame.
61 names = [c[0]
for c
in df.dtypes.items()
if c[1] == np.uint64]
62 return df.astype({name: np.int64
for name
in names})
65 def _make_midPointTai_start(visit_time: dafBase.DateTime, months: int) -> float:
66 """Calculate starting point for time-based source search.
70 visit_time : `lsst.daf.base.DateTime`
71 Time of current visit.
73 Number of months in the sources history.
78 A ``midPointTai`` starting point, MJD time.
82 return visit_time.get(system=dafBase.DateTime.MJD) - months * 30
86 def _ansi_session(engine: sqlalchemy.engine.Engine) -> Iterator[sqlalchemy.engine.Connection]:
87 """Returns a connection, makes sure that ANSI mode is set for MySQL
89 with engine.begin()
as conn:
90 if engine.name ==
'mysql':
91 conn.execute(sql.text(
"SET SESSION SQL_MODE = 'ANSI'"))
97 """APDB configuration class for SQL implementation (ApdbSql).
101 doc=
"SQLAlchemy database connection URI"
105 doc=
"Transaction isolation level, if unset then backend-default value "
106 "is used, except for SQLite backend where we use READ_UNCOMMITTED. "
107 "Some backends may not support every allowed value.",
109 "READ_COMMITTED":
"Read committed",
110 "READ_UNCOMMITTED":
"Read uncommitted",
111 "REPEATABLE_READ":
"Repeatable read",
112 "SERIALIZABLE":
"Serializable"
119 doc=
"If False then disable SQLAlchemy connection pool. "
120 "Do not use connection pool when forking.",
125 doc=
"Maximum time to wait time for database lock to be released before "
126 "exiting. Defaults to sqlachemy defaults if not set.",
132 doc=
"If True then pass SQLAlchemy echo option.",
137 doc=
"Indexing mode for DiaObject table",
139 'baseline':
"Index defined in baseline schema",
140 'pix_id_iov':
"(pixelId, objectId, iovStart) PK",
141 'last_object_table':
"Separate DiaObjectLast table"
147 doc=
"HTM indexing level",
152 doc=
"Max number of ranges in HTM envelope",
158 doc=
"Name of a HTM index column for DiaObject and DiaSource tables"
162 default=[
"ra",
"decl"],
163 doc=
"Names ra/dec columns in DiaObject table"
167 doc=
"List of columns to read from DiaObject, by default read all columns",
172 doc=
"If True (default) then use \"upsert\" for DiaObjectsLast table",
177 doc=
"Prefix to add to table names and index names",
182 doc=
"If True then run EXPLAIN SQL command on each executed query",
187 doc=
"If True then print/log timing information",
194 raise ValueError(
"ra_dec_columns must have exactly two column names")
198 """Implementation of APDB interface based on SQL database.
200 The implementation is configured via standard ``pex_config`` mechanism
201 using `ApdbSqlConfig` configuration class. For an example of different
202 configurations check ``config/`` folder.
206 config : `ApdbSqlConfig`
207 Configuration object.
210 ConfigClass = ApdbSqlConfig
216 _LOG.debug(
"APDB Configuration:")
217 _LOG.debug(
" dia_object_index: %s", self.
configconfig.dia_object_index)
218 _LOG.debug(
" read_sources_months: %s", self.
configconfig.read_sources_months)
219 _LOG.debug(
" read_forced_sources_months: %s", self.
configconfig.read_forced_sources_months)
220 _LOG.debug(
" dia_object_columns: %s", self.
configconfig.dia_object_columns)
221 _LOG.debug(
" object_last_replace: %s", self.
configconfig.object_last_replace)
222 _LOG.debug(
" schema_file: %s", self.
configconfig.schema_file)
223 _LOG.debug(
" extra_schema_file: %s", self.
configconfig.extra_schema_file)
224 _LOG.debug(
" schema prefix: %s", self.
configconfig.prefix)
228 kw = dict(echo=self.
configconfig.sql_echo)
229 conn_args: Dict[str, Any] = dict()
230 if not self.
configconfig.connection_pool:
231 kw.update(poolclass=NullPool)
232 if self.
configconfig.isolation_level
is not None:
233 kw.update(isolation_level=self.
configconfig.isolation_level)
234 elif self.
configconfig.db_url.startswith(
"sqlite"):
236 kw.update(isolation_level=
"READ_UNCOMMITTED")
237 if self.
configconfig.connection_timeout
is not None:
238 if self.
configconfig.db_url.startswith(
"sqlite"):
239 conn_args.update(timeout=self.
configconfig.connection_timeout)
240 elif self.
configconfig.db_url.startswith((
"postgresql",
"mysql")):
241 conn_args.update(connect_timeout=self.
configconfig.connection_timeout)
242 kw.update(connect_args=conn_args)
243 self.
_engine_engine = sqlalchemy.create_engine(self.
configconfig.db_url, **kw)
246 dia_object_index=self.
configconfig.dia_object_index,
247 schema_file=self.
configconfig.schema_file,
248 extra_schema_file=self.
configconfig.extra_schema_file,
249 prefix=self.
configconfig.prefix,
250 htm_index_column=self.
configconfig.htm_index_column)
255 """Returns dictionary with the table names and row counts.
257 Used by ``ap_proto`` to keep track of the size of the database tables.
258 Depending on database technology this could be expensive operation.
263 Dict where key is a table name and value is a row count.
266 tables: List[sqlalchemy.schema.Table] = [
268 if self.
configconfig.dia_object_index ==
'last_object_table':
269 tables.append(self.
_schema_schema.objects_last)
271 stmt = sql.select([func.count()]).select_from(table)
272 count = self.
_engine_engine.scalar(stmt)
273 res[table.name] = count
277 def tableDef(self, table: ApdbTables) -> Optional[TableDef]:
279 return self.
_schema_schema.tableSchemas.get(table)
289 table: sqlalchemy.schema.Table
290 if self.
configconfig.dia_object_index ==
'last_object_table':
291 table = self.
_schema_schema.objects_last
293 table = self.
_schema_schema.objects
294 if not self.
configconfig.dia_object_columns:
295 query = table.select()
297 columns = [table.c[col]
for col
in self.
configconfig.dia_object_columns]
298 query = sql.select(columns)
301 htm_index_column = table.columns[self.
configconfig.htm_index_column]
304 for low, upper
in pixel_ranges:
307 exprlist.append(htm_index_column == low)
309 exprlist.append(sql.expression.between(htm_index_column, low, upper))
310 query = query.where(sql.expression.or_(*exprlist))
313 if self.
configconfig.dia_object_index !=
'last_object_table':
314 query = query.where(table.c.validityEnd ==
None)
316 _LOG.debug(
"query: %s", query)
318 if self.
configconfig.explain:
323 with Timer(
'DiaObject select', self.
configconfig.timer):
324 with self.
_engine_engine.begin()
as conn:
325 objects = pandas.read_sql_query(query, conn)
326 _LOG.debug(
"found %s DiaObjects", len(objects))
330 object_ids: Optional[Iterable[int]],
333 if self.
configconfig.read_sources_months == 0:
334 _LOG.debug(
"Skip DiaSources fetching")
337 if object_ids
is None:
344 object_ids: Optional[Iterable[int]],
346 """Return catalog of DiaForcedSource instances from a given region.
350 region : `lsst.sphgeom.Region`
351 Region to search for DIASources.
352 object_ids : iterable [ `int` ], optional
353 List of DiaObject IDs to further constrain the set of returned
354 sources. If list is empty then empty catalog is returned with a
356 visit_time : `lsst.daf.base.DateTime`
357 Time of the current visit.
361 catalog : `pandas.DataFrame`, or `None`
362 Catalog containing DiaSource records. `None` is returned if
363 ``read_sources_months`` configuration parameter is set to 0.
368 Raised if ``object_ids`` is `None`.
372 Even though base class allows `None` to be passed for ``object_ids``,
373 this class requires ``object_ids`` to be not-`None`.
374 `NotImplementedError` is raised if `None` is passed.
376 This method returns DiaForcedSource catalog for a region with additional
377 filtering based on DiaObject IDs. Only a subset of DiaSource history
378 is returned limited by ``read_forced_sources_months`` config parameter,
379 w.r.t. ``visit_time``. If ``object_ids`` is empty then an empty catalog
380 is always returned with a correct schema (columns/types).
383 if self.
configconfig.read_forced_sources_months == 0:
384 _LOG.debug(
"Skip DiaForceSources fetching")
387 if object_ids
is None:
389 raise NotImplementedError(
"Region-based selection is not supported")
393 midPointTai_start = _make_midPointTai_start(visit_time, self.
configconfig.read_forced_sources_months)
394 _LOG.debug(
"midPointTai_start = %.6f", midPointTai_start)
396 table: sqlalchemy.schema.Table = self.
_schema_schema.forcedSources
397 with Timer(
'DiaForcedSource select', self.
configconfig.timer):
400 _LOG.debug(
"found %s DiaForcedSources", len(sources))
404 visit_time: dafBase.DateTime,
405 objects: pandas.DataFrame,
406 sources: Optional[pandas.DataFrame] =
None,
407 forced_sources: Optional[pandas.DataFrame] =
None) ->
None:
414 if sources
is not None:
419 if forced_sources
is not None:
425 if self.
_engine_engine.name ==
'postgresql':
428 _LOG.info(
"Running VACUUM on all tables")
429 connection = self.
_engine_engine.raw_connection()
430 ISOLATION_LEVEL_AUTOCOMMIT = 0
431 connection.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
432 cursor = connection.cursor()
433 cursor.execute(
"VACUUM ANALYSE")
439 table: sqlalchemy.schema.Table = self.
_schema_schema.objects
442 stmt = sql.select([func.count()]).select_from(table).where(table.c.nDiaSources == 1)
443 stmt = stmt.where(table.c.validityEnd ==
None)
446 with self.
_engine_engine.begin()
as conn:
447 count = conn.scalar(stmt)
451 def _getDiaSourcesInRegion(self, region: Region, visit_time:
dafBase.DateTime
452 ) -> pandas.DataFrame:
453 """Returns catalog of DiaSource instances from given region.
457 region : `lsst.sphgeom.Region`
458 Region to search for DIASources.
459 visit_time : `lsst.daf.base.DateTime`
460 Time of the current visit.
464 catalog : `pandas.DataFrame`
465 Catalog containing DiaSource records.
469 midPointTai_start = _make_midPointTai_start(visit_time, self.
configconfig.read_sources_months)
470 _LOG.debug(
"midPointTai_start = %.6f", midPointTai_start)
472 table: sqlalchemy.schema.Table = self.
_schema_schema.sources
473 query = table.select()
476 htm_index_column = table.columns[self.
configconfig.htm_index_column]
479 for low, upper
in pixel_ranges:
482 exprlist.append(htm_index_column == low)
484 exprlist.append(sql.expression.between(htm_index_column, low, upper))
485 time_filter = table.columns[
"midPointTai"] > midPointTai_start
486 where = sql.expression.and_(sql.expression.or_(*exprlist), time_filter)
487 query = query.where(where)
490 with Timer(
'DiaSource select', self.
configconfig.timer):
491 with _ansi_session(self.
_engine_engine)
as conn:
492 sources = pandas.read_sql_query(query, conn)
493 _LOG.debug(
"found %s DiaSources", len(sources))
496 def _getDiaSourcesByIDs(self, object_ids: List[int], visit_time:
dafBase.DateTime
497 ) -> pandas.DataFrame:
498 """Returns catalog of DiaSource instances given set of DiaObject IDs.
503 Collection of DiaObject IDs
504 visit_time : `lsst.daf.base.DateTime`
505 Time of the current visit.
509 catalog : `pandas.DataFrame`
510 Catalog contaning DiaSource records.
514 midPointTai_start = _make_midPointTai_start(visit_time, self.
configconfig.read_sources_months)
515 _LOG.debug(
"midPointTai_start = %.6f", midPointTai_start)
517 table: sqlalchemy.schema.Table = self.
_schema_schema.sources
518 with Timer(
'DiaSource select', self.
configconfig.timer):
519 sources = self.
_getSourcesByIDs_getSourcesByIDs(table, object_ids, midPointTai_start)
521 _LOG.debug(
"found %s DiaSources", len(sources))
524 def _getSourcesByIDs(self, table: sqlalchemy.schema.Table,
525 object_ids: List[int],
526 midPointTai_start: float
527 ) -> pandas.DataFrame:
528 """Returns catalog of DiaSource or DiaForcedSource instances given set
533 table : `sqlalchemy.schema.Table`
536 Collection of DiaObject IDs
537 midPointTai_start : `float`
538 Earliest midPointTai to retrieve.
542 catalog : `pandas.DataFrame`
543 Catalog contaning DiaSource records. `None` is returned if
544 ``read_sources_months`` configuration parameter is set to 0 or
545 when ``object_ids`` is empty.
547 sources: Optional[pandas.DataFrame] =
None
548 with _ansi_session(self.
_engine_engine)
as conn:
549 if len(object_ids) <= 0:
550 _LOG.debug(
"ID list is empty, just fetch empty result")
551 query = table.select().where(
False)
552 sources = pandas.read_sql_query(query, conn)
554 for ids
in _split(sorted(object_ids), 1000):
555 query = f
'SELECT * FROM "{table.name}" WHERE '
558 ids_str =
",".join(str(id)
for id
in ids)
559 query += f
'"diaObjectId" IN ({ids_str})'
560 query += f
' AND "midPointTai" > {midPointTai_start}'
563 df = pandas.read_sql_query(sql.text(query), conn)
567 sources = sources.append(df)
568 assert sources
is not None,
"Catalog cannot be None"
571 def _storeDiaObjects(self, objs: pandas.DataFrame, visit_time:
dafBase.DateTime) ->
None:
572 """Store catalog of DiaObjects from current visit.
576 objs : `pandas.DataFrame`
577 Catalog with DiaObject records.
578 visit_time : `lsst.daf.base.DateTime`
582 ids = sorted(objs[
'diaObjectId'])
583 _LOG.debug(
"first object ID: %d", ids[0])
587 table: sqlalchemy.schema.Table = self.
_schema_schema.objects
591 dt = visit_time.toPython()
594 with _ansi_session(self.
_engine_engine)
as conn:
596 ids_str =
",".join(str(id)
for id
in ids)
598 if self.
configconfig.dia_object_index ==
'last_object_table':
602 table = self.
_schema_schema.objects_last
603 do_replace = self.
configconfig.object_last_replace
607 if not do_replace
or isinstance(objs, pandas.DataFrame):
608 query =
'DELETE FROM "' + table.name +
'" '
609 query +=
'WHERE "diaObjectId" IN (' + ids_str +
') '
611 if self.
configconfig.explain:
615 with Timer(table.name +
' delete', self.
configconfig.timer):
616 res = conn.execute(sql.text(query))
617 _LOG.debug(
"deleted %s objects", res.rowcount)
619 extra_columns: Dict[str, Any] = dict(lastNonForcedSource=dt)
620 with Timer(
"DiaObjectLast insert", self.
configconfig.timer):
621 objs = _coerce_uint64(objs)
622 for col, data
in extra_columns.items():
624 objs.to_sql(
"DiaObjectLast", conn, if_exists=
'append',
629 table = self.
_schema_schema.objects
630 query =
'UPDATE "' + table.name +
'" '
631 query +=
"SET \"validityEnd\" = '" + str(dt) +
"' "
632 query +=
'WHERE "diaObjectId" IN (' + ids_str +
') '
633 query +=
'AND "validityEnd" IS NULL'
637 if self.
configconfig.explain:
641 with Timer(table.name +
' truncate', self.
configconfig.timer):
642 res = conn.execute(sql.text(query))
643 _LOG.debug(
"truncated %s intervals", res.rowcount)
646 table = self.
_schema_schema.objects
647 extra_columns = dict(lastNonForcedSource=dt, validityStart=dt,
649 with Timer(
"DiaObject insert", self.
configconfig.timer):
650 objs = _coerce_uint64(objs)
651 for col, data
in extra_columns.items():
653 objs.to_sql(
"DiaObject", conn, if_exists=
'append',
656 def _storeDiaSources(self, sources: pandas.DataFrame) ->
None:
657 """Store catalog of DiaSources from current visit.
661 sources : `pandas.DataFrame`
662 Catalog containing DiaSource records
665 with _ansi_session(self.
_engine_engine)
as conn:
667 with Timer(
"DiaSource insert", self.
configconfig.timer):
668 sources = _coerce_uint64(sources)
669 sources.to_sql(
"DiaSource", conn, if_exists=
'append', index=
False)
671 def _storeDiaForcedSources(self, sources: pandas.DataFrame) ->
None:
672 """Store a set of DiaForcedSources from current visit.
676 sources : `pandas.DataFrame`
677 Catalog containing DiaForcedSource records
681 with _ansi_session(self.
_engine_engine)
as conn:
683 with Timer(
"DiaForcedSource insert", self.
configconfig.timer):
684 sources = _coerce_uint64(sources)
685 sources.to_sql(
"DiaForcedSource", conn, if_exists=
'append', index=
False)
687 def _explain(self, query: str, conn: sqlalchemy.engine.Connection) ->
None:
688 """Run the query with explain
691 _LOG.info(
"explain for query: %s...", query[:64])
693 if conn.engine.name ==
'mysql':
694 query =
"EXPLAIN EXTENDED " + query
696 query =
"EXPLAIN " + query
698 res = conn.execute(sql.text(query))
700 _LOG.info(
"explain: %s", res.keys())
702 _LOG.info(
"explain: %s", row)
704 _LOG.info(
"EXPLAIN returned nothing")
706 def _htm_indices(self, region: Region) -> List[Tuple[int, int]]:
707 """Generate a set of HTM indices covering specified region.
711 region: `sphgeom.Region`
712 Region that needs to be indexed.
716 Sequence of ranges, range is a tuple (minHtmID, maxHtmID).
718 _LOG.debug(
'region: %s', region)
719 indices = self.
pixelatorpixelator.envelope(region, self.
configconfig.htm_max_ranges)
721 if _LOG.isEnabledFor(logging.DEBUG):
722 for irange
in indices.ranges():
723 _LOG.debug(
'range: %s %s', self.
pixelatorpixelator.toString(irange[0]),
724 self.
pixelatorpixelator.toString(irange[1]))
726 return indices.ranges()
728 def _add_obj_htm_index(self, df: pandas.DataFrame) -> pandas.DataFrame:
729 """Calculate HTM index for each record and add it to a DataFrame.
733 This overrides any existing column in a DataFrame with the same name
734 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
738 htm_index = np.zeros(df.shape[0], dtype=np.int64)
739 ra_col, dec_col = self.
configconfig.ra_dec_columns
740 for i, (ra, dec)
in enumerate(zip(df[ra_col], df[dec_col])):
742 idx = self.
pixelatorpixelator.index(uv3d)
745 df[self.
configconfig.htm_index_column] = htm_index
748 def _add_src_htm_index(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
749 """Add pixelId column to DiaSource catalog.
753 This method copies pixelId value from a matching DiaObject record.
754 DiaObject catalog needs to have a pixelId column filled by
755 ``_add_obj_htm_index`` method and DiaSource records need to be
756 associated to DiaObjects via ``diaObjectId`` column.
758 This overrides any existing column in a DataFrame with the same name
759 (pixelId). Original DataFrame is not changed, copy of a DataFrame is
762 pixel_id_map: Dict[int, int] = {
763 diaObjectId: pixelId
for diaObjectId, pixelId
764 in zip(objs[
"diaObjectId"], objs[self.
configconfig.htm_index_column])
770 htm_index = np.zeros(sources.shape[0], dtype=np.int64)
771 for i, diaObjId
in enumerate(sources[
"diaObjectId"]):
772 htm_index[i] = pixel_id_map[diaObjId]
773 sources = sources.copy()
774 sources[self.
configconfig.htm_index_column] = htm_index
Class for handling dates/times, including MJD, UTC, and TAI.
None _storeDiaForcedSources(self, pandas.DataFrame sources)
Optional[pandas.DataFrame] getDiaForcedSources(self, Region region, Optional[Iterable[int]] object_ids, dafBase.DateTime visit_time)
pandas.DataFrame _getDiaSourcesByIDs(self, List[int] object_ids, dafBase.DateTime visit_time)
pandas.DataFrame _add_src_htm_index(self, pandas.DataFrame sources, pandas.DataFrame objs)
None store(self, dafBase.DateTime visit_time, pandas.DataFrame objects, Optional[pandas.DataFrame] sources=None, Optional[pandas.DataFrame] forced_sources=None)
pandas.DataFrame getDiaObjects(self, Region region)
None _storeDiaSources(self, pandas.DataFrame sources)
None _explain(self, str query, sqlalchemy.engine.Connection conn)
List[Tuple[int, int]] _htm_indices(self, Region region)
None makeSchema(self, bool drop=False)
pandas.DataFrame _getDiaSourcesInRegion(self, Region region, dafBase.DateTime visit_time)
pandas.DataFrame _getSourcesByIDs(self, sqlalchemy.schema.Table table, List[int] object_ids, float midPointTai_start)
Optional[pandas.DataFrame] getDiaSources(self, Region region, Optional[Iterable[int]] object_ids, dafBase.DateTime visit_time)
Optional[TableDef] tableDef(self, ApdbTables table)
Dict[str, int] tableRowCount(self)
def __init__(self, ApdbSqlConfig config)
int countUnassociatedObjects(self)
pandas.DataFrame _add_obj_htm_index(self, pandas.DataFrame df)
None _storeDiaObjects(self, pandas.DataFrame objs, dafBase.DateTime visit_time)
HtmPixelization provides HTM indexing of points and regions.
UnitVector3d is a unit vector in ℝ³ with components stored in double precision.
daf::base::PropertyList * list