Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: Maximum allowed size exceeded when only one value #74

Open
tbrugere opened this issue Oct 27, 2023 · 0 comments
Open

ValueError: Maximum allowed size exceeded when only one value #74

tbrugere opened this issue Oct 27, 2023 · 0 comments

Comments

@tbrugere
Copy link

In the degenerated case when there is only one sample in a and b, and that is the same sample, Samplesloss will fail with a

ValueError: Maximum allowed size exceeded

Minimum example:

In: from geomloss import SamplesLoss
In: loss = SamplesLoss(loss="sinkhorn", p=2, blur=.05) 
In: a = b = Tensor([[1, 2]])
In: loss(a, b)
Out: 
sinkhorn_divergence.py:147: RuntimeWarning: divide by zero encountered in log                              
  p * np.log(diameter), p * np.log(blur), p * np.log(scaling)                                                                                                                      
---------------------------------------------------------------------------                                                                                                        
ValueError                                Traceback (most recent call last)                                                                                                        
Cell In[7], line 1                                                                                                                                                                 
----> 1 loss(a, b)                                                                                                                                                                 
                                                                                                                                                                                   
File ~/.conda/envs/default/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)                                               
   1496 # If we don't have any hooks, we want to skip the rest of the logic in                                                                                                     
   1497 # this function, and just call forward.                                                                                                                                    
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks                                                                 
   1499         or _global_backward_pre_hooks or _global_backward_hooks                                                                                                            
   1500         or _global_forward_hooks or _global_forward_pre_hooks):                                                                                                            
-> 1501     return forward_call(*args, **kwargs)                                                                                                                                   
   1502 # Do not call functions when jit is used                                                                                                                                   
   1503 full_backward_hooks, non_full_backward_hooks = [], []                                                                                                                      
                                                                                                                                                                                   
File ~/.conda/envs/default/lib/python3.11/site-packages/geomloss/samples_loss.py:265, in SamplesLoss.forward(self, *args)                                                          
    262     α, x, β, y = α.unsqueeze(0), x.unsqueeze(0), β.unsqueeze(0), y.unsqueeze(0)                                                                                            
    264 # Run --------------------------------------------------------------------------------                                                                                     
--> 265 values = routines[self.loss][backend](                                                                                                                                     
    266     α,                                                                                                                                                                     
    267     x,                                                                                                                                                                     
    268     β,                                                                                                                                                                     
    269     y,                                                                                                                                                                     
    270     p=self.p,                                                                                                                                                              
    271     blur=self.blur,
    272     reach=self.reach,
    273     diameter=self.diameter,
    274     scaling=self.scaling,
    275     truncate=self.truncate,
    276     cost=self.cost,
    277     kernel=self.kernel,
    278     cluster_scale=self.cluster_scale,
    279     debias=self.debias,
    280     potentials=self.potentials,
    281     labels_x=l_x,
    282     labels_y=l_y,
    283     verbose=self.verbose,
    284 )
    286 # Make sure that the output has the correct shape ------------------------------------
    287 if (
    288     self.potentials
    289 ):  # Return some dual potentials (= test functions) sampled on the input measures

File ~/.conda/envs/default/lib/python3.11/site-packages/geomloss/sinkhorn_samples.py:191, in sinkhorn_tensorized(a, x, b, y, p, blur, reach, diameter, scaling, cost, debias, poten
tials, **kwargs)
    186 C_yy = cost(y, y.detach()) if debias else None  # (B,M,M) torch Tensor
    188 # Compute the relevant values of the diameter of the configuration,
    189 # target temperature epsilon, temperature schedule across itereations
    190 # and strength of the marginal constraints:
--> 191 diameter, eps, eps_list, rho = scaling_parameters(
    192     x, y, p, blur, reach, diameter, scaling
    193 )
    195 # Use an optimal transport solver to retrieve the dual potentials:
    196 f_aa, g_bb, g_ab, f_ba = sinkhorn_loop(
    197     softmin_tensorized,
    198     log_weights(a),
   (...)
    206     debias=debias,
    207 )

File ~/.conda/envs/default/lib/python3.11/site-packages/geomloss/sinkhorn_divergence.py:163, in scaling_parameters(x, y, p, blur, reach, diameter, scaling)
    161 eps = blur ** p
    162 rho = None if reach is None else reach ** p
--> 163 eps_list = epsilon_schedule(p, diameter, blur, scaling)
    164 return diameter, eps, eps_list, rho

File ~/.conda/envs/default/lib/python3.11/site-packages/geomloss/sinkhorn_divergence.py:146, in epsilon_schedule(p, diameter, blur, scaling)
    116 def epsilon_schedule(p, diameter, blur, scaling):
    117     r"""Creates a list of values for the temperature "epsilon" across Sinkhorn iterations.
    118 
    119     We use an aggressive strategy with an exponential cooling
   (...)
    140         list of float: list of values for the temperature epsilon.
    141     """
    142     eps_list = (
    143         [diameter ** p]
    144         + [
    145             np.exp(e)
--> 146             for e in np.arange(
    147                 p * np.log(diameter), p * np.log(blur), p * np.log(scaling)
    148             )
    149         ]
    150         + [blur ** p]
    151     )
    152     return eps_list

ValueError: Maximum allowed size exceeded

This could be solved by checking for a 0 diameter, and returning 0 in that case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant