Skip to content

Commit

Permalink
fixed docs
Browse files Browse the repository at this point in the history
  • Loading branch information
HMUNACHI committed Feb 12, 2024
1 parent feaecbf commit 97252ba
Show file tree
Hide file tree
Showing 18 changed files with 1,777 additions and 2,215 deletions.
194 changes: 117 additions & 77 deletions nanodl/__src/models/diffusion.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,3 @@
"""
Example usage:
```
import jax
import jax.numpy as jnp
from nanodl import ArrayDataset, DataLoader
from nanodl import DiffusionModel, DiffusionDataParallelTrainer
image_size = 32
block_depth = 2
batch_size = 8
widths = [32, 64, 128]
key = jax.random.PRNGKey(0)
input_shape = (101, image_size, image_size, 3)
images = jax.random.normal(key, input_shape)
# Use your own images
dataset = ArrayDataset(images)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
drop_last=False)
# Create diffusion model
diffusion_model = DiffusionModel(image_size, widths, block_depth)
params = diffusion_model.init(key, images)
pred_noises, pred_images = diffusion_model.apply(params, images)
print(pred_noises.shape, pred_images.shape)
# Training on your data
# Note: saved params are often different from training weights, use the saved params for generation
trainer = DiffusionDataParallelTrainer(diffusion_model,
input_shape=images.shape,
weights_filename='params.pkl',
learning_rate=1e-4)
trainer.train(dataloader, 10, dataloader)
print(trainer.evaluate(dataloader))
# Generate some samples
params = trainer.load_params('params.pkl')
generated_images = diffusion_model.apply({'params': params},
num_images=5,
diffusion_steps=5,
method=diffusion_model.generate)
print(generated_images.shape)
```
"""

import jax
import flax
import time
Expand All @@ -60,15 +11,18 @@

class SinusoidalEmbedding(nn.Module):
"""
Sinusoidal Embedding for images.
Implements sinusoidal embeddings as a layer in a neural network using JAX.
This class generates sinusoidal embeddings for a given input tensor. The embeddings are
created using a range of frequencies determined by the minimum and maximum frequency parameters.
This layer generates sinusoidal embeddings based on input positions and a range of frequencies, producing embeddings that capture positional information in a continuous manner. It's particularly useful in models where the notion of position is crucial, such as in generative models for images and audio.
Attributes:
embedding_dims (int): The dimensionality of the embedding.
embedding_min_frequency (float): The minimum frequency for the sinusoidal embedding.
embedding_max_frequency (float): The maximum frequency for the sinusoidal embedding.
embedding_dims (int): The dimensionality of the output embeddings.
embedding_min_frequency (float): The minimum frequency used in the sinusoidal embedding.
embedding_max_frequency (float): The maximum frequency used in the sinusoidal embedding.
Methods:
setup(): Initializes the layer by computing the angular speeds for the sinusoidal functions based on the specified frequency range.
__call__(x: jnp.ndarray): Generates the sinusoidal embeddings for the input positions.
"""
embedding_dims: int
embedding_min_frequency: float
Expand All @@ -87,6 +41,17 @@ def __call__(self, x):


class UNetResidualBlock(nn.Module):
"""
Implements a residual block within a U-Net architecture using JAX.
This module defines a residual block with convolutional layers and normalization, followed by a residual connection. It's a fundamental building block in constructing deeper and more complex U-Net architectures for tasks like image segmentation and generation.
Attributes:
width (int): The number of output channels for the convolutional layers within the block.
Methods:
__call__(x: jnp.ndarray): Processes the input tensor through the residual block and returns the result.
"""
width: int

@nn.compact
Expand Down Expand Up @@ -115,13 +80,17 @@ def __call__(self,

class UNetDownBlock(nn.Module):
"""
Downsampling block for U-Net architecture.
Implements a down-sampling block in a U-Net architecture using JAX.
This block applies a series of residual blocks followed by average pooling to downsample the input.
This module consists of a sequence of residual blocks followed by an average pooling operation to reduce the spatial dimensions. It's used to capture higher-level features at reduced spatial resolutions in the encoding pathway of a U-Net.
Attributes:
width (int): The number of channels in the residual blocks.
block_depth (int): The number of residual blocks in the down block.
width (int): The number of output channels for the convolutional layers within the block.
block_depth (int): The number of residual blocks to include in the down-sampling block.
Methods:
setup(): Initializes the sequence of residual blocks.
__call__(x: jnp.ndarray): Processes the input tensor through the down-sampling block and returns the result.
"""
width: int
block_depth: int
Expand All @@ -139,14 +108,17 @@ def __call__(self,

class UNetUpBlock(nn.Module):
"""
Upsampling block for U-Net architecture.
Implements an up-sampling block in a U-Net architecture using JAX.
This block applies bilinear upsampling to the input and concatenates it with a skip connection.
It then applies a series of residual blocks to the concatenated tensor.
This module consists of a sequence of residual blocks and a bilinear up-sampling operation to increase the spatial dimensions. It's used in the decoding pathway of a U-Net to progressively recover spatial resolution and detail in the output image.
Attributes:
width (int): The number of channels in the residual blocks.
block_depth (int): The number of residual blocks in the up block.
width (int): The number of output channels for the convolutional layers within the block.
block_depth (int): The number of residual blocks to include in the up-sampling block.
Methods:
setup(): Initializes the sequence of residual blocks.
__call__(x: jnp.ndarray, skip: jnp.ndarray): Processes the input tensor and a skip connection from the encoding pathway through the up-sampling block and returns the result.
"""
width: int
block_depth: int
Expand All @@ -168,19 +140,21 @@ def __call__(self,

class UNet(nn.Module):
"""
U-Net architecture for image generation.
Implements the U-Net architecture for image processing tasks using JAX.
This class implements a U-Net model, which is commonly used for image-to-image translation tasks.
It consists of an encoder (downsampling path), a bottleneck, and a decoder (upsampling path)
with skip connections.
This model is widely used for tasks such as image segmentation, denoising, and super-resolution. It features a symmetric encoder-decoder structure with skip connections between corresponding layers in the encoder and decoder to preserve spatial information.
Attributes:
image_size (Tuple[int, int]): The size of the input images.
widths (List[int]): The number of channels in each block of the U-Net.
block_depth (int): The depth of each block in the U-Net.
embed_dims (int): The number of dimensions for the sinusoidal embedding.
embed_min_freq (float): The minimum frequency for the sinusoidal embedding.
embed_max_freq (float): The maximum frequency for the sinusoidal embedding.
image_size (Tuple[int, int]): The size of the input images (height, width).
widths (List[int]): The number of output channels for each block in the U-Net architecture.
block_depth (int): The number of residual blocks in each down-sampling and up-sampling block.
embed_dims (int): The dimensionality of the sinusoidal embeddings for encoding positional information.
embed_min_freq (float): The minimum frequency for the sinusoidal embeddings.
embed_max_freq (float): The maximum frequency for the sinusoidal embeddings.
Methods:
setup(): Initializes the U-Net architecture including the sinusoidal embedding layer, down-sampling blocks, residual blocks, and up-sampling blocks.
__call__(noisy_images: jnp.ndarray, noise_variances: jnp.ndarray): Processes noisy images and their associated noise variances through the U-Net and returns the denoised images.
"""
image_size: Tuple[int, int]
widths: List[int]
Expand Down Expand Up @@ -224,7 +198,75 @@ def __call__(self,


class DiffusionModel(nn.Module):

"""
Implements a diffusion model for image generation using JAX.
Diffusion models are a class of generative models that learn to denoise images through a gradual process of adding and removing noise. This implementation uses a U-Net architecture for the denoising process and supports custom diffusion schedules.
Attributes:
image_size (int): The size of the generated images.
widths (List[int]): The number of output channels for each block in the U-Net architecture.
block_depth (int): The number of residual blocks in each down-sampling and up-sampling block.
min_signal_rate (float): The minimum signal rate in the diffusion process.
max_signal_rate (float): The maximum signal rate in the diffusion process.
embed_dims (int): The dimensionality of the sinusoidal embeddings for encoding noise levels.
embed_min_freq (float): The minimum frequency for the sinusoidal embeddings.
embed_max_freq (float): The maximum frequency for the sinusoidal embeddings.
Methods:
setup(): Initializes the diffusion model including the U-Net architecture.
diffusion_schedule(diffusion_times: jnp.ndarray): Computes the noise and signal rates for given diffusion times.
denoise(noisy_images: jnp.ndarray, noise_rates: jnp.ndarray, signal_rates: jnp.ndarray): Denoises images given their noise and signal rates.
__call__(images: jnp.ndarray): Applies the diffusion process to a batch of images.
reverse_diffusion(initial_noise: jnp.ndarray, diffusion_steps: int): Reverses the diffusion process to generate images from noise.
generate(num_images: int, diffusion_steps: int): Generates images by reversing the diffusion process from random noise.
Example usage:
```
import jax
import jax.numpy as jnp
from nanodl import ArrayDataset, DataLoader
from nanodl import DiffusionModel, DiffusionDataParallelTrainer
image_size = 32
block_depth = 2
batch_size = 8
widths = [32, 64, 128]
key = jax.random.PRNGKey(0)
input_shape = (101, image_size, image_size, 3)
images = jax.random.normal(key, input_shape)
# Use your own images
dataset = ArrayDataset(images)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
drop_last=False)
# Create diffusion model
diffusion_model = DiffusionModel(image_size, widths, block_depth)
params = diffusion_model.init(key, images)
pred_noises, pred_images = diffusion_model.apply(params, images)
print(pred_noises.shape, pred_images.shape)
# Training on your data
# Note: saved params are often different from training weights, use the saved params for generation
trainer = DiffusionDataParallelTrainer(diffusion_model,
input_shape=images.shape,
weights_filename='params.pkl',
learning_rate=1e-4)
trainer.train(dataloader, 10, dataloader)
print(trainer.evaluate(dataloader))
# Generate some samples
params = trainer.load_params('params.pkl')
generated_images = diffusion_model.apply({'params': params},
num_images=5,
diffusion_steps=5,
method=diffusion_model.generate)
print(generated_images.shape)
```
"""
image_size: int
widths: List[int]
block_depth: int
Expand Down Expand Up @@ -446,10 +488,8 @@ def get_ema_weights(self, params, ema=0.999):
new_params = {}
for key, value in params.items():
if isinstance(value, dict):
# Recursively apply the function to nested dictionaries
new_params[key] = self.get_ema_weights(value, ema)
else:
# Multiply the value by ema multiplier
new_params[key] = ema * value + (1 - ema) * value
return new_params

Expand Down
Loading

0 comments on commit 97252ba

Please sign in to comment.