Skip to content

Add Type Checking in Forward Mode#808

Merged
sunxd3 merged 21 commits intomainfrom
sunxd/forward_type_check
Dec 2, 2025
Merged

Add Type Checking in Forward Mode#808
sunxd3 merged 21 commits intomainfrom
sunxd/forward_type_check

Conversation

@sunxd3
Copy link
Collaborator

@sunxd3 sunxd3 commented Oct 17, 2025

start to address #672

mostly conceptually copy-paste

@github-actions
Copy link
Contributor

Mooncake.jl documentation for PR #808 is available at:
https://chalk-lab.github.io/Mooncake.jl/previews/PR808/

@github-actions
Copy link
Contributor

github-actions bot commented Oct 17, 2025

Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.

┌────────────────────────────┬──────────┬──────────┬─────────────┬─────────┬─────────────┬────────┐
│                      Label │   Primal │ Mooncake │ MooncakeFwd │  Zygote │ ReverseDiff │ Enzyme │
│                     String │   String │   String │      String │  String │      String │ String │
├────────────────────────────┼──────────┼──────────┼─────────────┼─────────┼─────────────┼────────┤
│                   sum_1000 │ 100.0 ns │      1.8 │         1.9 │     1.1 │        5.61 │   8.21 │
│                  _sum_1000 │ 941.0 ns │     6.62 │        1.01 │  1460.0 │        34.4 │   1.09 │
│               sum_sin_1000 │  6.55 μs │     2.54 │        1.39 │    1.68 │        10.8 │   2.21 │
│              _sum_sin_1000 │  5.23 μs │     2.96 │        2.18 │   259.0 │        13.1 │   2.47 │
│                   kron_sum │ 259.0 μs │     50.2 │        3.26 │    6.15 │       254.0 │   15.0 │
│              kron_view_sum │ 315.0 μs │     43.3 │        3.56 │    11.3 │       245.0 │   6.99 │
│      naive_map_sin_cos_exp │  2.15 μs │     2.32 │         1.4 │ missing │        7.09 │   2.35 │
│            map_sin_cos_exp │  2.11 μs │      2.7 │        1.46 │    1.58 │        6.13 │   2.94 │
│      broadcast_sin_cos_exp │  2.28 μs │     2.38 │        1.39 │    2.28 │        1.46 │   2.25 │
│                 simple_mlp │ 200.0 μs │      6.2 │         2.9 │     1.8 │        10.8 │   3.37 │
│                     gp_lml │ 244.0 μs │      8.5 │        2.14 │    3.72 │     missing │   5.04 │
│ turing_broadcast_benchmark │  1.76 ms │     5.01 │        3.46 │ missing │        28.3 │   2.72 │
│         large_single_block │ 380.0 ns │     4.53 │        2.03 │  4560.0 │        32.5 │   2.24 │
└────────────────────────────┴──────────┴──────────┴─────────────┴─────────┴─────────────┴────────┘

@codecov
Copy link

codecov bot commented Oct 17, 2025

Codecov Report

❌ Patch coverage is 3.57143% with 27 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/debug_mode.jl 0.00% 27 Missing ⚠️

📢 Thoughts on this report? Let us know!

@sunxd3 sunxd3 marked this pull request as ready for review October 20, 2025 12:05
@yebai
Copy link
Member

yebai commented Oct 20, 2025

Thanks, @sunxd3 — I’ll take a look at this soon. In the meantime, could you check whether the debugging mode helps us identify the cause of #777 / #632 (comment)

@yebai yebai requested a review from Copilot October 20, 2025 16:40
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds type checking support for forward-mode automatic differentiation (AD) in debug mode, addressing issue #672. It implements DebugFRule as a forward-mode counterpart to the existing DebugRRule, providing runtime validation of tangent types in forward-mode AD operations.

Key changes:

  • Implementation of DebugFRule struct and validation logic for forward-mode AD
  • Comprehensive test suite covering various input types and error conditions
  • Integration with existing test infrastructure

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
src/debug_mode.jl Implements DebugFRule with type checking for Dual inputs/outputs and structural validation
src/test_utils.jl Updates test_rule to use DebugFRule wrapper when debug_mode=true for forward-mode primitives
test/debug_mode.jl Adds comprehensive test suite for forward-mode debug functionality covering valid inputs, type mismatches, and integration tests

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@sunxd3
Copy link
Collaborator Author

sunxd3 commented Nov 21, 2025

thanks Bruno

@Technici4n
Copy link
Collaborator

I wonder why this segfaults on 1.10 😓

@sunxd3
Copy link
Collaborator Author

sunxd3 commented Nov 27, 2025

@testset "scalar type mismatch detected" begin
rule = Mooncake.build_frule(zero_dual(identity), 1.0; debug_mode=true)
@test_throws ErrorException rule(
zero_dual(identity), Mooncake.Dual(1.0, Float32(1.0))
)
end
triggered the segfault.

I think they related to JuliaLang/julia#51016. Essentially, when DebugFRule receives a Dual{Float64, Float32} but the inner rule was compiled for Dual{Float64, Float64}, the compiler segfaults while trying to specialize the call.

Not certain if Base.inferencebarrier is the best solution here, but it seems to prevent the crash.

@Technici4n
Copy link
Collaborator

Is this maybe caused by the lack of verify_args?

@sunxd3
Copy link
Collaborator Author

sunxd3 commented Nov 27, 2025

I think adding verify_args alone wouldn't solve it -- it won't prevent the Julia compiler from trying to do the type inference.

I added verify_args similar to the reverse_mode one, but also had to wrap the whole debug rule function in generated block so that we can properly intercept the error.

@static if VERSION < v"1.11-"
y = Base.inferencebarrier(rule.rule)(x...)
else
@static if VERSION < v"1.11-"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is not needed for reverse debug mode, I am a bit skeptical that it's required? 😄

Copy link
Collaborator Author

@sunxd3 sunxd3 Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's tricky, reverse debug mode was not testing the case where the segfault appear for the forward debug mode

# Unless we explicitly check that the arguments are of the type as expected by the rule,
# this will segfault.
@testset "argument checking" begin
f = x -> 5x
rule = build_rrule(f, 5.0; debug_mode=true)
@test_throws ErrorException rule(zero_fcodual(f), CoDual(0.0f0, 1.0f0))
end

but e.g., same segfault will happen with

julia +lts --project -e '
      using Mooncake
      using Mooncake: CoDual, zero_fcodual, build_rrule

      f = x -> 5x
      rule = build_rrule(f, 5.0; debug_mode=true)

      rule(zero_fcodual(f), CoDual(5.0, 1.0))
      '

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. Funny how simple PRs turn into a rabbit hole sometimes... 😅

@yebai yebai requested a review from Technici4n November 28, 2025 12:44
@sunxd3
Copy link
Collaborator Author

sunxd3 commented Nov 28, 2025

Ideally, the coding style between forward and backward debug mode (tests too) would match better, but I don't want to make more style only changes now

@yebai yebai requested a review from penelopeysm December 1, 2025 09:56
@yebai
Copy link
Member

yebai commented Dec 1, 2025

@penelopeysm, can you help take a look?

Copy link
Collaborator

@Technici4n Technici4n left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

end
end
else
@noinline function (rule::DebugFRule)(x::Vararg{Dual,N}) where {N}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: inline?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am following DebugRRule (line 220) here and I think noinline makes more sense to me just from that we want this function to fail and correctly attribute to this definition. Frankly, I am not sure whether the inline hint is doing anything.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, either is fine by me

@yebai
Copy link
Member

yebai commented Dec 1, 2025

thanks @sunxd3 and @Technici4n -- please feel free to merge and make new releases!

@sunxd3 sunxd3 merged commit 4071f32 into main Dec 2, 2025
132 of 133 checks passed
@sunxd3 sunxd3 deleted the sunxd/forward_type_check branch December 2, 2025 04:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Comments