Skip to content

Commit

Permalink
[fix] linting
Browse files Browse the repository at this point in the history
  • Loading branch information
fracape committed Feb 2, 2024
1 parent e610731 commit 300ec44
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 40 deletions.
7 changes: 6 additions & 1 deletion compressai/entropy_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from .entropy_models import EntropyBottleneck, EntropyModel, GaussianConditional, GaussianMixtureConditional
from .entropy_models import (
EntropyBottleneck,
EntropyModel,
GaussianConditional,
GaussianMixtureConditional,
)

__all__ = [
"EntropyModel",
Expand Down
137 changes: 98 additions & 39 deletions compressai/entropy_models/entropy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,26 +709,34 @@ def build_indexes(self, scales: Tensor) -> Tensor:
indexes -= (scales <= s).int()
return indexes


class GaussianMixtureConditional(GaussianConditional):
def __init__(self,
K=3,
scale_table: Optional[Union[List, Tuple]] = None,
*args: Any,
**kwargs: Any):
def __init__(
self,
K=3,
scale_table: Optional[Union[List, Tuple]] = None,
*args: Any,
**kwargs: Any,
):
super().__init__(scale_table, *args, **kwargs)

self.K = K

def _likelihood(
self, inputs: Tensor, scales: Tensor, means: Tensor, weights: Tensor
self, inputs: Tensor, scales: Tensor, means: Tensor, weights: Tensor
) -> Tensor:
likelihood = torch.zeros_like(inputs)
M = inputs.size(1)

for k in range(self.K):
likelihood += super()._likelihood(
inputs, scales[:, M*k:M*(k+1)], means[:, M*k:M*(k+1)]
) * weights[:, M*k:M*(k+1)]
likelihood += (
super()._likelihood(
inputs,
scales[:, M * k : M * (k + 1)],
means[:, M * k : M * (k + 1)],
)
* weights[:, M * k : M * (k + 1)]
)

return likelihood

Expand All @@ -742,7 +750,9 @@ def forward(
) -> Tuple[Tensor, Tensor]:
if training is None:
training = self.training
outputs = self.quantize(inputs, "noise" if training else "dequantize", means=None)
outputs = self.quantize(
inputs, "noise" if training else "dequantize", means=None
)
likelihood = self._likelihood(outputs, scales, means, weights)
if self.use_likelihood_bound:
likelihood = self.likelihood_lower_bound(likelihood)
Expand All @@ -762,65 +772,111 @@ def _build_cdf(self, scales, means, weights, abs_max):
means_ = means.unsqueeze(-1).expand(-1, -1, num_samples)
weights_ = weights.unsqueeze(-1).expand(-1, -1, num_samples)

samples = torch.arange(num_samples).to(device).unsqueeze(0).expand(num_latents, -1)
samples = (
torch.arange(num_samples).to(device).unsqueeze(0).expand(num_latents, -1)
)

pmf = torch.zeros_like(samples).float()
for k in range(self.K):
pmf += (0.5 * (1 + torch.erf((samples + 0.5 - means_[k]) / ((scales_[k] + TINY) * 2 ** 0.5))) - \
0.5 * (1 + torch.erf((samples - 0.5 - means_[k]) / ((scales_[k] + TINY) * 2 ** 0.5)))) * weights_[k]

cdf_limit = 2 ** self.entropy_coder_precision - 1
pmf += (
0.5
* (
1
+ torch.erf(
(samples + 0.5 - means_[k]) / ((scales_[k] + TINY) * 2**0.5)
)
)
- 0.5
* (
1
+ torch.erf(
(samples - 0.5 - means_[k]) / ((scales_[k] + TINY) * 2**0.5)
)
)
) * weights_[k]

cdf_limit = 2**self.entropy_coder_precision - 1
pmf = torch.clamp(pmf, min=1.0 / cdf_limit, max=1.0)
pmf_scaled = torch.round(pmf * cdf_limit)
pmf_sum = torch.sum(pmf_scaled, 1, keepdim=True).expand(-1, num_samples)

cdf = F.pad(torch.cumsum(pmf_scaled * cdf_limit / pmf_sum, 1).int(), (1, 0), "constant", 0)
cdf = F.pad(
torch.cumsum(pmf_scaled * cdf_limit / pmf_sum, 1).int(),
(1, 0),
"constant",
0,
)
pmf_quantized = torch.diff(cdf, dim=1)

# We can't have zeros in PMF because rANS won't be able to encode it.
# Try to fix this by "stealing" probability from some unlikely symbols.

pmf_zero_count = num_samples - torch.count_nonzero(pmf_quantized, dim=1)

_, pmf_first_stealable_indices = torch.min(torch.where(
pmf_quantized > pmf_zero_count.unsqueeze(-1).expand(-1, num_samples),
pmf_quantized,
torch.tensor(cdf_limit + 1).int()
), dim=1)
_, pmf_first_stealable_indices = torch.min(
torch.where(
pmf_quantized > pmf_zero_count.unsqueeze(-1).expand(-1, num_samples),
pmf_quantized,
torch.tensor(cdf_limit + 1).int(),
),
dim=1,
)

pmf_real_zero_indices = (pmf_quantized == 0).nonzero().transpose(0, 1)
pmf_quantized[pmf_real_zero_indices[0], pmf_real_zero_indices[1]] += 1

pmf_real_steal_indices = torch.cat((torch.arange(num_latents).to(device).unsqueeze(-1),
pmf_first_stealable_indices.unsqueeze(-1)),
dim=1).transpose(0, 1)
pmf_quantized[pmf_real_steal_indices[0], pmf_real_steal_indices[1]] -= pmf_zero_count
pmf_real_steal_indices = torch.cat(
(
torch.arange(num_latents).to(device).unsqueeze(-1),
pmf_first_stealable_indices.unsqueeze(-1),
),
dim=1,
).transpose(0, 1)
pmf_quantized[
pmf_real_steal_indices[0], pmf_real_steal_indices[1]
] -= pmf_zero_count

cdf = F.pad(torch.cumsum(pmf_quantized, 1).int(), (1, 0), "constant", 0)
cdf = F.pad(cdf, (0, 1), "constant", cdf_limit + 1)

return cdf


def reshape_entropy_parameters(self, scales, means, weights, nonzero):
reshape_size = (scales.size(0), self.K, scales.size(1) // self.K, -1)

scales = scales.reshape(*reshape_size)[:, :, nonzero].permute(1, 0, 2, 3).reshape(self.K, -1)
means = means.reshape(*reshape_size)[:, :, nonzero].permute(1, 0, 2, 3).reshape(self.K, -1)
weights = weights.reshape(*reshape_size)[:, :, nonzero].permute(1, 0, 2, 3).reshape(self.K, -1)
scales = (
scales.reshape(*reshape_size)[:, :, nonzero]
.permute(1, 0, 2, 3)
.reshape(self.K, -1)
)
means = (
means.reshape(*reshape_size)[:, :, nonzero]
.permute(1, 0, 2, 3)
.reshape(self.K, -1)
)
weights = (
weights.reshape(*reshape_size)[:, :, nonzero]
.permute(1, 0, 2, 3)
.reshape(self.K, -1)
)
return scales, means, weights


def compress(self, y, scales, means, weights):
abs_max = max(torch.abs(y.max()).int().item(), torch.abs(y.min()).int().item()) + 1
abs_max = (
max(torch.abs(y.max()).int().item(), torch.abs(y.min()).int().item()) + 1
)
abs_max = 1 if abs_max < 1 else abs_max

y_quantized = torch.round(y)
zero_bitmap = torch.where(torch.sum(torch.abs(y_quantized), (3, 2)).squeeze(0) == 0, 0, 1)
zero_bitmap = torch.where(
torch.sum(torch.abs(y_quantized), (3, 2)).squeeze(0) == 0, 0, 1
)

nonzero = torch.nonzero(zero_bitmap).flatten().tolist()
symbols = y_quantized[:, nonzero] + abs_max
cdf = self._build_cdf(*self.reshape_entropy_parameters(scales, means, weights, nonzero), abs_max)
cdf = self._build_cdf(
*self.reshape_entropy_parameters(scales, means, weights, nonzero), abs_max
)

num_latents = cdf.size(0)

Expand All @@ -829,15 +885,16 @@ def compress(self, y, scales, means, weights):
torch.arange(num_latents).int().tolist(),
cdf.cpu().to(torch.int32),
torch.tensor(cdf.size(1)).repeat(num_latents).int().tolist(),
torch.tensor(0).repeat(num_latents).int().tolist()
torch.tensor(0).repeat(num_latents).int().tolist(),
)

return (rv, abs_max, zero_bitmap), y_quantized


def decompress(self, strings, abs_max, zero_bitmap, scales, means, weights):
nonzero = torch.nonzero(zero_bitmap).flatten().tolist()
cdf = self._build_cdf(*self.reshape_entropy_parameters(scales, means, weights, nonzero), abs_max)
cdf = self._build_cdf(
*self.reshape_entropy_parameters(scales, means, weights, nonzero), abs_max
)

num_latents = cdf.size(0)

Expand All @@ -846,13 +903,15 @@ def decompress(self, strings, abs_max, zero_bitmap, scales, means, weights):
torch.arange(num_latents).int().tolist(),
cdf.cpu().to(torch.int32),
torch.tensor(cdf.size(1)).repeat(num_latents).int().tolist(),
torch.tensor(0).repeat(num_latents).int().tolist()
torch.tensor(0).repeat(num_latents).int().tolist(),
)

symbols = torch.tensor(values) - abs_max
symbols = symbols.reshape(scales.size(0), -1, scales.size(2), scales.size(3))

y_hat = torch.zeros(scales.size(0), zero_bitmap.size(0), scales.size(2), scales.size(3))
y_hat = torch.zeros(
scales.size(0), zero_bitmap.size(0), scales.size(2), scales.size(3)
)
y_hat[:, nonzero] = symbols.float()

return y_hat
return y_hat

0 comments on commit 300ec44

Please sign in to comment.