import os
import pathlib
import time
from urllib.error import HTTPError
import astropy.units as u
import numpy as np
#hacky monkey-patch for python 3.8
if not hasattr(np, 'int'):
np.int = int
import pandas as pd
import requests
from astropy.coordinates import SkyCoord
from astropy.cosmology import LambdaCDM
import sys
if sys.version_info >= (3, 9):
import importlib.resources as pkg_resources
else:
import importlib_resources as pkg_resources
import importlib
from astropy.table import Table
from .diagnose import plot_match
from .helpers import GalaxyCatalog, Transient, setup_logger, sanitize_input, get_ned_specz
import logging
from collections import OrderedDict
import warnings
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
import gc
from colorama import Fore, Style
# Parallel processing settings
[docs]
NPROCESS_MAX = np.maximum(os.cpu_count() - 4, 1)
[docs]
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i+n]
# Default survey releases
[docs]
DEFAULT_RELEASES = {
"glade": "latest",
"decals": "dr9",
"panstarrs": "dr2",
"skymapper": "dr4",
"rubin": "dp0.2"
}
[docs]
ONLY_OFFSET_CATS = {"panstarrs", "skymapper", "rubin"}
# Filter unnecessary warnings
warnings.filterwarnings("ignore", category=RuntimeWarning, message="divide by zero encountered in divide")
[docs]
def infer_skycoord(row, coord_cols):
"""Infers a SkyCoord list from the rows of a pandas DF.
Parameters
----------
row : pandas.DataFrame row.
Row of transient_catalog that will be associated (containing properties of 1 transient).
coord_cols : tuple of two strings
Name of the coordinate columns in the pandas.DataFrame.
Returns
-------
array of astropy.coordinates.SkyCoord objects
List of retrieved coordinates for transients to associate.
"""
try:
# Convert to Astropy Table to use guess_from_table
table = Table(rows=[row[coord_cols]], names=coord_cols)
return SkyCoord.guess_from_table(table)
except Exception:
# Couldn't infer column properties
pass
# If guess_from_table() fails, try manual parsing
if isinstance(row[coord_cols[0]], str) and (":" in row[coord_cols[0]]):
return SkyCoord(row[coord_cols[0]], row[coord_cols[1]], unit=(u.hourangle, u.deg))
else:
return SkyCoord(float(row[coord_cols[0]]) * u.deg, float(row[coord_cols[1]]) * u.deg)
[docs]
def consolidate_results(results, transient_catalog):
"""Updates the original transient catalog with the host properties retrieved during association.
Parameters
----------
results : dictionary
Results from association; keys are row indices, and values are dictionaries of returned properties.
transient_catalog : pd.DataFrame
The dataset containing names, coordinates, and (optionally) redshift information for transients.
Returns
-------
pd.DataFrame
Original transient catalog, with host columns concatenated.
"""
valid_results = [r for r in results.values() if r is not None]
results_df = pd.DataFrame.from_records(valid_results)
extra_cat_cols_list = [res["extra_cat_cols"] for res in valid_results if isinstance(res, dict) and "extra_cat_cols" in res]
if extra_cat_cols_list:
extra_cat_cols_DF = pd.DataFrame.from_records(extra_cat_cols_list)
results_df = results_df.join(extra_cat_cols_DF)
if "idx" not in results_df.columns:
raise ValueError("No 'idx' column found in results, cannot update transient_catalog!")
transient_catalog = transient_catalog.merge(
results_df, left_index=True, right_on="idx", how="left"
)
# Convert all ID columns to string to avoid saving errors
id_cols = [col for col in transient_catalog.columns if col.endswith("id")]
for col in id_cols:
transient_catalog[col] = pd.to_numeric(transient_catalog[col], errors="coerce").astype("str")
# drop unassociated events in batch mode -- return a DF without host info if failed with only 1 transient
if 'host_objID' in transient_catalog.columns.values:
transient_catalog.dropna(subset=['host_total_posterior', 'host_objID'], inplace=True)
transient_catalog.reset_index(drop=True, inplace=True)
return transient_catalog
[docs]
def save_results(transient_catalog, run_name=None, save_path='./', drop_unassociated=True):
"""Save the transient catalog results to a CSV file with a timestamp (and optional run name).
Parameters
----------
transient_catalog : pandas.DataFrame
A DataFrame containing the transient catalog data.
run_name : str, optional
A string identifier for the current run.
save_path : str, optional
The directory path where the CSV file will be saved. Defaults to the current directory ('./').
drop_unassociated : bool, optional
If True, drops unassociated transients before saving. Defaults to True.
Returns
-------
None
"""
ts = int(time.time())
# Save the updated catalog
save_suffix = f"{ts}"
if run_name is not None:
save_suffix = f"{run_name}_{save_suffix}"
save_name = pathlib.Path(save_path, f"associated_transient_catalog_{save_suffix}.csv")
if drop_unassociated and ('host_objID' in transient_catalog.columns.values):
transient_catalog.dropna(subset=['host_objID', 'host_total_posterior'], inplace=True)
transient_catalog.to_csv(save_name, index=False)
[docs]
def log_host_properties(logger, transient_name, cat, host_idx, title, print_props, calc_host_props, condition_props):
"""Log selected host galaxy properties for a transient.
Parameters
----------
logger : logging.Logger
Logger instance to output messages.
transient_name : str
Name of the transient.
cat : GalaxyCatalog
Catalog containing candidate host galaxies.
host_idx : int
Index of the host galaxy in the catalog.
title : str
Header text for the log output.
print_props : list of str
List of property names to log directly (e.g., 'objID', 'ra', 'dec').
calc_host_props : list of str
List of properties (e.g., 'redshift', 'absmag', 'offset') for which mean and std are logged.
condition_props : list of str
List of properties (e.g., 'redshift', 'absmag', 'offset') for which posterior values are logged.
Returns
-------
None
Logs the formatted host properties.
"""
prop_lines = [f"\n {title} for {transient_name}:" + Style.RESET_ALL]
# Define all possible properties with labels and formats
prop_format = {
"objID": ("objID", "{:s}"),
'name': ("Name", "{:s}"),
"ra": ("R.A. (deg)", "{:.6f}"),
"dec": ("Dec. (deg)", "{:.6f}"),
"redshift": ("Redshift", "{:.4f}"),
"absmag": ("Absolute Magnitude", "{:.1f}"),
"offset": (r"Transient Offset", "{:.1f}"),
"posterior": ("Posterior", "{:.4e}"),
}
# Iterate through selected properties
for prop in print_props:
values = cat.galaxies[prop]
if prop == 'name':
raw_name = cat.galaxies["name"][host_idx]
if not raw_name.strip():
continue
if (prop in cat.galaxies.dtype.names) and (0 <= host_idx < len(values)): # Only include if property exists
label, fmt = prop_format.get(prop.split("_")[-1], (prop, "{:.4f}")) # Default fmt if missing
value = fmt.format(values[host_idx])
print_str = Fore.BLUE + f" {label}:" + Style.RESET_ALL + f" {value}"
prop_lines.append(print_str)
# get mean, std, and posterior for specific properties
for prop in calc_host_props:
if f"{prop}_mean" in cat.galaxies.dtype.names:
label, fmt = prop_format.get(prop, (prop, "{:.4f}")) # Get format or default
mean_value = fmt.format(cat.galaxies[f"{prop}_mean"][host_idx])
std_value = fmt.format(cat.galaxies[f"{prop}_std"][host_idx])
_, fmt = prop_format.get("posterior")
posterior = fmt.format(cat.galaxies[f"{prop}_posterior"][host_idx])
info = cat.galaxies[f"{prop}_info"][host_idx]
print_str = Fore.BLUE + f" {label}:" + Style.RESET_ALL + f" {mean_value} ± {std_value}"
if prop == 'offset':
print_str += " arcsec"
if len(info) > 0:
print_str += f" ({info})"
prop_lines.append(print_str)
if prop in condition_props:
prop_lines.append(Fore.BLUE + f" {label} Posterior:" + Style.RESET_ALL + f" {posterior}")
logger.info("\n".join(prop_lines))
[docs]
def get_catalogs(user_input):
"""Convert user input into a dictionary mapping catalog names to release versions.
Parameters
----------
user_input : iterable
An iterable of catalog entries, where each entry is either a string (catalog name)
or a tuple (catalog name, release version).
Returns
-------
dict
A dictionary with keys as sanitized catalog names and values as the corresponding release version.
"""
return {
sanitize_input(cat) if isinstance(cat, str) else sanitize_input(cat[0]):
cat[1] if isinstance(cat, tuple) else DEFAULT_RELEASES[sanitize_input(cat)]
for cat in user_input
}
[docs]
def associate_transient(
idx,
row,
glade_catalog,
n_samples,
priors,
likes,
cosmo,
catalogs,
cat_priority,
name_col,
coord_cols,
redshift_col,
cat_cols,
log_fn,
n_hosts=2,
calc_host_props=False,
verbose=0,
coord_err_cols=('ra_err', 'dec_err'),
strict_checking=False,
warn_on_fallback=True,
plot_match=False,
best_redshift=False,
):
"""Associates a transient with its most likely host galaxy.
Parameters
----------
idx : int
Index of the transient from a larger catalog (used to cross-match properties after association).
row : pandas Series
Full row of transient properties.
glade_catalog : pandas.DataFrame
GLADE catalog of galaxies, with sizes and photo-zs.
n_samples : int
Number of samples for the Monte Carlo sampling of associations.
priors : dict
Dictionary of priors for the run (at least one of redshift, offset, absolute magnitude).!
likes : dict
Dictionary of likelihoods for the run (at least one of offset, absolute magnitude).
cosmo : astropy.cosmology
Assumed cosmology for the run (defaults to LambdaCDM if unspecified).
catalogs : dict
Dict of source catalogs to query, with required key "name" and optional key "release".
cat_priorities : dict
The priority order to run the associations (with value 1 will run first, 2nd will run 2nd, etc). If None, defaults to the order
the catalogs are provided in.
cat_cols : boolean
If true, concatenates the source catalog fields to the returned dataframe.
log_fn : str, optional
The fn associated with the logger.Logger object.
calc_host_props : boolean
If true, calculates host galaxy properties even if not needed for association
verbose : int
The verbosity level of the output.
coord_err_cols : tuple of strings
The column names associated with positional uncertainties on the transient positions.
strict_checking : boolean, optional
If true, raises error if catalog doesn't support conditioning on a property requested.
warn_on_fallback : boolean, optional
If true, raises warning if catalog doesn't support conditioning on a property requested.
plot_match : boolean, optional
If true, attempts to generate a plot image.
best_redshift : boolean, optional
If True, queries NED for spectroscopic redshift when host is found within 1 arcsec.
Default is False.
Returns
-------
tuple
Properties of the first and second-best host galaxy matches, and
a dictionary of catalog columns (empty if cat_cols=False)
"""
logger = setup_logger(log_fn, verbose=verbose, is_main=False)
condition_host_props = list(priors.keys())
unsupported_props = {"redshift", "absmag"}.intersection(priors)
unsupported_catalogs = ONLY_OFFSET_CATS.intersection(catalogs)
if unsupported_props and unsupported_catalogs:
msg = (
f"{', '.join(sorted(unsupported_catalogs))} "
f"{'does not support conditioning on' if len(unsupported_catalogs)==1 else 'do not support conditioning on'} "
f"{', '.join(sorted(unsupported_props))}; falling back to 'offset' only for this subset."
)
if strict_checking:
raise ValueError(
msg + "\n\nInterested in contributing a photo-z estimator? "
"Open an issue at https://github.com/alexandergagliano/Prost/issues."
)
if warn_on_fallback:
logger.warning(msg)
# TODO change overloaded variable here
if calc_host_props:
calc_host_props = ['redshift', 'absmag', 'offset']
else:
calc_host_props = list(priors.keys())
try:
redshift = float(row[redshift_col]) if redshift_col in row else np.nan
except:
redshift = np.nan
logger.warning("Could not parse provided redshift column as float.")
try:
ra_err = float(row[coord_err_cols[0]]) if coord_err_cols[0] in row else 0.1
dec_err = float(row[coord_err_cols[1]]) if coord_err_cols[1] in row else 0.1
position_err = (ra_err*u.arcsec, dec_err*u.arcsec)
except:
position_err = (0.1*u.arcsec, 0.1*u.arcsec)
#if user-provided custom error columns, warn them that they won't be used
if (coord_err_cols[0] != 'ra_err') or (coord_err_cols[0] != 'dec_err'):
logger.warning(f"Could not parse {coord_err_cols[0]} and {coord_err_cols[1]} as floats. Setting a nominal positional uncertainty of (0.1'', 0.1'').")
transient = Transient(
name=row[name_col],
position=infer_skycoord(row, coord_cols),
position_err=position_err,
redshift=redshift,
n_samples=n_samples,
logger=logger,
n_hosts=n_hosts,
)
logger.info(
f"\n\nAssociating {transient.name} at RA, DEC = "
f"{transient.position.ra.deg:.6f}, {transient.position.dec.deg:.6f} (redshift {redshift:.3f})"
)
for key, val in priors.items():
transient.set_prior(key, val)
for key, val in likes.items():
transient.set_likelihood(key, val)
if 'redshift' in priors.keys():
transient.gen_z_samples(n_samples=n_samples)
# Define result fields and initialize all values
result = {
"idx": idx,
"best_cat": None,
"best_cat_release": None,
"query_time": np.nan,
"smallcone_posterior": np.nan,
"missedcat_posterior": np.nan,
"extra_cat_cols": {}
}
# Define the fields that we extract for best and second-best hosts
fields = ["objID", 'name', "total_posterior", "ra", "dec", "redshift_mean", "redshift_std"]
for prop in calc_host_props:
fields.append(f"{prop}_mean")
fields.append(f"{prop}_std")
fields.append(f"{prop}_info")
if prop in condition_host_props:
fields.append(f"{prop}_posterior")
if prop == "offset":
fields.append("frac_offset_mean")
fields.append("frac_offset_std")
if cat_priority is not None:
catalogs = sorted(
catalogs,
key=lambda cat: (cat_priority.get(cat[0], float("inf")), cat[1]) # Prioritize by catalog, then release
)
logger.info(f"Running association with the following catalog priorities: {catalogs}")
catalog_dict = OrderedDict(get_catalogs(catalogs))
for cat_name, cat_release in catalog_dict.items():
if cat_name in ONLY_OFFSET_CATS:
calc_host_props_cat = ['offset']
condition_host_props_cat = ['offset']
else:
calc_host_props_cat = calc_host_props
condition_host_props_cat = condition_host_props
cat = GalaxyCatalog(name=cat_name, n_samples=n_samples, data=glade_catalog, release=cat_release)
try:
cat.get_candidates(transient, time_query=True, logger=logger, cosmo=cosmo, calc_host_props=calc_host_props_cat, cat_cols=cat_cols)
except requests.exceptions.HTTPError:
logger.warning(f"Candidate retrieval failed for {transient.name} in catalog {cat_name} due to an HTTPError.")
continue
if cat.ngals > 0:
cat = transient.associate(cat, cosmo, condition_host_props=condition_host_props_cat)
if transient.best_host != -1:
print_props = ['objID', 'name', 'ra', 'dec', 'total_posterior']
condition_props = list(priors.keys())
ordinals = ["best", "2nd best", "3rd best", "4th best", "5th best", "6th best", "7th best", "8th best", "9th best", "10th best"]
# Log properties for top n_hosts
for i, host_idx in enumerate(transient.best_hosts):
rank_label = ordinals[i] if i < len(ordinals) else f"{i+1}th best"
log_host_properties(logger, transient.name, cat, host_idx, Fore.BLUE+f"\nProperties of {rank_label} host (in {cat_name} {cat_release})", print_props, calc_host_props, condition_props)
# Populate results for all n_hosts
for i, host_idx in enumerate(transient.best_hosts):
key = "host" if i == 0 else f"host_{i+1}"
for field in fields:
result[f"{key}_{field}"] = cat.galaxies[field][host_idx]
# Set additional metadata
result.update({
"best_cat": cat_name,
"best_cat_release": cat_release,
"query_time": cat.query_time,
"smallcone_posterior": transient.smallcone_posterior,
"missedcat_posterior": transient.missedcat_posterior,
"any_posterior": transient.any_posterior,
"none_posterior": transient.none_posterior,
})
# Collect extra catalog columns if needed (for best host only)
if cat_cols and len(transient.best_hosts) > 0:
result["extra_cat_cols"] = {field: cat.galaxies[field][transient.best_hosts[0]] for field in cat.cat_col_fields}
if (result['host_name'].startswith("NGC")) or (result['host_name'].startswith("M")):
logger.info(f"Matched host is {result['host_name']}!")
logger.info(
f"Chosen galaxy has catalog ID of {result['host_objID']} "
f"and RA, DEC = {result['host_ra']:.6f}, {result['host_dec']:.6f}"
)
# Query NED for spectroscopic redshift if requested
if best_redshift:
z_spec, z_spec_err, has_specz = get_ned_specz(
result['host_ra'],
result['host_dec'],
search_radius=1.0,
logger=logger
)
if has_specz:
logger.info(
f"Updating host redshift from catalog value "
f"(z={result['host_redshift_mean']:.4f}, {result['host_redshift_info']}) "
f"to NED spectroscopic value (z={z_spec:.4f}±{z_spec_err:.4f})"
)
result['host_redshift_mean'] = z_spec
result['host_redshift_std'] = z_spec_err
result['host_redshift_info'] = 'SPEC'
# For some reason the value of "verbose" is ignored here, and the effective
if plot_match and logger.getEffectiveLevel() == logging.DEBUG:
try:
plot_match(
[result["host_ra"]],
[result["host_dec"]],
result["host_redshift_mean"],
result["host_redshift_std"],
transient.position.ra.deg,
transient.position.dec.deg,
transient.name,
transient.redshift,
0,
f"{transient.name}_{cat_name}_{cat_release}",
logger
)
except HTTPError:
logger.warning("Couldn't get an image. Waiting 60s before moving on.")
time.sleep(60)
continue
# Stop searching after first valid match
break
if transient.best_host == -1:
logger.info("No good host found!")
return result
[docs]
def associate_sample(
transient_catalog,
catalogs,
name_col = None,
coord_cols = None,
redshift_col = None,
cat_priority=None,
run_name=None,
priors=None,
likes=None,
n_samples=1000,
verbose=1,
n_hosts=2,
parallel=True,
save=True,
save_path="./",
log_path=None,
cat_cols=False,
progress_bar=False,
cosmology=None,
n_processes=None,
calc_host_props=True,
coord_err_cols=None,
best_redshift=False
):
"""Wrapper function for associating sample of transients.
Parameters
----------
transient_catalog : pandas.DataFrame
Dataframe containing transient name and coordinates.
priors : dict
Dictionary of prior distributions on redshift, fractional offset, and/or absolute magnitude
likes : dict
Dictionary of likelihood distributions on redshift, fractional offset, absolute magnitude
catalogs : list
List of catalogs to query (can include 'glade', 'decals', 'panstarrs')
cat_priority : dict
Dict of catalog priority (determines what gets run first)
run_name : str or None
Optional name for the run -- used to name logfiles
n_samples : int
List of samples to draw for monte-carlo association.
verbose : int
Verbosity level for logging; can be 0 - 3.
n_hosts : int
Number of potential hosts to return.
parallel : boolean
If True, runs in parallel with multiprocessing. Cannot be used with ipython!
save : boolean
If True, saves resulting association table to save_path.
save_path : str
Path where the association table should be saved (when save=True).
log_path : str
Path where the logfile should be saved. If none, log everything to screen
cat_cols : boolean
If True, contatenates catalog columns to resulting DataFrame.
progress_bar : boolean
If True, prints a loading bar for each association (when parallel=True).
cosmology : astropy.cosmology
Assumed cosmology for the run (defaults to LambdaCDM if unspecified).
n_processes : int
Number of parallel processes to run when parallel=True (defaults to n_cores-4 if unspecified).
calc_host_props : boolean
If True, calculates all host properties (redshift, absmag, and fractional offset) regardless of whether or not
they're needed for association.
best_redshift : boolean, optional
If True, queries NED for spectroscopic redshift when host is found within 1 arcsec.
Default is False.
Returns
-------
pandas.DataFrame
The transient dataframe with columns corresponding to the associated transient.
"""
ts = int(time.time())
if isinstance(transient_catalog, pd.DataFrame):
# randomly shuffle
transient_catalog = transient_catalog.sample(frac=1)
transient_catalog.reset_index(drop=True, inplace=True)
else:
raise ValueError("transient_catalog parameter must be a pandas.DataFrame object.")
envkey = 'PYSPAWN_' + os.path.basename(__file__)
is_main = not os.environ.get(envkey, False)
if log_path is not None:
if run_name is not None:
log_fn = f"{log_path}/Prost_log_{run_name}_{ts}.txt"
else:
log_fn = f"{log_path}/Prost_log_{ts}.txt"
if is_main:
logger = setup_logger(log_file=log_fn, verbose=verbose, is_main=is_main)
logger.info(f"Created log file at {log_fn}.")
os.environ['LOG_PATH_ENV'] = log_fn
else:
log_fn = os.environ.get("LOG_PATH_ENV", None)
logger = setup_logger(log_file=log_fn, verbose=verbose, is_main=False)
else:
log_fn = None
logger = setup_logger(verbose=verbose, is_main=is_main)
if not cosmology:
cosmo = LambdaCDM(H0=70, Om0=0.3, Ode0=0.7)
possible_keys = ["offset", "absmag", "redshift"]
priors = {k: v for k, v in priors.items() if k in possible_keys}
likes = {k: v for k, v in likes.items() if k in possible_keys}
# ensure coordinates are in df
if coord_cols is None:
coord_cols = ('ra','dec')
if (coord_cols[0] not in transient_catalog.columns.values) or (coord_cols[1] not in transient_catalog.columns.values):
return ValueError("Could not find coordinate data in table. Specify RA and Dec columns with the argument 'coord_cols'.")
if coord_err_cols is None:
coord_err_cols = ("ra_err", "dec_err")
# make sure name is in the DF -- if not, use index of df
if (name_col is None):
name_col = 'name'
if (name_col not in transient_catalog.columns.values):
logger.warning("Could not find column for transient names. Creating dummy names from dataframe index instead.")
transient_catalog['name'] = ['Transient_%i'%x for x in transient_catalog.index]
# ensure redshift is in df (if using redshift)
if redshift_col is None:
redshift_col = 'redshift'
if (redshift_col not in transient_catalog.columns.values) and (priors is not None) and ('redshift' in (priors.keys())):
logger.warning("Using redshift for association but no redshift column found for transient. Association may be prior dominated.")
# Validate that at least one prior remains
if not priors:
raise ValueError(f"ERROR: Please set a prior function for at least one of {possible_keys}.")
if is_main:
logger.info(f"Conditioning association on the following properties: {list(priors.keys())}")
for key in priors:
if (key != 'redshift') and (key not in likes):
raise ValueError(f"ERROR: Please set a likelihood function for {key}.")
# always load GLADE -- we now use it for spec-zs.
pkg = pkg_resources.files("astro_prost")
pkg_data_file = pkg / "data" / "GLADE+_HyperLedaSizes_mod_withz.csv.gz"
try:
with pkg_resources.as_file(pkg_data_file) as csvfile:
glade_catalog = pd.read_csv(csvfile, compression="gzip", low_memory=False)
if glade_catalog is not None:
logger.info("Loaded GLADE+ catalog.")
except FileNotFoundError:
logger.warning("Could not find GLADE+ catalog.")
glade_catalog = None
results = {}
events = [
(
idx,
row,
glade_catalog,
n_samples,
priors,
likes,
cosmo,
catalogs,
cat_priority,
name_col,
coord_cols,
redshift_col,
cat_cols,
log_fn,
n_hosts,
calc_host_props,
verbose,
coord_err_cols,
False, # strict_checking
True, # warn_on_fallback
False, # plot_match
best_redshift
)
for idx, row in transient_catalog.iterrows()
]
if parallel and is_main:
os.environ[envkey] = str(os.getpid())
if n_processes is None or n_processes > NPROCESS_MAX:
logger.info(f"WARNING! n_processes > {NPROCESS_MAX}. Dropping down.")
n_processes = NPROCESS_MAX
logger.info(f"Parallelizing {len(transient_catalog)} associations across {n_processes} processes.")
# Limit batch size when parallelizing
batch_size = max(min(int(len(transient_catalog) / max(n_processes, 1)), 1000), 10)
total_batches = int(np.ceil(len(events) / batch_size))
for batch_num, batch in enumerate(chunks(events, batch_size), start=1):
logger.info(f"Processing batch {batch_num}/{total_batches} with {len(batch)} events.")
results_per_batch = {}
with ProcessPoolExecutor(max_workers=n_processes) as executor:
futures = {executor.submit(safe_associate_transient, *event): event[0] for event in batch}
for future in tqdm(as_completed(futures), total=len(futures), desc=f"Batch {batch_num}", disable=not progress_bar):
try:
results_per_batch[futures[future]] = future.result()
except Exception as e:
logger.error(f"Unhandled error for event {futures[future]}: {e}", exc_info=True)
results_per_batch[futures[future]] = None
results.update(results_per_batch) # Merge into main results
if save:
logger.info("Saving intermediate batch results...")
transient_catalog_batch = consolidate_results(results_per_batch, transient_catalog)
save_results(transient_catalog_batch, run_name, save_path, drop_unassociated=True)
gc.collect()
# Retry logic for failed associations
for retry in range(MAX_RETRIES):
failed_ids = [event_id for event_id, res in results_per_batch.items() if res is None]
if not failed_ids:
logger.info("All associations succeeded; no more retries needed.")
break
logger.info(f"Retry attempt {retry+1}: Rerunning {len(failed_ids)} failed associations.")
failed_events = [event for event in batch if event[0] in failed_ids]
with ProcessPoolExecutor(max_workers=n_processes) as executor:
new_futures = {executor.submit(safe_associate_transient, *event): event[0] for event in failed_events}
for future in tqdm(as_completed(new_futures), total=len(new_futures), desc="Retrying events", disable=not progress_bar):
try:
results_per_batch[new_futures[future]] = future.result()
except Exception as e:
logger.error(f"Retry failed for event {new_futures[future]}: {e}", exc_info=True)
results_per_batch[new_futures[future]] = None
# Merge new results into main results
results.update(results_per_batch)
if retry == MAX_RETRIES - 1:
logger.warning("Some associations still failed after maximum retries.")
else: # Serial execution mode
iterable = tqdm(
events,
total=len(events),
desc="Associating (serial)",
disable=not progress_bar)
results = {}
for i, event in enumerate(iterable):
results[i] = associate_transient(*event)
if (not parallel) or (os.environ.get(envkey) == str(os.getpid())):
transient_catalog = consolidate_results(results, transient_catalog)
# save final results
if save:
save_results(transient_catalog, run_name, save_path)
return transient_catalog
[docs]
def safe_associate_transient(*args, **kwargs):
"""Safely executes `associate_transient` while handling errors.
Parameters
----------
*args : tuple
Positional arguments to be passed directly to `associate_transient`.
The first argument (`args[0]`) is expected to be the transient's catalog index.
**kwargs : dict
Keyword arguments passed to `associate_transient`.
Returns
-------
dict or None
The output of `associate_transient` if successful, otherwise `None`.
"""
logger = logging.getLogger("Prost_logger")
try:
return associate_transient(*args, **kwargs)
except Exception as e:
logger.exception(f"Error processing event {args[0]}: {e}")
return None