Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 10 additions & 24 deletions test/ext_reactant/reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,21 @@ end

(Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), randn(Float32, 2), "Flux.Scale"),

# all arguments must have at least the same length of the firs one
# a = (Conv((3, 3), 2 => 3),)
# b = ((σ = nothing, weight = Float32[-0.169722 -0.12912463 0.026297366; -0.08920034 -0.11879107 -0.30971745; -0.11957143 0.3129449 0.32124594;;; 0.011128465 0.12124362 0.096895896; -0.29864514 -0.053307496 0.055420622; -0.30712044 0.2959723 0.5099815;;;; -0.169722 -0.12912463 0.026297366; -0.08920034 -0.11879107 -0.30971745; -0.11957143 0.3129449 0.32124594;;; 0.011128465 0.12124362 0.096895896; -0.29864514 -0.053307496 0.055420622; -0.30712044 0.2959723 0.5099815;;;; -0.169722 -0.12912463 0.026297366; -0.08920034 -0.11879107 -0.30971745; -0.11957143 0.3129449 0.32124594;;; 0.011128465 0.12124362 0.096895896; -0.29864514 -0.053307496 0.055420622; -0.30712044 0.2959723 0.5099815], bias = Float32[0.33333334, 0.33333334, 0.33333334], stride = nothing, pad = nothing, dilation = nothing, groups = nothing),)
# (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"),

# all arguments must have at least the same length of the firs one
# a = (Chain(Conv((3, 3), 2 => 3), Conv((3, 3), 3 => 1, tanh)),)
# b = ((layers = ((σ = nothing, weight = Float32[0.2703631 0.15815677 0.2918554; 0.20036785 0.43450722 0.3525422; 0.3541182 0.32077286 0.44091386;;; 0.3233156 0.08538988 0.25763267; 0.413441 0.66042584 0.16991; 0.36993486 0.5990643 0.10123589;;;; 0.45728725 0.500834 0.46808332; 0.3662355 0.35068494 0.27277413; 0.44974697 0.47245422 0.10595817;;; 0.36255562 0.6111583 0.52779496; 0.27237993 0.25857046 0.33643073; 0.6679214 0.066386 0.32072845;;;; -0.4879305 -0.59246373 -0.59834677; -0.55097836 -0.5006755 -0.4233263; -0.72177917 -0.65806544 -0.38224664;;; -0.4765812 -0.6856963 -0.5864509; -0.6547631 -0.55094117 -0.38632843; -0.74521375 -0.3817107 -0.48642716], bias = Float32[0.7159346, 0.7152501, -1.0509125], stride = nothing, pad = nothing, dilation = nothing, groups = nothing), (σ = nothing, weight = Float32[0.32858944 -0.10135343 -0.25303265; -0.13622479 0.023095237 0.1746222; 0.18829267 -0.5047879 0.07125988;;; 0.023820637 -0.06595295 -0.003393827; -0.111125976 0.0023178488 0.08700531; -0.073591515 0.057915907 0.048598815;;; 0.016056929 -0.5129501 -0.15588683; -0.3756476 -0.09993523 -0.45654622; -0.3688693 -0.33078116 -0.4093926;;;;], bias = Float32[0.77964276], stride = nothing, pad = nothing, dilation = nothing, groups = nothing)),),)
# (Chain(Conv((3, 3), 2 => 3, ), Conv((3, 3), 3 => 1, tanh)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"),

# https://github.com/EnzymeAD/Enzyme-JAX/issues/221
# (Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"),
(Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"),

(Chain(Conv((3, 3), 2 => 3, ), Conv((3, 3), 3 => 1, tanh)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"),

(Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"),

(Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"),

# error: 'stablehlo.multiply' op requires compatible types for all operands and results
# This requires an issue to be opened.
# (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"),
(SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"),

(Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"),

# error: inferred shape '[1, 3, 9, 9]' is incompatible with return type of operation 'tensor<1x3x5x5xf32>'
# (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"),
(ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"),

# (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"), # Apparent correctness issue
(BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"),
]

for (model, x, name) in models_xs
Expand All @@ -51,13 +41,9 @@ end
end

models_xs = [
# %23 = "stablehlo.gather"(%22, %0) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0, 1], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1>}> : (tensor<2x10xf32>, tensor<1x2xi64>) -> tensor<1x1xf32>
# (first ∘ LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"), # Zygote comparison test fails on the GPUArraysCore.@allowscalar in scalarfirst, so we globally allow scalar
(LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"), # Zygote comparison test fails on the GPUArraysCore.@allowscalar in scalarfirst, so we globally allow scalar

# Structural mismatch?
# a = (first ∘ MultiHeadAttention(16; nheads=8),)
# b = ((outer = nothing, inner = (nheads = nothing, q_proj = (weight = Float32[6.698509f-5 -0.0022470958 -0.0008736515 -0.0005296775 0.0015831951 0.00042026408 0.00017288685 -0.0014281741 0.002103838 0.0010096495 0.000121268895 -0.0014571806 0.0009639706 0.0007765734 0.0012100701 -0.0013474268; -0.0056395493 -0.00032762758 0.00046284366 -0.0025110473 -0.004034199 0.003250119 -0.008380175 0.004710109 0.003990736 0.00057849655 -0.0012971399 0.0019155457 0.0015314532 -0.0068585267 0.0016599876 -0.004347213; 0.0020082465 0.0012238229 -0.00056229153 -0.0001504221 0.0013426116 -0.0010430823 0.0022235492 -0.0009865433 -0.0018633157 -0.00044081628 0.0006363774 -0.00159968 -0.0014778266 0.002672868 -0.0004847405 0.0016751911; 0.0012528691 -0.0011000141 8.92279f-6 0.00068768614 0.00040906167 -0.00095790887 0.00010065457 -0.00065050577 0.00044544233 0.00051000295 0.000559349 -0.0010677502 -0.0005208101 0.0017521719 -9.7433345f-5 0.00096242613; -0.002746859 -0.006603405 0.0037514854 -0.00030500942 0.0031205479 -0.0005048999 -0.0017187177 0.0033558803 0.00045646066 -0.00036418485 -0.0026161629 -0.0035960171 -0.0025263052 -0.0077548814 -0.0051170154 8.457979f-5; 9.377861f-5 -0.003973993 0.0009713304 -0.0011969728 0.001513645 -0.0006679229 0.001459552 0.00081135886 -0.00059153314 -8.971001f-5 0.00021119814 -0.002439468 -0.00067647133 0.0002226921 -0.00066728727 0.00044864047; -0.00056112965 -0.0027489793 0.0004568826 -0.0014447274 -0.00020321325 0.0007325389 -0.0012505992 0.0013248054 0.0012759498 0.001128925 -1.4844116f-5 -9.816911f-5 -0.00087159855 -0.00011753605 -3.07296f-5 -0.0006746754; 0.0029512886 -0.0020365466 -0.0035659648 0.0018958003 -0.0001820854 -0.0010186834 0.0034461636 -0.0015790367 -0.0008278372 0.0019356965 0.0018923904 -0.0037499343 -0.004022793 0.0053548287 0.0005814631 0.0034943004; -0.002842697 -0.00088533934 -0.0010996269 0.0013907065 -0.0026552787 0.0033542977 0.002812051 -0.0032007345 -0.00019967918 0.00457535 -0.0028456086 0.0018094043 0.0020170186 0.0066173486 0.0034713345 -0.000893154; 0.0013207783 -0.003873941 0.0020811646 -4.9713576f-6 -0.00023950211 -0.001001733 -0.0040789447 -0.0014036176 0.007979447 0.00076982914 0.0032535167 0.00043590157 -4.3034237f-5 0.003959387 0.00057714456 -0.0003050054; 0.001119989 3.8554597f-5 -0.0005721566 0.00015138248 0.0003500976 -0.00064643787 0.0013506713 -0.00030026288 -0.0012833525 0.00036922994 0.000573249 -0.00048966246 -0.0009211753 0.0005839997 -0.0005077214 0.0006908914; -0.0030337535 0.00052082416 0.0010257302 -0.001478913 0.0002373675 0.0023485282 -0.0028962293 0.00060301815 0.0013557511 -0.0013943826 -0.0009102357 0.0016787873 0.0023580615 -0.0013070818 0.0011929763 -0.0017277088; -0.00092684606 -0.0077168294 0.00828396 0.0011758439 -0.0005089047 -0.0017280795 0.0007435059 0.0015110105 0.005371913 -0.0028669743 0.000875769 0.002308998 0.0042599128 0.001395078 -0.0017776758 -0.0013516224; 0.00039975057 -0.0012942385 -0.00014665103 -0.00033949892 -0.0013017454 0.0004817717 -0.00163979 -9.1762435f-5 0.0010266375 0.00039889078 0.0010956797 -0.00023120346 0.0012584099 0.0022373656 0.00086237746 -0.00015442599; -6.5123044f-5 -0.00026330378 -0.00043207427 8.621037f-5 -0.00037062084 0.00031112006 -0.00011293912 0.00050126563 -5.7062964f-5 0.0005467691 -0.0004273539 -0.00040489298 -0.00011429484 -0.0007625115 0.0007899677 -0.00016026688; -0.00069569284 0.0011769373 -0.0005589303 9.382861f-5 -0.0006318726 0.0004896153 0.0002667491 0.0007188753 -0.00041698967 6.143081f-5 -0.0001454085 0.00070576224 0.00019093762 -0.00088875054 -0.00022447937 -0.00013288778], bias = nothing, σ = nothing), k_proj = (weight = Float32[-0.006356695 0.006594197 -0.0015717247 -0.022280285 0.012346115 0.023102699 0.0047403295 -0.0069266427 -0.003254382 -0.005864935 0.0012176007 0.0042949007 0.006377186 0.008844895 0.00053933944 0.0015426228; 0.0017554002 0.006351081 -0.0031089287 0.0021859806 -0.0036134333 -0.0023656015 -0.0038744302 0.0008509146 0.006329448 0.00025891708 0.0030573597 -0.0015745768 -0.0023788416 0.0035228394 -0.0075742626 0.0051124166; 0.00091151806 0.00085750845 -0.0008750014 0.00079839904 0.00021681897 -0.0013127993 -0.00056082703 -2.6382602f-5 0.0011000173 -0.0012440274 -0.00020858315 -0.0015890934 -0.0017463376 -0.00036126093 0.00029166706 0.00027645877; -0.0005505805 -0.0005563506 0.0017027665 -0.001080131 0.0007292733 -0.0005962233 7.001866f-6 -0.0007215422 -0.00069050776 -0.0003195445 -0.0004750489 -0.00089392485 0.00077004 0.0006037653 -0.00033059713 -0.0006749068; -0.0004529958 -0.0052443813 0.0030781312 0.0071990876 0.0069423835 -0.002242495 0.0007727927 -0.003962856 0.0053477697 -0.0043114163 0.0046237335 0.0037209252 3.0608866f-5 0.00022473601 -0.005674615 0.001927928; 0.0014829627 -0.0077749034 0.005445325 0.0143755 -0.0012763313 -0.006293746 0.0067259604 -0.00050140906 -0.0004319277 -0.0037295015 0.005676918 -0.0018777495 -0.005927857 -0.0001434157 -0.010756415 0.003807423; -0.00018013112 0.0018464699 -0.00059436745 0.00031048636 0.00023857887 0.00024704775 0.0013714221 0.0020679014 0.0016252389 0.0004977781 -0.0011249675 -0.0017270674 -0.0012984275 -0.0005885963 0.0009813632 -0.00048815395; 0.0027159492 -0.00048827843 0.0020531586 -0.0054894765 0.0028238844 -0.00074551447 -0.0009845659 -0.00047163272 -0.00587537 0.0006710871 -0.00066500314 -0.0034161597 0.0018020959 0.0009858431 -0.003217934 -0.0010759861; -0.010235805 -0.015958961 0.0008485664 0.005028783 0.011771609 -0.0027914552 0.0021088992 -0.008706864 0.010346566 -0.0069462974 0.007557606 0.0009011793 0.010105434 0.009638056 -0.00057110406 0.0004567712; 0.0053732386 0.013016653 -0.0013767531 -0.009631875 -0.00032666832 0.004662013 -0.0057470086 -0.0069494033 -0.0026918105 0.007570288 -0.006892863 0.009569347 0.0037581015 0.004090113 0.007595024 -0.0048846523; 8.890491f-5 0.001371811 -0.0012469768 9.911141f-5 -0.0026242672 -0.0015197316 -0.0016033922 0.0019310054 0.00067747175 0.0035100468 0.0016518324 0.004094913 -0.0016921153 -0.0020144857 0.0013243607 -0.00044571012; 0.0015558703 0.0008978864 0.0010150168 -0.0009690616 -0.0009817334 -0.0009617688 0.0006837076 0.0014497713 -0.0012777236 -3.2045325f-5 -0.000682903 -0.0013235736 0.0012849444 -0.00017924164 0.00090924697 -0.000503184; 0.008791277 0.0005746033 -0.0079812175 -0.0011261118 -0.0009937619 -0.0010393955 -0.014725223 -0.008987852 -0.0050395797 0.009706033 0.0038135073 0.0047515035 0.0112827085 0.005382823 -0.004363266 -0.004263296; 0.013926888 -0.009588276 -0.0043002265 -0.015295094 -0.009948157 -0.008211927 -0.012872131 -0.0021384154 0.0073758634 0.017202897 -0.0030166083 -0.008871876 -0.0003680687 -0.002667496 0.0031453061 -0.015538573; 0.0015023113 0.00089796283 -0.0016161522 0.0026307243 0.0005870748 -0.00076024065 -0.0016193525 0.0001049305 -0.0006210599 -0.00083987165 0.0011756913 0.001613532 0.0015301305 0.00088819914 0.00082088343 0.0022942058; 0.00037856336 -3.0725616f-5 -0.0011789649 -0.0007305796 0.001611944 -0.00038102557 -9.7439326f-5 -0.0027701438 -0.000904333 -0.0011518413 0.0009823457 -0.0020134987 -0.00035934968 0.0006067025 -0.00025376517 -0.00022070059], bias = nothing, σ = nothing), v_proj = (weight = Float32[-0.013587591 0.0021811177 0.0010767771 -0.005041313 -0.0071066353 0.0125319455 -0.01503459 0.010944039 0.0040554493 -0.0017099766 -0.0025058363 0.006463745 0.007730677 -0.012512911 -0.0003947781 -0.007308612; -0.035211265 0.0056522083 0.002790391 -0.013064205 -0.018416336 0.032475628 -0.03896106 0.028360687 0.010509401 -0.0044312854 -0.006493693 0.01675033 0.020033496 -0.032426313 -0.0010230421 -0.018939745; 0.0072788447 0.0006080089 -0.0031948441 0.0022210581 0.0033583418 -0.0061120186 0.0056183804 -0.004731141 -0.0023898373 -0.0014288637 0.0026998194 -0.0060799946 -0.0053009554 0.0063893115 -0.001436585 0.0041905832; -0.0065591894 -0.00054789596 0.0028789737 -0.0020014634 -0.003026305 0.0055077276 -0.0050628963 0.004263376 0.002153556 0.0012875927 -0.0024328907 0.0054788697 0.004776853 -0.0057576043 0.0012945511 -0.0037762632; 0.018878516 0.0020030173 -0.008106283 0.0068584397 0.012459779 -0.011597034 0.020603744 -0.018353807 -0.014817484 0.002930419 0.0030872847 -0.009393913 -0.011461486 0.022462018 0.0023007696 0.011095487; -0.005244327 -0.0005564252 0.0022518726 -0.0019052284 -0.0034612457 0.00322158 -0.005723584 0.0050985673 0.004116201 -0.0008140509 -0.0008576271 0.0026095668 0.003183926 -0.0062398 -0.00063913903 -0.003082253; -0.0055147605 0.00061145355 0.0010943866 -0.005482203 -0.004378011 0.00501587 -0.008413549 0.0053443126 0.0034800244 0.0009773258 -0.0027125343 0.0018200793 0.0030731459 -0.005387559 0.0009186775 -0.0043234765; 0.011128688 -0.0012339031 -0.0022084522 0.011062993 0.008834751 -0.010121938 0.016978398 -0.010784733 -0.0070226304 -0.0019722276 0.005473847 -0.0036728885 -0.0062015536 0.010872002 -0.0018538754 0.008724701; -0.037833333 -0.0072354516 0.012748811 -0.022030998 -0.00609467 0.025423791 -0.029771697 0.021657798 0.01924544 -0.003315973 -0.009235187 0.0063374955 0.024534397 -0.028743971 0.00036492813 -0.023139128; -0.005001701 -0.00095655216 0.0016854381 -0.0029125751 -0.0008057365 0.0033611152 -0.0039359233 0.002863237 0.0025443155 -0.0004383828 -0.0012209245 0.00083783915 0.0032435346 -0.0038000545 4.8244667f-5 -0.0030590747; -0.005808082 0.0016860168 0.0031004879 -0.0043447237 -0.0031849432 0.0054290188 -0.009925432 0.004378885 0.005713183 -0.00037632295 -0.0038249025 0.0032741027 0.0064315083 -0.0065353923 0.0010171096 -0.0062728394; 0.0084031 -0.002439319 -0.0044857683 0.0062859254 0.0046079587 -0.007854676 0.014360062 -0.0063353474 -0.008265803 0.0005444638 0.005533848 -0.0047369534 -0.00930507 0.009455371 -0.0014715494 0.00907551; 0.0042213206 -0.00050243916 -0.0014028315 0.004047436 0.005755574 -0.0023939763 0.0076993443 -0.005163418 -0.00347788 0.00012065181 0.0018205052 -0.0019679202 -0.0038561192 0.007557435 0.000120696626 0.0045588817; 0.014316402 -0.0017039992 -0.0047576344 0.013726682 0.019519746 -0.008119053 0.026111946 -0.017511478 -0.011795062 0.0004091852 0.006174153 -0.0066741006 -0.013077837 0.025630666 0.00040933752 0.015461222; 0.007062415 0.0011600928 0.00010689935 0.0033846078 0.004499383 -0.005305766 0.0044237543 -0.005939392 -0.0036483253 0.0004251188 0.0007405997 -0.002122565 -0.0037322687 0.005132367 -0.0025776464 0.0037362974; 0.003025119 0.0004969145 4.5789613f-5 0.0014497649 0.0019272682 -0.0022726753 0.0018948735 -0.002544083 -0.0015627261 0.00018209587 0.00031722884 -0.0009091807 -0.0015986823 0.0021984011 -0.0011041105 0.0016004075], bias = nothing, σ = nothing), attn_drop = nothing, out_proj = (weight = Float32[0.0046252874 0.0047164033 0.006615041 0.0087485025 0.0036523344 0.0015425641 -0.0025191456 -0.0049306857 0.004231052 0.0056748213 0.013629846 -0.004974689 0.005989329 0.0102740945 -0.009758169 -0.013934809; 0.0046252874 0.0047164033 0.006615041 0.0087485025 0.0036523344 0.0015425641 -0.0025191456 -0.0049306857 0.004231052 0.0056748213 0.013629846 -0.004974689 0.005989329 0.0102740945 -0.009758169 -0.013934809; 0.0046252874 0.0047164033 0.006615041 0.0087485025 0.0036523344 0.0015425641 -0.0025191456 -0.0049306857 0.004231052 0.0056748213 0.013629846 -0.004974689 0.005989329 0.0102740945 -0.009758169 -0.013934809; 0.0046252874 0.0047164033 0.006615041 0.0087485025 0.0036523344 0.0015425641 -0.0025191456 -0.0049306857 0.004231052 0.0056748213 0.013629846 -0.004974689 0.005989329 0.0102740945 -0.009758169 -0.013934809; 0.0046252874 0.0047164033 0.006615041 0.0087485025 0.0036523344 0.0015425641 -0.0025191456 -0.0049306857 0.004231052 0.0056748213 0.013629846 -0.004974689 0.005989329 0.0102740945 -0.009758169 -0.013934809; 0.0046252874 0.0047164033 0.006615041 0.0087485025 0.0036523344 0.0015425641 -0.0025191456 -0.0049306857 0.004231052 0.0056748213 0.013629846 -0.004974689 0.005989329 0.0102740945 -0.009758169 -0.013934809; 0.0046252874 0.0047164033 0.006615041 0.0087485025 0.0036523344 0.0015425641 -0.0025191456 -0.0049306857 0.004231052 0.0056748213 0.013629846 -0.004974689 0.005989329 0.0102740945 -0.009758169 -0.013934809; 0.0046252874 0.0047164033 0.006615041 0.0087485025 0.0036523344 0.0015425641 -0.0025191456 -0.0049306857 0.004231052 0.0056748213 0.013629846 -0.004974689 0.005989329 0.0102740945 -0.009758169 -0.013934809; 0.0046252874 0.0047164033 0.006615041 0.0087485025 0.0036523344 0.0015425641 -0.0025191456 -0.0049306857 0.004231052 0.0056748213 0.013629846 -0.004974689 0.005989329 0.0102740945 -0.009758169 -0.013934809; 0.0046252874 0.0047164033 0.006615041 0.0087485025 0.0036523344 0.0015425641 -0.0025191456 -0.0049306857 0.004231052 0.0056748213 0.013629846 -0.004974689 0.005989329 0.0102740945 -0.009758169 -0.013934809; 0.0046252874 0.0047164033 0.006615041 0.0087485025 0.0036523344 0.0015425641 -0.0025191456 -0.0049306857 0.004231052 0.0056748213 0.013629846 -0.004974689 0.005989329 0.0102740945 -0.009758169 -0.013934809; 0.0046252874 0.0047164033 0.006615041 0.0087485025 0.0036523344 0.0015425641 -0.0025191456 -0.0049306857 0.004231052 0.0056748213 0.013629846 -0.004974689 0.005989329 0.0102740945 -0.009758169 -0.013934809; 0.0046252874 0.0047164033 0.006615041 0.0087485025 0.0036523344 0.0015425641 -0.0025191456 -0.0049306857 0.004231052 0.0056748213 0.013629846 -0.004974689 0.005989329 0.0102740945 -0.009758169 -0.013934809; 0.0046252874 0.0047164033 0.006615041 0.0087485025 0.0036523344 0.0015425641 -0.0025191456 -0.0049306857 0.004231052 0.0056748213 0.013629846 -0.004974689 0.005989329 0.0102740945 -0.009758169 -0.013934809; 0.0046252874 0.0047164033 0.006615041 0.0087485025 0.0036523344 0.0015425641 -0.0025191456 -0.0049306857 0.004231052 0.0056748213 0.013629846 -0.004974689 0.005989329 0.0102740945 -0.009758169 -0.013934809; 0.0046252874 0.0047164033 0.006615041 0.0087485025 0.0036523344 0.0015425641 -0.0025191456 -0.0049306857 0.004231052 0.0056748213 0.013629846 -0.004974689 0.005989329 0.0102740945 -0.009758169 -0.013934809], bias = nothing, σ = nothing))),)
# (first ∘ MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"), # Zygote comparison test fails on the GPUArraysCore.@allowscalar in scalarfirst, so we globally allow scalar
(first ∘ MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"), # Zygote comparison test fails on the GPUArraysCore.@allowscalar in scalarfirst, so we globally allow scalar
]

Reactant.allowscalar(true)
Expand Down
5 changes: 2 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,14 @@ using Zygote: Zygote
# ENV["FLUX_TEST_DISTRIBUTED_MPI"] = "true"
# ENV["FLUX_TEST_DISTRIBUTED_NCCL"] = "true"
# ENV["FLUX_TEST_ENZYME"] = "false"
ENV["FLUX_TEST_REACTANT"] = "false"
# ENV["FLUX_TEST_REACTANT"] = "true"

const FLUX_TEST_ENZYME = get(ENV, "FLUX_TEST_ENZYME", VERSION < v"1.12-" ? "true" : "false") == "true"
const FLUX_TEST_CPU = get(ENV, "FLUX_TEST_CPU", "true") == "true"

# Reactant will automatically select a GPU backend, if available, and TPU backend, if available.
# Otherwise it will fall back to CPU.
const FLUX_TEST_REACTANT = get(ENV, "FLUX_TEST_REACTANT",
VERSION < v"1.12-" && !Sys.iswindows() ? "true" : "false") == "true"
const FLUX_TEST_REACTANT = get(ENV, "FLUX_TEST_REACTANT", "true") == "true"

if FLUX_TEST_ENZYME || FLUX_TEST_REACTANT
Pkg.add("Enzyme")
Expand Down
Loading
Loading