LSST Applications 26.0.0,g0265f82a02+6660c170cc,g07994bdeae+30b05a742e,g0a0026dc87+17526d298f,g0a60f58ba1+17526d298f,g0e4bf8285c+96dd2c2ea9,g0ecae5effc+c266a536c8,g1e7d6db67d+6f7cb1f4bb,g26482f50c6+6346c0633c,g2bbee38e9b+6660c170cc,g2cc88a2952+0a4e78cd49,g3273194fdb+f6908454ef,g337abbeb29+6660c170cc,g337c41fc51+9a8f8f0815,g37c6e7c3d5+7bbafe9d37,g44018dc512+6660c170cc,g4a941329ef+4f7594a38e,g4c90b7bd52+5145c320d2,g58be5f913a+bea990ba40,g635b316a6c+8d6b3a3e56,g67924a670a+bfead8c487,g6ae5381d9b+81bc2a20b4,g93c4d6e787+26b17396bd,g98cecbdb62+ed2cb6d659,g98ffbb4407+81bc2a20b4,g9ddcbc5298+7f7571301f,ga1e77700b3+99e9273977,gae46bcf261+6660c170cc,gb2715bf1a1+17526d298f,gc86a011abf+17526d298f,gcf0d15dbbd+96dd2c2ea9,gdaeeff99f8+0d8dbea60f,gdb4ec4c597+6660c170cc,ge23793e450+96dd2c2ea9,gf041782ebf+171108ac67
LSST Data Management Base Package
Loading...
Searching...
No Matches
apdbCassandra.py
Go to the documentation of this file.
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21
22from __future__ import annotations
23
24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"]
25
26import logging
27import uuid
28from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Set, Tuple, Union, cast
29
30import numpy as np
31import pandas
32
33# If cassandra-driver is not there the module can still be imported
34# but ApdbCassandra cannot be instantiated.
35try:
36 import cassandra
37 import cassandra.query
38 from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile
39 from cassandra.policies import AddressTranslator, RoundRobinPolicy, WhiteListRoundRobinPolicy
40
41 CASSANDRA_IMPORTED = True
42except ImportError:
43 CASSANDRA_IMPORTED = False
44
45import felis.types
46import lsst.daf.base as dafBase
47from felis.simple import Table
48from lsst import sphgeom
49from lsst.pex.config import ChoiceField, Field, ListField
50from lsst.utils.iteration import chunk_iterable
51
52from .apdb import Apdb, ApdbConfig, ApdbInsertId, ApdbTableData
53from .apdbCassandraSchema import ApdbCassandraSchema, ExtraTables
54from .apdbSchema import ApdbTables
55from .cassandra_utils import (
56 ApdbCassandraTableData,
57 literal,
58 pandas_dataframe_factory,
59 quote_id,
60 raw_data_factory,
61 select_concurrent,
62)
63from .pixelization import Pixelization
64from .timer import Timer
65
66_LOG = logging.getLogger(__name__)
67
68
69class CassandraMissingError(Exception):
70 def __init__(self) -> None:
71 super().__init__("cassandra-driver module cannot be imported")
72
73
75 contact_points = ListField[str](
76 doc="The list of contact points to try connecting for cluster discovery.", default=["127.0.0.1"]
77 )
78 private_ips = ListField[str](doc="List of internal IP addresses for contact_points.", default=[])
79 keyspace = Field[str](doc="Default keyspace for operations.", default="apdb")
80 read_consistency = Field[str](
81 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.", default="QUORUM"
82 )
83 write_consistency = Field[str](
84 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.", default="QUORUM"
85 )
86 read_timeout = Field[float](doc="Timeout in seconds for read operations.", default=120.0)
87 write_timeout = Field[float](doc="Timeout in seconds for write operations.", default=10.0)
88 read_concurrency = Field[int](doc="Concurrency level for read operations.", default=500)
89 protocol_version = Field[int](
90 doc="Cassandra protocol version to use, default is V4",
91 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0,
92 )
93 dia_object_columns = ListField[str](
94 doc="List of columns to read from DiaObject[Last], by default read all columns", default=[]
95 )
96 prefix = Field[str](doc="Prefix to add to table names", default="")
97 part_pixelization = ChoiceField[str](
98 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"),
99 doc="Pixelization used for partitioning index.",
100 default="mq3c",
101 )
102 part_pix_level = Field[int](doc="Pixelization level used for partitioning index.", default=10)
103 part_pix_max_ranges = Field[int](doc="Max number of ranges in pixelization envelope", default=64)
104 ra_dec_columns = ListField[str](default=["ra", "dec"], doc="Names of ra/dec columns in DiaObject table")
105 timer = Field[bool](doc="If True then print/log timing information", default=False)
106 time_partition_tables = Field[bool](
107 doc="Use per-partition tables for sources instead of partitioning by time", default=True
108 )
109 time_partition_days = Field[int](
110 doc=(
111 "Time partitioning granularity in days, this value must not be changed after database is "
112 "initialized"
113 ),
114 default=30,
115 )
116 time_partition_start = Field[str](
117 doc=(
118 "Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. "
119 "This is used only when time_partition_tables is True."
120 ),
121 default="2018-12-01T00:00:00",
122 )
123 time_partition_end = Field[str](
124 doc=(
125 "Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI. "
126 "This is used only when time_partition_tables is True."
127 ),
128 default="2030-01-01T00:00:00",
129 )
130 query_per_time_part = Field[bool](
131 default=False,
132 doc=(
133 "If True then build separate query for each time partition, otherwise build one single query. "
134 "This is only used when time_partition_tables is False in schema config."
135 ),
136 )
137 query_per_spatial_part = Field[bool](
138 default=False,
139 doc="If True then build one query per spatial partition, otherwise build single query.",
140 )
141
142
143if CASSANDRA_IMPORTED:
144
145 class _AddressTranslator(AddressTranslator):
146 """Translate internal IP address to external.
147
148 Only used for docker-based setup, not viable long-term solution.
149 """
150
151 def __init__(self, public_ips: List[str], private_ips: List[str]):
152 self._map = dict((k, v) for k, v in zip(private_ips, public_ips))
153
154 def translate(self, private_ip: str) -> str:
155 return self._map.get(private_ip, private_ip)
156
157
158def _quote_column(name: str) -> str:
159 """Quote column name"""
160 if name.islower():
161 return name
162 else:
163 return f'"{name}"'
164
165
167 """Implementation of APDB database on to of Apache Cassandra.
168
169 The implementation is configured via standard ``pex_config`` mechanism
170 using `ApdbCassandraConfig` configuration class. For an example of
171 different configurations check config/ folder.
172
173 Parameters
174 ----------
175 config : `ApdbCassandraConfig`
176 Configuration object.
177 """
178
179 partition_zero_epoch = dafBase.DateTime(1970, 1, 1, 0, 0, 0, dafBase.DateTime.TAI)
180 """Start time for partition 0, this should never be changed."""
181
182 def __init__(self, config: ApdbCassandraConfig):
183 if not CASSANDRA_IMPORTED:
185
186 config.validate()
187 self.config = config
188
189 _LOG.debug("ApdbCassandra Configuration:")
190 for key, value in self.config.items():
191 _LOG.debug(" %s: %s", key, value)
192
194 config.part_pixelization, config.part_pix_level, config.part_pix_max_ranges
195 )
196
197 addressTranslator: Optional[AddressTranslator] = None
198 if config.private_ips:
199 addressTranslator = _AddressTranslator(config.contact_points, config.private_ips)
200
201 self._keyspace = config.keyspace
202
203 self._cluster = Cluster(
204 execution_profiles=self._makeProfiles(config),
205 contact_points=self.config.contact_points,
206 address_translator=addressTranslator,
207 protocol_version=self.config.protocol_version,
208 )
209 self._session = self._cluster.connect()
210 # Disable result paging
211 self._session.default_fetch_size = None
212
214 session=self._session,
215 keyspace=self._keyspace,
216 schema_file=self.config.schema_file,
217 schema_name=self.config.schema_name,
218 prefix=self.config.prefix,
219 time_partition_tables=self.config.time_partition_tables,
220 use_insert_id=self.config.use_insert_id,
221 )
222 self._partition_zero_epoch_mjd = self.partition_zero_epoch.get(system=dafBase.DateTime.MJD)
223
224 # Cache for prepared statements
225 self._prepared_statements: Dict[str, cassandra.query.PreparedStatement] = {}
226
227 def __del__(self) -> None:
228 self._cluster.shutdown()
229
230 def tableDef(self, table: ApdbTables) -> Optional[Table]:
231 # docstring is inherited from a base class
232 return self._schema.tableSchemas.get(table)
233
234 def makeSchema(self, drop: bool = False) -> None:
235 # docstring is inherited from a base class
236
237 if self.config.time_partition_tables:
238 time_partition_start = dafBase.DateTime(self.config.time_partition_start, dafBase.DateTime.TAI)
239 time_partition_end = dafBase.DateTime(self.config.time_partition_end, dafBase.DateTime.TAI)
240 part_range = (
241 self._time_partition(time_partition_start),
242 self._time_partition(time_partition_end) + 1,
243 )
244 self._schema.makeSchema(drop=drop, part_range=part_range)
245 else:
246 self._schema.makeSchema(drop=drop)
247
248 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame:
249 # docstring is inherited from a base class
250
251 sp_where = self._spatial_where(region)
252 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where))
253
254 # We need to exclude extra partitioning columns from result.
255 column_names = self._schema.apdbColumnNames(ApdbTables.DiaObjectLast)
256 what = ",".join(_quote_column(column) for column in column_names)
257
258 table_name = self._schema.tableName(ApdbTables.DiaObjectLast)
259 query = f'SELECT {what} from "{self._keyspace}"."{table_name}"'
260 statements: List[Tuple] = []
261 for where, params in sp_where:
262 full_query = f"{query} WHERE {where}"
263 if params:
264 statement = self._prep_statement(full_query)
265 else:
266 # If there are no params then it is likely that query has a
267 # bunch of literals rendered already, no point trying to
268 # prepare it because it's not reusable.
269 statement = cassandra.query.SimpleStatement(full_query)
270 statements.append((statement, params))
271 _LOG.debug("getDiaObjects: #queries: %s", len(statements))
272
273 with Timer("DiaObject select", self.config.timer):
274 objects = cast(
275 pandas.DataFrame,
276 select_concurrent(
277 self._session, statements, "read_pandas_multi", self.config.read_concurrency
278 ),
279 )
280
281 _LOG.debug("found %s DiaObjects", objects.shape[0])
282 return objects
283
285 self, region: sphgeom.Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime
286 ) -> Optional[pandas.DataFrame]:
287 # docstring is inherited from a base class
288 months = self.config.read_sources_months
289 if months == 0:
290 return None
291 mjd_end = visit_time.get(system=dafBase.DateTime.MJD)
292 mjd_start = mjd_end - months * 30
293
294 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource)
295
297 self, region: sphgeom.Region, object_ids: Optional[Iterable[int]], visit_time: dafBase.DateTime
298 ) -> Optional[pandas.DataFrame]:
299 # docstring is inherited from a base class
300 months = self.config.read_forced_sources_months
301 if months == 0:
302 return None
303 mjd_end = visit_time.get(system=dafBase.DateTime.MJD)
304 mjd_start = mjd_end - months * 30
305
306 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource)
307
308 def getInsertIds(self) -> list[ApdbInsertId] | None:
309 # docstring is inherited from a base class
310 if not self._schema.has_insert_id:
311 return None
312
313 # everything goes into a single partition
314 partition = 0
315
316 table_name = self._schema.tableName(ExtraTables.DiaInsertId)
317 query = f'SELECT insert_time, insert_id FROM "{self._keyspace}"."{table_name}" WHERE partition = ?'
318
319 result = self._session.execute(
320 self._prep_statement(query),
321 (partition,),
322 timeout=self.config.read_timeout,
323 execution_profile="read_tuples",
324 )
325 # order by insert_time
326 rows = sorted(result)
327 return [ApdbInsertId(row[1]) for row in rows]
328
329 def deleteInsertIds(self, ids: Iterable[ApdbInsertId]) -> None:
330 # docstring is inherited from a base class
331 if not self._schema.has_insert_id:
332 raise ValueError("APDB is not configured for history storage")
333
334 insert_ids = [id.id for id in ids]
335 params = ",".join("?" * len(insert_ids))
336
337 # everything goes into a single partition
338 partition = 0
339
340 table_name = self._schema.tableName(ExtraTables.DiaInsertId)
341 query = (
342 f'DELETE FROM "{self._keyspace}"."{table_name}" WHERE partition = ? and insert_id IN ({params})'
343 )
344
345 self._session.execute(
346 self._prep_statement(query),
347 [partition] + insert_ids,
348 timeout=self.config.write_timeout,
349 )
350
351 def getDiaObjectsHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
352 # docstring is inherited from a base class
353 return self._get_history(ExtraTables.DiaObjectInsertId, ids)
354
355 def getDiaSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
356 # docstring is inherited from a base class
357 return self._get_history(ExtraTables.DiaSourceInsertId, ids)
358
359 def getDiaForcedSourcesHistory(self, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
360 # docstring is inherited from a base class
361 return self._get_history(ExtraTables.DiaForcedSourceInsertId, ids)
362
363 def getSSObjects(self) -> pandas.DataFrame:
364 # docstring is inherited from a base class
365 tableName = self._schema.tableName(ApdbTables.SSObject)
366 query = f'SELECT * from "{self._keyspace}"."{tableName}"'
367
368 objects = None
369 with Timer("SSObject select", self.config.timer):
370 result = self._session.execute(query, execution_profile="read_pandas")
371 objects = result._current_rows
372
373 _LOG.debug("found %s DiaObjects", objects.shape[0])
374 return objects
375
376 def store(
377 self,
378 visit_time: dafBase.DateTime,
379 objects: pandas.DataFrame,
380 sources: Optional[pandas.DataFrame] = None,
381 forced_sources: Optional[pandas.DataFrame] = None,
382 ) -> None:
383 # docstring is inherited from a base class
384
385 insert_id: ApdbInsertId | None = None
386 if self._schema.has_insert_id:
387 insert_id = ApdbInsertId.new_insert_id()
388 self._storeInsertId(insert_id, visit_time)
389
390 # fill region partition column for DiaObjects
391 objects = self._add_obj_part(objects)
392 self._storeDiaObjects(objects, visit_time, insert_id)
393
394 if sources is not None:
395 # copy apdb_part column from DiaObjects to DiaSources
396 sources = self._add_src_part(sources, objects)
397 self._storeDiaSources(ApdbTables.DiaSource, sources, visit_time, insert_id)
398 self._storeDiaSourcesPartitions(sources, visit_time, insert_id)
399
400 if forced_sources is not None:
401 forced_sources = self._add_fsrc_part(forced_sources, objects)
402 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time, insert_id)
403
404 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
405 # docstring is inherited from a base class
406 self._storeObjectsPandas(objects, ApdbTables.SSObject)
407
408 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
409 # docstring is inherited from a base class
410
411 # To update a record we need to know its exact primary key (including
412 # partition key) so we start by querying for diaSourceId to find the
413 # primary keys.
414
415 table_name = self._schema.tableName(ExtraTables.DiaSourceToPartition)
416 # split it into 1k IDs per query
417 selects: List[Tuple] = []
418 for ids in chunk_iterable(idMap.keys(), 1_000):
419 ids_str = ",".join(str(item) for item in ids)
420 selects.append(
421 (
422 (
423 'SELECT "diaSourceId", "apdb_part", "apdb_time_part", "insert_id" '
424 f'FROM "{self._keyspace}"."{table_name}" WHERE "diaSourceId" IN ({ids_str})'
425 ),
426 {},
427 )
428 )
429
430 # No need for DataFrame here, read data as tuples.
431 result = cast(
432 List[Tuple[int, int, int, uuid.UUID | None]],
433 select_concurrent(self._session, selects, "read_tuples", self.config.read_concurrency),
434 )
435
436 # Make mapping from source ID to its partition.
437 id2partitions: Dict[int, Tuple[int, int]] = {}
438 id2insert_id: Dict[int, ApdbInsertId] = {}
439 for row in result:
440 id2partitions[row[0]] = row[1:3]
441 if row[3] is not None:
442 id2insert_id[row[0]] = ApdbInsertId(row[3])
443
444 # make sure we know partitions for each ID
445 if set(id2partitions) != set(idMap):
446 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions))
447 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
448
449 # Reassign in standard tables
450 queries = cassandra.query.BatchStatement()
451 table_name = self._schema.tableName(ApdbTables.DiaSource)
452 for diaSourceId, ssObjectId in idMap.items():
453 apdb_part, apdb_time_part = id2partitions[diaSourceId]
454 values: Tuple
455 if self.config.time_partition_tables:
456 query = (
457 f'UPDATE "{self._keyspace}"."{table_name}_{apdb_time_part}"'
458 ' SET "ssObjectId" = ?, "diaObjectId" = NULL'
459 ' WHERE "apdb_part" = ? AND "diaSourceId" = ?'
460 )
461 values = (ssObjectId, apdb_part, diaSourceId)
462 else:
463 query = (
464 f'UPDATE "{self._keyspace}"."{table_name}"'
465 ' SET "ssObjectId" = ?, "diaObjectId" = NULL'
466 ' WHERE "apdb_part" = ? AND "apdb_time_part" = ? AND "diaSourceId" = ?'
467 )
468 values = (ssObjectId, apdb_part, apdb_time_part, diaSourceId)
469 queries.add(self._prep_statement(query), values)
470
471 # Reassign in history tables, only if history is enabled
472 if id2insert_id:
473 # Filter out insert ids that have been deleted already. There is a
474 # potential race with concurrent removal of insert IDs, but it
475 # should be handled by WHERE in UPDATE.
476 known_ids = set()
477 if insert_ids := self.getInsertIdsgetInsertIds():
478 known_ids = set(insert_ids)
479 id2insert_id = {key: value for key, value in id2insert_id.items() if value in known_ids}
480 if id2insert_id:
481 table_name = self._schema.tableName(ExtraTables.DiaSourceInsertId)
482 for diaSourceId, ssObjectId in idMap.items():
483 if insert_id := id2insert_id.get(diaSourceId):
484 query = (
485 f'UPDATE "{self._keyspace}"."{table_name}" '
486 ' SET "ssObjectId" = ?, "diaObjectId" = NULL '
487 'WHERE "insert_id" = ? AND "diaSourceId" = ?'
488 )
489 values = (ssObjectId, insert_id.id, diaSourceId)
490 queries.add(self._prep_statement(query), values)
491
492 _LOG.debug("%s: will update %d records", table_name, len(idMap))
493 with Timer(table_name + " update", self.config.timer):
494 self._session.execute(queries, execution_profile="write")
495
496 def dailyJob(self) -> None:
497 # docstring is inherited from a base class
498 pass
499
500 def countUnassociatedObjects(self) -> int:
501 # docstring is inherited from a base class
502
503 # It's too inefficient to implement it for Cassandra in current schema.
504 raise NotImplementedError()
505
506 def _makeProfiles(self, config: ApdbCassandraConfig) -> Mapping[Any, ExecutionProfile]:
507 """Make all execution profiles used in the code."""
508
509 if config.private_ips:
510 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points)
511 else:
512 loadBalancePolicy = RoundRobinPolicy()
513
514 read_tuples_profile = ExecutionProfile(
515 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
516 request_timeout=config.read_timeout,
517 row_factory=cassandra.query.tuple_factory,
518 load_balancing_policy=loadBalancePolicy,
519 )
520 read_pandas_profile = ExecutionProfile(
521 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
522 request_timeout=config.read_timeout,
523 row_factory=pandas_dataframe_factory,
524 load_balancing_policy=loadBalancePolicy,
525 )
526 read_raw_profile = ExecutionProfile(
527 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
528 request_timeout=config.read_timeout,
529 row_factory=raw_data_factory,
530 load_balancing_policy=loadBalancePolicy,
531 )
532 # Profile to use with select_concurrent to return pandas data frame
533 read_pandas_multi_profile = ExecutionProfile(
534 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
535 request_timeout=config.read_timeout,
536 row_factory=pandas_dataframe_factory,
537 load_balancing_policy=loadBalancePolicy,
538 )
539 # Profile to use with select_concurrent to return raw data (columns and
540 # rows)
541 read_raw_multi_profile = ExecutionProfile(
542 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
543 request_timeout=config.read_timeout,
544 row_factory=raw_data_factory,
545 load_balancing_policy=loadBalancePolicy,
546 )
547 write_profile = ExecutionProfile(
548 consistency_level=getattr(cassandra.ConsistencyLevel, config.write_consistency),
549 request_timeout=config.write_timeout,
550 load_balancing_policy=loadBalancePolicy,
551 )
552 # To replace default DCAwareRoundRobinPolicy
553 default_profile = ExecutionProfile(
554 load_balancing_policy=loadBalancePolicy,
555 )
556 return {
557 "read_tuples": read_tuples_profile,
558 "read_pandas": read_pandas_profile,
559 "read_raw": read_raw_profile,
560 "read_pandas_multi": read_pandas_multi_profile,
561 "read_raw_multi": read_raw_multi_profile,
562 "write": write_profile,
563 EXEC_PROFILE_DEFAULT: default_profile,
564 }
565
567 self,
568 region: sphgeom.Region,
569 object_ids: Optional[Iterable[int]],
570 mjd_start: float,
571 mjd_end: float,
572 table_name: ApdbTables,
573 ) -> pandas.DataFrame:
574 """Returns catalog of DiaSource instances given set of DiaObject IDs.
575
576 Parameters
577 ----------
578 region : `lsst.sphgeom.Region`
579 Spherical region.
580 object_ids :
581 Collection of DiaObject IDs
582 mjd_start : `float`
583 Lower bound of time interval.
584 mjd_end : `float`
585 Upper bound of time interval.
586 table_name : `ApdbTables`
587 Name of the table.
588
589 Returns
590 -------
591 catalog : `pandas.DataFrame`, or `None`
592 Catalog containing DiaSource records. Empty catalog is returned if
593 ``object_ids`` is empty.
594 """
595 object_id_set: Set[int] = set()
596 if object_ids is not None:
597 object_id_set = set(object_ids)
598 if len(object_id_set) == 0:
599 return self._make_empty_catalog(table_name)
600
601 sp_where = self._spatial_where(region)
602 tables, temporal_where = self._temporal_where(table_name, mjd_start, mjd_end)
603
604 # We need to exclude extra partitioning columns from result.
605 column_names = self._schema.apdbColumnNames(table_name)
606 what = ",".join(_quote_column(column) for column in column_names)
607
608 # Build all queries
609 statements: List[Tuple] = []
610 for table in tables:
611 prefix = f'SELECT {what} from "{self._keyspace}"."{table}"'
612 statements += list(self._combine_where(prefix, sp_where, temporal_where))
613 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements))
614
615 with Timer(table_name.name + " select", self.config.timer):
616 catalog = cast(
617 pandas.DataFrame,
618 select_concurrent(
619 self._session, statements, "read_pandas_multi", self.config.read_concurrency
620 ),
621 )
622
623 # filter by given object IDs
624 if len(object_id_set) > 0:
625 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)])
626
627 # precise filtering on midpointMjdTai
628 catalog = cast(pandas.DataFrame, catalog[catalog["midpointMjdTai"] > mjd_start])
629
630 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name)
631 return catalog
632
633 def _get_history(self, table: ExtraTables, ids: Iterable[ApdbInsertId]) -> ApdbTableData:
634 """Return records from a particular table given set of insert IDs."""
635 if not self._schema.has_insert_id:
636 raise ValueError("APDB is not configured for history retrieval")
637
638 insert_ids = [id.id for id in ids]
639 params = ",".join("?" * len(insert_ids))
640
641 table_name = self._schema.tableName(table)
642 # I know that history table schema has only regular APDB columns plus
643 # an insert_id column, and this is exactly what we need to return from
644 # this method, so selecting a star is fine here.
645 query = f'SELECT * FROM "{self._keyspace}"."{table_name}" WHERE insert_id IN ({params})'
646 statement = self._prep_statement(query)
647
648 with Timer("DiaObject history", self.config.timer):
649 result = self._session.execute(statement, insert_ids, execution_profile="read_raw")
650 table_data = cast(ApdbCassandraTableData, result._current_rows)
651 return table_data
652
653 def _storeInsertId(self, insert_id: ApdbInsertId, visit_time: dafBase.DateTime) -> None:
654 # Cassandra timestamp uses milliseconds since epoch
655 timestamp = visit_time.nsecs() // 1_000_000
656
657 # everything goes into a single partition
658 partition = 0
659
660 table_name = self._schema.tableName(ExtraTables.DiaInsertId)
661 query = (
662 f'INSERT INTO "{self._keyspace}"."{table_name}" (partition, insert_id, insert_time) '
663 "VALUES (?, ?, ?)"
664 )
665
666 self._session.execute(
667 self._prep_statement(query),
668 (partition, insert_id.id, timestamp),
669 timeout=self.config.write_timeout,
670 execution_profile="write",
671 )
672
674 self, objs: pandas.DataFrame, visit_time: dafBase.DateTime, insert_id: ApdbInsertId | None
675 ) -> None:
676 """Store catalog of DiaObjects from current visit.
677
678 Parameters
679 ----------
680 objs : `pandas.DataFrame`
681 Catalog with DiaObject records
682 visit_time : `lsst.daf.base.DateTime`
683 Time of the current visit.
684 """
685 visit_time_dt = visit_time.toPython()
686 extra_columns = dict(lastNonForcedSource=visit_time_dt)
687 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns)
688
689 extra_columns["validityStart"] = visit_time_dt
690 time_part: Optional[int] = self._time_partition(visit_time)
691 if not self.config.time_partition_tables:
692 extra_columns["apdb_time_part"] = time_part
693 time_part = None
694
695 self._storeObjectsPandas(objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part)
696
697 if insert_id is not None:
698 extra_columns = dict(insert_id=insert_id.id, validityStart=visit_time_dt)
699 self._storeObjectsPandas(objs, ExtraTables.DiaObjectInsertId, extra_columns=extra_columns)
700
702 self,
703 table_name: ApdbTables,
704 sources: pandas.DataFrame,
705 visit_time: dafBase.DateTime,
706 insert_id: ApdbInsertId | None,
707 ) -> None:
708 """Store catalog of DIASources or DIAForcedSources from current visit.
709
710 Parameters
711 ----------
712 sources : `pandas.DataFrame`
713 Catalog containing DiaSource records
714 visit_time : `lsst.daf.base.DateTime`
715 Time of the current visit.
716 """
717 time_part: Optional[int] = self._time_partition(visit_time)
718 extra_columns: dict[str, Any] = {}
719 if not self.config.time_partition_tables:
720 extra_columns["apdb_time_part"] = time_part
721 time_part = None
722
723 self._storeObjectsPandas(sources, table_name, extra_columns=extra_columns, time_part=time_part)
724
725 if insert_id is not None:
726 extra_columns = dict(insert_id=insert_id.id)
727 if table_name is ApdbTables.DiaSource:
728 extra_table = ExtraTables.DiaSourceInsertId
729 else:
730 extra_table = ExtraTables.DiaForcedSourceInsertId
731 self._storeObjectsPandas(sources, extra_table, extra_columns=extra_columns)
732
734 self, sources: pandas.DataFrame, visit_time: dafBase.DateTime, insert_id: ApdbInsertId | None
735 ) -> None:
736 """Store mapping of diaSourceId to its partitioning values.
737
738 Parameters
739 ----------
740 sources : `pandas.DataFrame`
741 Catalog containing DiaSource records
742 visit_time : `lsst.daf.base.DateTime`
743 Time of the current visit.
744 """
745 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]])
746 extra_columns = {
747 "apdb_time_part": self._time_partition(visit_time),
748 "insert_id": insert_id.id if insert_id is not None else None,
749 }
750
752 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None
753 )
754
756 self,
757 records: pandas.DataFrame,
758 table_name: Union[ApdbTables, ExtraTables],
759 extra_columns: Optional[Mapping] = None,
760 time_part: Optional[int] = None,
761 ) -> None:
762 """Generic store method.
763
764 Takes Pandas catalog and stores a bunch of records in a table.
765
766 Parameters
767 ----------
768 records : `pandas.DataFrame`
769 Catalog containing object records
770 table_name : `ApdbTables`
771 Name of the table as defined in APDB schema.
772 extra_columns : `dict`, optional
773 Mapping (column_name, column_value) which gives fixed values for
774 columns in each row, overrides values in ``records`` if matching
775 columns exist there.
776 time_part : `int`, optional
777 If not `None` then insert into a per-partition table.
778
779 Notes
780 -----
781 If Pandas catalog contains additional columns not defined in table
782 schema they are ignored. Catalog does not have to contain all columns
783 defined in a table, but partition and clustering keys must be present
784 in a catalog or ``extra_columns``.
785 """
786 # use extra columns if specified
787 if extra_columns is None:
788 extra_columns = {}
789 extra_fields = list(extra_columns.keys())
790
791 # Fields that will come from dataframe.
792 df_fields = [column for column in records.columns if column not in extra_fields]
793
794 column_map = self._schema.getColumnMap(table_name)
795 # list of columns (as in felis schema)
796 fields = [column_map[field].name for field in df_fields if field in column_map]
797 fields += extra_fields
798
799 # check that all partitioning and clustering columns are defined
800 required_columns = self._schema.partitionColumns(table_name) + self._schema.clusteringColumns(
801 table_name
802 )
803 missing_columns = [column for column in required_columns if column not in fields]
804 if missing_columns:
805 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}")
806
807 qfields = [quote_id(field) for field in fields]
808 qfields_str = ",".join(qfields)
809
810 with Timer(table_name.name + " query build", self.config.timer):
811 table = self._schema.tableName(table_name)
812 if time_part is not None:
813 table = f"{table}_{time_part}"
814
815 holders = ",".join(["?"] * len(qfields))
816 query = f'INSERT INTO "{self._keyspace}"."{table}" ({qfields_str}) VALUES ({holders})'
817 statement = self._prep_statement(query)
818 queries = cassandra.query.BatchStatement()
819 for rec in records.itertuples(index=False):
820 values = []
821 for field in df_fields:
822 if field not in column_map:
823 continue
824 value = getattr(rec, field)
825 if column_map[field].datatype is felis.types.Timestamp:
826 if isinstance(value, pandas.Timestamp):
827 value = literal(value.to_pydatetime())
828 else:
829 # Assume it's seconds since epoch, Cassandra
830 # datetime is in milliseconds
831 value = int(value * 1000)
832 values.append(literal(value))
833 for field in extra_fields:
834 value = extra_columns[field]
835 values.append(literal(value))
836 queries.add(statement, values)
837
838 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), records.shape[0])
839 with Timer(table_name.name + " insert", self.config.timer):
840 self._session.execute(queries, timeout=self.config.write_timeout, execution_profile="write")
841
842 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame:
843 """Calculate spatial partition for each record and add it to a
844 DataFrame.
845
846 Notes
847 -----
848 This overrides any existing column in a DataFrame with the same name
849 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
850 returned.
851 """
852 # calculate HTM index for every DiaObject
853 apdb_part = np.zeros(df.shape[0], dtype=np.int64)
854 ra_col, dec_col = self.config.ra_dec_columns
855 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
856 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
857 idx = self._pixelization.pixel(uv3d)
858 apdb_part[i] = idx
859 df = df.copy()
860 df["apdb_part"] = apdb_part
861 return df
862
863 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
864 """Add apdb_part column to DiaSource catalog.
865
866 Notes
867 -----
868 This method copies apdb_part value from a matching DiaObject record.
869 DiaObject catalog needs to have a apdb_part column filled by
870 ``_add_obj_part`` method and DiaSource records need to be
871 associated to DiaObjects via ``diaObjectId`` column.
872
873 This overrides any existing column in a DataFrame with the same name
874 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
875 returned.
876 """
877 pixel_id_map: Dict[int, int] = {
878 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"])
879 }
880 apdb_part = np.zeros(sources.shape[0], dtype=np.int64)
881 ra_col, dec_col = self.config.ra_dec_columns
882 for i, (diaObjId, ra, dec) in enumerate(
883 zip(sources["diaObjectId"], sources[ra_col], sources[dec_col])
884 ):
885 if diaObjId == 0:
886 # DiaSources associated with SolarSystemObjects do not have an
887 # associated DiaObject hence we skip them and set partition
888 # based on its own ra/dec
889 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
890 idx = self._pixelization.pixel(uv3d)
891 apdb_part[i] = idx
892 else:
893 apdb_part[i] = pixel_id_map[diaObjId]
894 sources = sources.copy()
895 sources["apdb_part"] = apdb_part
896 return sources
897
898 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
899 """Add apdb_part column to DiaForcedSource catalog.
900
901 Notes
902 -----
903 This method copies apdb_part value from a matching DiaObject record.
904 DiaObject catalog needs to have a apdb_part column filled by
905 ``_add_obj_part`` method and DiaSource records need to be
906 associated to DiaObjects via ``diaObjectId`` column.
907
908 This overrides any existing column in a DataFrame with the same name
909 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
910 returned.
911 """
912 pixel_id_map: Dict[int, int] = {
913 diaObjectId: apdb_part for diaObjectId, apdb_part in zip(objs["diaObjectId"], objs["apdb_part"])
914 }
915 apdb_part = np.zeros(sources.shape[0], dtype=np.int64)
916 for i, diaObjId in enumerate(sources["diaObjectId"]):
917 apdb_part[i] = pixel_id_map[diaObjId]
918 sources = sources.copy()
919 sources["apdb_part"] = apdb_part
920 return sources
921
922 def _time_partition(self, time: Union[float, dafBase.DateTime]) -> int:
923 """Calculate time partiton number for a given time.
924
925 Parameters
926 ----------
927 time : `float` or `lsst.daf.base.DateTime`
928 Time for which to calculate partition number. Can be float to mean
930
931 Returns
932 -------
933 partition : `int`
934 Partition number for a given time.
935 """
936 if isinstance(time, dafBase.DateTime):
937 mjd = time.get(system=dafBase.DateTime.MJD)
938 else:
939 mjd = time
940 days_since_epoch = mjd - self._partition_zero_epoch_mjd
941 partition = int(days_since_epoch) // self.config.time_partition_days
942 return partition
943
944 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame:
945 """Make an empty catalog for a table with a given name.
946
947 Parameters
948 ----------
949 table_name : `ApdbTables`
950 Name of the table.
951
952 Returns
953 -------
954 catalog : `pandas.DataFrame`
955 An empty catalog.
956 """
957 table = self._schema.tableSchemas[table_name]
958
959 data = {
960 columnDef.name: pandas.Series(dtype=self._schema.column_dtype(columnDef.datatype))
961 for columnDef in table.columns
962 }
963 return pandas.DataFrame(data)
964
965 def _prep_statement(self, query: str) -> cassandra.query.PreparedStatement:
966 """Convert query string into prepared statement."""
967 stmt = self._prepared_statements.get(query)
968 if stmt is None:
969 stmt = self._session.prepare(query)
970 self._prepared_statements[query] = stmt
971 return stmt
972
974 self,
975 prefix: str,
976 where1: List[Tuple[str, Tuple]],
977 where2: List[Tuple[str, Tuple]],
978 suffix: Optional[str] = None,
979 ) -> Iterator[Tuple[cassandra.query.Statement, Tuple]]:
980 """Make cartesian product of two parts of WHERE clause into a series
981 of statements to execute.
982
983 Parameters
984 ----------
985 prefix : `str`
986 Initial statement prefix that comes before WHERE clause, e.g.
987 "SELECT * from Table"
988 """
989 # If lists are empty use special sentinels.
990 if not where1:
991 where1 = [("", ())]
992 if not where2:
993 where2 = [("", ())]
994
995 for expr1, params1 in where1:
996 for expr2, params2 in where2:
997 full_query = prefix
998 wheres = []
999 if expr1:
1000 wheres.append(expr1)
1001 if expr2:
1002 wheres.append(expr2)
1003 if wheres:
1004 full_query += " WHERE " + " AND ".join(wheres)
1005 if suffix:
1006 full_query += " " + suffix
1007 params = params1 + params2
1008 if params:
1009 statement = self._prep_statement(full_query)
1010 else:
1011 # If there are no params then it is likely that query
1012 # has a bunch of literals rendered already, no point
1013 # trying to prepare it.
1014 statement = cassandra.query.SimpleStatement(full_query)
1015 yield (statement, params)
1016
1018 self, region: Optional[sphgeom.Region], use_ranges: bool = False
1019 ) -> List[Tuple[str, Tuple]]:
1020 """Generate expressions for spatial part of WHERE clause.
1021
1022 Parameters
1023 ----------
1024 region : `sphgeom.Region`
1025 Spatial region for query results.
1026 use_ranges : `bool`
1027 If True then use pixel ranges ("apdb_part >= p1 AND apdb_part <=
1028 p2") instead of exact list of pixels. Should be set to True for
1029 large regions covering very many pixels.
1030
1031 Returns
1032 -------
1033 expressions : `list` [ `tuple` ]
1034 Empty list is returned if ``region`` is `None`, otherwise a list
1035 of one or more (expression, parameters) tuples
1036 """
1037 if region is None:
1038 return []
1039 if use_ranges:
1040 pixel_ranges = self._pixelization.envelope(region)
1041 expressions: List[Tuple[str, Tuple]] = []
1042 for lower, upper in pixel_ranges:
1043 upper -= 1
1044 if lower == upper:
1045 expressions.append(('"apdb_part" = ?', (lower,)))
1046 else:
1047 expressions.append(('"apdb_part" >= ? AND "apdb_part" <= ?', (lower, upper)))
1048 return expressions
1049 else:
1050 pixels = self._pixelization.pixels(region)
1051 if self.config.query_per_spatial_part:
1052 return [('"apdb_part" = ?', (pixel,)) for pixel in pixels]
1053 else:
1054 pixels_str = ",".join([str(pix) for pix in pixels])
1055 return [(f'"apdb_part" IN ({pixels_str})', ())]
1056
1058 self,
1059 table: ApdbTables,
1060 start_time: Union[float, dafBase.DateTime],
1061 end_time: Union[float, dafBase.DateTime],
1062 query_per_time_part: Optional[bool] = None,
1063 ) -> Tuple[List[str], List[Tuple[str, Tuple]]]:
1064 """Generate table names and expressions for temporal part of WHERE
1065 clauses.
1066
1067 Parameters
1068 ----------
1069 table : `ApdbTables`
1070 Table to select from.
1071 start_time : `dafBase.DateTime` or `float`
1072 Starting Datetime of MJD value of the time range.
1073 start_time : `dafBase.DateTime` or `float`
1074 Starting Datetime of MJD value of the time range.
1075 query_per_time_part : `bool`, optional
1076 If None then use ``query_per_time_part`` from configuration.
1077
1078 Returns
1079 -------
1080 tables : `list` [ `str` ]
1081 List of the table names to query.
1082 expressions : `list` [ `tuple` ]
1083 A list of zero or more (expression, parameters) tuples.
1084 """
1085 tables: List[str]
1086 temporal_where: List[Tuple[str, Tuple]] = []
1087 table_name = self._schema.tableName(table)
1088 time_part_start = self._time_partition(start_time)
1089 time_part_end = self._time_partition(end_time)
1090 time_parts = list(range(time_part_start, time_part_end + 1))
1091 if self.config.time_partition_tables:
1092 tables = [f"{table_name}_{part}" for part in time_parts]
1093 else:
1094 tables = [table_name]
1095 if query_per_time_part is None:
1096 query_per_time_part = self.config.query_per_time_part
1097 if query_per_time_part:
1098 temporal_where = [('"apdb_time_part" = ?', (time_part,)) for time_part in time_parts]
1099 else:
1100 time_part_list = ",".join([str(part) for part in time_parts])
1101 temporal_where = [(f'"apdb_time_part" IN ({time_part_list})', ())]
1102
1103 return tables, temporal_where
std::vector< SchemaItem< Flag > > * items
Class for handling dates/times, including MJD, UTC, and TAI.
Definition DateTime.h:64
list[ApdbInsertId]|None getInsertIds(self)
Definition apdb.py:252
__init__(self, List[str] public_ips, List[str] private_ips)
None deleteInsertIds(self, Iterable[ApdbInsertId] ids)
list[ApdbInsertId]|None getInsertIds(self)
None _storeObjectsPandas(self, pandas.DataFrame records, Union[ApdbTables, ExtraTables] table_name, Optional[Mapping] extra_columns=None, Optional[int] time_part=None)
pandas.DataFrame getDiaObjects(self, sphgeom.Region region)
pandas.DataFrame _make_empty_catalog(self, ApdbTables table_name)
cassandra.query.PreparedStatement _prep_statement(self, str query)
Optional[pandas.DataFrame] getDiaForcedSources(self, sphgeom.Region region, Optional[Iterable[int]] object_ids, dafBase.DateTime visit_time)
__init__(self, ApdbCassandraConfig config)
None store(self, dafBase.DateTime visit_time, pandas.DataFrame objects, Optional[pandas.DataFrame] sources=None, Optional[pandas.DataFrame] forced_sources=None)
None reassignDiaSources(self, Mapping[int, int] idMap)
Iterator[Tuple[cassandra.query.Statement, Tuple]] _combine_where(self, str prefix, List[Tuple[str, Tuple]] where1, List[Tuple[str, Tuple]] where2, Optional[str] suffix=None)
None _storeDiaObjects(self, pandas.DataFrame objs, dafBase.DateTime visit_time, ApdbInsertId|None insert_id)
pandas.DataFrame _add_fsrc_part(self, pandas.DataFrame sources, pandas.DataFrame objs)
None makeSchema(self, bool drop=False)
None storeSSObjects(self, pandas.DataFrame objects)
Tuple[List[str], List[Tuple[str, Tuple]]] _temporal_where(self, ApdbTables table, Union[float, dafBase.DateTime] start_time, Union[float, dafBase.DateTime] end_time, Optional[bool] query_per_time_part=None)
pandas.DataFrame _add_src_part(self, pandas.DataFrame sources, pandas.DataFrame objs)
ApdbTableData getDiaSourcesHistory(self, Iterable[ApdbInsertId] ids)
ApdbTableData getDiaForcedSourcesHistory(self, Iterable[ApdbInsertId] ids)
int _time_partition(self, Union[float, dafBase.DateTime] time)
Optional[Table] tableDef(self, ApdbTables table)
ApdbTableData getDiaObjectsHistory(self, Iterable[ApdbInsertId] ids)
Mapping[Any, ExecutionProfile] _makeProfiles(self, ApdbCassandraConfig config)
ApdbTableData _get_history(self, ExtraTables table, Iterable[ApdbInsertId] ids)
Optional[pandas.DataFrame] getDiaSources(self, sphgeom.Region region, Optional[Iterable[int]] object_ids, dafBase.DateTime visit_time)
None _storeDiaSourcesPartitions(self, pandas.DataFrame sources, dafBase.DateTime visit_time, ApdbInsertId|None insert_id)
List[Tuple[str, Tuple]] _spatial_where(self, Optional[sphgeom.Region] region, bool use_ranges=False)
pandas.DataFrame _getSources(self, sphgeom.Region region, Optional[Iterable[int]] object_ids, float mjd_start, float mjd_end, ApdbTables table_name)
None _storeInsertId(self, ApdbInsertId insert_id, dafBase.DateTime visit_time)
pandas.DataFrame _add_obj_part(self, pandas.DataFrame df)
None _storeDiaSources(self, ApdbTables table_name, pandas.DataFrame sources, dafBase.DateTime visit_time, ApdbInsertId|None insert_id)
Region is a minimal interface for 2-dimensional regions on the unit sphere.
Definition Region.h:79
UnitVector3d is a unit vector in ℝ³ with components stored in double precision.
daf::base::PropertyList * list
Definition fits.cc:928
daf::base::PropertySet * set
Definition fits.cc:927