-
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
127 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import jax.numpy as jnp | ||
from jax import random | ||
from typing import Tuple | ||
|
||
def fastica(X: jnp.ndarray, | ||
n_components: jnp.ndarray, | ||
max_iter: int = 1000, | ||
tol: float = 1e-4) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: | ||
""" | ||
Perform Independent Component Analysis (ICA) on the input data using the FastICA algorithm. | ||
Parameters: | ||
X : jax.numpy.ndarray | ||
The input data matrix, where each row represents a data point, and each column represents a different signal. | ||
The input data should be a 2D jax.numpy array with shape (n_samples, n_features). | ||
n_components : int | ||
The number of independent components to extract. This should be less than or equal to the number of features in the input data. | ||
max_iter : int, optional | ||
The maximum number of iterations for the optimization process. The default value is 1000 iterations. | ||
tol : float, optional | ||
The tolerance for convergence. The optimization process stops when the maximum absolute change in the diagonal elements of the | ||
unmixing matrix from one iteration to the next is less than this tolerance. The default value is 1e-4. | ||
Returns: | ||
S : jax.numpy.ndarray | ||
The separated independent components. This is a 2D jax.numpy array with shape (n_components, n_samples), where each row represents | ||
a different independent component, and each column represents a data point. | ||
W : jax.numpy.ndarray | ||
The unmixing matrix. This is a 2D jax.numpy array with shape (n_components, n_features), representing the estimated inverse of the | ||
mixing matrix. It is used to transform the input data back into the independent components. | ||
whitening_matrix : jax.numpy.ndarray | ||
The whitening matrix used to whiten the input data. This is a 2D jax.numpy array with shape (n_features, n_features), used to decorrelate | ||
the input data and make its covariance matrix the identity matrix. | ||
Description: | ||
The FastICA algorithm aims to separate the mixed input signals into statistically independent components. The function first whitens the input | ||
data to decorrelate it and normalize its variance. Then, it initializes a random unmixing matrix and uses an optimization process to find | ||
the optimal unmixing matrix that maximizes the independence of the source signals. | ||
The optimization process involves iteratively updating the unmixing matrix based on the non-linear function (`tanh` in this case) applied | ||
to the transformed data (`WX`). The process stops when the unmixing matrix converges according to the specified tolerance (`tol`) or when the | ||
maximum number of iterations (`max_iter`) is reached. | ||
Once the optimal unmixing matrix is found, the function applies it to the whitened data to obtain the separated independent components. | ||
Example usage: | ||
# Set random seed | ||
jax.random.PRNGKey(42) | ||
# Generate synthetic source signals | ||
n_samples = 2000 | ||
time = jnp.linspace(0, 8, n_samples) | ||
s1 = jnp.sin(2 * time) | ||
s2 = jnp.sign(jnp.sin(3 * time)) | ||
# Combine the sources with a mixing matrix | ||
A = jnp.array([[1, 1], [0.5, 2]]) | ||
X = jnp.dot(A, jnp.array([s1, s2])) | ||
# Perform ICA | ||
n_components = 2 | ||
S, W, whitening_matrix = fastica(X.T, n_components) | ||
# Plot the results | ||
plt.figure(figsize=(12, 8)) | ||
plt.subplot(3, 1, 1) | ||
plt.title('Original Source Signals') | ||
plt.plot(time, s1, label='Source 1 (Sine Wave)') | ||
plt.plot(time, s2, label='Source 2 (Square Wave)') | ||
plt.legend() | ||
plt.subplot(3, 1, 2) | ||
plt.title('Mixed Signals') | ||
plt.plot(time, X[0], label='Mixed Signal 1') | ||
plt.plot(time, X[1], label='Mixed Signal 2') | ||
plt.legend() | ||
plt.subplot(3, 1, 3) | ||
plt.title('Separated Signals (Using ICA)') | ||
plt.plot(time, S[0], label='Separated Signal 1') | ||
plt.plot(time, S[1], label='Separated Signal 2') | ||
plt.legend() | ||
plt.tight_layout() | ||
plt.show() | ||
""" | ||
# Calculate the covariance matrix and perform eigenvalue decomposition | ||
cov_matrix = jnp.cov(X, rowvar=False) | ||
eigenvalues, eigenvectors = jnp.linalg.eigh(cov_matrix) | ||
|
||
# Sort the eigenvalues and eigenvectors | ||
idx = jnp.argsort(eigenvalues)[::-1] | ||
eigenvalues = eigenvalues[idx] | ||
eigenvectors = eigenvectors[:, idx] | ||
|
||
# Create the whitening matrix | ||
D = jnp.diag(1.0 / jnp.sqrt(eigenvalues)) | ||
whitening_matrix = jnp.dot(eigenvectors, D) | ||
X_whitened = jnp.dot(X, whitening_matrix) | ||
|
||
# Initialize unmixing matrix with random values | ||
rng = random.PRNGKey(0) # Set a seed for reproducibility | ||
W = random.normal(rng, (n_components, n_components)) | ||
|
||
# Perform FastICA algorithm | ||
for _ in range(max_iter): | ||
WX = jnp.dot(X_whitened, W.T) | ||
g = jnp.tanh(WX) | ||
g_prime = 1 - g ** 2 | ||
W_new = (jnp.dot(X_whitened.T, g) / X.shape[0]) - jnp.diag(g_prime.mean(axis=0)).dot(W) | ||
|
||
# Orthogonalize the unmixing matrix | ||
W_new, _ = jnp.linalg.qr(W_new) | ||
|
||
# Check for convergence | ||
if jnp.max(jnp.abs(jnp.abs(jnp.diag(jnp.dot(W_new, W.T))) - 1)) < tol: | ||
W = W_new | ||
break | ||
|
||
W = W_new | ||
|
||
# Calculate the separated independent components | ||
S = jnp.dot(W, X_whitened.T) | ||
|
||
return S, W, whitening_matrix |