From 16b98370ae20663a054d87ff8966f824722a324e Mon Sep 17 00:00:00 2001 From: atheendre130505 Date: Fri, 31 Oct 2025 11:30:49 +0530 Subject: [PATCH 1/2] [BUG] Replace deprecated batched_dot with pt.sum in KroneckerNormal - Fixes Issue #7878 - Replace pt.batched_dot(sqrt_quad.T, sqrt_quad.T) with pt.sum(sqrt_quad.T ** 2, axis=-1) - Computes squared norm per sample using modern PyTensor operations - Eliminates deprecation warnings and ensures future compatibility --- pymc/distributions/multivariate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index f76a98546e..46b9165645 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2124,8 +2124,8 @@ def logp(value, rng, size, mu, sigma, *covs): sqrt_quad = sqrt_quad / pt.sqrt(eigs[:, None]) logdet = pt.sum(pt.log(eigs)) - # Square each sample - quad = pt.batched_dot(sqrt_quad.T, sqrt_quad.T) + # Square each sample - compute squared norm for each sample + quad = pt.sum(sqrt_quad.T ** 2, axis=-1) if onedim: quad = quad[0] From da5a75f0b7d8c7c6de0d19b3e59e618024b7be63 Mon Sep 17 00:00:00 2001 From: atheendre130505 Date: Sun, 2 Nov 2025 13:16:05 +0530 Subject: [PATCH 2/2] Format code with ruff --- pymc/distributions/multivariate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 46b9165645..9435b40fa7 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2125,7 +2125,7 @@ def logp(value, rng, size, mu, sigma, *covs): logdet = pt.sum(pt.log(eigs)) # Square each sample - compute squared norm for each sample - quad = pt.sum(sqrt_quad.T ** 2, axis=-1) + quad = pt.sum(sqrt_quad.T**2, axis=-1) if onedim: quad = quad[0]