-
Notifications
You must be signed in to change notification settings - Fork 1k
/
dnnl.h
3890 lines (3662 loc) · 176 KB
/
dnnl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*******************************************************************************
* Copyright 2016-2021 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
/// @file
/// C API
#ifndef ONEAPI_DNNL_DNNL_H
#define ONEAPI_DNNL_DNNL_H
#include "oneapi/dnnl/dnnl_config.h"
#include "oneapi/dnnl/dnnl_types.h"
#include "oneapi/dnnl/dnnl_version.h"
#ifdef __cplusplus
extern "C" {
#endif
/// @addtogroup dnnl_api
/// @{
/// @addtogroup dnnl_api_primitives
/// @{
/// @addtogroup dnnl_api_primitives_common
/// @{
/// Creates a primitive descriptor iterator.
///
/// @param iterator Output primitive descriptor iterator.
/// @param op_desc Operation descriptor.
/// @param attr Primitive attributes (can be NULL).
/// @param engine Engine to use.
/// @param hint_forward_primitive_desc For backward propagation: primitive
/// descriptor for a respective forward propagation primitive. Pass NULL
/// for forward propagation.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_create(
dnnl_primitive_desc_iterator_t *iterator, const_dnnl_op_desc_t op_desc,
const_dnnl_primitive_attr_t attr, dnnl_engine_t engine,
const_dnnl_primitive_desc_t hint_forward_primitive_desc);
/// Advances the primitive descriptor iterator to point to the next available
/// implementation.
///
/// @param iterator A primitive descriptor iterator to advance.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
/// @returns #dnnl_iterator_ends if no more implementations available.
dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_next(
dnnl_primitive_desc_iterator_t iterator);
/// Fetches the current primitive descriptor from a primitive descriptor
/// iterator.
///
/// @note
/// The user is responsible for deleting the resulting primitive
/// descriptor using dnnl_primitive_desc_destroy().
///
/// @param iterator A primitive descriptor iterator.
/// @returns A primitive descriptor.
dnnl_primitive_desc_t DNNL_API dnnl_primitive_desc_iterator_fetch(
const_dnnl_primitive_desc_iterator_t iterator);
/// Destroys a primitive descriptor iterator.
///
/// @param iterator Primitive descriptor iterator to destroy.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_destroy(
dnnl_primitive_desc_iterator_t iterator);
/// Creates a primitive descriptor. This function is equivalent to a sequence
/// of #dnnl_primitive_desc_iterator_create() and
/// #dnnl_primitive_desc_iterator_fetch(). In other words, the library will
/// pick the first suitable implementation.
///
/// @param primitive_desc Output primitive descriptor.
/// @param op_desc Operation descriptor.
/// @param attr Primitive attributes (can be NULL).
/// @param engine Engine to use.
/// @param hint_forward_primitive_desc For backward propagation: primitive
/// descriptor for a respective forward propagation primitive. Pass NULL
/// for forward propagation.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, const_dnnl_op_desc_t op_desc,
const_dnnl_primitive_attr_t attr, dnnl_engine_t engine,
const_dnnl_primitive_desc_t hint_forward_primitive_desc);
/// Clones a primitive descriptor. The resulting primitive descriptor must be
/// destroyed separately.
///
/// @param primitive_desc Output primitive descriptor.
/// @param existing_primitive_desc Primitive descriptor to clone.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_desc_clone(
dnnl_primitive_desc_t *primitive_desc,
const_dnnl_primitive_desc_t existing_primitive_desc);
/// Returns a constant reference to the attributes of a primitive descriptor.
///
/// @warning
/// It is an error to destroy the resulting @p attr.
///
/// @warning
/// The lifetime of an @p attr is the same as that of a @p
/// primitive_desc, so it is an error to use the @p attr once the @p
/// primitive_desc has been destroyed.
///
/// @param primitive_desc Primitive descriptor.
/// @param attr Output primitive attributes.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_desc_get_attr(
const_dnnl_primitive_desc_t primitive_desc,
const_dnnl_primitive_attr_t *attr);
/// Destroys a primitive descriptor.
///
/// @param primitive_desc Primitive descriptor to destroy.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_desc_destroy(
dnnl_primitive_desc_t primitive_desc);
/// Queries a primitive descriptor for various pieces of information.
///
/// The most common use case is to query a primitive descriptor, created with
/// source, weights, and destination memory descriptors with format tags set
/// to #dnnl_format_tag_any, for the corresponding memory descriptors (in this
/// case the @p what is set to #dnnl_query_src_md, #dnnl_query_weights_md, and
/// #dnnl_query_dst_md respectively) so that it is possible to create memory
/// objects and reorder primitives if necessary.
///
/// Another typical use case is to query a primitive descriptor for workspace
/// memory descriptor (with @p what set to #dnnl_query_workspace_md). If this
/// query returns #dnnl_not_required status, then workspace memory is not
/// required.
///
/// @note
/// When querying for a memory descriptor for a scratchpad, a workspace,
/// or an optional parameter, the query will return a pointer to a zero
/// memory descriptor if the parameter is not needed.
///
/// A few other use cases:
/// - query a primitive descriptor for the underlying operation descriptor
/// (#dnnl_query_convolution_d, #dnnl_query_eltwise_d, #dnnl_query_rnn_d,
/// etc.)
/// - query a primitive descriptor for the implementation information string
/// (#dnnl_query_impl_info_str)
/// - query a primitive descriptor for the number of inputs and outputs
/// (#dnnl_query_num_of_inputs_s32 and #dnnl_query_num_of_outputs_s32
/// respectively)
///
/// @sa dnnl_query_t for more options
///
/// @param primitive_desc Primitive descriptor.
/// @param what Parameter to query.
/// @param index Index of the parameter to query for.
/// @param result Output result. The type depends on the query. For example,
/// it must be a @c dnnl_memory_desc_t* if querying for a memory
/// descriptor.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_desc_query(
const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
int index, void *result);
/// Queries primitive descriptor for a memory descriptor.
///
/// @note
/// This function is a convenience version of
/// #dnnl_primitive_desc_query().
///
/// @param primitive_desc Primitive descriptor.
/// @param what Kind of memory descriptor parameter to query for.
/// @param index Index of the parameter to query.
/// @returns A pointer to the requested memory descriptor.
/// @returns A pointer to a zero memory descriptor if the parameter is not
/// needed.
/// @returns NULL in case of any error.
///
const dnnl_memory_desc_t DNNL_API *dnnl_primitive_desc_query_md(
const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
int index);
/// Queries primitive descriptor for a signed 32bit int.
///
/// @note
/// This function is a convenience version of
/// #dnnl_primitive_desc_query().
///
/// @param primitive_desc Primitive descriptor.
/// @param what Kind of the value to query for.
/// @param index Index of the parameter to query.
/// @returns The requested value.
/// @returns 0 in case of any error (in particular if the queried entity is
/// not of type int32_t). Note that 0 may also be the actual returned
/// value.
int DNNL_API dnnl_primitive_desc_query_s32(
const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
int index);
/// Creates a primitive.
///
/// @param primitive Output primitive.
/// @param primitive_desc Primitive descriptor used to create the primitive.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_create(dnnl_primitive_t *primitive,
const_dnnl_primitive_desc_t primitive_desc);
/// Executes a primitive.
///
/// @param primitive Primitive to execute.
/// @param stream Stream to use.
/// @param nargs Number of arguments.
/// @param args Array of arguments. Each argument is an
/// <index, #dnnl_memory_t> pair. The index is one of the `DNNL_ARG_*`
/// values such as `DNNL_ARG_SRC`. Unless runtime shapes are used (see
/// #DNNL_RUNTIME_DIM_VAL), the memory object must have the same memory
/// descriptor as that returned by
/// #dnnl_primitive_desc_query_md(#dnnl_query_exec_arg_md, index).
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
/// @note If any argument in @param args is padded (padded_dims >
/// dims), the primitive execution will assume properly zero-padded
/// input arguments, and produce zero-padded output arguments.
dnnl_status_t DNNL_API dnnl_primitive_execute(const_dnnl_primitive_t primitive,
dnnl_stream_t stream, int nargs, const dnnl_exec_arg_t *args);
/// Retrieves a constant reference to the primitive descriptor of a given
/// primitive.
///
/// @warning
/// It is an error to destroy the returned object. It is owned by the
/// primitive. The @c const qualifier of the returned object prevents
/// such attempts.
///
/// @param primitive Primitive to query for the primitive descriptor.
/// @param primitive_desc Output primitive descriptor.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_get_primitive_desc(
const_dnnl_primitive_t primitive,
const_dnnl_primitive_desc_t *primitive_desc);
/// Destroys a primitive.
///
/// @param primitive The primitive to destroy.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_destroy(dnnl_primitive_t primitive);
/// @} dnnl_api_primitives_common
/// @addtogroup dnnl_api_attributes
/// @{
/// Creates an empty (default) primitive attributes with all the parameters
/// set to their default values.
///
/// Empty attributes are implied whenever the respective argument is NULL.
///
/// @param attr Output primitive attributes.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_create(dnnl_primitive_attr_t *attr);
/// Clones primitive attributes.
///
/// @param attr Output primitive attributes.
/// @param existing_attr Primitive attributes to clone.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_clone(
dnnl_primitive_attr_t *attr, const_dnnl_primitive_attr_t existing_attr);
/// Destroys primitive attributes.
///
/// @param attr Primitive attributes to destroy.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_destroy(dnnl_primitive_attr_t attr);
/// Returns the floating-point math mode primitive attribute.
///
/// @param attr Primitive attributes.
/// @param mode Output FP math mode.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_fpmath_mode(
const_dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t *mode);
/// Sets the floating-point math mode primitive attributes.
///
/// @param attr Primitive attributes.
/// @param mode FP math mode. The possible values are:
/// #dnnl_fpmath_mode_strict (default),
/// #dnnl_fpmath_mode_bf16,
/// #dnnl_fpmath_mode_f16,
/// #dnnl_fpmath_mode_any.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_fpmath_mode(
dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t mode);
/// Returns the primitive attributes scratchpad mode.
///
/// @param attr Primitive attributes.
/// @param mode Output scratchpad mode.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_scratchpad_mode(
const_dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t *mode);
/// Sets primitive attributes scratchpad mode.
///
/// @param attr Primitive attributes.
/// @param mode Scratchpad mode. The possible values are:
/// #dnnl_scratchpad_mode_library (default) and
/// #dnnl_scratchpad_mode_user.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode(
dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t mode);
/// Returns primitive attributes output scaling factors correspondence mask
/// and values.
///
/// @warning
/// The @p scales array is an internal part of the primitive attributes
/// @p attr, so it is an error to modify or destroy the @p scales array.
///
/// @warning
/// The lifetime of @p scales array is the same as that of the primitive
/// attributes @p attr to which it belongs, so it is an error to use
/// @p scales after @p attr is destroyed.
///
/// @param attr Primitive attributes.
/// @param count Output length of the array of scaling factors @p scales.
/// @param mask Output scaling factors correspondence mask that defines the
/// correspondence between the output tensor dimensions and the @p scales
/// vector. The set i-th bit indicates that a dedicated output scaling
/// factor is used for each index along that dimension. The mask value of
/// 0 implies a common output scaling factor for the whole output tensor.
/// @param scales Output pointer to a constant array of scaling factors.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_output_scales(
const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask,
const float **scales);
/// Sets output scaling factors correspondence mask and values.
///
/// @note
/// The order of dimensions does not depend on how elements are laid
/// out in memory. For example:
/// - for a 2D CNN activations tensor the order is always (n, c)
/// - for a 4D CNN activations tensor the order is always (n, c, h, w)
/// - for a 5D CNN weights tensor the order is always
/// (g, oc, ic, kh, kw)
///
/// Example usage:
/// @code
/// int mb = 32, oc = 32, oh = 14, ow = 14; // convolution output params
/// float scales[oc] = { ... }; // unique output scales per output channel
/// int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ...
///
/// dnnl_convolution_desc_t conv_d; // create a convolution descriptor
///
/// dnnl_primitive_attr_t attr;
/// dnnl_primitive_attr_create(&attr); // create primitive attributes
/// dnnl_primitive_attr_set_output_scales(attr, oc, 1 << oc_dim, scales);
///
/// dnnl_primitive_desc_t conv_pd;
/// dnnl_primitive_desc_create(&conv_pd, &conv_d, attr, engine, NULL);
/// @endcode
///
/// @param attr Primitive attributes.
/// @param count Length of the array of scaling factors @p scales.
/// @param mask Scaling factors correspondence mask that defines the
/// correspondence between the output tensor dimensions and the @p scales
/// array. The set i-th bit indicates that a dedicated output scaling
/// factor is used for each index along that dimension. The mask value of
/// 0 implies a common output scaling factor for the whole output tensor.
/// @param scales Array of output scaling factors. If the output scaling
/// factors are known at the time of this call, this array must contain @p
/// count values and the following equality must hold:
/// \f[count = \prod\limits_{d \in mask} output.dims[d].\f]
/// Violations can only be detected when the attributes are used to create
/// a primitive descriptor.
/// If the output scaling factors are not known at the time of the call,
/// this array must contain a single #DNNL_RUNTIME_F32_VAL value and the
/// output scaling factors must be passed at execution time as an argument
/// with index #DNNL_ARG_ATTR_OUTPUT_SCALES.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_output_scales(
dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask,
const float *scales);
/// Returns primitive attributes scaling factors correspondence mask and values
/// for a given memory argument.
///
/// @warning
/// The output @p scales array is an internal part of the primitive
/// attributes @p attr, so it is an error to modify or destroy the @p
/// scales array.
///
/// @warning
/// The lifetime of the @p scales array is the same as that of the primitive
/// attributes @p attr to which it belongs, so it is an error to use @p
/// scales after @p attr is destroyed.
///
///
/// @param attr Primitive attributes.
/// @param arg Parameter argument index as passed to the
/// dnnl_primitive_execute() call.
/// @param count Output length of the array of scaling factors @p scales.
/// @param mask Output scaling factors correspondence mask that defines the
/// correspondence between the output tensor dimensions and the @p
/// scales array. The set i-th bit indicates that a dedicated output scaling
/// factor is used for each index along that dimension. The mask value of 0
/// implies a common scaling factor for the whole output tensor.
/// @param scales Output pointer to a constant array of float scaling factors.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_scales(
dnnl_primitive_attr_t attr, int arg, dnnl_dim_t *count, int *mask,
const float **scales);
/// Sets primitive attributes scaling factors for primitive operations for a
/// given memory argument.
///
/// @sa dnnl_primitive_attr_set_output_scales
///
///
/// @param attr Primitive attributes.
/// @param arg Parameter argument index as passed to the
/// dnnl_primitive_execute() call.
/// @param count Length of the array of scaling factors @p scales.
/// @param mask Scaling factors correspondence mask that defines the
/// correspondence between the tensor dimensions and the @p scales array.
/// The set i-th bit indicates that a dedicated scaling factor is used for
/// each index along that dimension. Set the mask to 0 to use a common
/// scaling factor for the whole output tensor.
/// @param scales Constant array of float scaling factors. This array must
/// contain @p count scales and the following equality must hold:
/// \f[count = \prod\limits_{d \in mask} output.dims[d].\f]
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales(
dnnl_primitive_attr_t attr, int arg, dnnl_dim_t count, int mask,
const float *scales);
/// Returns @p count, correspondence zero point @p mask, and a pointer to a
/// constant int32_t array of @p zero_points for given @p attr and memory
/// argument (index), previously set by dnnl_primitive_attr_set_zero_points.
///
/// @warning
/// The output @p zero_points array is an internal part of the primitive
/// attributes @p attr, so it is an error to modify or destroy the @p
/// zero_points array.
///
/// @warning
/// The lifetime of @p zero_points array is the same as that of the
/// primitive attributes @p attr to which it belongs, so it is an error
/// to use @p zero_points after @p attr is destroyed.
///
///
/// @param attr Primitive attributes.
/// @param arg Parameter argument index as passed to the
/// dnnl_primitive_execute() call.
/// @param count Output length of the array of zero points @p zero_points.
/// @param mask Output zero points correspondence mask that defines the
/// correspondence between the output tensor dimensions and the @p
/// zero_points array. The set i-th bit indicates that a dedicated output
/// zero point is used for each index along that dimension. The mask
/// value of 0 implies a common zero point for the whole output tensor.
/// @param zero_points Output pointer to a constant array of int32_t zero
/// points.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_zero_points(
const_dnnl_primitive_attr_t attr, int arg, dnnl_dim_t *count, int *mask,
const int32_t **zero_points);
/// Sets primitive attributes zero points for primitive operations for a given
/// memory argument.
///
/// @sa dnnl_primitive_attr_set_output_scales
///
///
/// @param attr Primitive attributes.
/// @param arg Parameter argument index as passed to the
/// dnnl_primitive_execute() call.
/// @param count Length of the array of zero points @p zero_points.
/// @param mask Zero point correspondence mask that defines the
/// correspondence between the tensor dimensions and the @p
/// zero_points array. The set i-th bit indicates that a dedicated
/// zero point is used for each index along that dimension. Set the
/// mask to 0 to use a common zero point for the whole output tensor.
/// @param zero_points Constant array of int32_t zero points. If the zero
/// points are known at the time of this call, this array must contain @p
/// count zero points and the following equality must hold:
/// \f[count = \prod\limits_{d \in mask} output.dims[d].\f]
/// If the zero points are not known at the time of the call, this array
/// must contain a single #DNNL_RUNTIME_S32_VAL and the zero points must
/// be passed at execution time as an argument with index
/// #DNNL_ARG_ATTR_ZERO_POINTS.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points(
dnnl_primitive_attr_t attr, int arg, dnnl_dim_t count, int mask,
const int32_t *zero_points);
/// Returns primitive attributes post-ops.
///
/// @warning
/// The output @p post_ops points to the internal @p attr field, so it is
/// an error to modify or destroy them. The lifetime of @p post_ops is
/// the same as that of the @p attr it belongs to, so it is an error to
/// use @p post_ops after @p attr has been destroyed.
///
/// @param attr Primitive attributes.
/// @param post_ops Output post-ops.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_post_ops(
const_dnnl_primitive_attr_t attr, const_dnnl_post_ops_t *post_ops);
/// Sets primitive attributes post-ops.
///
/// @note
/// There is no way to check whether the post-ops would be supported by
/// the target primitive. Any error will be reported by the
/// dnnl_primitive_desc_create() function call.
///
/// @param attr Primitive attributes.
/// @param post_ops Post-ops to set.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_post_ops(
dnnl_primitive_attr_t attr, const_dnnl_post_ops_t post_ops);
/// Creates empty post-ops sequence.
///
/// @param post_ops Output post-ops.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_post_ops_create(dnnl_post_ops_t *post_ops);
/// Destroys post-ops.
///
/// @param post_ops Post-ops to destroy.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_post_ops_destroy(dnnl_post_ops_t post_ops);
/// Returns the length of post-ops.
///
/// @param post_ops Post-ops.
/// @returns The number of post-ops entries.
int DNNL_API dnnl_post_ops_len(const_dnnl_post_ops_t post_ops);
/// Returns the kind of a post-op entry.
///
/// @param post_ops Post-ops.
/// @param index Post-op entry index.
/// @returns The kind of the post-op with the specified index.
/// @returns #dnnl_undefined_primitive if there is no post-op at the specified
/// index.
dnnl_primitive_kind_t DNNL_API dnnl_post_ops_get_kind(
const_dnnl_post_ops_t post_ops, int index);
/// Appends an accumulation (sum) to post-ops. Prior to accumulating the
/// result, the previous value is multiplied by a scale.
///
/// The kind of this post-op is #dnnl_sum.
///
/// This feature may improve performance for cases like residual learning
/// blocks, where the result of convolution is accumulated to the previously
/// computed activations. The parameter @p scale may be used for the
/// integer-based computations when the result and previous activations have
/// different logical scaling factors.
///
/// In the simplest case where the accumulation is the only post-op, the
/// computations will be:
///
/// dst[:] <- scale * dst[:] + op(...) // instead of dst[:] <- op(...)
///
/// @note
/// This post-op executes in-place and does not change the
/// destination layout.
///
/// @param post_ops Post-ops.
/// @param scale Accumulation scaling factor.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_post_ops_append_sum(
dnnl_post_ops_t post_ops, float scale);
/// Appends an accumulation v2 (sum) to post-ops. Prior to accumulating the
/// result, the previous value is multiplied by a scale.
///
/// The kind of this post-op is #dnnl_sum.
///
/// This feature may improve performance for cases like residual learning
/// blocks, where the result of convolution is accumulated to the previously
/// computed activations. The parameter @p scale may be used for the
/// integer-based computations when the result and previous activations have
/// different logical scaling factors.
///
/// In the simplest case where the accumulation is the only post-op, the
/// computations will be:
///
/// dst[:] <- scale * dst[:] + op(...) // instead of dst[:] <- op(...)
///
/// If @p data_type is specified, original dst tensor will be reinterpreted
/// as a tensor with provided data type. Since it is reinterpretation,
/// data_type and dst data type should have the same size.
/// As a result, computations will be:
///
/// dst[:] <- scale * as_data_type(dst[:]) + op(...)
/// // instead of dst[:] <- op(...)
/// @note
/// This post-op executes in-place and does not change the
/// destination layout.
///
/// @param post_ops Post-ops.
/// @param scale Accumulation scaling factor.
/// @param data_type Accumulation data_type.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_post_ops_append_sum_v2(
dnnl_post_ops_t post_ops, float scale, dnnl_data_type_t data_type);
/// Appends an accumulation v3 (sum) to post-ops. Prior to accumulating the
/// result, a zero point is subtracted from the previous value and is
/// multiplied by the scale.
///
/// The kind of this post-op is #dnnl_sum.
///
/// This feature may improve performance for cases like dequantize the
/// asymmetrically quantized sum's src1 tensor to f32 domain before performing
/// the sum operation by subtracting the @p zero_point before the scaling.
///
/// In the simplest case where accumulation is the only post-op, the
/// computations will be:
///
/// dst[:] <- scale * (dst[:] - zero_point) + op(...)
/// // instead of dst[:] <- op(...)
///
/// If @p data_type is specified, original dst tensor will be reinterpreted
/// as a tensor with provided data type. Since it is reinterpretation,
/// data_type and dst data type should have the same size.
/// As a result, computations will be:
///
/// dst[:] <- scale * (as_data_type(dst[:]) - zero_point) + op(...)
/// // instead of dst[:] <- op(...)
/// @note
/// This post-op executes in-place and does not change the
/// destination layout.
///
/// @param post_ops Post-ops.
/// @param scale Accumulation scaling factor.
/// @param zero_point Single scalar int32_t value of zero point.
/// @param data_type Accumulation data_type.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_post_ops_append_sum_v3(dnnl_post_ops_t post_ops,
float scale, int32_t zero_point, dnnl_data_type_t data_type);
/// Returns the parameters of an accumulation (sum) post-op.
///
/// @param post_ops Post-ops.
/// @param index Index of the sum post-op.
/// @param scale Output accumulation scaling factor.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
/// @returns #dnnl_invalid_arguments if @p index does not refer to a sum
/// post-op.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum(
const_dnnl_post_ops_t post_ops, int index, float *scale);
/// Returns the parameters of an accumulation (sum) post-op with
/// a data type parameter.
///
/// @param post_ops Post-ops.
/// @param index Index of the sum post-op.
/// @param scale Output accumulation scaling factor.
/// @param data_type Data type for accumulation.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum_v2(
const_dnnl_post_ops_t post_ops, int index, float *scale,
dnnl_data_type_t *data_type);
/// Returns the parameters of an accumulation (sum) post-op with
/// zero point and data type parameter.
///
/// @param post_ops Post-ops.
/// @param index Index of the sum post-op.
/// @param scale Output accumulation scaling factor.
/// @param zero_point Zero point.
/// @param data_type Data type for accumulation.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum_v3(
const_dnnl_post_ops_t post_ops, int index, float *scale,
int32_t *zero_point, dnnl_data_type_t *data_type);
/// Appends an elementwise post-op.
///
/// The kind of this post operation is #dnnl_eltwise.
///
/// In the simplest case when the elementwise is the only post operation, the
/// computations would be:
///
/// dst[:] <- scale * eltwise_op (op(...)) // instead of dst[:] <- op(...)
///
/// where eltwise_op is configured with the given parameters.
///
/// @param post_ops Post-ops.
/// @param scale Scaling factor.
/// @param alg_kind Elementwise algorithm for the post-op.
/// @param alpha Alpha parameter for the elementwise algorithm.
/// @param beta Beta parameter for the elementwise algorithm.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_post_ops_append_eltwise(dnnl_post_ops_t post_ops,
float scale, dnnl_alg_kind_t alg_kind, float alpha, float beta);
/// Returns the parameters of an elementwise post-op.
///
/// @param post_ops Post-ops.
/// @param index Index of the elementwise post-op.
/// @param scale Output scaling factor.
/// @param alg_kind Output elementwise algorithm kind.
/// @param alpha Output alpha parameter for the elementwise algorithm.
/// @param beta Output beta parameter for the elementwise algorithm.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
/// @returns #dnnl_invalid_arguments if @p index does not refer to an
/// elementwise post-op.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_eltwise(
const_dnnl_post_ops_t post_ops, int index, float *scale,
dnnl_alg_kind_t *alg_kind, float *alpha, float *beta);
/// Appends a depthwise post-op convolution with stride 1.
///
/// This post-op can only be fused with a 2D 1x1 convolution (convolution with
/// weights spatial dimension equal to 1 i.e., kh=kw=1).
///
/// The kind of this post-op is #dnnl_convolution.
///
/// The number of outputs for primitive remain same as before. The output size
/// remain same as the original primitive due to stride=1.
///
/// The Post-op can be defined as:
///
/// dst[:] <- scales * (conv_dw(conv_1x1))
///
/// See @ref dev_guide_attributes_post_ops_depthwise and
/// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info.
///
/// @param post_ops Post-ops.
/// @param weights_data_type Weights data type of depthwise post-op
/// @param bias_data_type Bias data type of depthwise post-op
/// @param dst_data_type Output data type of depthwise post-op
/// @param count Output length of the array of scaling factors @p scales.
/// @param mask Output scaling factors correspondence mask that defines the
/// correspondence between the output tensor dimensions and the @p
/// scales array. The set i-th bit indicates that a dedicated output scaling
/// factor is used for each index along that dimension. The mask value of 0
/// implies a common scaling factor for the whole output tensor.
/// @param scales Output pointer to a constant array of float scaling factors.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise
dnnl_status_t DNNL_API dnnl_post_ops_append_dw_k3s1p1(dnnl_post_ops_t post_ops,
dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type,
dnnl_data_type_t dst_data_type, dnnl_dim_t count, int mask,
const float *scales);
/// Returns the parameters of an depthwise post-op with stride 1.
///
/// @param post_ops Post-ops.
/// @param index Index of the elementwise post-op.
/// @param weights_data_type Weights data type of depthwise post-op
/// @param bias_data_type Bias data type of depthwise post-op
/// @param dst_data_type Output data type of depthwise post-op
/// @param count Output length of the array of scaling factors @p scales.
/// @param mask Output scaling factors correspondence mask that defines the
/// correspondence between the output tensor dimensions and the @p
/// scales array. The set i-th bit indicates that a dedicated output scaling
/// factor is used for each index along that dimension. The mask value of 0
/// implies a common scaling factor for the whole output tensor.
/// @param scales Output pointer to a constant array of float scaling factors.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise
dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw_k3s1p1(
const_dnnl_post_ops_t post_ops, int index,
dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type,
dnnl_data_type_t *dst_data_type, dnnl_dim_t *count, int *mask,
const float **scales);
/// Appends a depthwise post-op convolution with stride 2.
///
/// This post-op can only be fused with a 2D 1x1 convolution (convolution with
/// weights spatial dimension equal to 1 i.e., kh=kw=1).
///
/// The kind of this post-op is #dnnl_convolution.
///
/// The number of outputs for primitive remain same as before. The output
/// spatial size can be derived as below:
///
/// output_height = ceil(output_height_1x1_convolution, stride)
/// output_width = ceil(output_width_1x1_convolution, stride)
///
/// The Post-op can be defined as:
///
/// dst[:] <- scales * (conv_dw(conv_1x1))
///
/// See @ref dev_guide_attributes_post_ops_depthwise and
/// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info.
///
/// @param post_ops Post-ops.
/// @param weights_data_type Weights data type of depthwise post-op
/// @param bias_data_type Bias data type of depthwise post-op
/// @param dst_data_type Output data type of depthwise post-op
/// @param count Output length of the array of scaling factors @p scales.
/// @param mask Output scaling factors correspondence mask that defines the
/// correspondence between the output tensor dimensions and the @p
/// scales array. The set i-th bit indicates that a dedicated output scaling
/// factor is used for each index along that dimension. The mask value of 0
/// implies a common scaling factor for the whole output tensor.
/// @param scales Output pointer to a constant array of float scaling factors.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise
dnnl_status_t DNNL_API dnnl_post_ops_append_dw_k3s2p1(dnnl_post_ops_t post_ops,
dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type,
dnnl_data_type_t dst_data_type, dnnl_dim_t count, int mask,
const float *scales);
/// Returns the parameters of an depthwise post-op with stride 2.
///
/// @param post_ops Post-ops.
/// @param index Index of the elementwise post-op.
/// @param weights_data_type Weights data type of depthwise post-op
/// @param bias_data_type Bias data type of depthwise post-op
/// @param dst_data_type Output data type of depthwise post-op
/// @param count Output length of the array of scaling factors @p scales.
/// @param mask Output scaling factors correspondence mask that defines the
/// correspondence between the output tensor dimensions and the @p
/// scales array. The set i-th bit indicates that a dedicated output scaling
/// factor is used for each index along that dimension. The mask value of 0
/// implies a common scaling factor for the whole output tensor.
/// @param scales Output pointer to a constant array of float scaling factors.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise
dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw_k3s2p1(
const_dnnl_post_ops_t post_ops, int index,
dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type,
dnnl_data_type_t *dst_data_type, dnnl_dim_t *count, int *mask,
const float **scales);
/// Appends a binary post-op.
///
/// The kind of this post operation is #dnnl_binary.
///
/// In the simplest case when the binary is the only post operation, the
/// computations would be:
///
/// dst[:] <- binary_op (dst[:], another_input[:])
///
/// where binary_op is configured with the given parameters. binary_op supports
/// broadcast semantics for a second operand.
///
/// @param post_ops Post-ops.
/// @param alg_kind Binary algorithm for the post-op.
/// @param src1_desc Memory descriptor of a second operand.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_post_ops_append_binary(dnnl_post_ops_t post_ops,
dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src1_desc);
/// Returns the parameters of a binary post-op.
///
/// @param post_ops Post-ops.
/// @param index Index of the binary post-op.
/// @param alg_kind Output binary algorithm kind.
/// @param src1_desc Output memory descriptor of a second operand.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
/// @returns #dnnl_invalid_arguments if @p index does not refer to a binary
/// post-op.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_binary(
const_dnnl_post_ops_t post_ops, int index, dnnl_alg_kind_t *alg_kind,
const dnnl_memory_desc_t **src1_desc);
/// Appends a prelu forward post-op.
///
/// The kind of this post-op is #dnnl::primitive::kind::prelu.
///
/// The post-op can be defined as:
///
/// dst[:] <- prelu(dst[:], weights[:])
/// prelu:
/// dst[:] <- dst[:] if dst[:] > 0
/// dst[:] <- dst[:] * weights[:] if dst[:] <= 0
///
///
/// @note
/// The order of dimensions does not depend on how elements are laid
/// out in memory. For example:
/// - for a 2D CNN activations tensor the order is always (n, c)
/// - for a 4D CNN activations tensor the order is always (n, c, h, w)
/// - for a 5D CNN weights tensor the order is always
/// (g, oc, ic, kh, kw)
///
/// Prelu weights tensor is passed in runtime execution phase. Prelu
/// weights tensor data type is implicitly assumed as f32 using plain
/// layout (a, ab, acb, acdb, acdeb)
/// @param mask Defines the correspondence between the output tensor
/// dimensions and the prelu weights tensor. The set i-th bit indicates
/// that a dedicated weights value is used for each index along that
/// dimension. Set the mask to 0 to use a common weights value
/// for the whole output tensor.
dnnl_status_t DNNL_API dnnl_post_ops_append_prelu(
dnnl_post_ops_t post_ops, int mask);
/// Returns the parameters of a prelu post-op.
///
/// @param post_ops Post-ops.
/// @param index Index of the preu post-op.
/// @param mask Mask of the prelu post-op.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_prelu(
const_dnnl_post_ops_t post_ops, int index, int *mask);
/// @} dnnl_api_attributes
/// @} dnnl_api_primitives
/// @addtogroup dnnl_api_memory
/// @{
/// Initializes a memory descriptor using dimensions and strides.
///
/// @note
/// As always, the logical order of dimensions corresponds to the `abc...`
/// format tag, and the physical meaning of the dimensions depends on both
/// the primitive that consumes the memory and the context of that
/// consumption.
///
/// @param memory_desc Output memory descriptor.
/// @param ndims Number of dimensions
/// @param dims Array of dimensions.
/// @param data_type Elements data type.
/// @param strides Strides in each dimension.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_memory_desc_init_by_strides(
dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
dnnl_data_type_t data_type, const dnnl_dims_t strides);
/// Initializes a memory descriptor using dimensions and memory format tag.
///
/// @note
/// As always, the logical order of dimensions corresponds to the `abc...`
/// format tag, and the physical meaning of the dimensions depends on both
/// the primitive that consumes the memory and the context of that
/// consumption.
///
/// @param memory_desc Output memory descriptor.
/// @param ndims Number of dimensions
/// @param dims Array of dimensions.
/// @param data_type Elements data type.
/// @param tag Memory format tag. Can be #dnnl_format_tag_any which would
/// allow a primitive to chose the final memory format. In this case the
/// format_kind field of the memory descriptor would be set to
/// #dnnl_format_kind_any.