Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in get_domain_of_finite_discrete_rv of Categorical #331

Open
ricardoV94 opened this issue Apr 12, 2024 · 0 comments
Open

Bug in get_domain_of_finite_discrete_rv of Categorical #331

ricardoV94 opened this issue Apr 12, 2024 · 0 comments
Labels
bug Something isn't working good first issue Good for newcomers marginalization

Comments

@ricardoV94
Copy link
Member

Reported by @jessegrabowski

with MarginalModel(coords=coords) as m:
    x_data = pm.ConstantData('x', df.x, dims=['obs_idx'])
    y_data = pm.ConstantData('y', df.y, dims=['obs_idx'])

    X = pt.concatenate([pt.ones_like(x_data[:, None]), x_data[:, None], x_data[:, None] ** 2], axis=-1)

    mu = pm.Normal('mu', dims=['group'])
    beta_p = pm.Normal('beta_p', dims=['params', 'group'])
    logit_p_group = X @ beta_p
    group_idx = pm.Categorical('group_idx', logit_p=logit_p_group, dims=['obs_idx'])
    sigma = pm.Exponential('sigma', 1)

    mu = pt.switch(pt.lt(group_idx, 1), 
                   mu_trend,
                   pt.switch(pt.lt(group_idx, 2), 
                             p_x[:, 0], 
                             p_x[:, 1])
                  )
    
    y_hat = pm.Normal('y_hat', 
                      mu = mu,
                      sigma = sigma,
                      observed=y_data,
                      dims=['obs_idx'])

m.marginalize(["group_idx"])
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pymc_experimental/model/marginal_model.py:655, in get_domain_of_finite_discrete_rv(rv)
    653 elif isinstance(op, Categorical):
    654     p_param = rv.owner.inputs[3]
--> 655     return tuple(range(pt.get_vector_length(p_param)))
    656 elif isinstance(op, DiscreteUniform):
    657     lower, upper = constant_fold(rv.owner.inputs[3:])

File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/tensor/__init__.py:82, in get_vector_length(v)
     79 v = as_tensor_variable(v)
     81 if v.type.ndim != 1:
---> 82     raise TypeError(f"Argument must be a vector; got {v.type}")
     84 static_shape: Optional[int] = v.type.shape[0]
     85 if static_shape is not None:

TypeError: Argument must be a vector; got Matrix(float64, shape=(256, 3))

Instead of trying to get the vector length of p_param (which assumse p is always a vector), we should be constant folding p_param.shape[-1].

@ricardoV94 ricardoV94 added bug Something isn't working good first issue Good for newcomers marginalization labels Apr 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers marginalization
Projects
None yet
Development

No branches or pull requests

1 participant