LSST Applications g0b6bd0c080+a72a5dd7e6,g1182afd7b4+2a019aa3bb,g17e5ecfddb+2b8207f7de,g1d67935e3f+06cf436103,g38293774b4+ac198e9f13,g396055baef+6a2097e274,g3b44f30a73+6611e0205b,g480783c3b1+98f8679e14,g48ccf36440+89c08d0516,g4b93dc025c+98f8679e14,g5c4744a4d9+a302e8c7f0,g613e996a0d+e1c447f2e0,g6c8d09e9e7+25247a063c,g7271f0639c+98f8679e14,g7a9cd813b8+124095ede6,g9d27549199+a302e8c7f0,ga1cf026fa3+ac198e9f13,ga32aa97882+7403ac30ac,ga786bb30fb+7a139211af,gaa63f70f4e+9994eb9896,gabf319e997+ade567573c,gba47b54d5d+94dc90c3ea,gbec6a3398f+06cf436103,gc6308e37c7+07dd123edb,gc655b1545f+ade567573c,gcc9029db3c+ab229f5caf,gd01420fc67+06cf436103,gd877ba84e5+06cf436103,gdb4cecd868+6f279b5b48,ge2d134c3d5+cc4dbb2e3f,ge448b5faa6+86d1ceac1d,gecc7e12556+98f8679e14,gf3ee170dca+25247a063c,gf4ac96e456+ade567573c,gf9f5ea5b4d+ac198e9f13,gff490e6085+8c2580be5c,w.2022.27
LSST Data Management Base Package
apdbCassandra.py
Go to the documentation of this file.
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21
22from __future__ import annotations
23
24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"]
25
26import logging
27import numpy as np
28import pandas
29from typing import Any, cast, Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Set, Tuple, Union
30
31# If cassandra-driver is not there the module can still be imported
32# but ApdbCassandra cannot be instantiated.
33try:
34 import cassandra
35 from cassandra.cluster import Cluster, ExecutionProfile, EXEC_PROFILE_DEFAULT
36 from cassandra.policies import RoundRobinPolicy, WhiteListRoundRobinPolicy, AddressTranslator
37 import cassandra.query
38 CASSANDRA_IMPORTED = True
39except ImportError:
40 CASSANDRA_IMPORTED = False
41
42import lsst.daf.base as dafBase
43from lsst import sphgeom
44from lsst.pex.config import ChoiceField, Field, ListField
45from lsst.utils.iteration import chunk_iterable
46from .timer import Timer
47from .apdb import Apdb, ApdbConfig
48from .apdbSchema import ApdbTables, TableDef
49from .apdbCassandraSchema import ApdbCassandraSchema, ExtraTables
50from .cassandra_utils import (
51 literal,
52 pandas_dataframe_factory,
53 quote_id,
54 raw_data_factory,
55 select_concurrent,
56)
57from .pixelization import Pixelization
58
59_LOG = logging.getLogger(__name__)
60
61
62class CassandraMissingError(Exception):
63 def __init__(self) -> None:
64 super().__init__("cassandra-driver module cannot be imported")
65
66
68
69 contact_points = ListField(
70 dtype=str,
71 doc="The list of contact points to try connecting for cluster discovery.",
72 default=["127.0.0.1"]
73 )
74 private_ips = ListField(
75 dtype=str,
76 doc="List of internal IP addresses for contact_points.",
77 default=[]
78 )
79 keyspace = Field(
80 dtype=str,
81 doc="Default keyspace for operations.",
82 default="apdb"
83 )
84 read_consistency = Field(
85 dtype=str,
86 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.",
87 default="QUORUM"
88 )
89 write_consistency = Field(
90 dtype=str,
91 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.",
92 default="QUORUM"
93 )
94 read_timeout = Field(
95 dtype=float,
96 doc="Timeout in seconds for read operations.",
97 default=120.
98 )
99 write_timeout = Field(
100 dtype=float,
101 doc="Timeout in seconds for write operations.",
102 default=10.
103 )
104 read_concurrency = Field(
105 dtype=int,
106 doc="Concurrency level for read operations.",
107 default=500
108 )
109 protocol_version = Field(
110 dtype=int,
111 doc="Cassandra protocol version to use, default is V4",
112 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0
113 )
114 dia_object_columns = ListField(
115 dtype=str,
116 doc="List of columns to read from DiaObject, by default read all columns",
117 default=[]
118 )
119 prefix = Field(
120 dtype=str,
121 doc="Prefix to add to table names",
122 default=""
123 )
124 part_pixelization = ChoiceField(
125 dtype=str,
126 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"),
127 doc="Pixelization used for partitioning index.",
128 default="mq3c"
129 )
130 part_pix_level = Field(
131 dtype=int,
132 doc="Pixelization level used for partitioning index.",
133 default=10
134 )
135 part_pix_max_ranges = Field(
136 dtype=int,
137 doc="Max number of ranges in pixelization envelope",
138 default=64
139 )
140 ra_dec_columns = ListField(
141 dtype=str,
142 default=["ra", "decl"],
143 doc="Names ra/dec columns in DiaObject table"
144 )
145 timer = Field(
146 dtype=bool,
147 doc="If True then print/log timing information",
148 default=False
149 )
150 time_partition_tables = Field(
151 dtype=bool,
152 doc="Use per-partition tables for sources instead of partitioning by time",
153 default=True
154 )
155 time_partition_days = Field(
156 dtype=int,
157 doc="Time partitoning granularity in days, this value must not be changed"
158 " after database is initialized",
159 default=30
160 )
161 time_partition_start = Field(
162 dtype=str,
163 doc="Starting time for per-partion tables, in yyyy-mm-ddThh:mm:ss format, in TAI."
164 " This is used only when time_partition_tables is True.",
165 default="2018-12-01T00:00:00"
166 )
167 time_partition_end = Field(
168 dtype=str,
169 doc="Ending time for per-partion tables, in yyyy-mm-ddThh:mm:ss format, in TAI"
170 " This is used only when time_partition_tables is True.",
171 default="2030-01-01T00:00:00"
172 )
173 query_per_time_part = Field(
174 dtype=bool,
175 default=False,
176 doc="If True then build separate query for each time partition, otherwise build one single query. "
177 "This is only used when time_partition_tables is False in schema config."
178 )
179 query_per_spatial_part = Field(
180 dtype=bool,
181 default=False,
182 doc="If True then build one query per spacial partition, otherwise build single query. "
183 )
184 pandas_delay_conv = Field(
185 dtype=bool,
186 default=True,
187 doc="If True then combine result rows before converting to pandas. "
188 )
189
190
191if CASSANDRA_IMPORTED:
192
193 class _AddressTranslator(AddressTranslator):
194 """Translate internal IP address to external.
195
196 Only used for docker-based setup, not viable long-term solution.
197 """
198 def __init__(self, public_ips: List[str], private_ips: List[str]):
199 self._map_map = dict((k, v) for k, v in zip(private_ips, public_ips))
200
201 def translate(self, private_ip: str) -> str:
202 return self._map_map.get(private_ip, private_ip)
203
204
206 """Implementation of APDB database on to of Apache Cassandra.
207
208 The implementation is configured via standard ``pex_config`` mechanism
209 using `ApdbCassandraConfig` configuration class. For an example of
210 different configurations check config/ folder.
211
212 Parameters
213 ----------
214 config : `ApdbCassandraConfig`
215 Configuration object.
216 """
217
218 partition_zero_epoch = dafBase.DateTime(1970, 1, 1, 0, 0, 0, dafBase.DateTime.TAI)
219 """Start time for partition 0, this should never be changed."""
220
221 def __init__(self, config: ApdbCassandraConfig):
222
223 if not CASSANDRA_IMPORTED:
225
226 self.configconfig = config
227
228 _LOG.debug("ApdbCassandra Configuration:")
229 for key, value in self.configconfig.items():
230 _LOG.debug(" %s: %s", key, value)
231
232 self._pixelization_pixelization = Pixelization(
233 config.part_pixelization, config.part_pix_level, config.part_pix_max_ranges
234 )
235
236 addressTranslator: Optional[AddressTranslator] = None
237 if config.private_ips:
238 addressTranslator = _AddressTranslator(config.contact_points, config.private_ips)
239
240 self._keyspace_keyspace = config.keyspace
241
242 self._cluster_cluster = Cluster(execution_profiles=self._makeProfiles_makeProfiles(config),
243 contact_points=self.configconfig.contact_points,
244 address_translator=addressTranslator,
245 protocol_version=self.configconfig.protocol_version)
246 self._session_session = self._cluster_cluster.connect()
247 # Disable result paging
248 self._session_session.default_fetch_size = None
249
250 self._schema_schema = ApdbCassandraSchema(session=self._session_session,
251 keyspace=self._keyspace_keyspace,
252 schema_file=self.configconfig.schema_file,
253 schema_name=self.configconfig.schema_name,
254 prefix=self.configconfig.prefix,
255 time_partition_tables=self.configconfig.time_partition_tables)
256 self._partition_zero_epoch_mjd_partition_zero_epoch_mjd = self.partition_zero_epochpartition_zero_epoch.get(system=dafBase.DateTime.MJD)
257
258 # Cache for prepared statements
259 self._prepared_statements: Dict[str, cassandra.query.PreparedStatement] = {}
260
261 def tableDef(self, table: ApdbTables) -> Optional[TableDef]:
262 # docstring is inherited from a base class
263 return self._schema_schema.tableSchemas.get(table)
264
265 def makeSchema(self, drop: bool = False) -> None:
266 # docstring is inherited from a base class
267
268 if self.configconfig.time_partition_tables:
269 time_partition_start = dafBase.DateTime(self.configconfig.time_partition_start, dafBase.DateTime.TAI)
270 time_partition_end = dafBase.DateTime(self.configconfig.time_partition_end, dafBase.DateTime.TAI)
271 part_range = (
272 self._time_partition_time_partition(time_partition_start),
273 self._time_partition_time_partition(time_partition_end) + 1
274 )
275 self._schema_schema.makeSchema(drop=drop, part_range=part_range)
276 else:
277 self._schema_schema.makeSchema(drop=drop)
278
279 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame:
280 # docstring is inherited from a base class
281
282 sp_where = self._spatial_where_spatial_where(region)
283 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where))
284
285 table_name = self._schema_schema.tableName(ApdbTables.DiaObjectLast)
286 query = f'SELECT * from "{self._keyspace}"."{table_name}"'
287 statements: List[Tuple] = []
288 for where, params in sp_where:
289 full_query = f"{query} WHERE {where}"
290 if params:
291 statement = self._prep_statement_prep_statement(full_query)
292 else:
293 # If there are no params then it is likely that query has a
294 # bunch of literals rendered already, no point trying to
295 # prepare it because it's not reusable.
296 statement = cassandra.query.SimpleStatement(full_query)
297 statements.append((statement, params))
298 _LOG.debug("getDiaObjects: #queries: %s", len(statements))
299
300 with Timer('DiaObject select', self.configconfig.timer):
301 objects = cast(
302 pandas.DataFrame,
304 self._session_session, statements, "read_pandas_multi", self.configconfig.read_concurrency
305 )
306 )
307
308 _LOG.debug("found %s DiaObjects", objects.shape[0])
309 return objects
310
311 def getDiaSources(self, region: sphgeom.Region,
312 object_ids: Optional[Iterable[int]],
313 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]:
314 # docstring is inherited from a base class
315 months = self.configconfig.read_sources_months
316 if months == 0:
317 return None
318 mjd_end = visit_time.get(system=dafBase.DateTime.MJD)
319 mjd_start = mjd_end - months*30
320
321 return self._getSources_getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource)
322
323 def getDiaForcedSources(self, region: sphgeom.Region,
324 object_ids: Optional[Iterable[int]],
325 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]:
326 # docstring is inherited from a base class
327 months = self.configconfig.read_forced_sources_months
328 if months == 0:
329 return None
330 mjd_end = visit_time.get(system=dafBase.DateTime.MJD)
331 mjd_start = mjd_end - months*30
332
333 return self._getSources_getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource)
334
336 start_time: dafBase.DateTime,
337 end_time: dafBase.DateTime,
338 region: Optional[sphgeom.Region] = None) -> pandas.DataFrame:
339 # docstring is inherited from a base class
340
341 sp_where = self._spatial_where_spatial_where(region, use_ranges=True)
342 tables, temporal_where = self._temporal_where_temporal_where(ApdbTables.DiaObject, start_time, end_time, True)
343
344 # Build all queries
345 statements: List[Tuple] = []
346 for table in tables:
347 prefix = f'SELECT * from "{self._keyspace}"."{table}"'
348 statements += list(self._combine_where_combine_where(prefix, sp_where, temporal_where, "ALLOW FILTERING"))
349 _LOG.debug("getDiaObjectsHistory: #queries: %s", len(statements))
350
351 # Run all selects in parallel
352 with Timer("DiaObject history", self.configconfig.timer):
353 catalog = cast(
354 pandas.DataFrame,
356 self._session_session, statements, "read_pandas_multi", self.configconfig.read_concurrency
357 )
358 )
359
360 # precise filtering on validityStart
361 validity_start = start_time.toPython()
362 validity_end = end_time.toPython()
363 catalog = cast(
364 pandas.DataFrame,
365 catalog[(catalog["validityStart"] >= validity_start) & (catalog["validityStart"] < validity_end)]
366 )
367
368 _LOG.debug("found %d DiaObjects", catalog.shape[0])
369 return catalog
370
372 start_time: dafBase.DateTime,
373 end_time: dafBase.DateTime,
374 region: Optional[sphgeom.Region] = None) -> pandas.DataFrame:
375 # docstring is inherited from a base class
376 return self._getSourcesHistory_getSourcesHistory(ApdbTables.DiaSource, start_time, end_time, region)
377
379 start_time: dafBase.DateTime,
380 end_time: dafBase.DateTime,
381 region: Optional[sphgeom.Region] = None) -> pandas.DataFrame:
382 # docstring is inherited from a base class
383 return self._getSourcesHistory_getSourcesHistory(ApdbTables.DiaForcedSource, start_time, end_time, region)
384
385 def getSSObjects(self) -> pandas.DataFrame:
386 # docstring is inherited from a base class
387 tableName = self._schema_schema.tableName(ApdbTables.SSObject)
388 query = f'SELECT * from "{self._keyspace}"."{tableName}"'
389
390 objects = None
391 with Timer('SSObject select', self.configconfig.timer):
392 result = self._session_session.execute(query, execution_profile="read_pandas")
393 objects = result._current_rows
394
395 _LOG.debug("found %s DiaObjects", objects.shape[0])
396 return objects
397
398 def store(self,
399 visit_time: dafBase.DateTime,
400 objects: pandas.DataFrame,
401 sources: Optional[pandas.DataFrame] = None,
402 forced_sources: Optional[pandas.DataFrame] = None) -> None:
403 # docstring is inherited from a base class
404
405 # fill region partition column for DiaObjects
406 objects = self._add_obj_part_add_obj_part(objects)
407 self._storeDiaObjects_storeDiaObjects(objects, visit_time)
408
409 if sources is not None:
410 # copy apdb_part column from DiaObjects to DiaSources
411 sources = self._add_src_part_add_src_part(sources, objects)
412 self._storeDiaSources_storeDiaSources(ApdbTables.DiaSource, sources, visit_time)
413 self._storeDiaSourcesPartitions_storeDiaSourcesPartitions(sources, visit_time)
414
415 if forced_sources is not None:
416 forced_sources = self._add_fsrc_part_add_fsrc_part(forced_sources, objects)
417 self._storeDiaSources_storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time)
418
419 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
420 # docstring is inherited from a base class
421 self._storeObjectsPandas_storeObjectsPandas(objects, ApdbTables.SSObject)
422
423 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
424 # docstring is inherited from a base class
425
426 # To update a record we need to know its exact primary key (including
427 # partition key) so we start by querying for diaSourceId to find the
428 # primary keys.
429
430 table_name = self._schema_schema.tableName(ExtraTables.DiaSourceToPartition)
431 # split it into 1k IDs per query
432 selects: List[Tuple] = []
433 for ids in chunk_iterable(idMap.keys(), 1_000):
434 ids_str = ",".join(str(item) for item in ids)
435 selects.append((
436 (f'SELECT "diaSourceId", "apdb_part", "apdb_time_part" FROM "{self._keyspace}"."{table_name}"'
437 f' WHERE "diaSourceId" IN ({ids_str})'),
438 {}
439 ))
440
441 # No need for DataFrame here, read data as tuples.
442 result = cast(
443 List[Tuple[int, int, int]],
444 select_concurrent(self._session_session, selects, "read_tuples", self.configconfig.read_concurrency)
445 )
446
447 # Make mapping from source ID to its partition.
448 id2partitions: Dict[int, Tuple[int, int]] = {}
449 for row in result:
450 id2partitions[row[0]] = row[1:]
451
452 # make sure we know partitions for each ID
453 if set(id2partitions) != set(idMap):
454 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions))
455 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
456
457 queries = cassandra.query.BatchStatement()
458 table_name = self._schema_schema.tableName(ApdbTables.DiaSource)
459 for diaSourceId, ssObjectId in idMap.items():
460 apdb_part, apdb_time_part = id2partitions[diaSourceId]
461 values: Tuple
462 if self.configconfig.time_partition_tables:
463 query = (
464 f'UPDATE "{self._keyspace}"."{table_name}_{apdb_time_part}"'
465 ' SET "ssObjectId" = ?, "diaObjectId" = NULL'
466 ' WHERE "apdb_part" = ? AND "diaSourceId" = ?'
467 )
468 values = (ssObjectId, apdb_part, diaSourceId)
469 else:
470 query = (
471 f'UPDATE "{self._keyspace}"."{table_name}"'
472 ' SET "ssObjectId" = ?, "diaObjectId" = NULL'
473 ' WHERE "apdb_part" = ? AND "apdb_time_part" = ? AND "diaSourceId" = ?'
474 )
475 values = (ssObjectId, apdb_part, apdb_time_part, diaSourceId)
476 queries.add(self._prep_statement_prep_statement(query), values)
477
478 _LOG.debug("%s: will update %d records", table_name, len(idMap))
479 with Timer(table_name + ' update', self.configconfig.timer):
480 self._session_session.execute(queries, execution_profile="write")
481
482 def dailyJob(self) -> None:
483 # docstring is inherited from a base class
484 pass
485
486 def countUnassociatedObjects(self) -> int:
487 # docstring is inherited from a base class
488
489 # It's too inefficient to implement it for Cassandra in current schema.
490 raise NotImplementedError()
491
492 def _makeProfiles(self, config: ApdbCassandraConfig) -> Mapping[Any, ExecutionProfile]:
493 """Make all execution profiles used in the code."""
494
495 if config.private_ips:
496 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points)
497 else:
498 loadBalancePolicy = RoundRobinPolicy()
499
500 pandas_row_factory: Callable
501 if not config.pandas_delay_conv:
502 pandas_row_factory = pandas_dataframe_factory
503 else:
504 pandas_row_factory = raw_data_factory
505
506 read_tuples_profile = ExecutionProfile(
507 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
508 request_timeout=config.read_timeout,
509 row_factory=cassandra.query.tuple_factory,
510 load_balancing_policy=loadBalancePolicy,
511 )
512 read_pandas_profile = ExecutionProfile(
513 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
514 request_timeout=config.read_timeout,
515 row_factory=pandas_dataframe_factory,
516 load_balancing_policy=loadBalancePolicy,
517 )
518 read_pandas_multi_profile = ExecutionProfile(
519 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
520 request_timeout=config.read_timeout,
521 row_factory=pandas_row_factory,
522 load_balancing_policy=loadBalancePolicy,
523 )
524 write_profile = ExecutionProfile(
525 consistency_level=getattr(cassandra.ConsistencyLevel, config.write_consistency),
526 request_timeout=config.write_timeout,
527 load_balancing_policy=loadBalancePolicy,
528 )
529 # To replace default DCAwareRoundRobinPolicy
530 default_profile = ExecutionProfile(
531 load_balancing_policy=loadBalancePolicy,
532 )
533 return {
534 "read_tuples": read_tuples_profile,
535 "read_pandas": read_pandas_profile,
536 "read_pandas_multi": read_pandas_multi_profile,
537 "write": write_profile,
538 EXEC_PROFILE_DEFAULT: default_profile,
539 }
540
541 def _getSources(self, region: sphgeom.Region,
542 object_ids: Optional[Iterable[int]],
543 mjd_start: float,
544 mjd_end: float,
545 table_name: ApdbTables) -> pandas.DataFrame:
546 """Returns catalog of DiaSource instances given set of DiaObject IDs.
547
548 Parameters
549 ----------
550 region : `lsst.sphgeom.Region`
551 Spherical region.
552 object_ids :
553 Collection of DiaObject IDs
554 mjd_start : `float`
555 Lower bound of time interval.
556 mjd_end : `float`
557 Upper bound of time interval.
558 table_name : `ApdbTables`
559 Name of the table.
560
561 Returns
562 -------
563 catalog : `pandas.DataFrame`, or `None`
564 Catalog contaning DiaSource records. Empty catalog is returned if
565 ``object_ids`` is empty.
566 """
567 object_id_set: Set[int] = set()
568 if object_ids is not None:
569 object_id_set = set(object_ids)
570 if len(object_id_set) == 0:
571 return self._make_empty_catalog(table_name)
572
573 sp_where = self._spatial_where(region)
574 tables, temporal_where = self._temporal_where(table_name, mjd_start, mjd_end)
575
576 # Build all queries
577 statements: List[Tuple] = []
578 for table in tables:
579 prefix = f'SELECT * from "{self._keyspace}"."{table}"'
580 statements += list(self._combine_where(prefix, sp_where, temporal_where))
581 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements))
582
583 with Timer(table_name.name + ' select', self.config.timer):
584 catalog = cast(
585 pandas.DataFrame,
587 self._session, statements, "read_pandas_multi", self.config.read_concurrency
588 )
589 )
590
591 # filter by given object IDs
592 if len(object_id_set) > 0:
593 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)])
594
595 # precise filtering on midPointTai
596 catalog = cast(pandas.DataFrame, catalog[catalog["midPointTai"] > mjd_start])
597
598 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name)
599 return catalog
600
601 def _getSourcesHistory(
602 self,
603 table: ApdbTables,
604 start_time: dafBase.DateTime,
605 end_time: dafBase.DateTime,
606 region: Optional[sphgeom.Region] = None,
607 ) -> pandas.DataFrame:
608 """Returns catalog of DiaSource instances given set of DiaObject IDs.
609
610 Parameters
611 ----------
612 table : `ApdbTables`
613 Name of the table.
614 start_time : `dafBase.DateTime`
615 Starting time for DiaSource history search. DiaSource record is
616 selected when its ``midPointTai`` falls into an interval between
617 ``start_time`` (inclusive) and ``end_time`` (exclusive).
618 end_time : `dafBase.DateTime`
619 Upper limit on time for DiaSource history search.
620 region : `lsst.sphgeom.Region`
621 Spherical region.
622
623 Returns
624 -------
625 catalog : `pandas.DataFrame`
626 Catalog contaning DiaSource records.
627 """
628 sp_where = self._spatial_where(region, use_ranges=False)
629 tables, temporal_where = self._temporal_where(table, start_time, end_time, True)
630
631 # Build all queries
632 statements: List[Tuple] = []
633 for table_name in tables:
634 prefix = f'SELECT * from "{self._keyspace}"."{table_name}"'
635 statements += list(self._combine_where(prefix, sp_where, temporal_where, "ALLOW FILTERING"))
636 _LOG.debug("getDiaObjectsHistory: #queries: %s", len(statements))
637
638 # Run all selects in parallel
639 with Timer(f"{table.name} history", self.config.timer):
640 catalog = cast(
641 pandas.DataFrame,
643 self._session, statements, "read_pandas_multi", self.config.read_concurrency
644 )
645 )
646
647 # precise filtering on validityStart
648 period_start = start_time.get(system=dafBase.DateTime.MJD)
649 period_end = end_time.get(system=dafBase.DateTime.MJD)
650 catalog = cast(
651 pandas.DataFrame,
652 catalog[(catalog["midPointTai"] >= period_start) & (catalog["midPointTai"] < period_end)]
653 )
654
655 _LOG.debug("found %d %ss", catalog.shape[0], table.name)
656 return catalog
657
658 def _storeDiaObjects(self, objs: pandas.DataFrame, visit_time: dafBase.DateTime) -> None:
659 """Store catalog of DiaObjects from current visit.
660
661 Parameters
662 ----------
663 objs : `pandas.DataFrame`
664 Catalog with DiaObject records
665 visit_time : `lsst.daf.base.DateTime`
666 Time of the current visit.
667 """
668 visit_time_dt = visit_time.toPython()
669 extra_columns = dict(lastNonForcedSource=visit_time_dt)
670 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns)
671
672 extra_columns["validityStart"] = visit_time_dt
673 time_part: Optional[int] = self._time_partition(visit_time)
674 if not self.config.time_partition_tables:
675 extra_columns["apdb_time_part"] = time_part
676 time_part = None
677
678 self._storeObjectsPandas(objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part)
679
680 def _storeDiaSources(self, table_name: ApdbTables, sources: pandas.DataFrame,
681 visit_time: dafBase.DateTime) -> None:
682 """Store catalog of DIASources or DIAForcedSources from current visit.
683
684 Parameters
685 ----------
686 sources : `pandas.DataFrame`
687 Catalog containing DiaSource records
688 visit_time : `lsst.daf.base.DateTime`
689 Time of the current visit.
690 """
691 time_part: Optional[int] = self._time_partition(visit_time)
692 extra_columns = {}
693 if not self.config.time_partition_tables:
694 extra_columns["apdb_time_part"] = time_part
695 time_part = None
696
697 self._storeObjectsPandas(sources, table_name, extra_columns=extra_columns, time_part=time_part)
698
699 def _storeDiaSourcesPartitions(self, sources: pandas.DataFrame, visit_time: dafBase.DateTime) -> None:
700 """Store mapping of diaSourceId to its partitioning values.
701
702 Parameters
703 ----------
704 sources : `pandas.DataFrame`
705 Catalog containing DiaSource records
706 visit_time : `lsst.daf.base.DateTime`
707 Time of the current visit.
708 """
709 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]])
710 extra_columns = {
711 "apdb_time_part": self._time_partition(visit_time),
712 }
713
714 self._storeObjectsPandas(
715 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None
716 )
717
718 def _storeObjectsPandas(self, objects: pandas.DataFrame, table_name: Union[ApdbTables, ExtraTables],
719 extra_columns: Optional[Mapping] = None,
720 time_part: Optional[int] = None) -> None:
721 """Generic store method.
722
723 Takes catalog of records and stores a bunch of objects in a table.
724
725 Parameters
726 ----------
727 objects : `pandas.DataFrame`
728 Catalog containing object records
729 table_name : `ApdbTables`
730 Name of the table as defined in APDB schema.
731 extra_columns : `dict`, optional
732 Mapping (column_name, column_value) which gives column values to add
733 to every row, only if column is missing in catalog records.
734 time_part : `int`, optional
735 If not `None` then insert into a per-partition table.
736 """
737 # use extra columns if specified
738 if extra_columns is None:
739 extra_columns = {}
740 extra_fields = list(extra_columns.keys())
741
742 df_fields = [
743 column for column in objects.columns if column not in extra_fields
744 ]
745
746 column_map = self._schema.getColumnMap(table_name)
747 # list of columns (as in cat schema)
748 fields = [column_map[field].name for field in df_fields if field in column_map]
749 fields += extra_fields
750
751 # check that all partitioning and clustering columns are defined
752 required_columns = self._schema.partitionColumns(table_name) \
753 + self._schema.clusteringColumns(table_name)
754 missing_columns = [column for column in required_columns if column not in fields]
755 if missing_columns:
756 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}")
757
758 qfields = [quote_id(field) for field in fields]
759 qfields_str = ','.join(qfields)
760
761 with Timer(table_name.name + ' query build', self.config.timer):
762
763 table = self._schema.tableName(table_name)
764 if time_part is not None:
765 table = f"{table}_{time_part}"
766
767 holders = ','.join(['?']*len(qfields))
768 query = f'INSERT INTO "{self._keyspace}"."{table}" ({qfields_str}) VALUES ({holders})'
769 statement = self._prep_statement(query)
770 queries = cassandra.query.BatchStatement()
771 for rec in objects.itertuples(index=False):
772 values = []
773 for field in df_fields:
774 if field not in column_map:
775 continue
776 value = getattr(rec, field)
777 if column_map[field].type == "DATETIME":
778 if isinstance(value, pandas.Timestamp):
779 value = literal(value.to_pydatetime())
780 else:
781 # Assume it's seconds since epoch, Cassandra
782 # datetime is in milliseconds
783 value = int(value*1000)
784 values.append(literal(value))
785 for field in extra_fields:
786 value = extra_columns[field]
787 values.append(literal(value))
788 queries.add(statement, values)
789
790 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), objects.shape[0])
791 with Timer(table_name.name + ' insert', self.config.timer):
792 self._session.execute(queries, timeout=self.config.write_timeout, execution_profile="write")
793
794 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame:
795 """Calculate spacial partition for each record and add it to a
796 DataFrame.
797
798 Notes
799 -----
800 This overrides any existing column in a DataFrame with the same name
801 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
802 returned.
803 """
804 # calculate HTM index for every DiaObject
805 apdb_part = np.zeros(df.shape[0], dtype=np.int64)
806 ra_col, dec_col = self.config.ra_dec_columns
807 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
808 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
809 idx = self._pixelization.pixel(uv3d)
810 apdb_part[i] = idx
811 df = df.copy()
812 df["apdb_part"] = apdb_part
813 return df
814
815 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
816 """Add apdb_part column to DiaSource catalog.
817
818 Notes
819 -----
820 This method copies apdb_part value from a matching DiaObject record.
821 DiaObject catalog needs to have a apdb_part column filled by
822 ``_add_obj_part`` method and DiaSource records need to be
823 associated to DiaObjects via ``diaObjectId`` column.
824
825 This overrides any existing column in a DataFrame with the same name
826 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
827 returned.
828 """
829 pixel_id_map: Dict[int, int] = {
830 diaObjectId: apdb_part for diaObjectId, apdb_part
831 in zip(objs["diaObjectId"], objs["apdb_part"])
832 }
833 apdb_part = np.zeros(sources.shape[0], dtype=np.int64)
834 ra_col, dec_col = self.config.ra_dec_columns
835 for i, (diaObjId, ra, dec) in enumerate(zip(sources["diaObjectId"],
836 sources[ra_col], sources[dec_col])):
837 if diaObjId == 0:
838 # DiaSources associated with SolarSystemObjects do not have an
839 # associated DiaObject hence we skip them and set partition
840 # based on its own ra/dec
841 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
842 idx = self._pixelization.pixel(uv3d)
843 apdb_part[i] = idx
844 else:
845 apdb_part[i] = pixel_id_map[diaObjId]
846 sources = sources.copy()
847 sources["apdb_part"] = apdb_part
848 return sources
849
850 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
851 """Add apdb_part column to DiaForcedSource catalog.
852
853 Notes
854 -----
855 This method copies apdb_part value from a matching DiaObject record.
856 DiaObject catalog needs to have a apdb_part column filled by
857 ``_add_obj_part`` method and DiaSource records need to be
858 associated to DiaObjects via ``diaObjectId`` column.
859
860 This overrides any existing column in a DataFrame with the same name
861 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
862 returned.
863 """
864 pixel_id_map: Dict[int, int] = {
865 diaObjectId: apdb_part for diaObjectId, apdb_part
866 in zip(objs["diaObjectId"], objs["apdb_part"])
867 }
868 apdb_part = np.zeros(sources.shape[0], dtype=np.int64)
869 for i, diaObjId in enumerate(sources["diaObjectId"]):
870 apdb_part[i] = pixel_id_map[diaObjId]
871 sources = sources.copy()
872 sources["apdb_part"] = apdb_part
873 return sources
874
875 def _time_partition(self, time: Union[float, dafBase.DateTime]) -> int:
876 """Calculate time partiton number for a given time.
877
878 Parameters
879 ----------
880 time : `float` or `lsst.daf.base.DateTime`
881 Time for which to calculate partition number. Can be float to mean
883
884 Returns
885 -------
886 partition : `int`
887 Partition number for a given time.
888 """
889 if isinstance(time, dafBase.DateTime):
890 mjd = time.get(system=dafBase.DateTime.MJD)
891 else:
892 mjd = time
893 days_since_epoch = mjd - self._partition_zero_epoch_mjd
894 partition = int(days_since_epoch) // self.config.time_partition_days
895 return partition
896
897 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame:
898 """Make an empty catalog for a table with a given name.
899
900 Parameters
901 ----------
902 table_name : `ApdbTables`
903 Name of the table.
904
905 Returns
906 -------
907 catalog : `pandas.DataFrame`
908 An empty catalog.
909 """
910 table = self._schema.tableSchemas[table_name]
911
912 data = {columnDef.name: pandas.Series(dtype=columnDef.dtype) for columnDef in table.columns}
913 return pandas.DataFrame(data)
914
915 def _prep_statement(self, query: str) -> cassandra.query.PreparedStatement:
916 """Convert query string into prepared statement."""
917 stmt = self._prepared_statements.get(query)
918 if stmt is None:
919 stmt = self._session.prepare(query)
920 self._prepared_statements[query] = stmt
921 return stmt
922
923 def _combine_where(
924 self,
925 prefix: str,
926 where1: List[Tuple[str, Tuple]],
927 where2: List[Tuple[str, Tuple]],
928 suffix: Optional[str] = None,
929 ) -> Iterator[Tuple[cassandra.query.Statement, Tuple]]:
930 """Make cartesian product of two parts of WHERE clause into a series
931 of statements to execute.
932
933 Parameters
934 ----------
935 prefix : `str`
936 Initial statement prefix that comes before WHERE clause, e.g.
937 "SELECT * from Table"
938 """
939 # If lists are empty use special sentinels.
940 if not where1:
941 where1 = [("", ())]
942 if not where2:
943 where2 = [("", ())]
944
945 for expr1, params1 in where1:
946 for expr2, params2 in where2:
947 full_query = prefix
948 wheres = []
949 if expr1:
950 wheres.append(expr1)
951 if expr2:
952 wheres.append(expr2)
953 if wheres:
954 full_query += " WHERE " + " AND ".join(wheres)
955 if suffix:
956 full_query += " " + suffix
957 params = params1 + params2
958 if params:
959 statement = self._prep_statement(full_query)
960 else:
961 # If there are no params then it is likely that query
962 # has a bunch of literals rendered already, no point
963 # trying to prepare it.
964 statement = cassandra.query.SimpleStatement(full_query)
965 yield (statement, params)
966
967 def _spatial_where(
968 self, region: Optional[sphgeom.Region], use_ranges: bool = False
969 ) -> List[Tuple[str, Tuple]]:
970 """Generate expressions for spatial part of WHERE clause.
971
972 Parameters
973 ----------
974 region : `sphgeom.Region`
975 Spatial region for query results.
976 use_ranges : `bool`
977 If True then use pixel ranges ("apdb_part >= p1 AND apdb_part <=
978 p2") instead of exact list of pixels. Should be set to True for
979 large regions covering very many pixels.
980
981 Returns
982 -------
983 expressions : `list` [ `tuple` ]
984 Empty list is returned if ``region`` is `None`, otherwise a list
985 of one or more (expression, parameters) tuples
986 """
987 if region is None:
988 return []
989 if use_ranges:
990 pixel_ranges = self._pixelization.envelope(region)
991 expressions: List[Tuple[str, Tuple]] = []
992 for lower, upper in pixel_ranges:
993 upper -= 1
994 if lower == upper:
995 expressions.append(('"apdb_part" = ?', (lower, )))
996 else:
997 expressions.append(('"apdb_part" >= ? AND "apdb_part" <= ?', (lower, upper)))
998 return expressions
999 else:
1000 pixels = self._pixelization.pixels(region)
1001 if self.config.query_per_spatial_part:
1002 return [('"apdb_part" = ?', (pixel,)) for pixel in pixels]
1003 else:
1004 pixels_str = ",".join([str(pix) for pix in pixels])
1005 return [(f'"apdb_part" IN ({pixels_str})', ())]
1006
1007 def _temporal_where(
1008 self,
1009 table: ApdbTables,
1010 start_time: Union[float, dafBase.DateTime],
1011 end_time: Union[float, dafBase.DateTime],
1012 query_per_time_part: Optional[bool] = None,
1013 ) -> Tuple[List[str], List[Tuple[str, Tuple]]]:
1014 """Generate table names and expressions for temporal part of WHERE
1015 clauses.
1016
1017 Parameters
1018 ----------
1019 table : `ApdbTables`
1020 Table to select from.
1021 start_time : `dafBase.DateTime` or `float`
1022 Starting Datetime of MJD value of the time range.
1023 start_time : `dafBase.DateTime` or `float`
1024 Starting Datetime of MJD value of the time range.
1025 query_per_time_part : `bool`, optional
1026 If None then use ``query_per_time_part`` from configuration.
1027
1028 Returns
1029 -------
1030 tables : `list` [ `str` ]
1031 List of the table names to query.
1032 expressions : `list` [ `tuple` ]
1033 A list of zero or more (expression, parameters) tuples.
1034 """
1035 tables: List[str]
1036 temporal_where: List[Tuple[str, Tuple]] = []
1037 table_name = self._schema.tableName(table)
1038 time_part_start = self._time_partition(start_time)
1039 time_part_end = self._time_partition(end_time)
1040 time_parts = list(range(time_part_start, time_part_end + 1))
1041 if self.config.time_partition_tables:
1042 tables = [f"{table_name}_{part}" for part in time_parts]
1043 else:
1044 tables = [table_name]
1045 if query_per_time_part is None:
1046 query_per_time_part = self.config.query_per_time_part
1047 if query_per_time_part:
1048 temporal_where = [
1049 ('"apdb_time_part" = ?', (time_part,)) for time_part in time_parts
1050 ]
1051 else:
1052 time_part_list = ",".join([str(part) for part in time_parts])
1053 temporal_where = [(f'"apdb_time_part" IN ({time_part_list})', ())]
1054
1055 return tables, temporal_where
std::vector< SchemaItem< Flag > > * items
table::PointKey< int > pixel
Class for handling dates/times, including MJD, UTC, and TAI.
Definition: DateTime.h:64
def __init__(self, List[str] public_ips, List[str] private_ips)
def __init__(self, ApdbCassandraConfig config)
pandas.DataFrame getDiaForcedSourcesHistory(self, dafBase.DateTime start_time, dafBase.DateTime end_time, Optional[sphgeom.Region] region=None)
pandas.DataFrame getDiaSourcesHistory(self, dafBase.DateTime start_time, dafBase.DateTime end_time, Optional[sphgeom.Region] region=None)
pandas.DataFrame getDiaObjects(self, sphgeom.Region region)
cassandra.query.PreparedStatement _prep_statement(self, str query)
Optional[pandas.DataFrame] getDiaForcedSources(self, sphgeom.Region region, Optional[Iterable[int]] object_ids, dafBase.DateTime visit_time)
None store(self, dafBase.DateTime visit_time, pandas.DataFrame objects, Optional[pandas.DataFrame] sources=None, Optional[pandas.DataFrame] forced_sources=None)
None reassignDiaSources(self, Mapping[int, int] idMap)
Iterator[Tuple[cassandra.query.Statement, Tuple]] _combine_where(self, str prefix, List[Tuple[str, Tuple]] where1, List[Tuple[str, Tuple]] where2, Optional[str] suffix=None)
pandas.DataFrame _getSourcesHistory(self, ApdbTables table, dafBase.DateTime start_time, dafBase.DateTime end_time, Optional[sphgeom.Region] region=None)
pandas.DataFrame _add_fsrc_part(self, pandas.DataFrame sources, pandas.DataFrame objs)
None _storeDiaSources(self, ApdbTables table_name, pandas.DataFrame sources, dafBase.DateTime visit_time)
None makeSchema(self, bool drop=False)
None storeSSObjects(self, pandas.DataFrame objects)
None _storeObjectsPandas(self, pandas.DataFrame objects, Union[ApdbTables, ExtraTables] table_name, Optional[Mapping] extra_columns=None, Optional[int] time_part=None)
Tuple[List[str], List[Tuple[str, Tuple]]] _temporal_where(self, ApdbTables table, Union[float, dafBase.DateTime] start_time, Union[float, dafBase.DateTime] end_time, Optional[bool] query_per_time_part=None)
pandas.DataFrame _add_src_part(self, pandas.DataFrame sources, pandas.DataFrame objs)
Optional[TableDef] tableDef(self, ApdbTables table)
int _time_partition(self, Union[float, dafBase.DateTime] time)
None _storeDiaObjects(self, pandas.DataFrame objs, dafBase.DateTime visit_time)
Mapping[Any, ExecutionProfile] _makeProfiles(self, ApdbCassandraConfig config)
None _storeDiaSourcesPartitions(self, pandas.DataFrame sources, dafBase.DateTime visit_time)
Optional[pandas.DataFrame] getDiaSources(self, sphgeom.Region region, Optional[Iterable[int]] object_ids, dafBase.DateTime visit_time)
pandas.DataFrame getDiaObjectsHistory(self, dafBase.DateTime start_time, dafBase.DateTime end_time, Optional[sphgeom.Region] region=None)
List[Tuple[str, Tuple]] _spatial_where(self, Optional[sphgeom.Region] region, bool use_ranges=False)
pandas.DataFrame _getSources(self, sphgeom.Region region, Optional[Iterable[int]] object_ids, float mjd_start, float mjd_end, ApdbTables table_name)
pandas.DataFrame _add_obj_part(self, pandas.DataFrame df)
Region is a minimal interface for 2-dimensional regions on the unit sphere.
Definition: Region.h:79
UnitVector3d is a unit vector in ℝ³ with components stored in double precision.
Definition: UnitVector3d.h:55
daf::base::PropertyList * list
Definition: fits.cc:913
daf::base::PropertySet * set
Definition: fits.cc:912
str quote_id(str columnName)
Union[pandas.DataFrame, List] select_concurrent(Session session, List[Tuple] statements, str execution_profile, int concurrency)