Skip to content

Commit

Permalink
updating docs
Browse files Browse the repository at this point in the history
  • Loading branch information
rdyro committed Jan 21, 2025
1 parent 571f018 commit 1f9fb76
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 37 deletions.
87 changes: 50 additions & 37 deletions mkdocs.yml
Original file line number Diff line number Diff line change
@@ -1,55 +1,68 @@
site_name: My Docs
site_name: torch2jax
site_description: The documentation for the torch2jax package.
site_author: Robert Dyro
site_url: https://rdyro.github.io/torch2jax/
repo_url: https://github.com/rdyro/torch2jax/
repo_name: rdyro/torch2jax

theme:
name: "material"
logo: "img/favicon.svg"
favicon: "img/favicon.svg"
name: material
logo: img/favicon.svg
favicon: img/favicon.svg
palette:
- primary: white
- scheme: default
primary: white
toggle:
icon: material/weather-night
name: Switch to dark mode
- scheme: slate
primary: white
toggle:
icon: material/weather-sunny
name: Switch to light mode
features:
- navigation.sections # Sections are included in the navigation on the left.
- navigation.sections
- toc.integrate
- search.suggest
- search.highlight

site_name: torch2jax
site_description: The documentation for the torch2jax package.
site_author: Robert Dyro
site_url: https://rdyro.github.io/torch2jax/

repo_url: https://github.com/rdyro/torch2jax/
repo_name: rdyro/torch2jax
- content.code.copy

plugins:
- search
- autorefs
- mkdocstrings:
handlers:
python:
options:
inherited_members: true
show_root_heading: true
show_if_no_docstring: true
show_signature_annotations: false
heading_level: 0
members_order: source
python:
options:
inherited_members: true
show_root_heading: true
show_if_no_docstring: true
show_signature_annotations: false
heading_level: 0
members_order: source

markdown_extensions:
- pymdownx.highlight
- pymdownx.highlight:
anchor_linenums: true
line_spans: __span
pygments_lang_class: true
- pymdownx.inlinehilite
- pymdownx.snippets
- pymdownx.superfences
- mdx_truly_sane_lists
- pymdownx.tasklist

- pymdownx.tasklist:
custom_checkbox: true

nav:
- Overview: 'index.md'
- Installation: 'installation.md'
- Roadmap: 'roadmap.md'
- Changelog: 'changelog.md'
- Examples:
- BERT: 'examples/bert_example.md'
- ResNet: 'examples/resnet_example.md'
- Data Parallel: 'examples/data_parallel.md'
- Full API:
- torch2jax: 'api/torch2jax.md'
- torch2jax_with_vjp: 'api/torch2jax_with_vjp.md'
- utils: 'api/utils.md'
- Overview: index.md
- Installation: installation.md
- Roadmap: roadmap.md
- Changelog: changelog.md
- Examples:
- BERT: examples/bert_example.md
- ResNet: examples/resnet_example.md
- Data Parallel: examples/data_parallel.md
- Full API:
- torch2jax: api/torch2jax.md
- torch2jax_with_vjp: api/torch2jax_with_vjp.md
- utils: api/utils.md
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ dependencies = [
"torch",
]

[project.optional-dependencies]
docs = [
"mkdocs",
"mkdocs-material",
"mkdocstrings[python]",
"pymdown-extensions",
"mdx_truly_sane_lists",
]

[tool.setuptools.package-data]
torch2jax = ["cpp/*"]

Expand Down
2 changes: 2 additions & 0 deletions torch2jax/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def torch2jax_with_vjp(
use_torch_vjp (bool, optional): (Not supported, please use inside `shard_map`) Whether to use custom vjp or the
one from torch. False means fallback to `torch.autograd.grad` for more compatibility. Some older external
library PyTorch code may need this fallback. Defaults to True (i.e., do not use fallback).
output_sharding_spec: (not supported) sharding spec of the output, use shard_map instead for a device-local
version of this function
Returns:
Callable: JIT-compatible JAX version of the torch function (VJP defined up to depth `depth`).
Expand Down

0 comments on commit 1f9fb76

Please sign in to comment.