diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a3604c..0c2b7ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,11 @@ Copyright (c) 2021-2026 Claudio Satriano ## [unreleased] +- Optional parallel processing for `scan_templates`: time chunks are + dispatched to worker processes, spreading the cross-correlation across + cores and overlapping waveform downloads. On by default with automatic + worker-count detection; configurable with the `template_scan_nprocs` + config option or the `--nprocs` command-line option (`1` disables it). - New interactive curses pager for all ``print_`` commands (``print_catalog``, ``print_pairs``, ``print_families``). Automatically activated when output is a terminal; use ``--no-pager`` diff --git a/requake/config/configspec.conf b/requake/config/configspec.conf index f44fcd0..a9edea0 100644 --- a/requake/config/configspec.conf +++ b/requake/config/configspec.conf @@ -112,6 +112,9 @@ template_end_time = string(default=2021-08-24T00:00:00) time_chunk = float(default=3600) ## Overlap between time chunks (in seconds) time_chunk_overlap = float(default=60) +## Number of worker processes for the template scan. +## 0 means automatic selection; 1 disables parallelism. +template_scan_nprocs = integer(min=0, default=0) ## Minimum ratio between cross-correlation (cc) and median absolute deviation ## (MAD) of cross-correlation (cc_mad). A detection is declared when: ## cc/cc_mad > min_cc_mad_ratio diff --git a/requake/config/parse_arguments.py b/requake/config/parse_arguments.py index e39464c..8022f96 100644 --- a/requake/config/parse_arguments.py +++ b/requake/config/parse_arguments.py @@ -505,6 +505,13 @@ def parse_arguments(progname='requake'): 'File must be in SAC format, with a P pick in the ' '"A" header field' ) + scan_templates.add_argument( + '--nprocs', + type=_nonnegative_int, + default=None, + help='number of worker processes for scan_templates ' + '(0: auto, 1: disable parallelism)' + ) # --- # --- wfcache wfcache = subparser.add_parser( diff --git a/requake/scan/scan_templates.py b/requake/scan/scan_templates.py index ddded78..eb2be6b 100644 --- a/requake/scan/scan_templates.py +++ b/requake/scan/scan_templates.py @@ -9,6 +9,7 @@ GNU General Public License v3.0 or later (https://www.gnu.org/licenses/gpl-3.0-standalone.html) """ +import io import logging import os import sys @@ -149,6 +150,179 @@ def _read_templates(): return templates +# --------------------------------------------------------------------------- +# Optional parallel scan over time chunks. +# +# Time chunks are independent units of work: each chunk scans all templates +# over one time window. Workers run in separate processes so the CPU-bound +# cross-correlation is spread across cores and the FDSN download I/O of +# different chunks overlaps. The config singleton is rebuilt inside each +# worker from a pickle-safe snapshot, because the network clients it holds +# are not pickleable. Detections produced twice in the overlap between +# adjacent chunks are deduplicated by the database UNIQUE constraint on +# (family_number, trace_id, evid), exactly as in the serial scan. +# --------------------------------------------------------------------------- +_worker_templates = [] + + +def _template_time_chunks(): + """ + Build the list of (t0, t1) windows scanned by the template scan. + + The windows match the serial scan exactly: a new window starts every + ``time_chunk`` seconds and spans ``time_chunk + time_chunk_overlap``. + + :return: list of (starttime, endtime) tuples + :rtype: list + """ + chunks = [] + time = config.template_start_time + time_chunk = config.time_chunk + overlap = config.time_chunk_overlap + while time <= config.template_end_time: + chunks.append((time, time + time_chunk + overlap)) + time += time_chunk + return chunks + + +def _resolve_template_scan_nprocs(nchunks): + """ + Resolve the effective number of worker processes for the template scan. + + The value comes from the ``--nprocs`` command-line option, falling back + to the ``template_scan_nprocs`` config parameter. A value of 0 selects the + number of available CPUs (minus one when more than one is available) and a + value of 1 disables parallelism. The result is capped by ``nchunks``. + + :param nchunks: number of time chunks to process + :type nchunks: int + :return: effective number of worker processes (at least 1) + :rtype: int + """ + import multiprocessing + cli_nprocs = getattr(config.args, 'nprocs', None) + config_nprocs = getattr(config, 'template_scan_nprocs', 0) + requested = cli_nprocs if cli_nprocs is not None else config_nprocs + if requested is None or requested < 0: + requested = 0 + if requested == 0: + base_nprocs = multiprocessing.cpu_count() + if base_nprocs > 1: + base_nprocs -= 1 + else: + base_nprocs = requested + return min(max(1, base_nprocs), max(1, nchunks)) + + +def _scan_templates_worker_initializer(cfg_dict, templates): + """ + Initialize a worker process for the parallel template scan. + + The pickle-safe config snapshot is restored into the module-level config + singleton, the network clients are recreated inside the worker process and + the templates are stored for reuse across the chunks handled by the worker. + + The client-connection and logging helpers are shared with the catalog scan + to keep a single definition of how a worker connects to data services. + + :param cfg_dict: pickle-safe config snapshot from + :func:`requake.config.to_picklable_config_dict` + :type cfg_dict: dict + :param templates: list of template traces + :type templates: list + """ + import signal + from ..config import from_picklable_config_dict + from .scan_catalog_workers import ( + _connect_worker_clients, _silence_worker_console_logging + ) + global _worker_templates + _silence_worker_console_logging() + signal.signal(signal.SIGINT, signal.SIG_IGN) + restored_cfg = from_picklable_config_dict(cfg_dict) + config.clear() + config.update(restored_cfg) + _connect_worker_clients() + _worker_templates = templates + + +def _scan_chunk_worker(time_range): + """ + Scan all templates over one time chunk (worker process entry point). + + The same :func:`_scan_family_template` used by the serial scan is called + here, so the detection logic is identical. Progress written to stdout by + that function is suppressed in the worker to keep the parent progress line + readable. + + :param time_range: (starttime, endtime) of the chunk + :type time_range: tuple + :return: list of detection tuples (family_number, trace_id, event, cc_max) + :rtype: list + """ + import contextlib + t0, t1 = time_range + detections = [] + with contextlib.redirect_stdout(io.StringIO()): + for template in _worker_templates: + try: + detection = _scan_family_template(template, t0, t1) + except NoWaveformError: + continue + if detection is not None: + detections.append(detection) + trace_cache.clear() + return detections + + +def _scan_templates_parallel(templates, nprocs): + """ + Scan templates over continuous data using a pool of worker processes. + + Each time chunk is dispatched to a worker that scans all templates over + that window. Detections are collected in the parent process and written to + the database in a single transaction. Duplicate detections from the + overlap between adjacent chunks are removed by the database UNIQUE + constraint. + + :param templates: list of template traces + :type templates: list + :param nprocs: number of worker processes + :type nprocs: int + """ + from concurrent.futures import ProcessPoolExecutor, as_completed + from ..config import to_picklable_config_dict + chunks = _template_time_chunks() + cfg_dict = to_picklable_config_dict(config) + logger.info( + 'Parallel template scan: %d time chunks, %d worker processes', + len(chunks), nprocs + ) + with ProcessPoolExecutor( + max_workers=nprocs, + initializer=_scan_templates_worker_initializer, + initargs=(cfg_dict, templates), + ) as executor: + futures = { + executor.submit(_scan_chunk_worker, chunk): idx + for idx, chunk in enumerate(chunks) + } + ordered = [None] * len(chunks) + for done, future in enumerate(as_completed(futures), start=1): + ordered[futures[future]] = future.result() + sys.stdout.write(f'Scanned {done}/{len(chunks)} time chunks\r') + sys.stdout.write('\n') + # Flatten in chunk order so that, for detections seen twice in the overlap + # between two chunks, the later chunk wins the database REPLACE, matching + # the serial scan order. + detections = [ + detection for chunk_detections in ordered + for detection in chunk_detections + ] + if detections: + write_template_detections(detections, append=True) + + def scan_templates(): """Scan a continuous waveform stream using one or more templates.""" try: @@ -165,6 +339,11 @@ def scan_templates(): logger.info('Scan aborted by user; previous detections kept') rq_exit(0) clear_template_detections() + nchunks = len(_template_time_chunks()) + nprocs = _resolve_template_scan_nprocs(nchunks) + if nprocs > 1: + _scan_templates_parallel(templates, nprocs) + return time = config.template_start_time time_chunk = config.time_chunk overlap = config.time_chunk_overlap diff --git a/tests/unit/test_scan_templates_parallel.py b/tests/unit/test_scan_templates_parallel.py new file mode 100644 index 0000000..45ac843 --- /dev/null +++ b/tests/unit/test_scan_templates_parallel.py @@ -0,0 +1,182 @@ +# -*- coding: utf8 -*- +# SPDX-License-Identifier: GPL-3.0-or-later +""" +Unit tests for the optional parallel template scan. + +:copyright: + 2021-2026 Claudio Satriano , + Marius Yvard +:license: + GNU General Public License v3.0 or later + (https://www.gnu.org/licenses/gpl-3.0-standalone.html) +""" + +import unittest +import tempfile +import sqlite3 +import multiprocessing +from argparse import Namespace +from unittest.mock import patch +from obspy import UTCDateTime +from requake.catalog import RequakeEvent +from requake.config import config +from requake.database.db import get_db_path +from requake.database.templates import write_template_detections +from requake.waveforms import NoWaveformError +import importlib +st = importlib.import_module('requake.scan.scan_templates') + + +def _serial_time_chunks(start, end, time_chunk, overlap): + """Independent reference implementation of the serial chunk stepping.""" + chunks = [] + time = start + while time <= end: + chunks.append((time, time + time_chunk + overlap)) + time += time_chunk + return chunks + + +class TestTemplateTimeChunks(unittest.TestCase): + """The parallel scan must cover exactly the serial time windows.""" + + def test_time_chunks_match_serial_stepping(self): + """_template_time_chunks reproduces the serial while-loop windows.""" + start = UTCDateTime('2021-08-23T00:00:00') + end = start + 10000.0 + with patch.dict( + config, + { + 'template_start_time': start, + 'template_end_time': end, + 'time_chunk': 3600.0, + 'time_chunk_overlap': 60.0, + }, + clear=False, + ): + chunks = st._template_time_chunks() + self.assertEqual( + chunks, _serial_time_chunks(start, end, 3600.0, 60.0) + ) + self.assertEqual(len(chunks), 3) + + +class TestResolveTemplateScanNprocs(unittest.TestCase): + """Worker-count resolution: CLI over config, 0 auto, 1 serial, capped.""" + + def _auto(self): + """Expected automatic worker count on the running machine.""" + ncpu = multiprocessing.cpu_count() + return ncpu - 1 if ncpu > 1 else 1 + + def _resolve(self, nchunks, cli_nprocs, config_nprocs): + """Resolve nprocs with a patched CLI and config value. + + ``patch.object`` is used rather than ``patch.dict`` so that the value + is set through ``Config.__setattr__``, keeping the attribute and the + dict item in sync (a stale ``config.args`` instance attribute left by + another test would otherwise shadow a ``patch.dict`` item). + """ + with patch.object( + config, 'args', Namespace(nprocs=cli_nprocs) + ), patch.object( + config, 'template_scan_nprocs', config_nprocs, create=True + ): + return st._resolve_template_scan_nprocs(nchunks) + + def test_config_one_disables_parallelism(self): + """template_scan_nprocs = 1 forces the serial fast path.""" + self.assertEqual(self._resolve(3, None, 1), 1) + + def test_config_zero_is_auto_capped_by_chunks(self): + """template_scan_nprocs = 0 selects auto, capped by nchunks.""" + self.assertEqual(self._resolve(10, None, 0), min(self._auto(), 10)) + self.assertEqual(self._resolve(1, None, 0), 1) + + def test_cli_overrides_config(self): + """--nprocs takes precedence over the config value.""" + self.assertEqual(self._resolve(5, 1, 0), 1) + self.assertEqual(self._resolve(10, 4, 0), 4) + self.assertEqual(self._resolve(2, 4, 0), 2) + + def test_cli_zero_is_auto(self): + """--nprocs 0 means auto even when the config value is set.""" + self.assertEqual(self._resolve(10, 0, 8), min(self._auto(), 10)) + + +class TestScanChunkWorker(unittest.TestCase): + """The chunk worker delegates to the shared detection function.""" + + def test_worker_collects_and_skips_missing_data(self): + """Detections are collected; NoWaveformError templates are skipped.""" + def fake_scan_family_template(template, t0, t1): + if template == 'with_data': + return (0, 'XX.TEST.00.BHZ', 'detection', 0.9) + if template == 'no_data': + raise NoWaveformError('no data') + return None + + t0 = UTCDateTime('2021-08-23T00:00:00') + t1 = t0 + 3660.0 + with patch.object( + st, '_scan_family_template', fake_scan_family_template + ), patch.object( + st, '_worker_templates', ['with_data', 'no_data', 'no_detection'] + ): + results = st._scan_chunk_worker((t0, t1)) + self.assertEqual(results, [(0, 'XX.TEST.00.BHZ', 'detection', 0.9)]) + + +class TestOverlapDedup(unittest.TestCase): + """Overlap-zone duplicates collapse by evid; the later write wins.""" + + def setUp(self): + """Create a temporary directory for the test database.""" + self.test_dir = tempfile.TemporaryDirectory() + self.addCleanup(self.test_dir.cleanup) + + def _patch_runtime_config(self): + """Point the global config to a temporary database.""" + return patch.dict( + config, + { + 'outdir': self.test_dir.name, + 'args': Namespace(outdir=self.test_dir.name, template=True), + }, + clear=False, + ) + + def _detection(self, evid, cc_max): + """Build a detection tuple for one fixed family and trace.""" + event = RequakeEvent( + evid=evid, + orig_time=UTCDateTime('2021-08-23T00:30:00'), + lon=10.0, + lat=45.0, + depth=10.0, + trace_id='XX.TEST.00.BHZ', + ) + return (0, 'XX.TEST.00.BHZ', event, cc_max) + + def test_same_evid_collapses_keeping_last(self): + """Two detections with one evid yield one row, last writer wins.""" + # Same event detected in two overlapping chunks: identical evid, + # marginally different cc_max. Flattening in chunk order means the + # later chunk is written last and wins the database REPLACE. + detections = [self._detection('reqk2021aaaaaa', 0.80), + self._detection('reqk2021aaaaaa', 0.90)] + with self._patch_runtime_config(): + write_template_detections(detections, append=False) + conn = sqlite3.connect(get_db_path()) + try: + rows = conn.execute( + 'SELECT evid, cc_max FROM template_detections' + ).fetchall() + finally: + conn.close() + self.assertEqual(len(rows), 1) + self.assertAlmostEqual(rows[0][1], 0.90, places=6) + + +if __name__ == '__main__': + unittest.main()