Source code for daxs.correctors

"""Classes to deal with different types of measurement correctors."""

from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from functools import cached_property
from itertools import cycle

import numpy as np
import numpy.typing as npt

from daxs.scans import Scan, Scans

logger = logging.getLogger(__name__)


[docs] class ConcentrationCorrectionError(Exception):
[docs] def __init__(self, message: str): super().__init__(f"{message} The concentration correction failed.")
[docs] class Corrector(ABC): """Base class for measurement correctors."""
[docs] @abstractmethod def apply(self, scans: Scans) -> None: """Apply the correction to the scans."""
[docs] class SimpleConcentrationCorrector(Corrector): """Class to perform simple, length-based, concentration corrections."""
[docs] def __init__(self, scans: Scan | Scans | list[Scan]): """Initialize the simple concentration corrector. Args: scans: Scans used for concentration correction. """ if isinstance(scans, Scans): self.conc_corr_scans = scans elif isinstance(scans, list): self.conc_corr_scans = Scans(scans) else: self.conc_corr_scans = Scans([scans])
[docs] def apply(self, scans: Scans) -> None: logger.info("Applying simple concentration correction.") # When there is a single concentration correction scan and the number # of points in it is equal to the number of scans to be corrected, each # point will be used to correct a scan. if len(self.conc_corr_scans) == 1: [conc_corr_scan] = self.conc_corr_scans if len(scans) == conc_corr_scan.signal.size: for i, scan in enumerate(scans): try: scalars = (conc_corr_scan.signal[i], conc_corr_scan.monitor[i]) except IndexError: scalars = (conc_corr_scan.signal[i],) scan.divide_by_scalars(*scalars) return # When there is a single concentration correction scan and the previous # condition is not met, divide all scans by it, by cycling it. if len(self.conc_corr_scans) == 1: conc_corr_scans = cycle(self.conc_corr_scans) # When the number of scans to be corrected is equal to the number of # concentration correction scans, each scan will be corrected by a # corresponding concentration correction scan. elif len(self.conc_corr_scans) == len(scans): conc_corr_scans = self.conc_corr_scans # No other case is supported. else: raise ConcentrationCorrectionError( "Incompatible number of scans to correct and concentration " "correction scans." ) for scan, conc_corr_scan in zip(scans, conc_corr_scans): try: scan.divide_by_scan(conc_corr_scan) except (TypeError, ValueError) as e: raise ConcentrationCorrectionError( f"The length of the signal or monitor in the scan {scan.label} " "is different than that from the correction scan " f"{conc_corr_scan.label}." ) from e
[docs] class DataDrivenConcentrationCorrector(Corrector): """Class to perform concentration corrections using data from specified mappings."""
[docs] def __init__(self, scans: Scan | Scans | list[Scan], data_mappings: dict[str, str]): """Initialize the data-driven concentration corrector. Args: scans: Scans used for concentration correction. data_mappings: Mappings between scan attributes and paths in the raw data. """ if isinstance(scans, Scans): self.conc_corr_scans = scans elif isinstance(scans, list): self.conc_corr_scans = Scans(scans) else: self.conc_corr_scans = Scans([scans]) self.data_mappings = data_mappings
@cached_property def conc_corr_points(self) -> npt.NDArray[np.float64]: """Array of points used to determine the concentration correction indices. Returns: Array of points used for concentration correction. Raises: ValueError: If the concentration correction counters do not have the same length. """ points = [] for path in self.data_mappings.values(): points_at_path = [] for scan in self.conc_corr_scans: if scan.filename is None or scan.index is None: raise ConcentrationCorrectionError( "The concentration correction scans must have a filename " "and index defined." ) points_at_path.extend( scan.read_data_at_paths(scan.filename, scan.index, path) ) points.append(points_at_path) try: return np.asarray(points, dtype=np.float64).T except ValueError as e: raise ConcentrationCorrectionError( "The concentration correction counters must have the same length." ) from e
[docs] def find_conc_corr_indices(self, scan: Scan) -> list[int]: """Determine the indices of the concentration correction data for the scan. Args: scan: Scan for which the concentration correction data need to be found. Returns: Indices of the concentration correction data for the points in the scan. """ # Get concentration correction points. conc_corr_points = self.conc_corr_points # Get the scan data at the same keys as the concentration correction data. data_points = [] for key in self.data_mappings: try: data_points.append(scan.data[key]) except KeyError as e: raise ConcentrationCorrectionError( f"The data in scan {scan.label} does not have the key {key} among" "the source data paths. Make sure the source data paths are" "correctly set." ) from e data_points = np.asarray(data_points) # Add a new axis if the data points are 1D. if data_points.ndim == 1: data_points = data_points[:, None] # Transpose the data points to have shape (N, p) [N points, p paths]. data_points = data_points.T # Calculate distances between each data point and each concentration correction # point. # data_points has shape (N, p) [N points, p paths] # data_points[:, None, :] has shape (N, 1, p) # conc_corr_points has shape (M, p) [M points, p paths] # conc_corr_points[None, :, :] has shape (1, M, p) [M points] # The subtraction uses broadcasting to yield an array of shape (N, M, p), # np.linalg.norm(..., axis=2) computes the Euclidean distance resulting in an # array of shape [N, M]. distances = np.linalg.norm( data_points[:, None, :] - conc_corr_points[None, :, :], axis=2, ) threshold = np.finfo(np.float64).eps # For each data point, find the index where the distance is smaller # than the threshold. If the number of indices is different than 1, raise an # error. mask = distances < threshold if np.any(mask.sum(axis=1) == 0): raise ConcentrationCorrectionError( f"No concentration correction points were found for scan {scan.label}." ) if np.any(mask.sum(axis=1) > 1): raise ConcentrationCorrectionError( f"Multiple concentration correction points were found for " f"scan {scan.label}." ) _, indices = np.where(mask) return indices.tolist()
[docs] def create_conc_corr_scan(self, indices: list[int]) -> Scan: """Create a scan with the concentration correction data. Args: indices: Indices of the concentration correction data to be used. Returns: Scan with the concentration correction data at the specified indices. """ signal = np.concatenate([scan.signal for scan in self.conc_corr_scans]) x = np.ones_like(signal[indices]) scan = Scan(x, signal[indices]) try: monitor = np.concatenate([scan.monitor for scan in self.conc_corr_scans]) scan.monitor = monitor[indices] except IndexError: pass return scan
[docs] def apply(self, scans: Scans) -> None: """Apply the concentration correction using data from the specified paths. Args: scans: Scans to be corrected. """ logger.info("Applying data-informed concentration correction.") for scan in scans: indices = self.find_conc_corr_indices(scan) conc_corr_scan = self.create_conc_corr_scan(indices) scan.divide_by_scan(conc_corr_scan)
[docs] class DeadTimeCorrector(Corrector): """Class to perform dead time corrections."""