import logging
import numpy as np
import seisbench.models as sbm
from seisbench.models.base import WaveformModel
from mspasspy.ccore.seismic import TimeSeries, TimeSeriesEnsemble
from mspasspy.ccore.algorithms.basic import TimeWindow
from obspy import UTCDateTime
[docs]
def annotate_arrival_time(
timeseries: TimeSeries,
threshold=0.2,
time_window: TimeWindow = None,
model: WaveformModel = None,
model_args: dict = None,
):
"""
Predict the arrival time of the P wave using the provided seisbench WaveformModel.
The arrival time will be saved as a dictionary in the input TimeSeries object and can be accessed using
the key 'p_wave_picks'. In the dictionary, the key is the arrival time in the UTC timestamp format,
and the value is the probability of the pick.
:param timeseries: The time series data to predict the arrival time.
:param threshold: The probability threshold (0-1) to filter p-wave picks.
Any picks with probability less than the threshold will be removed. Default value is 0.2.
:param time_window: The time window (in utc timestamp) to filter the predicted arrival time.
If not provided, the whole time series will be used.
:param model: The model used to predict the arrival time.
:param model_args: arguments to initialize a new model if not provided
:type timeseries: mspasspy.ccore.seismic.TimeSeries
:type threshold: float
:type time_window: mspasspy.ccore.algorithms.basic.TimeWindow defined as absolute time in UTC
:type model: seisbench.models.base.WaveformModel
:type model_args: dict
"""
default_threshold = 0.2
# Check the input arguments
if not 0 <= threshold <= 1:
logging.warning(
"Threshold should be in the range of [0, 1]. Using default threshold {}}".format(
default_threshold
)
)
threshold = default_threshold
# convert timeseries to absolute time
timeseries.rtoa()
# load pretrained model based on the args if not provided
if model == None:
# 'stead' model was trained on STEAD for 100 epochs with a learning rate of 0.01.
# use sbm.PhaseNet.list_pretrained(details=True) to list out other supported models
# when using this model, please reference the SeisBench publications listed at https://github.com/seisbench/seisbench
pretrained_model = (
"stead"
if (model_args == None or "name" not in model_args)
else model_args["name"]
)
model = sbm.PhaseNet.from_pretrained(pretrained_model)
ts_ensemble = TimeSeriesEnsemble()
ts_ensemble.member.append(timeseries)
stream = ts_ensemble.toStream()
# apply the window if provided and convert time series to stream
start_time_utc = stream[0].stats.starttime.timestamp # UTC timestamp
end_time_utc = stream[0].stats.endtime.timestamp # UTC timestamp
# adjust the time window if it is out of the time range of the time series
if time_window:
if time_window.end < start_time_utc or time_window.start > end_time_utc:
time_window.start = start_time_utc
time_window.end = end_time_utc
logging.warning(
"Time window is out of the time range of the time series. Adjusting the time window to the time range of the time series."
)
if time_window.end > end_time_utc:
time_window.end = end_time_utc
if time_window.start < start_time_utc:
time_window.start = start_time_utc
windowed_stream = (
stream.trim(UTCDateTime(time_window.start), UTCDateTime(time_window.end))
if time_window
else stream
)
# prediction result is the probability for picks over time
pred_st = model.annotate(windowed_stream)
# Step 1: Access the probability data
trace = None
for tr in pred_st:
if tr.stats.channel == "PhaseNet_P":
trace = tr
break
if trace is None:
timeseries["p_wave_picks"] = {}
logging.warning("Model annotation output does not contain a PhaseNet_P trace.")
return
data = trace.data
# Step 2: Find all the index with probability value greater than the threshold
indices = np.where(data >= threshold)[0]
# Step 3: Calculate the corresponding time in utc timestamp
timestamps = trace.times("timestamp")[indices]
if time_window:
in_window = (timestamps >= time_window.start) & (timestamps <= time_window.end)
timestamps = timestamps[in_window]
indices = indices[in_window]
# Step 4: Create a dictionary with timestamps as keys and probability values as values
p_wave_picks = {ts: data[i] for ts, i in zip(timestamps, indices)}
# Step 5: Save the arrival time dictionary in absolute time
timeseries["p_wave_picks"] = p_wave_picks