Source code for golconda.adjustcls

# from utils.config import *
import numpy as np
import jax
import jax.numpy as jnp
import pyfftw
from jax.numpy.fft import rfft2, irfft2, rfftfreq, fftfreq, ifft2, fftn
from scipy.interpolate import interp1d
[docs] class PowerSpectrumAdjuster: def __init__(self, map_shape, pixel_size): """ Initialize the precomputed values for power spectrum adjustment. Parameters: - map_shape: tuple, the shape of the input map (assumes square maps). - pixel_size: float, the pixel size in arcmin/pixel. """ self.map_shape = map_shape self.pixel_size = pixel_size self.pix_to_rad = pixel_size / 60. * np.pi / 180. # Convert to radians/pixel self.N = map_shape[0] # Assume square maps self.center = self.N / 2 # Enable pyFFTW plan cache for efficiency pyfftw.interfaces.cache.enable() # Precompute 2D ell grid and radial bins self._precompute_grids() def _precompute_grids(self): """ Precompute the 2D ell grid and radial bins. """ inds = jnp.fft.fftfreq(self.N, d=1. / self.N) X, Y = jnp.meshgrid(inds, inds) self.R = jnp.sqrt(X**2 + Y**2) self.ell2d = self.R * (2 * np.pi / self.pix_to_rad) / 360. # Precompute radial bin indices (k) y, x = np.indices(self.map_shape) self.radial_bins = np.sqrt((x - self.center)**2 + (y - self.center)**2).astype(np.int32)
[docs] def compute_power_spectrum(self, image): """ Compute the 1D power spectrum from a given 2D image. Parameters: - image: 2D numpy array, the input image. Returns: - ell: 1D numpy array, the multipole moments. - power_spectrum_1d: 1D numpy array, the power spectrum values in [rad^2]. """ # FFT using pyFFTW fft_object = pyfftw.builders.fft2(image, threads=8) fft_map = fft_object() # Normalize the FFT by N^2 fft_map_normalized = fft_map / (self.N**2) # Shift the zero frequency to the center data_ft_shifted = jnp.fft.fftshift(np.abs(fft_map_normalized)) # Compute the 2D power spectrum power_spectrum_2d = (data_ft_shifted**2) # Radial profile using precomputed bins tbin = jnp.bincount(self.radial_bins.ravel(), power_spectrum_2d.ravel()) nr = jnp.bincount(self.radial_bins.ravel()) radial_profile = jnp.where(nr > 0, tbin / nr, 0) # Nyquist limit nyquist = self.N // 2 powerspectrum = radial_profile[:nyquist] # 1D Power spectrum in [rad^2] power_spectrum_1d = powerspectrum * (self.pix_to_rad**2) # Multipole moments (ell) ks = jnp.arange(power_spectrum_1d.shape[0]) # ell = 2. * np.pi * ks / (self.pix_to_rad * self.N) ell = 4. * np.pi * np.pi * ks return ell, power_spectrum_1d
[docs] def adjust_map_cls(self, input_map, target_cls, target_ells): """ Adjust the input map to match the target power spectrum. Parameters: - input_map: 2D numpy array, the input map. - target_cls: 1D numpy array, the target power spectrum values. - target_ells: 1D numpy array, the target multipole moments. Returns: - adjusted_map: 2D numpy array, the adjusted map. """ # Ensure the input map has zero mean # input_map = input_map - np.mean(input_map) # Compute the 2D Fourier grid (if not precomputed) # if not hasattr(self, 'ell2d'): ell_x = 2 * np.pi * np.fft.fftfreq(input_map.shape[0], d=self.pix_to_rad) ell_y = 2 * np.pi * np.fft.fftfreq(input_map.shape[1], d=self.pix_to_rad) ell_x, ell_y = np.meshgrid(ell_x, ell_y) ell2d = np.sqrt(ell_x**2 + ell_y**2) # Interpolate target Cl spectrum onto 2D grid cl_interp = interp1d(target_ells, target_cls, bounds_error=False, fill_value=target_cls[-1]) CLTarget2d = cl_interp(ell2d) # Compute input map's Cl and interpolate ell_input, power_spectrum_1d_input = self.compute_power_spectrum(input_map) cl_interp_input = interp1d(ell_input, power_spectrum_1d_input, bounds_error=False, fill_value=power_spectrum_1d_input[-1]) CLInput2d = cl_interp_input(ell2d) # Compute scaling factors epsilon = 1e-25 scaling_log = 0.5 * (np.log(CLTarget2d) - np.log(CLInput2d + epsilon)) scaling = np.exp(scaling_log) # FFT the input map using pyFFTW fft_object = pyfftw.builders.fft2(input_map, threads=8) fft_map = fft_object() current_amplitudes = np.abs(fft_map) phases = np.angle(fft_map) # Scale Fourier amplitudes to match the target Cl adjusted_amplitudes = current_amplitudes * scaling adjusted_fft_map = adjusted_amplitudes * np.exp(1j * phases) # IFFT to get the adjusted map ifft_object = pyfftw.builders.ifft2(adjusted_fft_map, threads=8) adjusted_map = ifft_object().real return adjusted_map
[docs] class PowerSpectrum: """ Class for calculating power spectrum and generating fields with target power spectrum. Args: map (ndarray): Input map. pixelsize (float): Size of each pixel in degrees. Attributes: pixelsize (float): Size of each pixel in degrees. map_size (int): Size of the input map. ell_min (float): Minimum value of ell. ell_max (float): Maximum value of ell. deltaell (int): Interval between ell values. nbinsell (int): Number of ell bins. pixel_size_rad_perpixel (float): Pixel size in radians per pixel. lpix (float): Value of lpix. lx (ndarray): Array of lx values. ly (ndarray): Array of ly values. l (ndarray): Array of l values. ell_edges (ndarray): Array of ell bin edges. Methods: calculate_Cls(map): Calculates the power spectrum (Cls) of the input map. generate_field_with_target_cls(input_field, target_cls, target_ells): Generates a field with the target power spectrum. """ def __init__(self, map, pixelsize): self.pixelsize = pixelsize self.map_size = map.shape[0] self.ell_min = 180. / (self.map_size * np.deg2rad(self.pixelsize)) self.ell_max = 90 / np.deg2rad(self.pixelsize) self.deltaell = 200 self.nbinsell = int((self.ell_max - self.ell_min) / self.deltaell) self.pixel_size_rad_perpixel = np.pi * self.pixelsize / (180. * 60.) self.lpix = 2. * np.pi / self.pixel_size_rad_perpixel / 360. self.lx = rfftfreq(self.map_size) * self.map_size * self.lpix self.ly = fftfreq(self.map_size) * self.map_size * self.lpix self.l = jnp.sqrt(self.lx[np.newaxis, :] ** 2 + self.ly[:, np.newaxis] ** 2) self.ell_edges = jnp.linspace(self.ell_min, self.ell_max, num=self.nbinsell + 1)
[docs] def calculate_Cls(self, map): """ Calculates the power spectrum (Cls) of the input map. Args: map (ndarray): Input map. Returns: tuple: A tuple containing the ell edges, ell bins, and Cls values. """ map_ft = rfft2(map) power_spectrum = jnp.abs(map_ft) ** 2 # Digitize frequencies into bins bin_idx = jnp.digitize(self.l.ravel(), self.ell_edges) - 1 valid_mask = (bin_idx >= 0) & (bin_idx < self.nbinsell) # Aggregate power and counts per bin power_l = jnp.zeros(self.nbinsell) hits = jnp.zeros(self.nbinsell) power_l = power_l.at[bin_idx[valid_mask]].add(power_spectrum.ravel()[valid_mask]) hits = hits.at[bin_idx[valid_mask]].add(1) # Compute Cls cls_values = jnp.where(hits > 0, power_l / hits, 0.0) ell_bins = 0.5 * (self.ell_edges[1:] + self.ell_edges[:-1]) normalization = (jnp.deg2rad(self.pixelsize * self.map_size) / self.map_size ** 2) ** 2 return self.ell_edges, ell_bins, cls_values * normalization
[docs] def generate_field_with_target_cls(self, input_field, target_cls, target_ells): """ Generates a field with the target power spectrum. Args: input_field (ndarray): Input field. target_cls (ndarray): Target power spectrum values. target_ells (ndarray): Corresponding ell values for the target power spectrum. Returns: ndarray: Generated field with the target power spectrum. Raises: AssertionError: If the lengths of target_cls and target_ells are not equal. AssertionError: If any value in target_cls is negative. """ assert len(target_cls) == len(target_ells), "target_cls and target_ells must have the same length" assert jnp.all(target_cls >= 0), "All target_cls values must be non-negative" field_ft = rfft2(input_field) _, ell_bins, field_cls = self.calculate_Cls(input_field) # Interpolate Cls field_cls_interp = interp1d(ell_bins, field_cls, kind="linear", bounds_error=False, fill_value=field_cls[-1]) target_cls_interp = interp1d(target_ells, target_cls, kind="linear", bounds_error=False, fill_value=target_cls[-1]) Cl_field = field_cls_interp(self.l) Cl_target = target_cls_interp(self.l) # Adjust the amplitude adjustment_factor = jnp.sqrt(Cl_target / (Cl_field + 1e-20)) adjusted_amplitude = jnp.abs(field_ft) * adjustment_factor # Preserve phase and inverse transform adjusted_field_ft = adjusted_amplitude * jnp.exp(1j * jnp.angle(field_ft)) return irfft2(adjusted_field_ft).real
# if __name__ == "__main__": # # Generate a random input map # input_map = np.random.rand(256, 256) # pixelsize = 0.5 # Example pixel size in degrees # # Instantiate the PowerSpectrum class # power_spectrum_calculator = PowerSpectrum(input_map, pixelsize) # # Compute the power spectrum # ell_edges, ell_bins, cls_values = power_spectrum_calculator.calculate_Cls(input_map) # print("Power spectrum calculated:", cls_values) # # Generate a field with a target power spectrum # target_ells = np.linspace(ell_edges.min(), ell_edges.max(), len(cls_values)) # target_cls = cls_values * 1.2 # Example modification of Cls values # generated_field = power_spectrum_calculator.generate_field_with_target_cls(input_map, target_cls, target_ells) # print("Generated field with target power spectrum.")
[docs] def fourier_coordinate(x, y, map_size): return (map_size // 2 + 1) * x + y
[docs] def calculate_Cls_(map, angle, ell_min, ell_max, n_bins): """ map: the image from which the angular power spectra (Cls) has to be calculated angle: side angle in the units of degree ell_min: the minimum multipole moment to get the Cls ell_max: the maximum multipole moment to get the Cls n_bins: number of bins in the ells """ ell_min = jnp.array(ell_min) ell_max = jnp.array(ell_max) # n_bins = jnp.array(n_bins, int) # Calculate the Fourier Transforms map_ft = rfft2(map) ## rfft2 map_ft = map_ft.flatten() ell_edges = jnp.linspace(ell_min, ell_max, num=n_bins+1) # Define pixel physical size in Fourier space pixel_size = angle*60 / map.shape[0] pixel_size_rad_perpixel = np.pi * pixel_size / 180. / 60. lpix = 2. * np.pi / pixel_size_rad_perpixel / 360. # Initialize arrays to store power and hits for each ell bin power_l = jnp.zeros(n_bins) hits = jnp.zeros(n_bins) def loop_body(j, val): i, power_l, hits = val lx = jnp.minimum(i, map.shape[1] - i) * lpix ly = j * lpix l = jnp.sqrt(lx**2. + ly**2.) pixid = fourier_coordinate(i, j, map.shape[1]) bin_idx = jnp.digitize(l, ell_edges) - 1 power_l = power_l.at[bin_idx].add(jnp.abs(map_ft[pixid]**2.)) hits = hits.at[bin_idx].add(1) return i, power_l, hits def outer_loop_body(i, val): _, power_l, hits = val _, power_l, hits = jax.lax.fori_loop(0, map.shape[0], loop_body, (i, power_l, hits)) return i, power_l, hits _, power_l, hits = jax.lax.fori_loop(0, map.shape[1], outer_loop_body, (0, power_l, hits)) # Calculate Cls based on the accumulated power and hits cls_values = jnp.where(hits > 0, power_l / hits, 0.0) # Ensure no division by zero # cls_values = power_l/hits cls_values = jnp.maximum(cls_values, 0) # Clip negative values to zero if any ell_bins = 0.5 * (ell_edges[1:] + ell_edges[:-1]) normalization = (jnp.deg2rad(angle) / (map.shape[0] * map.shape[0]))**2 return jnp.array(ell_edges), jnp.array(ell_bins), jnp.array(cls_values * normalization)
# @partial(jax.jit, static_argnums=(1, 4, 5))
[docs] def generate_field_with_target_cls_(input_field, angle, target_cls, target_ells, ell_max=40000, n_bins=50): """ Adjusts the power spectrum of an input field to match a target power spectrum. Parameters: - input_field (jax.numpy.ndarray): The input field as a 2D JAX array. - angle (float): The field of view in degrees. - target_cls (jax.numpy.ndarray): Target power spectrum values. - target_ells (jax.numpy.ndarray): Target multipole moments associated with target_cls. - ell_max (int): Maximum multipole moment to consider. - n_bins (int): Number of bins to use for the power spectrum calculation. Returns: - jax.numpy.ndarray: The adjusted input field, transformed to match the target power spectrum. """ shape = input_field.shape assert len(target_cls) == len(target_ells), "target_cls and target_ells must have the same length" assert jnp.all(target_cls >= 0), "All target_cls values must be non-negative" # Fourier transform of the input field field_ft = rfft2(input_field) ## # Calculate the Cls for the input field ell_edges, ell_bins, field_cls = calculate_Cls(input_field, angle, 0, ell_max, n_bins) map = input_field # Compute lpix and l values for FFT pixels pixel_size = angle*60 / map.shape[0] pixel_size_rad_perpixel = np.pi * pixel_size / 180. / 60. lpix = 2. * np.pi / pixel_size_rad_perpixel / 360. lx = rfftfreq(shape[0]) * shape[0] * lpix ## ly = fftfreq(shape[1]) * shape[1] * lpix l = jnp.sqrt(lx[np.newaxis, :]**2 + ly[:, np.newaxis]**2) # Interpolate Cls for the input field and the target field_cls_interp = interp1d(ell_bins, field_cls, kind="linear", bounds_error=False, fill_value=target_cls[-1]) Cl_field = field_cls_interp(l) target_cls_interp = interp1d(target_ells, target_cls, kind="linear", bounds_error=False, fill_value=target_cls[-1]) Cl_target = target_cls_interp(l) # Adjust the amplitude based on the target Cls adjustment_factor = jnp.sqrt(Cl_target / Cl_field) adjusted_amplitude = jnp.abs(field_ft) * adjustment_factor # Recombine adjusted amplitude with original phase adjusted_field_ft = adjusted_amplitude * jnp.exp(1j * jnp.angle(field_ft)) # print(adjusted_field_ft.shape) # Inverse Fourier Transform to get the adjusted field in real space adjusted_field = irfft2(adjusted_field_ft) ## # print(shape, adjusted_field.shape, field_ft.shape) return adjusted_field.real