Updated C wrapper wrt. Torch v1.10#61
Conversation
acb50ff to
6d3200a
Compare
|
What's the status now? |
|
It's been a idle for a while, but status is summarised by these comments: |
6d3200a to
6afc28e
Compare
|
See #54 for status. |
6afc28e to
f407437
Compare
Copied from https://github.com/LaurentMazare/ocaml-torch/tree/a6499811f40282a071179d4306afbbb6023dcc4a/src/gen/gen.ml Also updated dune-project accordingly.
6f18889 to
beb24f2
Compare
41f53f2 to
1051fc8
Compare
A. Non-void, non-* methods.
Search/replace:
1. torch_api.h: ^(?!void)(\w+) (at.+)\( -> int $2($1 *,
2. torch_api.h: , \);$ -> );
3. torch_api.cpp: ^(?!void)(\w+) (at.+)\( -> int $2($1 *out__,
4. torch_api.cpp: , \) \{$ -> ) {
B. void-methods.
Search/replace
1. torch_api.{h,cpp}: ^void (at.+)\( -> int $1(
C. *-methods.
Search/replace
1. torch.api.h: ^(\w+ \*)(at.+)\( -> int $2($1*,
2. torch_api.cpp: ^(\w+ \*)(at.+)\( -> int $2($1*out__,
D. Implemented return status code
Replaced
```
^(\s*) return new (.+)
\s*\)
\s*return nullptr;
```
with
```
$1 out__[0] = new $2
$1 return 0;
$1)
$1return 1;
```
E. Implemented return status code
1. Replaced
```
^(\s*)PROTECT\(return new (.+)\)
\s*return nullptr;
```
with
```
$1PROTECT(
$1 out__[0] = new $2
$1 return 0;
$1)
$1return 1;
```
F. Implemented return status code
1. Replaced
```
^(\s*)PROTECT\(return (.+)\)
```
with
```
$1PROTECT(
$1 out__[0] = $2
$1 return 0;
$1)
```
G. Implemented return status code
1. Replaced
```
^(\s*) return (.+)
\s*\)
\s*return nullptr;
```
with
```
$1 out__[0] = $2
$1 return 0;
$1)
$1return 1;
```
H. Restored error handling
Handled caml_failwith by search/replace: Replaced:
```
$
^(\s*) caml_failwith\((.+)
```
with:
```
{
$1 myerr = strdup($2
$1 return 1;
$1}
```
I. Replaced
```
^(\s+)PROTECT\(
return (.+)
\)
```
with
```
$1PROTECT(
$1 out__[0] = $2
$1 return 0;
$1)
```
J. Manual implement return status code
K. Changed return code from -1 to 1 to reduce diff
L. Fixed a couple of warnings
090f09e to
aa302ad
Compare
Also, made CUDA build optional. int at_empty_cache(); int at_no_grad(int flag); int at_sync(); int at_from_blob(tensor *, void *data, int64_t *dims, int ndims, int64_t *strides, int nstrides, int dev);
Also: * Dev. container: Updated for Torch 1.10.2 * Added /build to .gitignore
0ef32d1 to
6483d63
Compare
Factored out scripts - for re-use in CI and dev. container.
|
Perhaps it would make sense to start merging some of the excellent changes here? |
|
Yes! :-) Please give Edit: I'll try to go over it as well and try to make a re-cap of the changes. |
|
This is the current main diff which covers the hand-written part ( |
|
The overall aim was to to update for Torch v1.10.2 - but also to make it easier to apply a diff of changes for subsequent version updates... |
|
Recap: General
Modified function definitionsint at_float_vec(double *values, int value_len, int type);int at_int_vec(int64_t *values, int value_len, int type);int at_grad_set_enabled(int);int at_int64_value_at_indexes(double *i, tensor, int *indexes, int indexes_len);tensor at_load(char *filename);int ato_adam(optimizer *, double learning_rate,
double beta1,
double beta2,
double weight_decay);int atm_load(char *, module *);Added function definitionsint at_is_sparse(int *, tensor)
int at_device(int *, tensor)int at_stride(tensor, int *)int at_autocast_clear_cache();
int at_autocast_decrement_nesting(int *);
int at_autocast_increment_nesting(int *);
int at_autocast_is_enabled(int *);
int at_autocast_set_enabled(int *, int b);int at_to_string(char **, tensor, int line_size)int at_get_num_threads(int *);
int at_set_num_threads(int n_threads);int ati_none(ivalue *);
int ati_bool(ivalue *, int);
int ati_string(ivalue *, char *);
int ati_tuple(ivalue *, ivalue *, int);
int ati_generic_list(ivalue *, ivalue *, int);
int ati_generic_dict(ivalue *, ivalue *, int);
int ati_int_list(ivalue *, int64_t *, int);
int ati_double_list(ivalue *, double *, int);
int ati_bool_list(ivalue *, char *, int);
int ati_string_list(ivalue *, char **, int);
int ati_tensor_list(ivalue *, tensor *, int);int ati_to_string(char **, ivalue);
int ati_to_bool(int *, ivalue);
int ati_length(int *, ivalue);
int ati_to_generic_list(ivalue, ivalue *, int);
int ati_to_generic_dict(ivalue, ivalue *, int);
int ati_to_int_list(ivalue, int64_t *, int);
int ati_to_double_list(ivalue, double *, int);
int ati_to_bool_list(ivalue, char *, int);
int ati_to_tensor_list(ivalue, tensor *, int); |
|
@DhairyaLGandhi Do you know if anyone is available for reviewing these changes? I'm at JuliaCon, FYI. |
Updates the C wrapper based on ocaml-torch @ 0.14 - matching Torch v1.10 (current JLL-build)
Contributes to #54 - follow-up for #56
Notable included changes:
torch_api.{cpp, h}buildkite stepsGitHub Actions workflow for building C wrapperThe last two changes could be moved to a separate PR (to reduce number of changes in this PR).To-do:
torch_api.{cpp, h}Changed torch_api.cpp to reduce diffshould be removed before merging: It is meant to reduce the diff when reviewing - it's only a bunch of indentation changes etc. to make the diff smaller.