From 1f9fb762e11cd8f55b82beeb7d8f4ca7aa85de75 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Tue, 21 Jan 2025 10:02:09 -0800 Subject: [PATCH] updating docs --- mkdocs.yml | 87 ++++++++++++++++++++++++------------------ pyproject.toml | 9 +++++ torch2jax/gradients.py | 2 + 3 files changed, 61 insertions(+), 37 deletions(-) diff --git a/mkdocs.yml b/mkdocs.yml index 1e4670e..89923ba 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,55 +1,68 @@ -site_name: My Docs +site_name: torch2jax +site_description: The documentation for the torch2jax package. +site_author: Robert Dyro +site_url: https://rdyro.github.io/torch2jax/ +repo_url: https://github.com/rdyro/torch2jax/ +repo_name: rdyro/torch2jax + theme: - name: "material" - logo: "img/favicon.svg" - favicon: "img/favicon.svg" + name: material + logo: img/favicon.svg + favicon: img/favicon.svg palette: - - primary: white + - scheme: default + primary: white + toggle: + icon: material/weather-night + name: Switch to dark mode + - scheme: slate + primary: white + toggle: + icon: material/weather-sunny + name: Switch to light mode features: - - navigation.sections # Sections are included in the navigation on the left. + - navigation.sections - toc.integrate - search.suggest - search.highlight - -site_name: torch2jax -site_description: The documentation for the torch2jax package. -site_author: Robert Dyro -site_url: https://rdyro.github.io/torch2jax/ - -repo_url: https://github.com/rdyro/torch2jax/ -repo_name: rdyro/torch2jax + - content.code.copy plugins: - search - autorefs - mkdocstrings: handlers: - python: - options: - inherited_members: true - show_root_heading: true - show_if_no_docstring: true - show_signature_annotations: false - heading_level: 0 - members_order: source + python: + options: + inherited_members: true + show_root_heading: true + show_if_no_docstring: true + show_signature_annotations: false + heading_level: 0 + members_order: source markdown_extensions: - - pymdownx.highlight + - pymdownx.highlight: + anchor_linenums: true + line_spans: __span + pygments_lang_class: true + - pymdownx.inlinehilite + - pymdownx.snippets - pymdownx.superfences - mdx_truly_sane_lists - - pymdownx.tasklist - + - pymdownx.tasklist: + custom_checkbox: true nav: - - Overview: 'index.md' - - Installation: 'installation.md' - - Roadmap: 'roadmap.md' - - Changelog: 'changelog.md' - - Examples: - - BERT: 'examples/bert_example.md' - - ResNet: 'examples/resnet_example.md' - - Data Parallel: 'examples/data_parallel.md' - - Full API: - - torch2jax: 'api/torch2jax.md' - - torch2jax_with_vjp: 'api/torch2jax_with_vjp.md' - - utils: 'api/utils.md' + - Overview: index.md + - Installation: installation.md + - Roadmap: roadmap.md + - Changelog: changelog.md + - Examples: + - BERT: examples/bert_example.md + - ResNet: examples/resnet_example.md + - Data Parallel: examples/data_parallel.md + - Full API: + - torch2jax: api/torch2jax.md + - torch2jax_with_vjp: api/torch2jax_with_vjp.md + - utils: api/utils.md diff --git a/pyproject.toml b/pyproject.toml index 71aed69..a781bb5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,15 @@ dependencies = [ "torch", ] +[project.optional-dependencies] +docs = [ + "mkdocs", + "mkdocs-material", + "mkdocstrings[python]", + "pymdown-extensions", + "mdx_truly_sane_lists", +] + [tool.setuptools.package-data] torch2jax = ["cpp/*"] diff --git a/torch2jax/gradients.py b/torch2jax/gradients.py index 3ef6e54..703287f 100644 --- a/torch2jax/gradients.py +++ b/torch2jax/gradients.py @@ -43,6 +43,8 @@ def torch2jax_with_vjp( use_torch_vjp (bool, optional): (Not supported, please use inside `shard_map`) Whether to use custom vjp or the one from torch. False means fallback to `torch.autograd.grad` for more compatibility. Some older external library PyTorch code may need this fallback. Defaults to True (i.e., do not use fallback). + output_sharding_spec: (not supported) sharding spec of the output, use shard_map instead for a device-local + version of this function Returns: Callable: JIT-compatible JAX version of the torch function (VJP defined up to depth `depth`).