# from utils.config import *
import numpy as np
import jax.numpy as jnp
from jax.numpy.fft import fftfreq, ifft2, fftn
from lenspack.image.transforms import starlet2d
from typing import List
import scipy
[docs]
class WaveletDecomposer:
def __init__(self, num_scales: int = 3) -> None:
"""
Initialize the WaveletDecomposer class with optional static parameter initialization.
Parameters:
-----------
num_scales : int, optional
The number of decomposition scales to be used. Defaults to 3.
"""
self.num_scales = num_scales
self.precomputed_tophat_filters = {}
[docs]
def set_size_dependent_params(self, image_shape, L = None):
"""
Initialize size-dependent parameters (frequency grids, pixel size, etc.) based on the image shape.
Parameters:
-----------
image_shape : Tuple[int, int]
The shape of the input image.
L : int, optional
The length scale of the map. Defaults to the size of the input map.
"""
self.image_shape = image_shape
self.L = L or image_shape[0] # Set the length scale to the size of the image if not provided
self.dx = self.L / image_shape[0] # Calculate pixel size
self.kx, self.ky = fftfreq(image_shape[0], self.dx), fftfreq(image_shape[1], self.dx)
self.kx, self.ky = jnp.meshgrid(self.kx, self.ky, indexing='ij')
self.k_squared = (self.kx**2 + self.ky**2)
self.precomputed_tophat_filters.clear()
[docs]
def set_image(self, input_map: np.ndarray) -> None:
"""
Set the input image and compute its Fourier transform.
Parameters:
-----------
input_map : np.ndarray
The new input image or map for decomposition.
"""
self.input_map = input_map
self.field_ft = fftn(self.input_map) # Compute the Fourier transform of the input map
[docs]
@staticmethod
def top_hat_window_fourier(k):
"""
Compute the top-hat window function in Fourier space.
Parameters:
-----------
k : np.ndarray
The frequency grid in Fourier space.
Returns:
--------
np.ndarray
The top-hat filter in Fourier space.
"""
k = jnp.where(k == 0, 1e-7, k) # Avoid division by zero
return 2.0 * scipy.special.j1(k) / k
[docs]
def get_top_hat_filter(self, window_radius: float) -> np.ndarray:
"""
Retrieve or compute the top-hat filter in Fourier space for a given radius.
Parameters:
-----------
window_radius : float
The radius for the top-hat filter in Fourier space.
Returns:
--------
np.ndarray
The top-hat filter for the specified radius.
"""
k = 2 * jnp.pi * jnp.sqrt(self.k_squared)
filter_window = self.top_hat_window_fourier(k * window_radius)
return filter_window
[docs]
def get_th_smooth_map(self, map, window_radius: float):
"""
Retrieve or compute the top-hat filter in Fourier space for a given radius.
Parameters:
-----------
window_radius : float
The radius for the top-hat filter in Fourier space.
Returns:
--------
np.ndarray
The top-hat filter for the specified radius.
"""
self.set_size_dependent_params(map.shape)
self.set_image(map)
filter_window = self.get_top_hat_filter(window_radius)
mapft = self.field_ft * filter_window
map_smooth = ifft2(mapft).real
return map_smooth
[docs]
def decompose_with_tophat(self) -> np.ndarray:
"""
Perform wavelet decomposition using the top-hat filter.
Returns:
--------
np.ndarray
The wavelet coefficients and coarse scale image.
"""
coarse_image = self.input_map # Coarse scale image (initially the input map)
wavelet_coeffs = [] # List to store wavelet coefficients
for scale in range(1, self.num_scales + 1):
# Apply the top-hat filter in Fourier space for the current scale (2^scale)
coarse_image_ft = self.field_ft * self.get_top_hat_filter(2**scale)
coarse_image_new = ifft2(coarse_image_ft).real # Inverse Fourier transform to obtain the new coarse image
wavelet_coeff = coarse_image - coarse_image_new # Compute the wavelet coefficient
coarse_image = coarse_image_new # Update coarse image for the next iteration
wavelet_coeffs.append(wavelet_coeff)
wavelet_coeffs.append(coarse_image) # Append the final coarse image
return np.array(wavelet_coeffs)
[docs]
def decompose_with_starlet(self) -> np.ndarray:
"""
Perform wavelet decomposition using the starlet transform.
Returns:
--------
np.ndarray
The wavelet coefficients and coarse scale image.
"""
return starlet2d(self.input_map, self.num_scales)
[docs]
def reconstruct(self, coefficients: List[np.ndarray]) -> np.ndarray:
"""
Reconstruct image from wavelet coefficients by summing all scales.
Args:
coefficients: List of coefficient arrays from decompose()
Returns:
Reconstructed image array
"""
# Simple reconstruction by summing all scales
return np.sum(coefficients, axis=0)
[docs]
def decompose(self, input_map: np.ndarray, num_scales = 1, filter_type: str = 'tophat', recalculate_params: bool = True):
"""
Perform wavelet decomposition based on the chosen filter type for the given image.
Parameters:
-----------
input_map : np.ndarray
The image to decompose.
num_scales : int, optional
Number of scales for decomposition. If None, the class default is used.
filter_type : str, optional
The type of filter to use for decomposition ('tophat' or 'starlet'). Default is 'tophat'.
recalculate_params : bool, optional
If True, recalculates the static parameters (L, kx, ky) based on the image size. Default is False.
Returns:
--------
np.ndarray
The wavelet coefficients and coarse scale image.
"""
if num_scales is not None:
self.num_scales = num_scales
if recalculate_params:
self.set_size_dependent_params(input_map.shape)
self.set_image(input_map)
if filter_type == 'tophat':
return self.decompose_with_tophat()
elif filter_type == 'starlet':
return self.decompose_with_starlet()
else:
raise ValueError(f"Unknown filter type: {filter_type}. Supported values are 'tophat' and 'starlet'.")
if __name__ == "__main__":
input_image_1 = np.random.rand(256, 256)
input_image_2 = np.random.rand(256, 256)
decomposer = WaveletDecomposer()
wavelet_coefficients_1 = decomposer.decompose(input_image_1, num_scales=5, filter_type='tophat', recalculate_params=True)
wavelet_coefficients_2 = decomposer.decompose(input_image_2, num_scales=5, filter_type='starlet', recalculate_params=False)