|  | 
|  | 1 | +use std::ptr; | 
|  | 2 | + | 
|  | 3 | +use rustc_ast::expand::batch_attrs::{BatchAttrs, BatchItem, BatchActivity}; | 
|  | 4 | +use rustc_codegen_ssa::ModuleCodegen; | 
|  | 5 | +use rustc_codegen_ssa::back::write::ModuleConfig; | 
|  | 6 | +use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; | 
|  | 7 | +use rustc_errors::FatalError; | 
|  | 8 | +use rustc_middle::ty::TyCtxt; | 
|  | 9 | +use rustc_session::config::Lto; | 
|  | 10 | +use tracing::{debug, trace}; | 
|  | 11 | + | 
|  | 12 | +use crate::back::write::{llvm_err, llvm_optimize}; | 
|  | 13 | +use crate::builder::Builder; | 
|  | 14 | +use crate::declare::declare_raw_fn; | 
|  | 15 | +use crate::errors::LlvmError; | 
|  | 16 | +use crate::llvm::AttributePlace::Function; | 
|  | 17 | +use crate::llvm::{Metadata, True}; | 
|  | 18 | +use crate::value::Value; | 
|  | 19 | +use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, context, llvm}; | 
|  | 20 | + | 
|  | 21 | +fn get_params(fnc: &Value) -> Vec<&Value> { | 
|  | 22 | +    unsafe { | 
|  | 23 | +        let param_num = llvm::LLVMCountParams(fnc) as usize; | 
|  | 24 | +        let mut fnc_args: Vec<&Value> = vec![]; | 
|  | 25 | +        fnc_args.reserve(param_num); | 
|  | 26 | +        llvm::LLVMGetParams(fnc, fnc_args.as_mut_ptr()); | 
|  | 27 | +        fnc_args.set_len(param_num); | 
|  | 28 | +        fnc_args | 
|  | 29 | +    } | 
|  | 30 | +} | 
|  | 31 | + | 
|  | 32 | +/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another | 
|  | 33 | +/// function with expected naming and calling conventions[^1] which will be | 
|  | 34 | +/// discovered by the enzyme LLVM pass and its body populated with the differentiated | 
|  | 35 | +/// `fn_to_diff`. `outer_fn` is then modified to have a call to the generated | 
|  | 36 | +/// function and handle the differences between the Rust calling convention and | 
|  | 37 | +/// Enzyme. | 
|  | 38 | +/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/> | 
|  | 39 | +// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to | 
|  | 40 | +// cover some assumptions of enzyme/batch, which could lead to UB otherwise. | 
|  | 41 | +fn generate_enzyme_call<'ll, 'tcx>( | 
|  | 42 | +    cx: &context::CodegenCx<'ll, 'tcx>, | 
|  | 43 | +    fn_to_diff: &'ll Value, | 
|  | 44 | +    outer_fn: &'ll Value, | 
|  | 45 | +    attrs: BatchAttrs, | 
|  | 46 | +) { | 
|  | 47 | +    let inputs = attrs.input_activity; | 
|  | 48 | +    let width = attrs.width; | 
|  | 49 | +    let mut ad_name: String = "__enzyme_batch".to_string(); | 
|  | 50 | + | 
|  | 51 | +    // add outer_fn name to ad_name to make it unique, in case users apply batch to multiple | 
|  | 52 | +    // functions. Unwrap will only panic, if LLVM gave us an invalid string. | 
|  | 53 | +    let name = llvm::get_value_name(outer_fn); | 
|  | 54 | +    let outer_fn_name = std::ffi::CStr::from_bytes_with_nul(name).unwrap().to_str().unwrap(); | 
|  | 55 | +    ad_name.push_str(outer_fn_name.to_string().as_str()); | 
|  | 56 | + | 
|  | 57 | +    // Let us assume the user wrote the following function square: | 
|  | 58 | +    // | 
|  | 59 | +    // ```llvm | 
|  | 60 | +    // define double @square(double %x) { | 
|  | 61 | +    // entry: | 
|  | 62 | +    //  %0 = fmul double %x, %x | 
|  | 63 | +    //  ret double %0 | 
|  | 64 | +    // } | 
|  | 65 | +    // ``` | 
|  | 66 | +    // | 
|  | 67 | +    // The user now applies batching to the function square, in which case fn_to_diff will be `square`. | 
|  | 68 | +    // Our macro generates the following placeholder code (slightly simplified): | 
|  | 69 | +    // | 
|  | 70 | +    // ```llvm | 
|  | 71 | +    // define double @dsquare(double %x) { | 
|  | 72 | +    //  ; placeholder code | 
|  | 73 | +    //  return 0.0; | 
|  | 74 | +    // } | 
|  | 75 | +    // ``` | 
|  | 76 | +    // | 
|  | 77 | +    // so our `outer_fn` will be `dsquare`. The unsafe code section below now removes the placeholder | 
|  | 78 | +    // code and inserts an batching call. We also add a declaration for the __enzyme_batch call. | 
|  | 79 | +    // Again, the arguments to all functions are slightly simplified. | 
|  | 80 | +    // ```llvm | 
|  | 81 | +    // declare double @__enzyme_batch_square(...) | 
|  | 82 | +    // | 
|  | 83 | +    // define double @dsquare(double %x0, double %x1, double %x2, double %x3) { | 
|  | 84 | +    // entry: | 
|  | 85 | +    //   %0 = tail call double (...) @__enzyme_batch_square(double (double)* nonnull @square, metadata !"enzyme_width", i64 4, | 
|  | 86 | +    //   metadata !"enzyme_vector", double %x0, double %x1, double %x2, double %x3) | 
|  | 87 | +    //   ret double %0 | 
|  | 88 | +    // } | 
|  | 89 | +    // ``` | 
|  | 90 | +    unsafe { | 
|  | 91 | +        // On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input | 
|  | 92 | +        // arguments. We do however need to declare them with their correct return type. | 
|  | 93 | +        // We already figured the correct return type out in our frontend, when generating the outer_fn, | 
|  | 94 | +        // so we can now just go ahead and use that. FIXME(ZuseZ4): This doesn't handle sret yet. | 
|  | 95 | +        let fn_ty = llvm::LLVMGlobalGetValueType(outer_fn); | 
|  | 96 | +        let ret_ty = llvm::LLVMGetReturnType(fn_ty); | 
|  | 97 | + | 
|  | 98 | +        // LLVM can figure out the input types on it's own, so we take a shortcut here. | 
|  | 99 | +        let enzyme_ty = llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True); | 
|  | 100 | + | 
|  | 101 | +        //FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and | 
|  | 102 | +        // think a bit more about what should go here. | 
|  | 103 | +        let cc = llvm::LLVMGetFunctionCallConv(outer_fn); | 
|  | 104 | +        let ad_fn = declare_raw_fn( | 
|  | 105 | +            cx, | 
|  | 106 | +            &ad_name, | 
|  | 107 | +            llvm::CallConv::try_from(cc).expect("invalid callconv"), | 
|  | 108 | +            llvm::UnnamedAddr::No, | 
|  | 109 | +            llvm::Visibility::Default, | 
|  | 110 | +            enzyme_ty, | 
|  | 111 | +        ); | 
|  | 112 | + | 
|  | 113 | +        // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to | 
|  | 114 | +        // do it's work. | 
|  | 115 | +        let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx); | 
|  | 116 | +        attributes::apply_to_llfn(ad_fn, Function, &[attr]); | 
|  | 117 | + | 
|  | 118 | +        // first, remove all calls from fnc | 
|  | 119 | +        let entry = llvm::LLVMGetFirstBasicBlock(outer_fn); | 
|  | 120 | +        let br = llvm::LLVMRustGetTerminator(entry); | 
|  | 121 | +        llvm::LLVMRustEraseInstFromParent(br); | 
|  | 122 | + | 
|  | 123 | +        let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap(); | 
|  | 124 | +        let mut builder = Builder::build(cx, entry); | 
|  | 125 | + | 
|  | 126 | +        let num_args = llvm::LLVMCountParams(&fn_to_diff); | 
|  | 127 | +        let mut args = Vec::with_capacity(num_args as usize + 1); | 
|  | 128 | +        args.push(fn_to_diff); | 
|  | 129 | + | 
|  | 130 | +        let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap(); | 
|  | 131 | +        let enzyme_vector = cx.create_metadata("enzyme_vector".to_string()).unwrap(); | 
|  | 132 | +        let enzyme_buffer = cx.create_metadata("enzyme_buffer".to_string()).unwrap(); | 
|  | 133 | + | 
|  | 134 | +        trace!("matching batch arguments"); | 
|  | 135 | +        // We now handle the issue that Rust level arguments not always match the llvm-ir level | 
|  | 136 | +        // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on | 
|  | 137 | +        // llvm-ir level. The number of activities matches the number of Rust level arguments, so we | 
|  | 138 | +        // need to match those. | 
|  | 139 | +        // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it | 
|  | 140 | +        // using iterators and peek()? | 
|  | 141 | +        let mut outer_pos: usize = 0; | 
|  | 142 | +        let mut activity_pos = 0; | 
|  | 143 | +        let outer_args: Vec<&llvm::Value> = get_params(outer_fn); | 
|  | 144 | +        while activity_pos < inputs.len() { | 
|  | 145 | +            let activity = inputs[activity_pos]; | 
|  | 146 | +            let (activity, vectorized): (&Metadata, bool) = match activity { | 
|  | 147 | +                BatchActivity::Const => (enzyme_const, false), | 
|  | 148 | +                BatchActivity::Vector => (enzyme_vector, true), | 
|  | 149 | +                BatchActivity::Leaf => (enzyme_buffer, false), | 
|  | 150 | +                BatchActivity::FakeActivitySize => (enzyme_const, false), | 
|  | 151 | +            }; | 
|  | 152 | +            let outer_arg = outer_args[outer_pos]; | 
|  | 153 | +            args.push(cx.get_metadata_value(activity)); | 
|  | 154 | +            args.push(outer_arg); | 
|  | 155 | +            if vectorized { | 
|  | 156 | +                // We know that vectorized args by construction have <width-1> following arguments, | 
|  | 157 | +                // so this can not be out of bounds. | 
|  | 158 | +                let next_outer_arg = outer_args[outer_pos + width - 1]; | 
|  | 159 | +                let next_outer_ty = cx.val_ty(next_outer_arg); | 
|  | 160 | +                // FIXME(ZuseZ4): We should add support for Vec here too, but it's less urgent since | 
|  | 161 | +                // vectors behind references (&Vec<T>) are already supported. Users can not pass a | 
|  | 162 | +                // Vec by value for reverse mode, so this would only help forward mode batch. | 
|  | 163 | +                let slice = { | 
|  | 164 | +                    if activity_pos + 1 >= inputs.len() { | 
|  | 165 | +                        // If there is no arg following our ptr, it also can't be a slice, | 
|  | 166 | +                        // since that would lead to a ptr, int pair. | 
|  | 167 | +                        false | 
|  | 168 | +                    } else { | 
|  | 169 | +                        let next_activity = inputs[activity_pos + 1]; | 
|  | 170 | +                        // We analyze the MIR types and add this dummy activity if we visit a slice. | 
|  | 171 | +                        next_activity == BatchActivity::FakeActivitySize | 
|  | 172 | +                    } | 
|  | 173 | +                }; | 
|  | 174 | +                if slice { | 
|  | 175 | +                    // A 4x batched slice will have the following two outer_fn arguments: | 
|  | 176 | +                    // (..., ptr0, int0, ptr1, int1, ...). We add the following llvm-ir to our __enzyme call: | 
|  | 177 | +                    // (..., metadata! enzyme_vector, ptr0, ptr1, ptr2, ptr3, int1, ...). | 
|  | 178 | +                    // FIXME(ZuseZ4): We will upstream a safety check later which asserts that | 
|  | 179 | +                    // int2 >= int1, which means the shadow args are equally large | 
|  | 180 | + | 
|  | 181 | +                    args.push(cx.get_metadata_value(enzyme_const)); | 
|  | 182 | +                    // Now we verify that we have width pairs of (ptr/int) | 
|  | 183 | +                    for i in 0..width { | 
|  | 184 | +                        let next_outer_arg = outer_args[outer_pos + 2 * i]; | 
|  | 185 | +                        let next_outer_ty = cx.val_ty(next_outer_arg); | 
|  | 186 | +                        assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer); | 
|  | 187 | +                        let next_outer_arg2 = outer_args[outer_pos + 2 * i + 1]; | 
|  | 188 | +                        let next_outer_ty2 = cx.val_ty(next_outer_arg2); | 
|  | 189 | +                        assert!(llvm::LLVMRustGetTypeKind(next_outer_ty2) == llvm::TypeKind::Integer); | 
|  | 190 | +                        args.push(next_outer_arg); | 
|  | 191 | +                        args.push(next_outer_arg2); | 
|  | 192 | +                    } | 
|  | 193 | +                    args.push(cx.get_metadata_value(enzyme_const)); | 
|  | 194 | +                    args.push(next_outer_arg); | 
|  | 195 | +                    outer_pos += 4; | 
|  | 196 | +                    activity_pos += 2; | 
|  | 197 | +                } else { | 
|  | 198 | +                    // A vectorized pointer will have the following two outer_fn arguments: | 
|  | 199 | +                    // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call: | 
|  | 200 | +                    // (..., metadata! enzyme_dup, ptr, ptr, ...). | 
|  | 201 | +                    assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer); | 
|  | 202 | +                    args.push(next_outer_arg); | 
|  | 203 | +                    outer_pos += 2; | 
|  | 204 | +                    activity_pos += 1; | 
|  | 205 | +                } | 
|  | 206 | +            } else { | 
|  | 207 | +                // We do not differentiate with resprect to this argument. | 
|  | 208 | +                // We already added the metadata and argument above, so just increase the counters. | 
|  | 209 | +                outer_pos += 1; | 
|  | 210 | +                activity_pos += 1; | 
|  | 211 | +            } | 
|  | 212 | +        } | 
|  | 213 | + | 
|  | 214 | +        let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); | 
|  | 215 | + | 
|  | 216 | +        // This part is a bit iffy. LLVM requires that a call to an inlineable function has some | 
|  | 217 | +        // metadata attachted to it, but we just created this code oota. Given that the | 
|  | 218 | +        // differentiated function already has partly confusing metadata, and given that this | 
|  | 219 | +        // affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the | 
|  | 220 | +        // dummy code which we inserted at a higher level. | 
|  | 221 | +        // FIXME(ZuseZ4): Work with Enzyme core devs to clarify what debug metadata issues we have, | 
|  | 222 | +        // and how to best improve it for enzyme core and rust-enzyme. | 
|  | 223 | +        let md_ty = cx.get_md_kind_id("dbg"); | 
|  | 224 | +        if llvm::LLVMRustHasMetadata(last_inst, md_ty) { | 
|  | 225 | +            let md = llvm::LLVMRustDIGetInstMetadata(last_inst) | 
|  | 226 | +                .expect("failed to get instruction metadata"); | 
|  | 227 | +            let md_todiff = cx.get_metadata_value(md); | 
|  | 228 | +            llvm::LLVMSetMetadata(call, md_ty, md_todiff); | 
|  | 229 | +        } else { | 
|  | 230 | +            // We don't panic, since depending on whether we are in debug or release mode, we might | 
|  | 231 | +            // have no debug info to copy, which would then be ok. | 
|  | 232 | +            trace!("no dbg info"); | 
|  | 233 | +        } | 
|  | 234 | +        // Now that we copied the metadata, get rid of dummy code. | 
|  | 235 | +        llvm::LLVMRustEraseInstBefore(entry, last_inst); | 
|  | 236 | +        llvm::LLVMRustEraseInstFromParent(last_inst); | 
|  | 237 | + | 
|  | 238 | +        if cx.val_ty(outer_fn) != cx.type_void() { | 
|  | 239 | +            builder.ret(call); | 
|  | 240 | +        } else { | 
|  | 241 | +            builder.ret_void(); | 
|  | 242 | +        } | 
|  | 243 | + | 
|  | 244 | +        // Let's crash in case that we messed something up above and generated invalid IR. | 
|  | 245 | +        llvm::LLVMRustVerifyFunction( | 
|  | 246 | +            outer_fn, | 
|  | 247 | +            llvm::LLVMRustVerifierFailureAction::LLVMAbortProcessAction, | 
|  | 248 | +        ); | 
|  | 249 | +    } | 
|  | 250 | +} | 
|  | 251 | + | 
|  | 252 | +pub(crate) fn batch<'ll, 'tcx>( | 
|  | 253 | +    module: &'ll ModuleCodegen<ModuleLlvm>, | 
|  | 254 | +    cgcx: &CodegenContext<LlvmCodegenBackend>, | 
|  | 255 | +    tcx: TyCtxt<'tcx>, | 
|  | 256 | +    batch_items: Vec<BatchItem>, | 
|  | 257 | +    config: &ModuleConfig, | 
|  | 258 | +) -> Result<(), FatalError> { | 
|  | 259 | +    for item in &batch_items { | 
|  | 260 | +        trace!("{}", item); | 
|  | 261 | +    } | 
|  | 262 | + | 
|  | 263 | +    let diag_handler = cgcx.create_dcx(); | 
|  | 264 | +    let (_, _, cgus) = tcx.collect_and_partition_mono_items(()); | 
|  | 265 | +    let cx = context::CodegenCx::new(tcx, &cgus.first().unwrap(), &module.module_llvm); | 
|  | 266 | + | 
|  | 267 | +    // Before dumping the module, we want all the TypeTrees to become part of the module. | 
|  | 268 | +    for item in batch_items.iter() { | 
|  | 269 | +        let name = item.source.clone(); | 
|  | 270 | +        let fn_def: Option<&llvm::Value> = cx.get_function(&name); | 
|  | 271 | +        let Some(fn_def) = fn_def else { | 
|  | 272 | +            return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareBatching { | 
|  | 273 | +                src: item.source.clone(), | 
|  | 274 | +                target: item.target.clone(), | 
|  | 275 | +                error: "could not find source function".to_owned(), | 
|  | 276 | +            })); | 
|  | 277 | +        }; | 
|  | 278 | +        debug!(?item.target); | 
|  | 279 | +        let fn_target: Option<&llvm::Value> = cx.get_function(&item.target); | 
|  | 280 | +        let Some(fn_target) = fn_target else { | 
|  | 281 | +            return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareBatching { | 
|  | 282 | +                src: item.source.clone(), | 
|  | 283 | +                target: item.target.clone(), | 
|  | 284 | +                error: "could not find target function".to_owned(), | 
|  | 285 | +            })); | 
|  | 286 | +        }; | 
|  | 287 | + | 
|  | 288 | +        generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone()); | 
|  | 289 | +    } | 
|  | 290 | + | 
|  | 291 | +    // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts | 
|  | 292 | + | 
|  | 293 | +    if let Some(opt_level) = config.opt_level { | 
|  | 294 | +        let opt_stage = match cgcx.lto { | 
|  | 295 | +            Lto::Fat => llvm::OptStage::PreLinkFatLTO, | 
|  | 296 | +            Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO, | 
|  | 297 | +            _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, | 
|  | 298 | +            _ => llvm::OptStage::PreLinkNoLTO, | 
|  | 299 | +        }; | 
|  | 300 | +        // This is our second opt call, so now we run all opts, | 
|  | 301 | +        // to make sure we get the best performance. | 
|  | 302 | +        let skip_size_increasing_opts = false; | 
|  | 303 | +        trace!("running Module Optimization after differentiation"); | 
|  | 304 | +        unsafe { | 
|  | 305 | +            llvm_optimize( | 
|  | 306 | +                cgcx, | 
|  | 307 | +                diag_handler.handle(), | 
|  | 308 | +                module, | 
|  | 309 | +                config, | 
|  | 310 | +                opt_level, | 
|  | 311 | +                opt_stage, | 
|  | 312 | +                skip_size_increasing_opts, | 
|  | 313 | +            )? | 
|  | 314 | +        }; | 
|  | 315 | +    } | 
|  | 316 | +    trace!("done with differentiate()"); | 
|  | 317 | + | 
|  | 318 | +    Ok(()) | 
|  | 319 | +} | 
0 commit comments