Skip to content

Commit

Permalink
Added FastICA algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
HMUNACHI committed May 10, 2024
1 parent 6a951af commit b398e99
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 1 deletion.
2 changes: 1 addition & 1 deletion nanodl/__src/models/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state
from typing import List, Tuple, Any, Optional, Dict, Iterable
from typing import Tuple, Any, Optional, Iterable

class PatchEmbedding(nn.Module):
"""
Expand Down
126 changes: 126 additions & 0 deletions nanodl/__src/sklearn_gpu/dsp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import jax.numpy as jnp
from jax import random
from typing import Tuple

def fastica(X: jnp.ndarray,
n_components: jnp.ndarray,
max_iter: int = 1000,
tol: float = 1e-4) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
Perform Independent Component Analysis (ICA) on the input data using the FastICA algorithm.
Parameters:
X : jax.numpy.ndarray
The input data matrix, where each row represents a data point, and each column represents a different signal.
The input data should be a 2D jax.numpy array with shape (n_samples, n_features).
n_components : int
The number of independent components to extract. This should be less than or equal to the number of features in the input data.
max_iter : int, optional
The maximum number of iterations for the optimization process. The default value is 1000 iterations.
tol : float, optional
The tolerance for convergence. The optimization process stops when the maximum absolute change in the diagonal elements of the
unmixing matrix from one iteration to the next is less than this tolerance. The default value is 1e-4.
Returns:
S : jax.numpy.ndarray
The separated independent components. This is a 2D jax.numpy array with shape (n_components, n_samples), where each row represents
a different independent component, and each column represents a data point.
W : jax.numpy.ndarray
The unmixing matrix. This is a 2D jax.numpy array with shape (n_components, n_features), representing the estimated inverse of the
mixing matrix. It is used to transform the input data back into the independent components.
whitening_matrix : jax.numpy.ndarray
The whitening matrix used to whiten the input data. This is a 2D jax.numpy array with shape (n_features, n_features), used to decorrelate
the input data and make its covariance matrix the identity matrix.
Description:
The FastICA algorithm aims to separate the mixed input signals into statistically independent components. The function first whitens the input
data to decorrelate it and normalize its variance. Then, it initializes a random unmixing matrix and uses an optimization process to find
the optimal unmixing matrix that maximizes the independence of the source signals.
The optimization process involves iteratively updating the unmixing matrix based on the non-linear function (`tanh` in this case) applied
to the transformed data (`WX`). The process stops when the unmixing matrix converges according to the specified tolerance (`tol`) or when the
maximum number of iterations (`max_iter`) is reached.
Once the optimal unmixing matrix is found, the function applies it to the whitened data to obtain the separated independent components.
Example usage:
# Set random seed
jax.random.PRNGKey(42)
# Generate synthetic source signals
n_samples = 2000
time = jnp.linspace(0, 8, n_samples)
s1 = jnp.sin(2 * time)
s2 = jnp.sign(jnp.sin(3 * time))
# Combine the sources with a mixing matrix
A = jnp.array([[1, 1], [0.5, 2]])
X = jnp.dot(A, jnp.array([s1, s2]))
# Perform ICA
n_components = 2
S, W, whitening_matrix = fastica(X.T, n_components)
# Plot the results
plt.figure(figsize=(12, 8))
plt.subplot(3, 1, 1)
plt.title('Original Source Signals')
plt.plot(time, s1, label='Source 1 (Sine Wave)')
plt.plot(time, s2, label='Source 2 (Square Wave)')
plt.legend()
plt.subplot(3, 1, 2)
plt.title('Mixed Signals')
plt.plot(time, X[0], label='Mixed Signal 1')
plt.plot(time, X[1], label='Mixed Signal 2')
plt.legend()
plt.subplot(3, 1, 3)
plt.title('Separated Signals (Using ICA)')
plt.plot(time, S[0], label='Separated Signal 1')
plt.plot(time, S[1], label='Separated Signal 2')
plt.legend()
plt.tight_layout()
plt.show()
"""
# Calculate the covariance matrix and perform eigenvalue decomposition
cov_matrix = jnp.cov(X, rowvar=False)
eigenvalues, eigenvectors = jnp.linalg.eigh(cov_matrix)

# Sort the eigenvalues and eigenvectors
idx = jnp.argsort(eigenvalues)[::-1]
eigenvalues = eigenvalues[idx]
eigenvectors = eigenvectors[:, idx]

# Create the whitening matrix
D = jnp.diag(1.0 / jnp.sqrt(eigenvalues))
whitening_matrix = jnp.dot(eigenvectors, D)
X_whitened = jnp.dot(X, whitening_matrix)

# Initialize unmixing matrix with random values
rng = random.PRNGKey(0) # Set a seed for reproducibility
W = random.normal(rng, (n_components, n_components))

# Perform FastICA algorithm
for _ in range(max_iter):
WX = jnp.dot(X_whitened, W.T)
g = jnp.tanh(WX)
g_prime = 1 - g ** 2
W_new = (jnp.dot(X_whitened.T, g) / X.shape[0]) - jnp.diag(g_prime.mean(axis=0)).dot(W)

# Orthogonalize the unmixing matrix
W_new, _ = jnp.linalg.qr(W_new)

# Check for convergence
if jnp.max(jnp.abs(jnp.abs(jnp.diag(jnp.dot(W_new, W.T))) - 1)) < tol:
W = W_new
break

W = W_new

# Calculate the separated independent components
S = jnp.dot(W, X_whitened.T)

return S, W, whitening_matrix

0 comments on commit b398e99

Please sign in to comment.