Source code for wale.CriticalPoints

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import CubicSpline, UnivariateSpline

from .RateFunction import (
    get_psi_2cell,
    get_psi_derivative_delta1,
    get_psi_derivative_delta2,
)


[docs] class CriticalPointsFinder: r""" A class designed to identify critical points where the rate function's convexity changes in a cosmological context. This is achieved through analyzing the Hessian matrix of the rate function across a grid of values, identifying zero crossings in its determinant to locate changes in convexity. The rate function :math:`I(x)` characterizes the exponential decay rate of the probabilities of certain outcomes as the system size increases. The rate function is required to be convex, which ensures that the study of rare events through large deviation principles can be approached effectively through minimization techniques. **Cumulant Generating Function and Legendre-Fenchel Transform** The CGF, denoted by :math:`\Lambda(\theta)`, is foundational for deriving the rate function through the Legendre-Fenchel transform. This transform connects the CGF and the rate function as follows: .. math:: I(x) = \sup_{\theta} \{ \theta x - \Lambda(\theta) \} This equation ensures that the rate function :math:`I(x)` is convex, inheriting this property from the convex CGF :math:`\Lambda(\theta)`. The supremum operation over :math:`\theta` highlights that :math:`I(x)` represents the tightest upper bound of the linear functions defined by :math:`\theta x - \Lambda(\theta)`. **Convexity of the Rate Function** The convexity of the rate function :math:`I(x)` implies the following inequality for any two points :math:`x_1` and :math:`x_2` in its domain and any :math:`\lambda \in [0, 1]`: .. math:: I(\lambda x_1 + (1 - \lambda)x_2) \leq \lambda I(x_1) + (1 - \lambda) I(x_2) This inequality defines the convexity of the rate function, critical for analyzing rare events in large deviation theory. In this method, we use the determinant of the Hessian of the rate function to locate points where it vanishes. These points help identify the values of :math:`\lambda` used in our subsequent calculations. """ def __init__( self, variables, lw, z, chis, ngrid=50, plot=False, ): """ Initializes the CriticalPointsFinder with cosmology and variance objects, and optionally configures plotting. Parameters: variables (VariablesGenerator): An instance containing all necessary cosmological parameters and variables. ngrid (int): The number of grid points to use for delta value calculations. plot (bool): Flag to enable plotting of critical points. """ self.variables = variables self.plot = plot print( f"Setting ngrid = {ngrid}. Increase this for more accuracy, but note that computation becomes slower!" ) self.delta1_vals = np.linspace(-0.99, 1.99, ngrid) self.delta2_vals = np.linspace(-0.99, 1.99, ngrid) self.D1, self.D2 = np.meshgrid( self.delta1_vals, self.delta2_vals, indexing="ij" ) self.lw = lw self.z = z self.chis = chis
[docs] def get_hessian(self, x): """Calculates the Hessian matrix of a function.""" x_grad = np.gradient(x) hessian = np.empty((x.ndim, x.ndim) + x.shape, dtype=x.dtype) for k, grad_k in enumerate(x_grad): tmp_grad = np.gradient(grad_k) for l, grad_kl in enumerate(tmp_grad): hessian[k, l, :, :] = grad_kl return hessian
[docs] def find_zero_crossing_point( self, x1, y1, x2, y2, determinant_value1, determinant_value2 ): """Finds the zero crossing point between two points based on the determinant values.""" t = abs(determinant_value1) / ( abs(determinant_value1) + abs(determinant_value2) ) zero_crossing_x = x1 + t * (x2 - x1) zero_crossing_y = y1 + t * (y2 - y1) return zero_crossing_x, zero_crossing_y
[docs] def find_zero_crossings(self, determinant): """Identifies zero crossings in the determinant grid.""" zero_crossings = [] for i in range(determinant.shape[0] - 1): for j in range(determinant.shape[1] - 1): if determinant[i, j] * determinant[i, j + 1] <= 0: newx, newy = self.find_zero_crossing_point( self.D1[i, j], self.D2[i, j], self.D1[i, j + 1], self.D2[i, j + 1], determinant[i, j], determinant[i, j + 1], ) zero_crossings.append((newx, newy)) if determinant[i, j] * determinant[i + 1, j] <= 0: newx, newy = self.find_zero_crossing_point( self.D1[i, j], self.D2[i, j], self.D1[i + 1, j], self.D2[i + 1, j], determinant[i, j], determinant[i + 1, j], ) zero_crossings.append((newx, newy)) return zero_crossings
[docs] def get_critical_points(self, variance, lw, z, chi_value): """Calculates critical points for the given redshift z and plots them if requested.""" recal_value = self.variables.recal_value theta1 = self.variables.theta1_radian theta2 = self.variables.theta2_radian deld = 1e-8 rate_function = np.vectorize( lambda d1, d2: get_psi_2cell( variance, chi_value, recal_value, z, d1, d2, theta1, theta2, ) )(self.D1, self.D2) hessian = np.array(self.get_hessian(rate_function)) determinants = np.array( (hessian[0, 0, :, :] * hessian[1, 1, :, :]) - (hessian[0, 1, :, :] * hessian[1, 0, :, :]) ) zero_crossings = np.array(self.find_zero_crossings(determinants)) drf1, drf2 = [], [] for x, y in zero_crossings: drf1.append( get_psi_derivative_delta1( deld, variance, chi_value, recal_value, z, x, y, theta1, theta2, ) * self.variables.cosmo.h / lw ) drf2.append( get_psi_derivative_delta2( deld, variance, chi_value, recal_value, z, x, y, theta1, theta2, ) * self.variables.cosmo.h / lw ) drf1, drf2 = np.array(drf1), np.array(drf2) sorted_indices = np.argsort(drf1[:, 0]) # Sort drf1 and drf2 using the sorted indices sorted_drf1 = drf1[sorted_indices, 0] sorted_drf2 = drf2[sorted_indices, 0] drf_spline = CubicSpline(sorted_drf1[:], sorted_drf2[:]) drf1_new = np.linspace(-1000, 3000, 200) drf2_new = drf_spline(drf1_new) sum_derivatives = drf1_new + drf2_new # Fit spline to the sum of derivatives spline1 = UnivariateSpline(drf1_new, sum_derivatives, s=0) sorted_indices = np.argsort(drf2_new) # Find the value of x where the spline is 0 critical_points1 = spline1.roots() # print( # "The approximate critical points at redshift z: ", # z, # " are: ", # -critical_points1, # ) if self.plot: plt.plot(drf1_new, sum_derivatives, label=z) plt.scatter(critical_points1, spline1(critical_points1), color="r") plt.xlim(-1000, 2000) plt.ylim(-1000, 2000) plt.grid(visible=True, which="both", axis="both") plt.legend() return [-x for x in critical_points1]
[docs] def find_smallest_pair(critical_values): """ Finds the pair of points with the smallest Euclidean distance between them from a set of critical values. Parameters: critical_values (numpy.ndarray): An array of critical points. Returns: tuple: The pair of points with the smallest distance and their Euclidean distance. """ num_points = critical_values.shape[0] if num_points < 2: return None, float("inf") # No pair exists smallest_distance = float("inf") smallest_pair = None for i in range(num_points - 1): for j in range(i + 1, num_points): distance = np.linalg.norm(critical_values[i] - critical_values[j]) if distance < smallest_distance: smallest_distance = distance smallest_pair = (critical_values[i], critical_values[j]) return smallest_pair
[docs] def find_critical_points_for_cosmo( variables, variance, ngrid_critical=90, plot=False, min_z=1, max_z=4 ): """ Finds critical points in the lensing potential based on the provided variables. Args: variables: Object containing necessary cosmological variables (e.g., lensingweights, redshifts, chis). variance: Variance or smoothing parameter needed for critical point computation. ngrid_critical (int, optional): Grid resolution for critical point search. Defaults to 90. plot (bool, optional): Whether to plot the results. Defaults to False. max_nz (int, optional): Maximum number of redshift slices to use. Defaults to 4. Returns: tuple: Smallest positive and largest negative critical point values. """ print("Finding critical points...") criticalpoints = CriticalPointsFinder( variables, ngrid=ngrid_critical, lw=variables.lensingweights[min_z:max_z], z=variables.redshifts[min_z:max_z], chis=variables.chis, plot=plot, ) critical_values_list = [] for i, z_crit in enumerate(criticalpoints.z): crit_vals = criticalpoints.get_critical_points( variance, lw=criticalpoints.lw[i], z=z_crit, chi_value=criticalpoints.chis[i], ) if crit_vals is not None and len(crit_vals) >= 2: critical_values_list.append(crit_vals[:2]) if not critical_values_list: print(" Warning: No critical points found in the specified redshift range.") return None, None # Flatten values flat_values = [] for item in critical_values_list: if isinstance(item, (np.ndarray, list)): flat_values.extend(np.ravel(item)) else: flat_values.append(item) flat_values = np.array(flat_values) flat_values = flat_values[~np.isnan(flat_values)] # Compute smallest positive and largest negative values positive_values = flat_values[flat_values > 0] negative_values = flat_values[flat_values < 0] smallest_positive = np.min(positive_values) if positive_values.size > 0 else None largest_negative = np.max(negative_values) if negative_values.size > 0 else None print( "Smallest distance pair of critical points:", smallest_positive, largest_negative, ) return smallest_positive, largest_negative