LSST Applications g070148d5b3+33e5256705,g0d53e28543+25c8b88941,g0da5cf3356+2dd1178308,g1081da9e2a+62d12e78cb,g17e5ecfddb+7e422d6136,g1c76d35bf8+ede3a706f7,g295839609d+225697d880,g2e2c1a68ba+cc1f6f037e,g2ffcdf413f+853cd4dcde,g38293774b4+62d12e78cb,g3b44f30a73+d953f1ac34,g48ccf36440+885b902d19,g4b2f1765b6+7dedbde6d2,g5320a0a9f6+0c5d6105b6,g56b687f8c9+ede3a706f7,g5c4744a4d9+ef6ac23297,g5ffd174ac0+0c5d6105b6,g6075d09f38+66af417445,g667d525e37+2ced63db88,g670421136f+2ced63db88,g71f27ac40c+2ced63db88,g774830318a+463cbe8d1f,g7876bc68e5+1d137996f1,g7985c39107+62d12e78cb,g7fdac2220c+0fd8241c05,g96f01af41f+368e6903a7,g9ca82378b8+2ced63db88,g9d27549199+ef6ac23297,gabe93b2c52+e3573e3735,gb065e2a02a+3dfbe639da,gbc3249ced9+0c5d6105b6,gbec6a3398f+0c5d6105b6,gc9534b9d65+35b9f25267,gd01420fc67+0c5d6105b6,geee7ff78d7+a14128c129,gf63283c776+ede3a706f7,gfed783d017+0c5d6105b6,w.2022.47
LSST Data Management Base Package
Loading...
Searching...
No Matches
apdbCassandra.py
Go to the documentation of this file.
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21
22from __future__ import annotations
23
24__all__ = ["ApdbCassandraConfig", "ApdbCassandra"]
25
26import logging
27from typing import Any, Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Set, Tuple, Union, cast
28
29import numpy as np
30import pandas
31
32# If cassandra-driver is not there the module can still be imported
33# but ApdbCassandra cannot be instantiated.
34try:
35 import cassandra
36 import cassandra.query
37 from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile
38 from cassandra.policies import AddressTranslator, RoundRobinPolicy, WhiteListRoundRobinPolicy
39 CASSANDRA_IMPORTED = True
40except ImportError:
41 CASSANDRA_IMPORTED = False
42
43import felis.types
44import lsst.daf.base as dafBase
45from felis.simple import Table
46from lsst import sphgeom
47from lsst.pex.config import ChoiceField, Field, ListField
48from lsst.utils.iteration import chunk_iterable
49
50from .apdb import Apdb, ApdbConfig
51from .apdbCassandraSchema import ApdbCassandraSchema, ExtraTables
52from .apdbSchema import ApdbTables
53from .cassandra_utils import literal, pandas_dataframe_factory, quote_id, raw_data_factory, select_concurrent
54from .pixelization import Pixelization
55from .timer import Timer
56
57_LOG = logging.getLogger(__name__)
58
59
60class CassandraMissingError(Exception):
61 def __init__(self) -> None:
62 super().__init__("cassandra-driver module cannot be imported")
63
64
66
67 contact_points = ListField[str](
68 doc="The list of contact points to try connecting for cluster discovery.",
69 default=["127.0.0.1"]
70 )
71 private_ips = ListField[str](
72 doc="List of internal IP addresses for contact_points.",
73 default=[]
74 )
75 keyspace = Field[str](
76 doc="Default keyspace for operations.",
77 default="apdb"
78 )
79 read_consistency = Field[str](
80 doc="Name for consistency level of read operations, default: QUORUM, can be ONE.",
81 default="QUORUM"
82 )
83 write_consistency = Field[str](
84 doc="Name for consistency level of write operations, default: QUORUM, can be ONE.",
85 default="QUORUM"
86 )
87 read_timeout = Field[float](
88 doc="Timeout in seconds for read operations.",
89 default=120.
90 )
91 write_timeout = Field[float](
92 doc="Timeout in seconds for write operations.",
93 default=10.
94 )
95 read_concurrency = Field[int](
96 doc="Concurrency level for read operations.",
97 default=500
98 )
99 protocol_version = Field[int](
100 doc="Cassandra protocol version to use, default is V4",
101 default=cassandra.ProtocolVersion.V4 if CASSANDRA_IMPORTED else 0
102 )
103 dia_object_columns = ListField[str](
104 doc="List of columns to read from DiaObject, by default read all columns",
105 default=[]
106 )
107 prefix = Field[str](
108 doc="Prefix to add to table names",
109 default=""
110 )
111 part_pixelization = ChoiceField[str](
112 allowed=dict(htm="HTM pixelization", q3c="Q3C pixelization", mq3c="MQ3C pixelization"),
113 doc="Pixelization used for partitioning index.",
114 default="mq3c"
115 )
116 part_pix_level = Field[int](
117 doc="Pixelization level used for partitioning index.",
118 default=10
119 )
120 part_pix_max_ranges = Field[int](
121 doc="Max number of ranges in pixelization envelope",
122 default=64
123 )
124 ra_dec_columns = ListField[str](
125 default=["ra", "decl"],
126 doc="Names ra/dec columns in DiaObject table"
127 )
128 timer = Field[bool](
129 doc="If True then print/log timing information",
130 default=False
131 )
132 time_partition_tables = Field[bool](
133 doc="Use per-partition tables for sources instead of partitioning by time",
134 default=True
135 )
136 time_partition_days = Field[int](
137 doc="Time partitioning granularity in days, this value must not be changed"
138 " after database is initialized",
139 default=30
140 )
141 time_partition_start = Field[str](
142 doc="Starting time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI."
143 " This is used only when time_partition_tables is True.",
144 default="2018-12-01T00:00:00"
145 )
146 time_partition_end = Field[str](
147 doc="Ending time for per-partition tables, in yyyy-mm-ddThh:mm:ss format, in TAI"
148 " This is used only when time_partition_tables is True.",
149 default="2030-01-01T00:00:00"
150 )
151 query_per_time_part = Field[bool](
152 default=False,
153 doc="If True then build separate query for each time partition, otherwise build one single query. "
154 "This is only used when time_partition_tables is False in schema config."
155 )
156 query_per_spatial_part = Field[bool](
157 default=False,
158 doc="If True then build one query per spacial partition, otherwise build single query. "
159 )
160 pandas_delay_conv = Field[bool](
161 default=True,
162 doc="If True then combine result rows before converting to pandas. "
163 )
164
165
166if CASSANDRA_IMPORTED:
167
168 class _AddressTranslator(AddressTranslator):
169 """Translate internal IP address to external.
170
171 Only used for docker-based setup, not viable long-term solution.
172 """
173 def __init__(self, public_ips: List[str], private_ips: List[str]):
174 self._map = dict((k, v) for k, v in zip(private_ips, public_ips))
175
176 def translate(self, private_ip: str) -> str:
177 return self._map.get(private_ip, private_ip)
178
179
181 """Implementation of APDB database on to of Apache Cassandra.
182
183 The implementation is configured via standard ``pex_config`` mechanism
184 using `ApdbCassandraConfig` configuration class. For an example of
185 different configurations check config/ folder.
186
187 Parameters
188 ----------
189 config : `ApdbCassandraConfig`
190 Configuration object.
191 """
192
193 partition_zero_epoch = dafBase.DateTime(1970, 1, 1, 0, 0, 0, dafBase.DateTime.TAI)
194 """Start time for partition 0, this should never be changed."""
195
196 def __init__(self, config: ApdbCassandraConfig):
197
198 if not CASSANDRA_IMPORTED:
200
201 config.validate()
202 self.config = config
203
204 _LOG.debug("ApdbCassandra Configuration:")
205 for key, value in self.config.items():
206 _LOG.debug(" %s: %s", key, value)
207
209 config.part_pixelization, config.part_pix_level, config.part_pix_max_ranges
210 )
211
212 addressTranslator: Optional[AddressTranslator] = None
213 if config.private_ips:
214 addressTranslator = _AddressTranslator(config.contact_points, config.private_ips)
215
216 self._keyspace = config.keyspace
217
218 self._cluster = Cluster(execution_profiles=self._makeProfiles(config),
219 contact_points=self.config.contact_points,
220 address_translator=addressTranslator,
221 protocol_version=self.config.protocol_version)
222 self._session = self._cluster.connect()
223 # Disable result paging
224 self._session.default_fetch_size = None
225
226 self._schema = ApdbCassandraSchema(session=self._session,
227 keyspace=self._keyspace,
228 schema_file=self.config.schema_file,
229 schema_name=self.config.schema_name,
230 prefix=self.config.prefix,
231 time_partition_tables=self.config.time_partition_tables)
232 self._partition_zero_epoch_mjd = self.partition_zero_epoch.get(system=dafBase.DateTime.MJD)
233
234 # Cache for prepared statements
235 self._prepared_statements: Dict[str, cassandra.query.PreparedStatement] = {}
236
237 def __del__(self) -> None:
238 self._cluster.shutdown()
239
240 def tableDef(self, table: ApdbTables) -> Optional[Table]:
241 # docstring is inherited from a base class
242 return self._schema.tableSchemas.get(table)
243
244 def makeSchema(self, drop: bool = False) -> None:
245 # docstring is inherited from a base class
246
247 if self.config.time_partition_tables:
248 time_partition_start = dafBase.DateTime(self.config.time_partition_start, dafBase.DateTime.TAI)
249 time_partition_end = dafBase.DateTime(self.config.time_partition_end, dafBase.DateTime.TAI)
250 part_range = (
251 self._time_partition(time_partition_start),
252 self._time_partition(time_partition_end) + 1
253 )
254 self._schema.makeSchema(drop=drop, part_range=part_range)
255 else:
256 self._schema.makeSchema(drop=drop)
257
258 def getDiaObjects(self, region: sphgeom.Region) -> pandas.DataFrame:
259 # docstring is inherited from a base class
260
261 sp_where = self._spatial_where(region)
262 _LOG.debug("getDiaObjects: #partitions: %s", len(sp_where))
263
264 table_name = self._schema.tableName(ApdbTables.DiaObjectLast)
265 query = f'SELECT * from "{self._keyspace}"."{table_name}"'
266 statements: List[Tuple] = []
267 for where, params in sp_where:
268 full_query = f"{query} WHERE {where}"
269 if params:
270 statement = self._prep_statement(full_query)
271 else:
272 # If there are no params then it is likely that query has a
273 # bunch of literals rendered already, no point trying to
274 # prepare it because it's not reusable.
275 statement = cassandra.query.SimpleStatement(full_query)
276 statements.append((statement, params))
277 _LOG.debug("getDiaObjects: #queries: %s", len(statements))
278
279 with Timer('DiaObject select', self.config.timer):
280 objects = cast(
281 pandas.DataFrame,
282 select_concurrent(
283 self._session, statements, "read_pandas_multi", self.config.read_concurrency
284 )
285 )
286
287 _LOG.debug("found %s DiaObjects", objects.shape[0])
288 return objects
289
290 def getDiaSources(self, region: sphgeom.Region,
291 object_ids: Optional[Iterable[int]],
292 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]:
293 # docstring is inherited from a base class
294 months = self.config.read_sources_months
295 if months == 0:
296 return None
297 mjd_end = visit_time.get(system=dafBase.DateTime.MJD)
298 mjd_start = mjd_end - months*30
299
300 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaSource)
301
302 def getDiaForcedSources(self, region: sphgeom.Region,
303 object_ids: Optional[Iterable[int]],
304 visit_time: dafBase.DateTime) -> Optional[pandas.DataFrame]:
305 # docstring is inherited from a base class
306 months = self.config.read_forced_sources_months
307 if months == 0:
308 return None
309 mjd_end = visit_time.get(system=dafBase.DateTime.MJD)
310 mjd_start = mjd_end - months*30
311
312 return self._getSources(region, object_ids, mjd_start, mjd_end, ApdbTables.DiaForcedSource)
313
315 start_time: dafBase.DateTime,
316 end_time: dafBase.DateTime,
317 region: Optional[sphgeom.Region] = None) -> pandas.DataFrame:
318 # docstring is inherited from a base class
319
320 sp_where = self._spatial_where(region, use_ranges=True)
321 tables, temporal_where = self._temporal_where(ApdbTables.DiaObject, start_time, end_time, True)
322
323 # Build all queries
324 statements: List[Tuple] = []
325 for table in tables:
326 prefix = f'SELECT * from "{self._keyspace}"."{table}"'
327 statements += list(self._combine_where(prefix, sp_where, temporal_where, "ALLOW FILTERING"))
328 _LOG.debug("getDiaObjectsHistory: #queries: %s", len(statements))
329
330 # Run all selects in parallel
331 with Timer("DiaObject history", self.config.timer):
332 catalog = cast(
333 pandas.DataFrame,
334 select_concurrent(
335 self._session, statements, "read_pandas_multi", self.config.read_concurrency
336 )
337 )
338
339 # precise filtering on validityStart
340 validity_start = start_time.toPython()
341 validity_end = end_time.toPython()
342 catalog = cast(
343 pandas.DataFrame,
344 catalog[(catalog["validityStart"] >= validity_start) & (catalog["validityStart"] < validity_end)]
345 )
346
347 _LOG.debug("found %d DiaObjects", catalog.shape[0])
348 return catalog
349
351 start_time: dafBase.DateTime,
352 end_time: dafBase.DateTime,
353 region: Optional[sphgeom.Region] = None) -> pandas.DataFrame:
354 # docstring is inherited from a base class
355 return self._getSourcesHistory(ApdbTables.DiaSource, start_time, end_time, region)
356
358 start_time: dafBase.DateTime,
359 end_time: dafBase.DateTime,
360 region: Optional[sphgeom.Region] = None) -> pandas.DataFrame:
361 # docstring is inherited from a base class
362 return self._getSourcesHistory(ApdbTables.DiaForcedSource, start_time, end_time, region)
363
364 def getSSObjects(self) -> pandas.DataFrame:
365 # docstring is inherited from a base class
366 tableName = self._schema.tableName(ApdbTables.SSObject)
367 query = f'SELECT * from "{self._keyspace}"."{tableName}"'
368
369 objects = None
370 with Timer('SSObject select', self.config.timer):
371 result = self._session.execute(query, execution_profile="read_pandas")
372 objects = result._current_rows
373
374 _LOG.debug("found %s DiaObjects", objects.shape[0])
375 return objects
376
377 def store(self,
378 visit_time: dafBase.DateTime,
379 objects: pandas.DataFrame,
380 sources: Optional[pandas.DataFrame] = None,
381 forced_sources: Optional[pandas.DataFrame] = None) -> None:
382 # docstring is inherited from a base class
383
384 # fill region partition column for DiaObjects
385 objects = self._add_obj_part(objects)
386 self._storeDiaObjects(objects, visit_time)
387
388 if sources is not None:
389 # copy apdb_part column from DiaObjects to DiaSources
390 sources = self._add_src_part(sources, objects)
391 self._storeDiaSources(ApdbTables.DiaSource, sources, visit_time)
392 self._storeDiaSourcesPartitions(sources, visit_time)
393
394 if forced_sources is not None:
395 forced_sources = self._add_fsrc_part(forced_sources, objects)
396 self._storeDiaSources(ApdbTables.DiaForcedSource, forced_sources, visit_time)
397
398 def storeSSObjects(self, objects: pandas.DataFrame) -> None:
399 # docstring is inherited from a base class
400 self._storeObjectsPandas(objects, ApdbTables.SSObject)
401
402 def reassignDiaSources(self, idMap: Mapping[int, int]) -> None:
403 # docstring is inherited from a base class
404
405 # To update a record we need to know its exact primary key (including
406 # partition key) so we start by querying for diaSourceId to find the
407 # primary keys.
408
409 table_name = self._schema.tableName(ExtraTables.DiaSourceToPartition)
410 # split it into 1k IDs per query
411 selects: List[Tuple] = []
412 for ids in chunk_iterable(idMap.keys(), 1_000):
413 ids_str = ",".join(str(item) for item in ids)
414 selects.append((
415 (f'SELECT "diaSourceId", "apdb_part", "apdb_time_part" FROM "{self._keyspace}"."{table_name}"'
416 f' WHERE "diaSourceId" IN ({ids_str})'),
417 {}
418 ))
419
420 # No need for DataFrame here, read data as tuples.
421 result = cast(
422 List[Tuple[int, int, int]],
423 select_concurrent(self._session, selects, "read_tuples", self.config.read_concurrency)
424 )
425
426 # Make mapping from source ID to its partition.
427 id2partitions: Dict[int, Tuple[int, int]] = {}
428 for row in result:
429 id2partitions[row[0]] = row[1:]
430
431 # make sure we know partitions for each ID
432 if set(id2partitions) != set(idMap):
433 missing = ",".join(str(item) for item in set(idMap) - set(id2partitions))
434 raise ValueError(f"Following DiaSource IDs do not exist in the database: {missing}")
435
436 queries = cassandra.query.BatchStatement()
437 table_name = self._schema.tableName(ApdbTables.DiaSource)
438 for diaSourceId, ssObjectId in idMap.items():
439 apdb_part, apdb_time_part = id2partitions[diaSourceId]
440 values: Tuple
441 if self.config.time_partition_tables:
442 query = (
443 f'UPDATE "{self._keyspace}"."{table_name}_{apdb_time_part}"'
444 ' SET "ssObjectId" = ?, "diaObjectId" = NULL'
445 ' WHERE "apdb_part" = ? AND "diaSourceId" = ?'
446 )
447 values = (ssObjectId, apdb_part, diaSourceId)
448 else:
449 query = (
450 f'UPDATE "{self._keyspace}"."{table_name}"'
451 ' SET "ssObjectId" = ?, "diaObjectId" = NULL'
452 ' WHERE "apdb_part" = ? AND "apdb_time_part" = ? AND "diaSourceId" = ?'
453 )
454 values = (ssObjectId, apdb_part, apdb_time_part, diaSourceId)
455 queries.add(self._prep_statement(query), values)
456
457 _LOG.debug("%s: will update %d records", table_name, len(idMap))
458 with Timer(table_name + ' update', self.config.timer):
459 self._session.execute(queries, execution_profile="write")
460
461 def dailyJob(self) -> None:
462 # docstring is inherited from a base class
463 pass
464
465 def countUnassociatedObjects(self) -> int:
466 # docstring is inherited from a base class
467
468 # It's too inefficient to implement it for Cassandra in current schema.
469 raise NotImplementedError()
470
471 def _makeProfiles(self, config: ApdbCassandraConfig) -> Mapping[Any, ExecutionProfile]:
472 """Make all execution profiles used in the code."""
473
474 if config.private_ips:
475 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points)
476 else:
477 loadBalancePolicy = RoundRobinPolicy()
478
479 pandas_row_factory: Callable
480 if not config.pandas_delay_conv:
481 pandas_row_factory = pandas_dataframe_factory
482 else:
483 pandas_row_factory = raw_data_factory
484
485 read_tuples_profile = ExecutionProfile(
486 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
487 request_timeout=config.read_timeout,
488 row_factory=cassandra.query.tuple_factory,
489 load_balancing_policy=loadBalancePolicy,
490 )
491 read_pandas_profile = ExecutionProfile(
492 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
493 request_timeout=config.read_timeout,
494 row_factory=pandas_dataframe_factory,
495 load_balancing_policy=loadBalancePolicy,
496 )
497 read_pandas_multi_profile = ExecutionProfile(
498 consistency_level=getattr(cassandra.ConsistencyLevel, config.read_consistency),
499 request_timeout=config.read_timeout,
500 row_factory=pandas_row_factory,
501 load_balancing_policy=loadBalancePolicy,
502 )
503 write_profile = ExecutionProfile(
504 consistency_level=getattr(cassandra.ConsistencyLevel, config.write_consistency),
505 request_timeout=config.write_timeout,
506 load_balancing_policy=loadBalancePolicy,
507 )
508 # To replace default DCAwareRoundRobinPolicy
509 default_profile = ExecutionProfile(
510 load_balancing_policy=loadBalancePolicy,
511 )
512 return {
513 "read_tuples": read_tuples_profile,
514 "read_pandas": read_pandas_profile,
515 "read_pandas_multi": read_pandas_multi_profile,
516 "write": write_profile,
517 EXEC_PROFILE_DEFAULT: default_profile,
518 }
519
520 def _getSources(self, region: sphgeom.Region,
521 object_ids: Optional[Iterable[int]],
522 mjd_start: float,
523 mjd_end: float,
524 table_name: ApdbTables) -> pandas.DataFrame:
525 """Returns catalog of DiaSource instances given set of DiaObject IDs.
526
527 Parameters
528 ----------
529 region : `lsst.sphgeom.Region`
530 Spherical region.
531 object_ids :
532 Collection of DiaObject IDs
533 mjd_start : `float`
534 Lower bound of time interval.
535 mjd_end : `float`
536 Upper bound of time interval.
537 table_name : `ApdbTables`
538 Name of the table.
539
540 Returns
541 -------
542 catalog : `pandas.DataFrame`, or `None`
543 Catalog contaning DiaSource records. Empty catalog is returned if
544 ``object_ids`` is empty.
545 """
546 object_id_set: Set[int] = set()
547 if object_ids is not None:
548 object_id_set = set(object_ids)
549 if len(object_id_set) == 0:
550 return self._make_empty_catalog(table_name)
551
552 sp_where = self._spatial_where(region)
553 tables, temporal_where = self._temporal_where(table_name, mjd_start, mjd_end)
554
555 # Build all queries
556 statements: List[Tuple] = []
557 for table in tables:
558 prefix = f'SELECT * from "{self._keyspace}"."{table}"'
559 statements += list(self._combine_where(prefix, sp_where, temporal_where))
560 _LOG.debug("_getSources %s: #queries: %s", table_name, len(statements))
561
562 with Timer(table_name.name + ' select', self.config.timer):
563 catalog = cast(
564 pandas.DataFrame,
565 select_concurrent(
566 self._session, statements, "read_pandas_multi", self.config.read_concurrency
567 )
568 )
569
570 # filter by given object IDs
571 if len(object_id_set) > 0:
572 catalog = cast(pandas.DataFrame, catalog[catalog["diaObjectId"].isin(object_id_set)])
573
574 # precise filtering on midPointTai
575 catalog = cast(pandas.DataFrame, catalog[catalog["midPointTai"] > mjd_start])
576
577 _LOG.debug("found %d %ss", catalog.shape[0], table_name.name)
578 return catalog
579
580 def _getSourcesHistory(
581 self,
582 table: ApdbTables,
583 start_time: dafBase.DateTime,
584 end_time: dafBase.DateTime,
585 region: Optional[sphgeom.Region] = None,
586 ) -> pandas.DataFrame:
587 """Returns catalog of DiaSource instances given set of DiaObject IDs.
588
589 Parameters
590 ----------
591 table : `ApdbTables`
592 Name of the table.
593 start_time : `dafBase.DateTime`
594 Starting time for DiaSource history search. DiaSource record is
595 selected when its ``midPointTai`` falls into an interval between
596 ``start_time`` (inclusive) and ``end_time`` (exclusive).
597 end_time : `dafBase.DateTime`
598 Upper limit on time for DiaSource history search.
599 region : `lsst.sphgeom.Region`
600 Spherical region.
601
602 Returns
603 -------
604 catalog : `pandas.DataFrame`
605 Catalog contaning DiaSource records.
606 """
607 sp_where = self._spatial_where(region, use_ranges=False)
608 tables, temporal_where = self._temporal_where(table, start_time, end_time, True)
609
610 # Build all queries
611 statements: List[Tuple] = []
612 for table_name in tables:
613 prefix = f'SELECT * from "{self._keyspace}"."{table_name}"'
614 statements += list(self._combine_where(prefix, sp_where, temporal_where, "ALLOW FILTERING"))
615 _LOG.debug("getDiaObjectsHistory: #queries: %s", len(statements))
616
617 # Run all selects in parallel
618 with Timer(f"{table.name} history", self.config.timer):
619 catalog = cast(
620 pandas.DataFrame,
621 select_concurrent(
622 self._session, statements, "read_pandas_multi", self.config.read_concurrency
623 )
624 )
625
626 # precise filtering on validityStart
627 period_start = start_time.get(system=dafBase.DateTime.MJD)
628 period_end = end_time.get(system=dafBase.DateTime.MJD)
629 catalog = cast(
630 pandas.DataFrame,
631 catalog[(catalog["midPointTai"] >= period_start) & (catalog["midPointTai"] < period_end)]
632 )
633
634 _LOG.debug("found %d %ss", catalog.shape[0], table.name)
635 return catalog
636
637 def _storeDiaObjects(self, objs: pandas.DataFrame, visit_time: dafBase.DateTime) -> None:
638 """Store catalog of DiaObjects from current visit.
639
640 Parameters
641 ----------
642 objs : `pandas.DataFrame`
643 Catalog with DiaObject records
644 visit_time : `lsst.daf.base.DateTime`
645 Time of the current visit.
646 """
647 visit_time_dt = visit_time.toPython()
648 extra_columns = dict(lastNonForcedSource=visit_time_dt)
649 self._storeObjectsPandas(objs, ApdbTables.DiaObjectLast, extra_columns=extra_columns)
650
651 extra_columns["validityStart"] = visit_time_dt
652 time_part: Optional[int] = self._time_partition(visit_time)
653 if not self.config.time_partition_tables:
654 extra_columns["apdb_time_part"] = time_part
655 time_part = None
656
657 self._storeObjectsPandas(objs, ApdbTables.DiaObject, extra_columns=extra_columns, time_part=time_part)
658
659 def _storeDiaSources(self, table_name: ApdbTables, sources: pandas.DataFrame,
660 visit_time: dafBase.DateTime) -> None:
661 """Store catalog of DIASources or DIAForcedSources from current visit.
662
663 Parameters
664 ----------
665 sources : `pandas.DataFrame`
666 Catalog containing DiaSource records
667 visit_time : `lsst.daf.base.DateTime`
668 Time of the current visit.
669 """
670 time_part: Optional[int] = self._time_partition(visit_time)
671 extra_columns = {}
672 if not self.config.time_partition_tables:
673 extra_columns["apdb_time_part"] = time_part
674 time_part = None
675
676 self._storeObjectsPandas(sources, table_name, extra_columns=extra_columns, time_part=time_part)
677
678 def _storeDiaSourcesPartitions(self, sources: pandas.DataFrame, visit_time: dafBase.DateTime) -> None:
679 """Store mapping of diaSourceId to its partitioning values.
680
681 Parameters
682 ----------
683 sources : `pandas.DataFrame`
684 Catalog containing DiaSource records
685 visit_time : `lsst.daf.base.DateTime`
686 Time of the current visit.
687 """
688 id_map = cast(pandas.DataFrame, sources[["diaSourceId", "apdb_part"]])
689 extra_columns = {
690 "apdb_time_part": self._time_partition(visit_time),
691 }
692
693 self._storeObjectsPandas(
694 id_map, ExtraTables.DiaSourceToPartition, extra_columns=extra_columns, time_part=None
695 )
696
697 def _storeObjectsPandas(self, objects: pandas.DataFrame, table_name: Union[ApdbTables, ExtraTables],
698 extra_columns: Optional[Mapping] = None,
699 time_part: Optional[int] = None) -> None:
700 """Generic store method.
701
702 Takes catalog of records and stores a bunch of objects in a table.
703
704 Parameters
705 ----------
706 objects : `pandas.DataFrame`
707 Catalog containing object records
708 table_name : `ApdbTables`
709 Name of the table as defined in APDB schema.
710 extra_columns : `dict`, optional
711 Mapping (column_name, column_value) which gives column values to add
712 to every row, only if column is missing in catalog records.
713 time_part : `int`, optional
714 If not `None` then insert into a per-partition table.
715 """
716 # use extra columns if specified
717 if extra_columns is None:
718 extra_columns = {}
719 extra_fields = list(extra_columns.keys())
720
721 df_fields = [
722 column for column in objects.columns if column not in extra_fields
723 ]
724
725 column_map = self._schema.getColumnMap(table_name)
726 # list of columns (as in cat schema)
727 fields = [column_map[field].name for field in df_fields if field in column_map]
728 fields += extra_fields
729
730 # check that all partitioning and clustering columns are defined
731 required_columns = self._schema.partitionColumns(table_name) \
732 + self._schema.clusteringColumns(table_name)
733 missing_columns = [column for column in required_columns if column not in fields]
734 if missing_columns:
735 raise ValueError(f"Primary key columns are missing from catalog: {missing_columns}")
736
737 qfields = [quote_id(field) for field in fields]
738 qfields_str = ','.join(qfields)
739
740 with Timer(table_name.name + ' query build', self.config.timer):
741
742 table = self._schema.tableName(table_name)
743 if time_part is not None:
744 table = f"{table}_{time_part}"
745
746 holders = ','.join(['?']*len(qfields))
747 query = f'INSERT INTO "{self._keyspace}"."{table}" ({qfields_str}) VALUES ({holders})'
748 statement = self._prep_statement(query)
749 queries = cassandra.query.BatchStatement()
750 for rec in objects.itertuples(index=False):
751 values = []
752 for field in df_fields:
753 if field not in column_map:
754 continue
755 value = getattr(rec, field)
756 if column_map[field].datatype is felis.types.Timestamp:
757 if isinstance(value, pandas.Timestamp):
758 value = literal(value.to_pydatetime())
759 else:
760 # Assume it's seconds since epoch, Cassandra
761 # datetime is in milliseconds
762 value = int(value*1000)
763 values.append(literal(value))
764 for field in extra_fields:
765 value = extra_columns[field]
766 values.append(literal(value))
767 queries.add(statement, values)
768
769 _LOG.debug("%s: will store %d records", self._schema.tableName(table_name), objects.shape[0])
770 with Timer(table_name.name + ' insert', self.config.timer):
771 self._session.execute(queries, timeout=self.config.write_timeout, execution_profile="write")
772
773 def _add_obj_part(self, df: pandas.DataFrame) -> pandas.DataFrame:
774 """Calculate spacial partition for each record and add it to a
775 DataFrame.
776
777 Notes
778 -----
779 This overrides any existing column in a DataFrame with the same name
780 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
781 returned.
782 """
783 # calculate HTM index for every DiaObject
784 apdb_part = np.zeros(df.shape[0], dtype=np.int64)
785 ra_col, dec_col = self.config.ra_dec_columns
786 for i, (ra, dec) in enumerate(zip(df[ra_col], df[dec_col])):
787 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
788 idx = self._pixelization.pixel(uv3d)
789 apdb_part[i] = idx
790 df = df.copy()
791 df["apdb_part"] = apdb_part
792 return df
793
794 def _add_src_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
795 """Add apdb_part column to DiaSource catalog.
796
797 Notes
798 -----
799 This method copies apdb_part value from a matching DiaObject record.
800 DiaObject catalog needs to have a apdb_part column filled by
801 ``_add_obj_part`` method and DiaSource records need to be
802 associated to DiaObjects via ``diaObjectId`` column.
803
804 This overrides any existing column in a DataFrame with the same name
805 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
806 returned.
807 """
808 pixel_id_map: Dict[int, int] = {
809 diaObjectId: apdb_part for diaObjectId, apdb_part
810 in zip(objs["diaObjectId"], objs["apdb_part"])
811 }
812 apdb_part = np.zeros(sources.shape[0], dtype=np.int64)
813 ra_col, dec_col = self.config.ra_dec_columns
814 for i, (diaObjId, ra, dec) in enumerate(zip(sources["diaObjectId"],
815 sources[ra_col], sources[dec_col])):
816 if diaObjId == 0:
817 # DiaSources associated with SolarSystemObjects do not have an
818 # associated DiaObject hence we skip them and set partition
819 # based on its own ra/dec
820 uv3d = sphgeom.UnitVector3d(sphgeom.LonLat.fromDegrees(ra, dec))
821 idx = self._pixelization.pixel(uv3d)
822 apdb_part[i] = idx
823 else:
824 apdb_part[i] = pixel_id_map[diaObjId]
825 sources = sources.copy()
826 sources["apdb_part"] = apdb_part
827 return sources
828
829 def _add_fsrc_part(self, sources: pandas.DataFrame, objs: pandas.DataFrame) -> pandas.DataFrame:
830 """Add apdb_part column to DiaForcedSource catalog.
831
832 Notes
833 -----
834 This method copies apdb_part value from a matching DiaObject record.
835 DiaObject catalog needs to have a apdb_part column filled by
836 ``_add_obj_part`` method and DiaSource records need to be
837 associated to DiaObjects via ``diaObjectId`` column.
838
839 This overrides any existing column in a DataFrame with the same name
840 (apdb_part). Original DataFrame is not changed, copy of a DataFrame is
841 returned.
842 """
843 pixel_id_map: Dict[int, int] = {
844 diaObjectId: apdb_part for diaObjectId, apdb_part
845 in zip(objs["diaObjectId"], objs["apdb_part"])
846 }
847 apdb_part = np.zeros(sources.shape[0], dtype=np.int64)
848 for i, diaObjId in enumerate(sources["diaObjectId"]):
849 apdb_part[i] = pixel_id_map[diaObjId]
850 sources = sources.copy()
851 sources["apdb_part"] = apdb_part
852 return sources
853
854 def _time_partition(self, time: Union[float, dafBase.DateTime]) -> int:
855 """Calculate time partiton number for a given time.
856
857 Parameters
858 ----------
859 time : `float` or `lsst.daf.base.DateTime`
860 Time for which to calculate partition number. Can be float to mean
862
863 Returns
864 -------
865 partition : `int`
866 Partition number for a given time.
867 """
868 if isinstance(time, dafBase.DateTime):
869 mjd = time.get(system=dafBase.DateTime.MJD)
870 else:
871 mjd = time
872 days_since_epoch = mjd - self._partition_zero_epoch_mjd
873 partition = int(days_since_epoch) // self.config.time_partition_days
874 return partition
875
876 def _make_empty_catalog(self, table_name: ApdbTables) -> pandas.DataFrame:
877 """Make an empty catalog for a table with a given name.
878
879 Parameters
880 ----------
881 table_name : `ApdbTables`
882 Name of the table.
883
884 Returns
885 -------
886 catalog : `pandas.DataFrame`
887 An empty catalog.
888 """
889 table = self._schema.tableSchemas[table_name]
890
891 data = {
892 columnDef.name: pandas.Series(dtype=self._schema.column_dtype(columnDef.datatype))
893 for columnDef in table.columns
894 }
895 return pandas.DataFrame(data)
896
897 def _prep_statement(self, query: str) -> cassandra.query.PreparedStatement:
898 """Convert query string into prepared statement."""
899 stmt = self._prepared_statements.get(query)
900 if stmt is None:
901 stmt = self._session.prepare(query)
902 self._prepared_statements[query] = stmt
903 return stmt
904
905 def _combine_where(
906 self,
907 prefix: str,
908 where1: List[Tuple[str, Tuple]],
909 where2: List[Tuple[str, Tuple]],
910 suffix: Optional[str] = None,
911 ) -> Iterator[Tuple[cassandra.query.Statement, Tuple]]:
912 """Make cartesian product of two parts of WHERE clause into a series
913 of statements to execute.
914
915 Parameters
916 ----------
917 prefix : `str`
918 Initial statement prefix that comes before WHERE clause, e.g.
919 "SELECT * from Table"
920 """
921 # If lists are empty use special sentinels.
922 if not where1:
923 where1 = [("", ())]
924 if not where2:
925 where2 = [("", ())]
926
927 for expr1, params1 in where1:
928 for expr2, params2 in where2:
929 full_query = prefix
930 wheres = []
931 if expr1:
932 wheres.append(expr1)
933 if expr2:
934 wheres.append(expr2)
935 if wheres:
936 full_query += " WHERE " + " AND ".join(wheres)
937 if suffix:
938 full_query += " " + suffix
939 params = params1 + params2
940 if params:
941 statement = self._prep_statement(full_query)
942 else:
943 # If there are no params then it is likely that query
944 # has a bunch of literals rendered already, no point
945 # trying to prepare it.
946 statement = cassandra.query.SimpleStatement(full_query)
947 yield (statement, params)
948
949 def _spatial_where(
950 self, region: Optional[sphgeom.Region], use_ranges: bool = False
951 ) -> List[Tuple[str, Tuple]]:
952 """Generate expressions for spatial part of WHERE clause.
953
954 Parameters
955 ----------
956 region : `sphgeom.Region`
957 Spatial region for query results.
958 use_ranges : `bool`
959 If True then use pixel ranges ("apdb_part >= p1 AND apdb_part <=
960 p2") instead of exact list of pixels. Should be set to True for
961 large regions covering very many pixels.
962
963 Returns
964 -------
965 expressions : `list` [ `tuple` ]
966 Empty list is returned if ``region`` is `None`, otherwise a list
967 of one or more (expression, parameters) tuples
968 """
969 if region is None:
970 return []
971 if use_ranges:
972 pixel_ranges = self._pixelization.envelope(region)
973 expressions: List[Tuple[str, Tuple]] = []
974 for lower, upper in pixel_ranges:
975 upper -= 1
976 if lower == upper:
977 expressions.append(('"apdb_part" = ?', (lower, )))
978 else:
979 expressions.append(('"apdb_part" >= ? AND "apdb_part" <= ?', (lower, upper)))
980 return expressions
981 else:
982 pixels = self._pixelization.pixels(region)
983 if self.config.query_per_spatial_part:
984 return [('"apdb_part" = ?', (pixel,)) for pixel in pixels]
985 else:
986 pixels_str = ",".join([str(pix) for pix in pixels])
987 return [(f'"apdb_part" IN ({pixels_str})', ())]
988
989 def _temporal_where(
990 self,
991 table: ApdbTables,
992 start_time: Union[float, dafBase.DateTime],
993 end_time: Union[float, dafBase.DateTime],
994 query_per_time_part: Optional[bool] = None,
995 ) -> Tuple[List[str], List[Tuple[str, Tuple]]]:
996 """Generate table names and expressions for temporal part of WHERE
997 clauses.
998
999 Parameters
1000 ----------
1001 table : `ApdbTables`
1002 Table to select from.
1003 start_time : `dafBase.DateTime` or `float`
1004 Starting Datetime of MJD value of the time range.
1005 start_time : `dafBase.DateTime` or `float`
1006 Starting Datetime of MJD value of the time range.
1007 query_per_time_part : `bool`, optional
1008 If None then use ``query_per_time_part`` from configuration.
1009
1010 Returns
1011 -------
1012 tables : `list` [ `str` ]
1013 List of the table names to query.
1014 expressions : `list` [ `tuple` ]
1015 A list of zero or more (expression, parameters) tuples.
1016 """
1017 tables: List[str]
1018 temporal_where: List[Tuple[str, Tuple]] = []
1019 table_name = self._schema.tableName(table)
1020 time_part_start = self._time_partition(start_time)
1021 time_part_end = self._time_partition(end_time)
1022 time_parts = list(range(time_part_start, time_part_end + 1))
1023 if self.config.time_partition_tables:
1024 tables = [f"{table_name}_{part}" for part in time_parts]
1025 else:
1026 tables = [table_name]
1027 if query_per_time_part is None:
1028 query_per_time_part = self.config.query_per_time_part
1029 if query_per_time_part:
1030 temporal_where = [
1031 ('"apdb_time_part" = ?', (time_part,)) for time_part in time_parts
1032 ]
1033 else:
1034 time_part_list = ",".join([str(part) for part in time_parts])
1035 temporal_where = [(f'"apdb_time_part" IN ({time_part_list})', ())]
1036
1037 return tables, temporal_where
std::vector< SchemaItem< Flag > > * items
Class for handling dates/times, including MJD, UTC, and TAI.
Definition: DateTime.h:64
def __init__(self, List[str] public_ips, List[str] private_ips)
def __init__(self, ApdbCassandraConfig config)
pandas.DataFrame getDiaForcedSourcesHistory(self, dafBase.DateTime start_time, dafBase.DateTime end_time, Optional[sphgeom.Region] region=None)
pandas.DataFrame getDiaSourcesHistory(self, dafBase.DateTime start_time, dafBase.DateTime end_time, Optional[sphgeom.Region] region=None)
pandas.DataFrame getDiaObjects(self, sphgeom.Region region)
cassandra.query.PreparedStatement _prep_statement(self, str query)
Optional[pandas.DataFrame] getDiaForcedSources(self, sphgeom.Region region, Optional[Iterable[int]] object_ids, dafBase.DateTime visit_time)
None store(self, dafBase.DateTime visit_time, pandas.DataFrame objects, Optional[pandas.DataFrame] sources=None, Optional[pandas.DataFrame] forced_sources=None)
None reassignDiaSources(self, Mapping[int, int] idMap)
Iterator[Tuple[cassandra.query.Statement, Tuple]] _combine_where(self, str prefix, List[Tuple[str, Tuple]] where1, List[Tuple[str, Tuple]] where2, Optional[str] suffix=None)
pandas.DataFrame _getSourcesHistory(self, ApdbTables table, dafBase.DateTime start_time, dafBase.DateTime end_time, Optional[sphgeom.Region] region=None)
pandas.DataFrame _add_fsrc_part(self, pandas.DataFrame sources, pandas.DataFrame objs)
None _storeDiaSources(self, ApdbTables table_name, pandas.DataFrame sources, dafBase.DateTime visit_time)
None makeSchema(self, bool drop=False)
None storeSSObjects(self, pandas.DataFrame objects)
None _storeObjectsPandas(self, pandas.DataFrame objects, Union[ApdbTables, ExtraTables] table_name, Optional[Mapping] extra_columns=None, Optional[int] time_part=None)
Tuple[List[str], List[Tuple[str, Tuple]]] _temporal_where(self, ApdbTables table, Union[float, dafBase.DateTime] start_time, Union[float, dafBase.DateTime] end_time, Optional[bool] query_per_time_part=None)
pandas.DataFrame _add_src_part(self, pandas.DataFrame sources, pandas.DataFrame objs)
int _time_partition(self, Union[float, dafBase.DateTime] time)
Optional[Table] tableDef(self, ApdbTables table)
None _storeDiaObjects(self, pandas.DataFrame objs, dafBase.DateTime visit_time)
Mapping[Any, ExecutionProfile] _makeProfiles(self, ApdbCassandraConfig config)
None _storeDiaSourcesPartitions(self, pandas.DataFrame sources, dafBase.DateTime visit_time)
Optional[pandas.DataFrame] getDiaSources(self, sphgeom.Region region, Optional[Iterable[int]] object_ids, dafBase.DateTime visit_time)
pandas.DataFrame getDiaObjectsHistory(self, dafBase.DateTime start_time, dafBase.DateTime end_time, Optional[sphgeom.Region] region=None)
List[Tuple[str, Tuple]] _spatial_where(self, Optional[sphgeom.Region] region, bool use_ranges=False)
pandas.DataFrame _getSources(self, sphgeom.Region region, Optional[Iterable[int]] object_ids, float mjd_start, float mjd_end, ApdbTables table_name)
pandas.DataFrame _add_obj_part(self, pandas.DataFrame df)
Region is a minimal interface for 2-dimensional regions on the unit sphere.
Definition: Region.h:79
UnitVector3d is a unit vector in ℝ³ with components stored in double precision.
Definition: UnitVector3d.h:55
daf::base::PropertyList * list
Definition: fits.cc:928
daf::base::PropertySet * set
Definition: fits.cc:927