This repository has been archived by the owner on Jun 9, 2020. It is now read-only.
forked from stan-dev/stanc3
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Semantic_check.ml
1773 lines (1625 loc) · 65 KB
/
Semantic_check.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
(** Semantic validation of AST*)
(* Idea: check many of things related to identifiers that are hard to check
during parsing and are in fact irrelevant for building up the parse tree *)
open Core_kernel
open Symbol_table
open Middle
open Ast
open Errors
module Validate = Common.Validation.Make (Semantic_error)
(* There is a semantic checking function for each AST node that calls
the checking functions for its children left to right. *)
(* Top level function semantic_check_program declares the AST while operating
on (1) a global symbol table vm, and (2) structure of type context_flags_record
to communicate information down the AST. *)
let check_of_compatible_return_type rt1 srt2 =
UnsizedType.(
match (rt1, srt2) with
| Void, NoReturnType
|Void, Incomplete Void
|Void, Complete Void
|Void, AnyReturnType ->
true
| ReturnType UReal, Complete (ReturnType UInt) -> true
| ReturnType rt1, Complete (ReturnType rt2) -> rt1 = rt2
| ReturnType _, AnyReturnType -> true
| _ -> false)
(** Origin blocks, to keep track of where variables are declared *)
type originblock =
| MathLibrary
| Functions
| Data
| TData
| Param
| TParam
| Model
| GQuant
(** Print all the signatures of a stan math operator, for the purposes of error messages. *)
let check_that_all_functions_have_definition = ref true
let model_name = ref ""
let vm = Symbol_table.initialize ()
(* Record structure holding flags and other markers about context to be
used for error reporting. *)
type context_flags_record =
{ current_block: originblock
; in_fun_def: bool
; in_returning_fun_def: bool
; in_rng_fun_def: bool
; in_lp_fun_def: bool
; loop_depth: int }
(* Some helper functions *)
let dup_exists l =
match List.find_a_dup ~compare:String.compare l with
| Some _ -> true
| None -> false
let type_of_expr_typed ue = ue.emeta.type_
let rec unsizedtype_contains_int ut =
match ut with
| UnsizedType.UInt -> true
| UArray ut -> unsizedtype_contains_int ut
| _ -> false
let rec unsizedtype_of_sizedtype = function
| SizedType.SInt -> UnsizedType.UInt
| SReal -> UReal
| SVector _ -> UVector
| SRowVector _ -> URowVector
| SMatrix (_, _) -> UMatrix
| SArray (st, _) -> UArray (unsizedtype_of_sizedtype st)
let rec lub_ad_type = function
| [] -> UnsizedType.DataOnly
| x :: xs ->
let y = lub_ad_type xs in
if UnsizedType.compare_autodifftype x y < 0 then y else x
let calculate_autodifftype at ut =
match at with
| (Param | TParam | Model | Functions) when not (unsizedtype_contains_int ut)
->
UnsizedType.AutoDiffable
| _ -> DataOnly
let has_int_type ue = ue.emeta.type_ = UInt
let has_int_array_type ue = ue.emeta.type_ = UArray UInt
let has_int_or_real_type ue =
match ue.emeta.type_ with UInt | UReal -> true | _ -> false
let probability_distribution_name_variants id =
let name = id.name in
let open String in
List.map
~f:(fun n -> {name= n; id_loc= id.id_loc})
( if name = "multiply_log" || name = "binomial_coefficient_log" then [name]
else if is_suffix ~suffix:"_lpmf" name then
[name; drop_suffix name 5 ^ "_lpdf"; drop_suffix name 5 ^ "_log"]
else if is_suffix ~suffix:"_lpdf" name then
[name; drop_suffix name 5 ^ "_lpmf"; drop_suffix name 5 ^ "_log"]
else if is_suffix ~suffix:"_lcdf" name then
[name; drop_suffix name 5 ^ "_cdf_log"]
else if is_suffix ~suffix:"_lccdf" name then
[name; drop_suffix name 6 ^ "_ccdf_log"]
else if is_suffix ~suffix:"_cdf_log" name then
[name; drop_suffix name 8 ^ "_lcdf"]
else if is_suffix ~suffix:"_ccdf_log" name then
[name; drop_suffix name 9 ^ "_lccdf"]
else if is_suffix ~suffix:"_log" name then
[name; drop_suffix name 4 ^ "_lpmf"; drop_suffix name 4 ^ "_lpdf"]
else [name] )
let lub_rt loc rt1 rt2 =
match (rt1, rt2) with
| UnsizedType.ReturnType UReal, UnsizedType.ReturnType UInt
|ReturnType UInt, ReturnType UReal ->
Validate.ok (UnsizedType.ReturnType UReal)
| _, _ when rt1 = rt2 -> Validate.ok rt2
| _ -> Semantic_error.mismatched_return_types loc rt1 rt2 |> Validate.error
let check_fresh_variable_basic id is_nullary_function =
Validate.(
(* No shadowing! *)
(* For some strange reason, Stan allows user declared identifiers that are
not of nullary function types to clash with nullary library functions.
No other name clashes are tolerated. Here's the logic to
achieve that. *)
if
Stan_math_signatures.is_stan_math_function_name id.name
&& ( is_nullary_function
|| Stan_math_signatures.stan_math_returntype id.name [] = None )
then Semantic_error.ident_is_stanmath_name id.id_loc id.name |> error
else
match Symbol_table.look vm id.name with
| Some _ -> Semantic_error.ident_in_use id.id_loc id.name |> error
| None -> ok ())
let check_fresh_variable id is_nullary_function =
List.fold ~init:(Validate.ok ())
~f:(fun v0 name ->
check_fresh_variable_basic name is_nullary_function
|> Validate.apply_const v0 )
(probability_distribution_name_variants id)
(** Least upper bound of expression autodiff types *)
let lub_ad_e exprs =
exprs |> List.map ~f:(fun x -> x.emeta.ad_level) |> lub_ad_type
(* == SEMANTIC CHECK OF PROGRAM ELEMENTS ==================================== *)
(* Probably nothing to do here *)
let semantic_check_assignmentoperator op = Validate.ok op
(* Probably nothing to do here *)
let semantic_check_autodifftype at = Validate.ok at
(* Probably nothing to do here *)
let rec semantic_check_unsizedtype : UnsizedType.t -> unit Validate.t =
function
| UFun (l, rt) ->
(* fold over argument types accumulating errors with initial state
given by validating the return type *)
List.fold
~f:(fun v0 (at, ut) ->
Validate.(
apply_const
(apply_const v0 (semantic_check_autodifftype at))
(semantic_check_unsizedtype ut)) )
~init:(semantic_check_returntype rt)
l
| UArray ut -> semantic_check_unsizedtype ut
| _ -> Validate.ok ()
and semantic_check_returntype : UnsizedType.returntype -> unit Validate.t =
function
| Void -> Validate.ok ()
| ReturnType ut -> semantic_check_unsizedtype ut
(* -- Indentifiers ---------------------------------------------------------- *)
let reserved_keywords =
[ "true"; "false"; "repeat"; "until"; "then"; "var"; "fvar"; "STAN_MAJOR"
; "STAN_MINOR"; "STAN_PATCH"; "STAN_MATH_MAJOR"; "STAN_MATH_MINOR"
; "STAN_MATH_PATCH"; "alignas"; "alignof"; "and"; "and_eq"; "asm"; "auto"
; "bitand"; "bitor"; "bool"; "break"; "case"; "catch"; "char"; "char16_t"
; "char32_t"; "class"; "compl"; "const"; "constexpr"; "const_cast"
; "continue"; "decltype"; "default"; "delete"; "do"; "double"; "dynamic_cast"
; "else"; "enum"; "explicit"; "export"; "extern"; "false"; "float"; "for"
; "friend"; "goto"; "if"; "inline"; "int"; "long"; "mutable"; "namespace"
; "new"; "noexcept"; "not"; "not_eq"; "nullptr"; "operator"; "or"; "or_eq"
; "private"; "protected"; "public"; "register"; "reinterpret_cast"; "return"
; "short"; "signed"; "sizeof"; "static"; "static_assert"; "static_cast"
; "struct"; "switch"; "template"; "this"; "thread_local"; "throw"; "true"
; "try"; "typedef"; "typeid"; "typename"; "union"; "unsigned"; "using"
; "virtual"; "void"; "volatile"; "wchar_t"; "while"; "xor"; "xor_eq" ]
let semantic_check_identifier id =
Validate.(
if id.name = !model_name then
Semantic_error.ident_is_model_name id.id_loc id.name |> error
else if
String.is_suffix id.name ~suffix:"__"
|| List.exists ~f:(fun str -> str = id.name) reserved_keywords
then Semantic_error.ident_is_keyword id.id_loc id.name |> error
else ok ())
(* -- Operators ------------------------------------------------------------- *)
let semantic_check_operator _ = Validate.ok ()
(* == Expressions =========================================================== *)
let arg_type x = (x.emeta.ad_level, x.emeta.type_)
let get_arg_types = List.map ~f:arg_type
(* -- Function application -------------------------------------------------- *)
let semantic_check_fn_map_rect ~loc id es =
Validate.(
match (id.name, es) with
| "map_rect", {expr= Variable arg1; _} :: _
when String.(
is_suffix arg1.name ~suffix:"_lp"
|| is_suffix arg1.name ~suffix:"_rng") ->
Semantic_error.invalid_map_rect_fn loc arg1.name |> error
| _ -> ok ())
let semantic_check_fn_conditioning ~loc id =
Validate.(
if
List.exists ["_lpdf"; "_lpmf"; "_lcdf"; "_lccdf"] ~f:(fun x ->
String.is_suffix id.name ~suffix:x )
then Semantic_error.conditioning_required loc |> error
else ok ())
(** `Target+=` can only be used in model and functions
with right suffix (same for tilde etc)
*)
let semantic_check_fn_target_plus_equals cf ~loc id =
Validate.(
if
String.is_suffix id.name ~suffix:"_lp"
&& not (cf.in_lp_fun_def || cf.current_block = Model)
then Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error
else ok ())
(** Rng functions cannot be used in Tp or Model and only
in funciton defs with the right suffix
*)
let semantic_check_fn_rng cf ~loc id =
Validate.(
if
String.is_suffix id.name ~suffix:"_rng"
&& ( (cf.in_fun_def && not cf.in_rng_fun_def)
|| cf.current_block = TParam || cf.current_block = Model )
then Semantic_error.invalid_rng_fn loc |> error
else ok ())
let mk_fun_app ~is_cond_dist (x, y, z) =
if is_cond_dist then CondDistApp (x, y, z) else FunApp (x, y, z)
(* Regular function application *)
let semantic_check_fn_normal ~is_cond_dist ~loc id es =
Validate.(
match Symbol_table.look vm id.name with
| Some (_, UnsizedType.UFun (_, Void)) ->
Semantic_error.returning_fn_expected_nonreturning_found loc id.name
|> error
| Some (_, UFun (listedtypes, rt))
when not
(UnsizedType.check_compatible_arguments_mod_conv id.name
listedtypes (get_arg_types es)) ->
es
|> List.map ~f:type_of_expr_typed
|> Semantic_error.illtyped_userdefined_fn_app loc id.name listedtypes
rt
|> error
| Some (_, UFun (_, ReturnType ut)) ->
mk_typed_expression
~expr:(mk_fun_app ~is_cond_dist (UserDefined, id, es))
~ad_level:(lub_ad_e es) ~type_:ut ~loc
|> ok
| Some _ ->
(* Check that Funaps are actually functions *)
Semantic_error.returning_fn_expected_nonfn_found loc id.name |> error
| None ->
Semantic_error.returning_fn_expected_undeclaredident_found loc id.name
|> error)
(* Stan-Math function application *)
let semantic_check_fn_stan_math ~is_cond_dist ~loc id es =
match
Stan_math_signatures.stan_math_returntype id.name (get_arg_types es)
with
| Some UnsizedType.Void ->
Semantic_error.returning_fn_expected_nonreturning_found loc id.name
|> Validate.error
| Some (UnsizedType.ReturnType ut) ->
mk_typed_expression
~expr:(mk_fun_app ~is_cond_dist (StanLib, id, es))
~ad_level:(lub_ad_e es) ~type_:ut ~loc
|> Validate.ok
| _ ->
es
|> List.map ~f:(fun e -> e.emeta.type_)
|> Semantic_error.illtyped_stanlib_fn_app loc id.name
|> Validate.error
let fn_kind_from_application id es =
(* We need to check an application here, rather than a mere name of the
function because, technically, user defined functions can shadow
constants in StanLib. *)
if
Stan_math_signatures.stan_math_returntype id.name
(List.map ~f:(fun x -> (x.emeta.ad_level, x.emeta.type_)) es)
<> None
|| Symbol_table.look vm id.name = None
&& Stan_math_signatures.is_stan_math_function_name id.name
then StanLib
else UserDefined
(** Determines the function kind based on the identifier and performs the
corresponding semantic check
*)
let semantic_check_fn ~is_cond_dist ~loc id es =
match fn_kind_from_application id es with
| StanLib -> semantic_check_fn_stan_math ~is_cond_dist ~loc id es
| UserDefined -> semantic_check_fn_normal ~is_cond_dist ~loc id es
(* -- Ternary If ------------------------------------------------------------ *)
let semantic_check_ternary_if loc (pe, te, fe) =
Validate.(
let err =
Semantic_error.illtyped_ternary_if loc pe.emeta.type_ te.emeta.type_
fe.emeta.type_
in
if pe.emeta.type_ = UInt then
match UnsizedType.common_type (te.emeta.type_, fe.emeta.type_) with
| Some type_ ->
mk_typed_expression
~expr:(TernaryIf (pe, te, fe))
~ad_level:(lub_ad_e [pe; te; fe])
~type_ ~loc
|> ok
| None -> error err
else error err)
(* -- Binary (Infix) Operators ---------------------------------------------- *)
let semantic_check_binop loc op (le, re) =
Validate.(
let err =
Semantic_error.illtyped_binary_op loc op le.emeta.type_ re.emeta.type_
in
[le; re] |> List.map ~f:arg_type
|> Stan_math_signatures.operator_stan_math_return_type op
|> Option.value_map ~default:(error err) ~f:(function
| ReturnType type_ ->
mk_typed_expression
~expr:(BinOp (le, op, re))
~ad_level:(lub_ad_e [le; re])
~type_ ~loc
|> ok
| Void -> error err ))
let to_exn v =
v |> Validate.to_result
|> Result.map_error ~f:Fmt.(to_to_string @@ list ~sep:cut Semantic_error.pp)
|> Result.ok_or_failwith
let semantic_check_binop_exn loc op (le, re) =
semantic_check_binop loc op (le, re) |> to_exn
(* -- Prefix Operators ------------------------------------------------------ *)
let semantic_check_prefixop loc op e =
Validate.(
let err = Semantic_error.illtyped_prefix_op loc op e.emeta.type_ in
Stan_math_signatures.operator_stan_math_return_type op [arg_type e]
|> Option.value_map ~default:(error err) ~f:(function
| ReturnType type_ ->
mk_typed_expression
~expr:(PrefixOp (op, e))
~ad_level:(lub_ad_e [e])
~type_ ~loc
|> ok
| Void -> error err ))
(* -- Postfix operators ----------------------------------------------------- *)
let semantic_check_postfixop loc op e =
Validate.(
let err = Semantic_error.illtyped_postfix_op loc op e.emeta.type_ in
Stan_math_signatures.operator_stan_math_return_type op [arg_type e]
|> Option.value_map ~default:(error err) ~f:(function
| ReturnType type_ ->
mk_typed_expression
~expr:(PostfixOp (e, op))
~ad_level:(lub_ad_e [e])
~type_ ~loc
|> ok
| Void -> error err ))
(* -- Variables ------------------------------------------------------------- *)
let semantic_check_variable loc id =
Validate.(
match Symbol_table.look vm id.name with
| None when not (Stan_math_signatures.is_stan_math_function_name id.name)
->
Semantic_error.ident_not_in_scope loc id.name |> error
| None ->
mk_typed_expression ~expr:(Variable id)
~ad_level:(calculate_autodifftype MathLibrary UMathLibraryFunction)
~type_:UMathLibraryFunction ~loc
|> ok
| Some (originblock, type_) ->
mk_typed_expression ~expr:(Variable id)
~ad_level:(calculate_autodifftype originblock type_)
~type_ ~loc
|> ok)
(* -- Conditioned Distribution Application ---------------------------------- *)
let semantic_check_conddist_name ~loc id =
Validate.(
if
List.exists
~f:(fun x -> String.is_suffix id.name ~suffix:x)
["_lpdf"; "_lpmf"; "_lcdf"; "_lccdf"]
then ok ()
else Semantic_error.conditional_notation_not_allowed loc |> error)
(* -- Array Expressions ----------------------------------------------------- *)
(* Array expressions must be of uniform type. (Or mix of int and real) *)
let semantic_check_array_expr_type ~loc es =
Validate.(
match es with
| next :: _ ->
let ty = next.emeta.type_ in
if
List.exists
~f:(fun x ->
not
( UnsizedType.check_of_same_type_mod_array_conv ""
x.emeta.type_ ty
|| UnsizedType.check_of_same_type_mod_array_conv "" ty
x.emeta.type_ ) )
es
then Semantic_error.mismatched_array_types loc |> error
else ok ()
| _ -> Semantic_error.empty_array loc |> error)
let semantic_check_array_expr ~loc es =
Validate.(
match List.map ~f:type_of_expr_typed es with
| [] -> Semantic_error.empty_array loc |> error
| ty :: _ as elementtypes ->
let type_ =
if List.exists ~f:(fun x -> ty <> x) elementtypes then
UnsizedType.UArray UReal
else UArray ty
and ad_level = lub_ad_e es in
mk_typed_expression ~expr:(ArrayExpr es) ~ad_level ~type_ ~loc |> ok)
(* -- Row Vector Expresssion ------------------------------------------------ *)
let semantic_check_rowvector ~loc es =
Validate.(
let elementtypes = List.map ~f:(fun y -> y.emeta.type_) es
and ad_level = lub_ad_e es in
if List.for_all ~f:(fun x -> x = UReal || x = UInt) elementtypes then
mk_typed_expression ~expr:(RowVectorExpr es) ~ad_level ~type_:URowVector
~loc
|> ok
else if List.for_all ~f:(fun x -> x = URowVector) elementtypes then
mk_typed_expression ~expr:(RowVectorExpr es) ~ad_level ~type_:UMatrix
~loc
|> ok
else Semantic_error.invalid_row_vector_types loc |> error)
(* -- Indexed Expressions --------------------------------------------------- *)
let tuple2 a b = (a, b)
let tuple3 a b c = (a, b, c)
let index_with_type idx =
match idx with
| Single e -> (idx, e.emeta.type_)
| _ -> (idx, UnsizedType.UInt)
let inferred_unsizedtype_of_indexed ~loc ut indices =
let rec aux k ut xs =
match (ut, xs) with
| UnsizedType.UMatrix, [(All, _); (Single _, UnsizedType.UInt)]
|UMatrix, [(Upfrom _, _); (Single _, UInt)]
|UMatrix, [(Downfrom _, _); (Single _, UInt)]
|UMatrix, [(Between _, _); (Single _, UInt)]
|UMatrix, [(Single _, UArray UInt); (Single _, UInt)] ->
k @@ Validate.ok UnsizedType.UVector
| _, [] -> k @@ Validate.ok ut
| _, next :: rest -> (
match next with
| Single _, UInt -> (
match ut with
| UArray inner_ty -> aux k inner_ty rest
| UVector | URowVector -> aux k UReal rest
| UMatrix -> aux k URowVector rest
| _ -> Semantic_error.not_indexable loc ut |> Validate.error )
| _ -> (
match ut with
| UArray inner_ty ->
let k' =
Fn.compose k (Validate.map ~f:(fun t -> UnsizedType.UArray t))
in
aux k' inner_ty rest
| UVector | URowVector | UMatrix -> aux k ut rest
| _ -> Semantic_error.not_indexable loc ut |> Validate.error ) )
in
aux Fn.id ut (List.map ~f:index_with_type indices)
let inferred_unsizedtype_of_indexed_exn ~loc ut indices =
inferred_unsizedtype_of_indexed ~loc ut indices |> to_exn
let inferred_ad_type_of_indexed at uindices =
lub_ad_type
( at
:: List.map
~f:(function
| All -> UnsizedType.DataOnly
| Single ue1 | Upfrom ue1 | Downfrom ue1 ->
lub_ad_type [at; ue1.emeta.ad_level]
| Between (ue1, ue2) ->
lub_ad_type [at; ue1.emeta.ad_level; ue2.emeta.ad_level])
uindices )
let rec semantic_check_indexed ~loc ~cf e indices =
Validate.(
indices
|> List.map ~f:(semantic_check_index cf)
|> sequence
|> liftA2 tuple2 (semantic_check_expression cf e)
>>= fun (ue, uindices) ->
let at = inferred_ad_type_of_indexed ue.emeta.ad_level uindices in
uindices
|> inferred_unsizedtype_of_indexed ~loc ue.emeta.type_
|> map ~f:(fun ut ->
mk_typed_expression
~expr:(Indexed (ue, uindices))
~ad_level:at ~type_:ut ~loc ))
and semantic_check_index cf = function
| All -> Validate.ok All
(* Check that indexes have int (container) type *)
| Single e ->
Validate.(
semantic_check_expression cf e
>>= fun ue ->
if has_int_type ue || has_int_array_type ue then ok @@ Single ue
else
Semantic_error.int_intarray_or_range_expected ue.emeta.loc
ue.emeta.type_
|> error)
| Upfrom e ->
semantic_check_expression_of_int_type cf e "Range bound"
|> Validate.map ~f:(fun e -> Upfrom e)
| Downfrom e ->
semantic_check_expression_of_int_type cf e "Range bound"
|> Validate.map ~f:(fun e -> Downfrom e)
| Between (e1, e2) ->
let le = semantic_check_expression_of_int_type cf e1 "Range bound"
and ue = semantic_check_expression_of_int_type cf e2 "Range bound" in
Validate.liftA2 (fun l u -> Between (l, u)) le ue
(* -- Top-level expressions ------------------------------------------------- *)
and semantic_check_expression cf ({emeta; expr} : Ast.untyped_expression) :
Ast.typed_expression Validate.t =
match expr with
| TernaryIf (e1, e2, e3) ->
let pe = semantic_check_expression cf e1
and te = semantic_check_expression cf e2
and fe = semantic_check_expression cf e3 in
Validate.(liftA3 tuple3 pe te fe >>= semantic_check_ternary_if emeta.loc)
| BinOp (e1, op, e2) ->
let le = semantic_check_expression cf e1
and re = semantic_check_expression cf e2
and warn_int_division (x, y) =
match (x.emeta.type_, y.emeta.type_, op) with
| UInt, UReal, Divide | UInt, UInt, Divide ->
Fmt.pr
"@[<hov>Info: Found int division at %s:@ @[<hov 2>%a@]@,%s@,@]"
(Location_span.to_string x.emeta.loc)
Pretty_printing.pp_expression {expr; emeta}
"Positive values rounded down, negative values rounded up or \
down in platform-dependent way." ;
(x, y)
| _ -> (x, y)
in
Validate.(
liftA2 tuple2 le re |> map ~f:warn_int_division
|> apply_const (semantic_check_operator op)
>>= semantic_check_binop emeta.loc op)
| PrefixOp (op, e) ->
Validate.(
semantic_check_expression cf e
|> apply_const (semantic_check_operator op)
>>= semantic_check_prefixop emeta.loc op)
| PostfixOp (e, op) ->
Validate.(
semantic_check_expression cf e
|> apply_const (semantic_check_operator op)
>>= semantic_check_postfixop emeta.loc op)
| Variable id ->
semantic_check_variable emeta.loc id
|> Validate.apply_const (semantic_check_identifier id)
| IntNumeral s ->
mk_typed_expression ~expr:(IntNumeral s) ~ad_level:DataOnly ~type_:UInt
~loc:emeta.loc
|> Validate.ok
| RealNumeral s ->
mk_typed_expression ~expr:(RealNumeral s) ~ad_level:DataOnly ~type_:UReal
~loc:emeta.loc
|> Validate.ok
| FunApp (_, id, es) ->
semantic_check_funapp ~is_cond_dist:false id es cf emeta
| CondDistApp (_, id, es) ->
semantic_check_funapp ~is_cond_dist:true id es cf emeta
| GetLP ->
(* Target+= can only be used in model and functions with right suffix (same for tilde etc) *)
if
not
( cf.in_lp_fun_def || cf.current_block = Model
|| cf.current_block = TParam )
then
Semantic_error.target_plusequals_outisde_model_or_logprob emeta.loc
|> Validate.error
else
mk_typed_expression ~expr:GetLP
~ad_level:(calculate_autodifftype cf.current_block UReal)
~type_:UReal ~loc:emeta.loc
|> Validate.ok
| GetTarget ->
(* Target+= can only be used in model and functions with right suffix (same for tilde etc) *)
if
not
( cf.in_lp_fun_def || cf.current_block = Model
|| cf.current_block = TParam )
then
Semantic_error.target_plusequals_outisde_model_or_logprob emeta.loc
|> Validate.error
else
mk_typed_expression ~expr:GetTarget
~ad_level:(calculate_autodifftype cf.current_block UReal)
~type_:UReal ~loc:emeta.loc
|> Validate.ok
| ArrayExpr es ->
Validate.(
es
|> List.map ~f:(semantic_check_expression cf)
|> sequence
>>= fun ues ->
semantic_check_array_expr ~loc:emeta.loc ues
|> apply_const (semantic_check_array_expr_type ~loc:emeta.loc ues))
| RowVectorExpr es ->
Validate.(
es
|> List.map ~f:(semantic_check_expression cf)
|> sequence
>>= semantic_check_rowvector ~loc:emeta.loc)
| Paren e ->
semantic_check_expression cf e
|> Validate.map ~f:(fun ue ->
mk_typed_expression ~expr:(Paren ue) ~ad_level:ue.emeta.ad_level
~type_:ue.emeta.type_ ~loc:emeta.loc )
| Indexed (e, indices) -> semantic_check_indexed ~loc:emeta.loc ~cf e indices
and semantic_check_funapp ~is_cond_dist id es cf emeta =
let name_check =
if is_cond_dist then semantic_check_conddist_name
else semantic_check_fn_conditioning
in
Validate.(
es
|> List.map ~f:(semantic_check_expression cf)
|> sequence
>>= fun ues ->
semantic_check_fn ~is_cond_dist ~loc:emeta.loc id ues
|> apply_const (semantic_check_identifier id)
|> apply_const (semantic_check_fn_map_rect ~loc:emeta.loc id ues)
|> apply_const (name_check ~loc:emeta.loc id)
|> apply_const (semantic_check_fn_target_plus_equals cf ~loc:emeta.loc id)
|> apply_const (semantic_check_fn_rng cf ~loc:emeta.loc id))
and semantic_check_expression_of_int_type cf e name =
Validate.(
semantic_check_expression cf e
>>= fun ue ->
if has_int_type ue then ok ue
else Semantic_error.int_expected ue.emeta.loc name ue.emeta.type_ |> error)
and semantic_check_expression_of_int_or_real_type cf e name =
Validate.(
semantic_check_expression cf e
>>= fun ue ->
if has_int_or_real_type ue then ok ue
else
Semantic_error.int_or_real_expected ue.emeta.loc name ue.emeta.type_
|> error)
(* -- Sized Types ----------------------------------------------------------- *)
let rec semantic_check_sizedtype cf = function
| SizedType.SInt -> Validate.ok SizedType.SInt
| SReal -> Validate.ok SizedType.SReal
| SVector e ->
semantic_check_expression_of_int_type cf e "Vector sizes"
|> Validate.map ~f:(fun ue -> SizedType.SVector ue)
| SRowVector e ->
semantic_check_expression_of_int_type cf e "Row vector sizes"
|> Validate.map ~f:(fun ue -> SizedType.SRowVector ue)
| SMatrix (e1, e2) ->
let ue1 = semantic_check_expression_of_int_type cf e1 "Matrix sizes"
and ue2 = semantic_check_expression_of_int_type cf e2 "Matrix sizes" in
Validate.liftA2 (fun ue1 ue2 -> SizedType.SMatrix (ue1, ue2)) ue1 ue2
| SArray (st, e) ->
let ust = semantic_check_sizedtype cf st
and ue = semantic_check_expression_of_int_type cf e "Array sizes" in
Validate.liftA2 (fun ust ue -> SizedType.SArray (ust, ue)) ust ue
(* -- Transformations ------------------------------------------------------- *)
let semantic_check_transformation cf = function
| Program.Identity -> Validate.ok Program.Identity
| Lower e ->
semantic_check_expression_of_int_or_real_type cf e "Lower bound"
|> Validate.map ~f:(fun ue -> Program.Lower ue)
| Upper e ->
semantic_check_expression_of_int_or_real_type cf e "Upper bound"
|> Validate.map ~f:(fun ue -> Program.Upper ue)
| LowerUpper (e1, e2) ->
let ue1 =
semantic_check_expression_of_int_or_real_type cf e1 "Lower bound"
and ue2 =
semantic_check_expression_of_int_or_real_type cf e2 "Upper bound"
in
Validate.liftA2 (fun ue1 ue2 -> Program.LowerUpper (ue1, ue2)) ue1 ue2
| Offset e ->
semantic_check_expression_of_int_or_real_type cf e "Offset"
|> Validate.map ~f:(fun ue -> Program.Offset ue)
| Multiplier e ->
semantic_check_expression_of_int_or_real_type cf e "Multiplier"
|> Validate.map ~f:(fun ue -> Program.Multiplier ue)
| OffsetMultiplier (e1, e2) ->
let ue1 = semantic_check_expression_of_int_or_real_type cf e1 "Offset"
and ue2 =
semantic_check_expression_of_int_or_real_type cf e2 "Multiplier"
in
Validate.liftA2
(fun ue1 ue2 -> Program.OffsetMultiplier (ue1, ue2))
ue1 ue2
| Ordered -> Validate.ok Program.Ordered
| PositiveOrdered -> Validate.ok Program.PositiveOrdered
| Simplex -> Validate.ok Program.Simplex
| UnitVector -> Validate.ok Program.UnitVector
| CholeskyCorr -> Validate.ok Program.CholeskyCorr
| CholeskyCov -> Validate.ok Program.CholeskyCov
| Correlation -> Validate.ok Program.Correlation
| Covariance -> Validate.ok Program.Covariance
(* -- Printables ------------------------------------------------------------ *)
let semantic_check_printable cf = function
| PString s -> Validate.ok @@ PString s
(* Print/reject expressions cannot be of function type. *)
| PExpr e -> (
Validate.(
semantic_check_expression cf e
>>= fun ue ->
match ue.emeta.type_ with
| UFun _ | UMathLibraryFunction ->
Semantic_error.not_printable ue.emeta.loc |> error
| _ -> ok @@ PExpr ue) )
(* -- Truncations ----------------------------------------------------------- *)
let semantic_check_truncation cf = function
| NoTruncate -> Validate.ok NoTruncate
| TruncateUpFrom e ->
semantic_check_expression_of_int_or_real_type cf e "Truncation bound"
|> Validate.map ~f:(fun ue -> TruncateUpFrom ue)
| TruncateDownFrom e ->
semantic_check_expression_of_int_or_real_type cf e "Truncation bound"
|> Validate.map ~f:(fun ue -> TruncateDownFrom ue)
| TruncateBetween (e1, e2) ->
let ue1 =
semantic_check_expression_of_int_or_real_type cf e1 "Truncation bound"
and ue2 =
semantic_check_expression_of_int_or_real_type cf e2 "Truncation bound"
in
Validate.liftA2 (fun ue1 ue2 -> TruncateBetween (ue1, ue2)) ue1 ue2
(* == Statements ============================================================ *)
(* -- Non-returning function application ------------------------------------ *)
let semantic_check_nrfn_target ~loc ~cf id =
Validate.(
if
String.is_suffix id.name ~suffix:"_lp"
&& not (cf.in_lp_fun_def || cf.current_block = Model)
then Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error
else ok ())
let semantic_check_nrfn_normal ~loc id es =
Validate.(
match Symbol_table.look vm id.name with
| Some (_, UFun (listedtypes, Void))
when UnsizedType.check_compatible_arguments_mod_conv id.name listedtypes
(get_arg_types es) ->
mk_typed_statement
~stmt:(NRFunApp (UserDefined, id, es))
~return_type:NoReturnType ~loc
|> ok
| Some (_, UFun (listedtypes, Void)) ->
es
|> List.map ~f:type_of_expr_typed
|> Semantic_error.illtyped_userdefined_fn_app loc id.name listedtypes
Void
|> error
| Some (_, UFun (_, ReturnType _)) ->
Semantic_error.nonreturning_fn_expected_returning_found loc id.name
|> error
| Some _ ->
Semantic_error.nonreturning_fn_expected_nonfn_found loc id.name
|> error
| None ->
Semantic_error.nonreturning_fn_expected_undeclaredident_found loc
id.name
|> error)
let semantic_check_nrfn_stan_math ~loc id es =
Validate.(
match
Stan_math_signatures.stan_math_returntype id.name (get_arg_types es)
with
| Some UnsizedType.Void ->
mk_typed_statement
~stmt:(NRFunApp (StanLib, id, es))
~return_type:NoReturnType ~loc
|> ok
| Some (UnsizedType.ReturnType _) ->
Semantic_error.nonreturning_fn_expected_returning_found loc id.name
|> error
| None ->
es
|> List.map ~f:type_of_expr_typed
|> Semantic_error.illtyped_stanlib_fn_app loc id.name
|> error)
let semantic_check_nr_fnkind ~loc id es =
match fn_kind_from_application id es with
| StanLib -> semantic_check_nrfn_stan_math ~loc id es
| UserDefined -> semantic_check_nrfn_normal ~loc id es
let semantic_check_nr_fn_app ~loc ~cf id es =
Validate.(
es
|> List.map ~f:(semantic_check_expression cf)
|> sequence
|> apply_const (semantic_check_identifier id)
|> apply_const (semantic_check_nrfn_target ~loc ~cf id)
>>= semantic_check_nr_fnkind ~loc id)
(* -- Assignment ------------------------------------------------------------ *)
let semantic_check_assignment_read_only ~loc id =
Validate.(
if Symbol_table.get_read_only vm id.name then
Semantic_error.cannot_assign_to_read_only loc id.name |> error
else ok ())
(* Variables from previous blocks are read-only.
In particular, data and parameters never assigned to
*)
let semantic_check_assignment_global ~loc ~cf ~block id =
Validate.(
if (not (Symbol_table.is_global vm id.name)) || block = cf.current_block
then ok ()
else Semantic_error.cannot_assign_to_global loc id.name |> error)
let mk_assignment_from_indexed_expr assop lhs rhs =
Assignment
{assign_lhs= Ast.lvalue_of_expr lhs; assign_op= assop; assign_rhs= rhs}
let semantic_check_assignment_operator ~loc assop lhs rhs =
Validate.(
let err =
Semantic_error.illtyped_assignment loc assop lhs.emeta.type_
rhs.emeta.type_
in
match assop with
| Assign | ArrowAssign ->
if
UnsizedType.check_of_same_type_mod_array_conv "" lhs.emeta.type_
rhs.emeta.type_
then
mk_typed_statement ~return_type:NoReturnType ~loc
~stmt:(mk_assignment_from_indexed_expr assop lhs rhs)
|> ok
else error err
| OperatorAssign op ->
List.map ~f:arg_type [lhs; rhs]
|> Stan_math_signatures.assignmentoperator_stan_math_return_type op
|> Option.value_map ~default:(error err) ~f:(function
| ReturnType _ -> error err
| Void ->
mk_typed_statement ~return_type:NoReturnType ~loc
~stmt:(mk_assignment_from_indexed_expr assop lhs rhs)
|> ok ))
let semantic_check_assignment ~loc ~cf assign_lhs assign_op assign_rhs =
let assign_id = Ast.id_of_lvalue assign_lhs in
let lhs = expr_of_lvalue assign_lhs |> semantic_check_expression cf
and assop = semantic_check_assignmentoperator assign_op
and rhs = semantic_check_expression cf assign_rhs
and block =
Symbol_table.look vm assign_id.name
|> Option.map ~f:(fun (block, _) -> Validate.ok block)
|> Option.value
~default:
( if Stan_math_signatures.is_stan_math_function_name assign_id.name
then Validate.ok MathLibrary
else
Validate.error
@@ Semantic_error.ident_not_in_scope loc assign_id.name )
in
Validate.(
liftA2 tuple2 (liftA3 tuple3 lhs assop rhs) block
>>= fun ((lhs, assop, rhs), block) ->
semantic_check_assignment_operator ~loc assop lhs rhs
|> apply_const (semantic_check_assignment_global ~loc ~cf ~block assign_id)
|> apply_const (semantic_check_assignment_read_only ~loc assign_id))
(* -- Target plus-equals / Increment log-prob ------------------------------- *)
let semantic_check_target_pe_expr_type ~loc e =
match e.emeta.type_ with
| UFun _ | UMathLibraryFunction ->
Semantic_error.int_or_real_container_expected loc e.emeta.type_
|> Validate.error
| _ -> Validate.ok ()
let semantic_check_target_pe_usage ~loc ~cf =
if cf.in_lp_fun_def || cf.current_block = Model then Validate.ok ()
else
Semantic_error.target_plusequals_outisde_model_or_logprob loc
|> Validate.error
let semantic_check_target_pe ~loc ~cf e =
Validate.(
semantic_check_expression cf e
|> apply_const (semantic_check_target_pe_usage ~loc ~cf)
>>= fun ue ->
semantic_check_target_pe_expr_type ~loc ue
|> map ~f:(fun _ ->
mk_typed_statement ~stmt:(TargetPE ue) ~return_type:NoReturnType
~loc ))
let semantic_check_incr_logprob ~loc ~cf e =
Validate.(
semantic_check_expression cf e
|> apply_const (semantic_check_target_pe_usage ~loc ~cf)
>>= fun ue ->
semantic_check_target_pe_expr_type ~loc ue
|> map ~f:(fun _ ->
mk_typed_statement ~stmt:(IncrementLogProb ue)
~return_type:NoReturnType ~loc ))
(* -- Tilde (Sampling notation) --------------------------------------------- *)
let semantic_check_sampling_pdf_pmf id =
Validate.(
if
String.(
is_suffix id.name ~suffix:"_lpdf" || is_suffix id.name ~suffix:"_lpmf")
then error @@ Semantic_error.invalid_sampling_pdf_or_pmf id.id_loc
else ok ())
let semantic_check_sampling_cdf_ccdf ~loc id =
Validate.(
if
String.(
is_suffix id.name ~suffix:"_cdf" || is_suffix id.name ~suffix:"_ccdf")
then error @@ Semantic_error.invalid_sampling_cdf_or_ccdf loc id.name