Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -797,8 +797,8 @@ let rec trans_sizedtype_decl declc tr name st =
| SVector (mem_pattern, s) ->
let fn =
match (declc.transform_action, tr) with
| Constrain, Transformation.Simplex ->
Internal_fun.FnValidateSizeSimplex
| Constrain, (Transformation.Simplex | SumToZero) ->
Internal_fun.FnValidateSizePositive
| Constrain, UnitVector -> FnValidateSizeUnitVector
| _ -> FnValidateSize in
let l, s = grab_size fn n s in
Expand All @@ -813,8 +813,17 @@ let rec trans_sizedtype_decl declc tr name st =
let l, s = grab_size FnValidateSize n s in
(l, SizedType.SComplexVector s)
| SMatrix (mem_pattern, r, c) ->
let l1, r = grab_size FnValidateSize n r in
let l2, c = grab_size FnValidateSize (n + 1) c in
let fn1, fn2 =
match (declc.transform_action, tr) with
| Constrain, Transformation.SumToZero ->
( Internal_fun.FnValidateSizePositive
, Internal_fun.FnValidateSizePositive )
| Constrain, StochasticColumn ->
(FnValidateSizePositive, FnValidateSize)
| Constrain, StochasticRow -> (FnValidateSize, FnValidateSizePositive)
| _ -> (FnValidateSize, FnValidateSize) in
let l1, r = grab_size fn1 n r in
let l2, c = grab_size fn2 (n + 1) c in
let cf_cov =
match (declc.transform_action, tr) with
| Constrain, CholeskyCov ->
Expand Down
4 changes: 2 additions & 2 deletions src/middle/Internal_fun.ml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type 'expr t =
; mem_pattern: Mem_pattern.t }
| FnWriteParam of {unconstrain_opt: 'expr Transformation.t option; var: 'expr}
| FnValidateSize
| FnValidateSizeSimplex
| FnValidateSizePositive
| FnValidateSizeUnitVector
| FnCheck of {trans: 'expr Transformation.t; var_name: string; var: 'expr}
| FnPrint
Expand Down Expand Up @@ -54,7 +54,7 @@ let pp (pp_expr : 'a Fmt.t) ppf internal =
*)
let can_side_effect = function
| FnReadParam _ | FnReadData | FnReadDeserializer | FnWriteParam _
|FnValidateSize | FnValidateSizeSimplex | FnValidateSizeUnitVector
|FnValidateSize | FnValidateSizePositive | FnValidateSizeUnitVector
|FnReadWriteEventsOpenCL _ ->
true
| FnLength | FnMakeArray | FnMakeRowVec | FnNegInf | FnPrint | FnReject
Expand Down
2 changes: 1 addition & 1 deletion src/stan_math_backend/Lower_stmt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ let check_to_string = function
let math_fn_translations = function
| Internal_fun.FnValidateSize ->
Some "stan::math::validate_non_negative_index"
| FnValidateSizeSimplex -> Some "stan::math::validate_positive_index"
| FnValidateSizePositive -> Some "stan::math::validate_positive_index"
| FnValidateSizeUnitVector -> Some "stan::math::validate_unit_vector_index"
| FnReadWriteEventsOpenCL x -> Some (x ^ ".wait_for_read_write_events")
| _ -> None
Expand Down
423 changes: 256 additions & 167 deletions test/integration/good/code-gen/cpp.expected

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions test/integration/good/code-gen/mir.expected
Original file line number Diff line number Diff line change
Expand Up @@ -6256,7 +6256,7 @@
((pattern (Var N)) (meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta <opaque>))
((pattern
(NRFunApp (CompilerInternal FnValidateSizeSimplex)
(NRFunApp (CompilerInternal FnValidateSizePositive)
(((pattern (Lit Str p_simplex))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Str N)) (meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
Expand All @@ -6270,7 +6270,7 @@
((pattern (Var N)) (meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta <opaque>))
((pattern
(NRFunApp (CompilerInternal FnValidateSizeSimplex)
(NRFunApp (CompilerInternal FnValidateSizePositive)
(((pattern (Lit Str p_1d_simplex))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Str N)) (meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
Expand Down Expand Up @@ -6298,7 +6298,7 @@
((pattern (Var K)) (meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta <opaque>))
((pattern
(NRFunApp (CompilerInternal FnValidateSizeSimplex)
(NRFunApp (CompilerInternal FnValidateSizePositive)
(((pattern (Lit Str p_3d_simplex))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Str N)) (meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
Expand Down
8 changes: 5 additions & 3 deletions test/integration/good/code-gen/stochastic_matrices.stan
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@ data {
}

transformed data {
int ROWS = 10;
int COLS = 10;
column_stochastic_matrix[10, 10] td_csm = d_csm;
row_stochastic_matrix[10, 10] td_rsm = d_rsm;
array[2, 2] row_stochastic_matrix[10, 10] td_arsm;
}
parameters {
column_stochastic_matrix[10, 10] p_csm;
row_stochastic_matrix[10, 10] p_rsm;
array[2, 2] row_stochastic_matrix[10, 10] p_arsm;
column_stochastic_matrix[ROWS, COLS] p_csm;
row_stochastic_matrix[ROWS, COLS] p_rsm;
array[2, 2] row_stochastic_matrix[ROWS, COLS] p_arsm;
}

transformed parameters {
Expand Down
10 changes: 6 additions & 4 deletions test/integration/good/code-gen/sum_to_zero.stan
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@ data {
}

transformed data {
int v = 10;
int N = 4;
sum_to_zero_vector[10] td_stzv = d_stzv;
array[2, 3] sum_to_zero_vector[10] td_astzv;
sum_to_zero_matrix[4,5] td_stzm = d_stzm;
array[2, 3] sum_to_zero_matrix[4,5] td_astzm;
}
parameters {
sum_to_zero_vector[10] p_stzv;
array[2, 3] sum_to_zero_vector[10] p_astzv;
sum_to_zero_matrix[4,5] p_stzm;
array[2, 3] sum_to_zero_matrix[4,5] p_astzm;
sum_to_zero_vector[v] p_stzv;
array[2, 3] sum_to_zero_vector[v] p_astzv;
sum_to_zero_matrix[N,N+1] p_stzm;
array[2, 3] sum_to_zero_matrix[N,N+1] p_astzm;
}

transformed parameters {
Expand Down
6 changes: 3 additions & 3 deletions test/integration/good/code-gen/transformed_mir.expected
Original file line number Diff line number Diff line change
Expand Up @@ -9309,7 +9309,7 @@
((pattern (Var N)) (meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta <opaque>))
((pattern
(NRFunApp (CompilerInternal FnValidateSizeSimplex)
(NRFunApp (CompilerInternal FnValidateSizePositive)
(((pattern (Lit Str p_simplex))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Str N)) (meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
Expand All @@ -9323,7 +9323,7 @@
((pattern (Var N)) (meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta <opaque>))
((pattern
(NRFunApp (CompilerInternal FnValidateSizeSimplex)
(NRFunApp (CompilerInternal FnValidateSizePositive)
(((pattern (Lit Str p_1d_simplex))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Str N)) (meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
Expand Down Expand Up @@ -9351,7 +9351,7 @@
((pattern (Var K)) (meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta <opaque>))
((pattern
(NRFunApp (CompilerInternal FnValidateSizeSimplex)
(NRFunApp (CompilerInternal FnValidateSizePositive)
(((pattern (Lit Str p_3d_simplex))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Str N)) (meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
Expand Down