Source code for wale.FilterFunctions

# from imports import *
import numpy as np

# import scipy.special
from scipy import special as sp
from functools import lru_cache
import mpmath as mp


[docs] def top_hat_filter(k, R): """ Calculates the top-hat window function for a given radius. Parameters: R (float or numpy.ndarray): The scale (or array of scales) at which to calculate the window function. Returns: numpy.ndarray: The top-hat window function values at the given scale(s). """ return 2.0 * sp.j1(k * R) / (k * R)
# def top_hat_window(R): # """ # Calculates the top-hat window function for a given radius. # Parameters: # R (float or numpy.ndarray): The scale (or array of scales) at which to calculate the window function. # Returns: # numpy.ndarray: The top-hat window function values at the given scale(s). # """ # return 2.0 * sp.j1(R) / R
[docs] def get_W2D_FL(window_radius, map_shape, filter_type, L=505): """ Constructs a 2D Fourier-space window function for a top-hat filter. Parameters: window_radius : float The top-hat window radius in physical units (must be consistent with L). map_shape : tuple Shape of the map (assumed square, e.g. (600,600)). L : float, optional Physical size of the map (default is 505, as used for SLICS). Returns: 2D numpy array representing the Fourier-space window. """ N = map_shape[0] dx = N / N # Generate Fourier frequencies. kx = np.fft.fftshift(np.fft.fftfreq(N, dx)) ky = np.fft.fftshift(np.fft.fftfreq(N, dx)) kx, ky = np.meshgrid(kx, ky, indexing="ij") k2 = kx**2 + ky**2 # Convert to radial wavenumber (with 2pi factor). k = 2 * np.pi * np.sqrt(k2) # Avoid division by zero at the center. ind = int(N / 2) k[ind, ind] = 1e-7 if filter_type == "tophat": return top_hat_filter(k, window_radius) elif filter_type == "starlet": # print("Getting starlet W2D_FL") return starlet_filter(k, window_radius)
# return uHat_starlet_analytical(k, window_radius)
[docs] def b3_1D_ft(x): return (np.sin(x / 2) / (x / 2)) ** 4.0
[docs] def b3_2D_ft(x, y): return b3_1D_ft(x) * b3_1D_ft(y)
[docs] def starlet_filter(k, R): """ Computes the Fourier-space starlet filter. Args: k (np.ndarray): 2D array of Fourier frequencies. R (float): The scale at which to compute the filter. Returns: np.ndarray: The computed starlet filter in Fourier space. """ # Calculate the radial frequency # k_radial = np.sqrt(k**2) # Compute the starlet filter return b3_2D_ft(k * R, k * R)
# Fast memoized scalar S function
[docs] @lru_cache(maxsize=None) def S_scalar(n: int, b: float) -> float: if n < -1: raise ValueError("n cannot be smaller than -1.") J0 = sp.j0(b) J1 = sp.j1(b) if n == 0: return b * J1 elif n == -1: return b * float(mp.hyp1f2(0.5, 1, 1.5, -(b**2) / 4)) else: return b ** (n + 1) * J1 + n * b**n * J0 - n**2 * S_scalar(n - 2, b)
# Wrapper to handle arrays
[docs] def S(n: int, b): b = np.asarray(b) if b.ndim == 0: return S_scalar(n, float(b)) else: vec_func = np.vectorize(lambda x: S_scalar(n, float(x))) return vec_func(b)
# Fast uHat_starlet_analytical
[docs] def uHat_starlet_analytical(eta, R): """ Computes the analytical Hankel transform of the starlet U-filter. Args: eta (np.ndarray or float): Dimensionless argument \( \hat{u} \). Returns: float or np.ndarray: Computed \( \hat{u} \). """ # print("Calculating uHat_starlet_analytical (optimized version)") eta = np.asarray(eta) * R eta_safe = np.clip(eta, 2e-2, 100) # Stability for small eta # Precompute all needed S values b_half = 0.5 * eta_safe b_one = eta_safe b_two = 2.0 * eta_safe S0_half = S(0, b_half) S1_half = S(1, b_half) S2_half = S(2, b_half) S3_half = S(3, b_half) S0_one = S(0, b_one) S1_one = S(1, b_one) S2_one = S(2, b_one) S3_one = S(3, b_one) S0_two = S(0, b_two) S1_two = S(1, b_two) S2_two = S(2, b_two) S3_two = S(3, b_two) # Compute factors factor1 = ( 0.125 * eta_safe**3 * S0_half - 0.75 * eta_safe**2 * S1_half + 1.5 * eta_safe * S2_half - S3_half ) # print("done factor1") factor2 = ( eta_safe**3 * S0_one - 3 * eta_safe**2 * S1_one + 3 * eta_safe * S2_one - S3_one ) # print("done factor2") factor3 = ( 8 * eta_safe**3 * S0_two - 12 * eta_safe**2 * S1_two + 6 * eta_safe * S2_two - S3_two ) # print("done factor3") # Final result result = ( (2 * np.pi) * (-128 / 9 * factor1 + 4 * factor2 - 1 / 9 * factor3) / eta_safe**5 ) return result