Source code for mspasspy.algorithms.ml.arrival

import logging
import numpy as np
import seisbench.models as sbm

from seisbench.models.base import WaveformModel
from mspasspy.ccore.seismic import TimeSeries, TimeSeriesEnsemble
from mspasspy.ccore.algorithms.basic import TimeWindow
from mspasspy.algorithms.window import WindowData
from obspy import UTCDateTime


[docs] def annotate_arrival_time( timeseries: TimeSeries, threshold=0.2, time_window: TimeWindow = None, model: WaveformModel = None, model_args: dict = None, ): """ Predict the arrival time of the P wave using the provided seisbench WaveformModel. The arrival time will be saved as a dictionary in the input TimeSeries object and can be accessed using the key 'p_wave_picks'. In the dictionay, the key is the arrival time in the utc timestamp format, and the value is the probability of the pick. :param timeseries: The time series data to predict the arrival time. :param threshold: The probability threshold (0-1) to filter p-wave picks. Any picks with probability less than the threshold will be removed. Default value is 0.2. :param time_window: The time window (in utc timestamp) to filter the predicted arrival time. If not provided, the whole time series will be used. :param model: The model used to predict the arrival time. :param model_args: arguments to initialize a new model if not provided :type timeseries: mspasspy.ccore.seismic.TimeSeries :type threshold: float :type time_window: mspasspy.ccore.algorithms.basic.TimeWindow defined as absolute time in UTC :type model: seisbench.models.base.WaveformModel :type model_args: dict """ default_threshold = 0.2 # Check the input arguments if not 0 <= threshold <= 1: logging.warning( "Threshold should be in the range of [0, 1]. Using default threshold {}}".format( default_threshold ) ) threshold = default_threshold # convert timeseries to absolute time timeseries.rtoa() # load pretrained model based on the args if not provided if model == None: # 'stead' model was trained on STEAD for 100 epochs with a learning rate of 0.01. # use sbm.PhaseNet.list_pretrained(details=True) to list out other supported models # when using this model, please reference the SeisBench publications listed at https://github.com/seisbench/seisbench pretrained_model = ( "stead" if (model_args == None or "name" not in model_args) else model_args["name"] ) model = sbm.PhaseNet.from_pretrained(pretrained_model) ts_ensemble = TimeSeriesEnsemble() ts_ensemble.member.append(timeseries) stream = ts_ensemble.toStream() # apply the window if provided and convert time series to stream start_time_utc = stream[0].stats.starttime.timestamp # UTC timestamp end_time_utc = stream[0].stats.endtime.timestamp # UTC timestamp # adjust the time window if it is out of the time range of the time series if time_window: if time_window.end < start_time_utc or time_window.start > end_time_utc: time_window.start = start_time_utc time_window.end = end_time_utc logging.warning( "Time window is out of the time range of the time series. Adjusting the time window to the time range of the time series." ) if time_window.end > end_time_utc: time_window.end = end_time_utc if time_window.start < start_time_utc: time_window.start = start_time_utc windowed_stream = ( stream.trim(UTCDateTime(time_window.start), UTCDateTime(time_window.end)) if time_window else stream ) # prediction result is the probability for picks over time pred_st = model.annotate(windowed_stream) # Step 1: Access the probability data for tr in pred_st: if tr.stats.channel == "PhaseNet_P": trace = tr break data = trace.data # Step 2: Find all the index with probability value greater than the threshold indices = np.where(data >= threshold)[0] # Step 3: Calculate the corresponding time in utc timestamp timestamps = trace.times("timestamp")[indices] # Step 4: Create a dictionary with timestamps as keys and probability values as values p_wave_picks = {ts: data[i] for ts, i in zip(timestamps, indices)} # Step 5: Save the arrival time dictionary in absolute time timeseries["p_wave_picks"] = p_wave_picks