From ca702da264bce69c687de3d3ba37d9215d39a587 Mon Sep 17 00:00:00 2001 From: Jia Yuan Lo Date: Mon, 18 Sep 2023 22:00:24 +0800 Subject: [PATCH 01/14] Allow building for x86_64-linux-android (#7055) --- crates/jit/Cargo.toml | 2 +- crates/jit/src/profiling.rs | 2 +- crates/runtime/src/traphandlers/unix.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/jit/Cargo.toml b/crates/jit/Cargo.toml index 25cf74bbb1d4..13fccaf9f769 100644 --- a/crates/jit/Cargo.toml +++ b/crates/jit/Cargo.toml @@ -37,7 +37,7 @@ features = [ "Win32_System_Diagnostics_Debug", ] -[target.'cfg(target_arch = "x86_64")'.dependencies] +[target.'cfg(all(target_arch = "x86_64", not(target_os = "android")))'.dependencies] ittapi = { version = "0.3.3", optional = true } [features] diff --git a/crates/jit/src/profiling.rs b/crates/jit/src/profiling.rs index d816d65f8ba5..5e24344ac915 100644 --- a/crates/jit/src/profiling.rs +++ b/crates/jit/src/profiling.rs @@ -33,7 +33,7 @@ cfg_if::cfg_if! { cfg_if::cfg_if! { // Note: VTune support is disabled on windows mingw because the ittapi crate doesn't compile // there; see also https://github.com/bytecodealliance/wasmtime/pull/4003 for rationale. - if #[cfg(all(feature = "vtune", target_arch = "x86_64", not(all(target_os = "windows", target_env = "gnu"))))] { + if #[cfg(all(feature = "vtune", target_arch = "x86_64", not(any(target_os = "android", all(target_os = "windows", target_env = "gnu")))))] { mod vtune; pub use vtune::new as new_vtune; } else { diff --git a/crates/runtime/src/traphandlers/unix.rs b/crates/runtime/src/traphandlers/unix.rs index 4a7e2027b679..613119af77c4 100644 --- a/crates/runtime/src/traphandlers/unix.rs +++ b/crates/runtime/src/traphandlers/unix.rs @@ -190,7 +190,7 @@ struct ucontext_t { unsafe fn get_pc_and_fp(cx: *mut libc::c_void, _signum: libc::c_int) -> (*const u8, usize) { cfg_if::cfg_if! { - if #[cfg(all(target_os = "linux", target_arch = "x86_64"))] { + if #[cfg(all(any(target_os = "linux", target_os = "android"), target_arch = "x86_64"))] { let cx = &*(cx as *const libc::ucontext_t); ( cx.uc_mcontext.gregs[libc::REG_RIP as usize] as *const u8, From b23f534c07e2d8cdc3dd4c9913183969081a10da Mon Sep 17 00:00:00 2001 From: Afonso Bordado Date: Mon, 18 Sep 2023 15:44:06 +0100 Subject: [PATCH 02/14] riscv64: Implement ELF TLS GD Relocations (#7003) * cranelift: Add support for public labels * cranelift: Allow targeting labels with relocations * cranelift: Emit label related relocations in module * riscv64: Implement TLS GD * cranelift: Rename `label_is_public` field * cranelift: Avoid making MachLabel part of the exposed API --- cranelift/codegen/src/binemit/mod.rs | 15 ++- .../codegen/src/isa/aarch64/inst/emit.rs | 8 +- cranelift/codegen/src/isa/riscv64/inst.isle | 11 +++ .../codegen/src/isa/riscv64/inst/emit.rs | 65 ++++++++++++- cranelift/codegen/src/isa/riscv64/inst/mod.rs | 11 +++ cranelift/codegen/src/isa/riscv64/lower.isle | 5 + cranelift/codegen/src/lib.rs | 3 +- cranelift/codegen/src/machinst/buffer.rs | 95 +++++++++++++++++-- cranelift/codegen/src/machinst/mod.rs | 2 +- .../filetests/isa/riscv64/tls-elf.clif | 58 +++++++++++ cranelift/jit/src/backend.rs | 28 +++--- cranelift/jit/src/compiled_blob.rs | 8 +- cranelift/module/src/data_context.rs | 22 ++--- cranelift/module/src/lib.rs | 2 +- cranelift/module/src/module.rs | 70 ++++++++------ cranelift/object/src/backend.rs | 83 +++++++++++++--- cranelift/src/disasm.rs | 12 +-- crates/cranelift-shared/src/lib.rs | 26 ++--- 18 files changed, 417 insertions(+), 107 deletions(-) create mode 100644 cranelift/filetests/filetests/isa/riscv64/tls-elf.clif diff --git a/cranelift/codegen/src/binemit/mod.rs b/cranelift/codegen/src/binemit/mod.rs index 9f00b93791ac..d5582d87f3be 100644 --- a/cranelift/codegen/src/binemit/mod.rs +++ b/cranelift/codegen/src/binemit/mod.rs @@ -92,6 +92,18 @@ pub enum Reloc { /// jalr ra, ra, 0 RiscvCall, + /// RISC-V TLS GD: High 20 bits of 32-bit PC-relative TLS GD GOT reference, + /// + /// This is the `R_RISCV_TLS_GD_HI20` relocation from the RISC-V ELF psABI document. + /// https://github.com/riscv-non-isa/riscv-elf-psabi-doc/blob/master/riscv-elf.adoc#global-dynamic + RiscvTlsGdHi20, + + /// Low 12 bits of a 32-bit PC-relative relocation (I-Type instruction) + /// + /// This is the `R_RISCV_PCREL_LO12_I` relocation from the RISC-V ELF psABI document. + /// https://github.com/riscv-non-isa/riscv-elf-psabi-doc/blob/master/riscv-elf.adoc#pc-relative-symbol-addresses + RiscvPCRelLo12I, + /// s390x TLS GD64 - 64-bit offset of tls_index for GD symbol in GOT S390xTlsGd64, /// s390x TLS GDCall - marker to enable optimization of TLS calls @@ -114,7 +126,8 @@ impl fmt::Display for Reloc { Self::X86SecRel => write!(f, "SecRel"), Self::Arm32Call | Self::Arm64Call => write!(f, "Call"), Self::RiscvCall => write!(f, "RiscvCall"), - + Self::RiscvTlsGdHi20 => write!(f, "RiscvTlsGdHi20"), + Self::RiscvPCRelLo12I => write!(f, "RiscvPCRelLo12I"), Self::ElfX86_64TlsGd => write!(f, "ElfX86_64TlsGd"), Self::MachOX86_64Tlv => write!(f, "MachOX86_64Tlv"), Self::MachOAarch64TlsAdrPage21 => write!(f, "MachOAarch64TlsAdrPage21"), diff --git a/cranelift/codegen/src/isa/aarch64/inst/emit.rs b/cranelift/codegen/src/isa/aarch64/inst/emit.rs index 2a3c9b8b9b75..43bdd4a3e981 100644 --- a/cranelift/codegen/src/isa/aarch64/inst/emit.rs +++ b/cranelift/codegen/src/isa/aarch64/inst/emit.rs @@ -3174,7 +3174,7 @@ impl MachInstEmit for Inst { // Note: this is not `Inst::Jump { .. }.emit(..)` because we // have different metadata in this case: we don't have a label // for the target, but rather a function relocation. - sink.add_reloc(Reloc::Arm64Call, callee, 0); + sink.add_reloc(Reloc::Arm64Call, &**callee, 0); sink.put4(enc_jump26(0b000101, 0)); sink.add_call_site(ir::Opcode::ReturnCall); @@ -3382,12 +3382,12 @@ impl MachInstEmit for Inst { // ldr rd, [rd, :got_lo12:X] // adrp rd, symbol - sink.add_reloc(Reloc::Aarch64AdrGotPage21, name, 0); + sink.add_reloc(Reloc::Aarch64AdrGotPage21, &**name, 0); let inst = Inst::Adrp { rd, off: 0 }; inst.emit(&[], sink, emit_info, state); // ldr rd, [rd, :got_lo12:X] - sink.add_reloc(Reloc::Aarch64Ld64GotLo12Nc, name, 0); + sink.add_reloc(Reloc::Aarch64Ld64GotLo12Nc, &**name, 0); let inst = Inst::ULoad64 { rd, mem: AMode::reg(rd.to_reg()), @@ -3415,7 +3415,7 @@ impl MachInstEmit for Inst { dest: BranchTarget::ResolvedOffset(12), }; inst.emit(&[], sink, emit_info, state); - sink.add_reloc(Reloc::Abs8, name, offset); + sink.add_reloc(Reloc::Abs8, &**name, offset); sink.put8(0); } } diff --git a/cranelift/codegen/src/isa/riscv64/inst.isle b/cranelift/codegen/src/isa/riscv64/inst.isle index e14adf2818f1..2fe4fa6f8e02 100644 --- a/cranelift/codegen/src/isa/riscv64/inst.isle +++ b/cranelift/codegen/src/isa/riscv64/inst.isle @@ -148,6 +148,11 @@ (name BoxExternalName) (offset i64)) + ;; Load a TLS symbol address + (ElfTlsGetAddr + (rd WritableReg) + (name BoxExternalName)) + ;; Load address referenced by `mem` into `rd`. (LoadAddr (rd WritableReg) @@ -2625,6 +2630,12 @@ (decl load_ext_name (ExternalName i64) Reg) (extern constructor load_ext_name load_ext_name) +(decl elf_tls_get_addr (ExternalName) Reg) +(rule (elf_tls_get_addr name) + (let ((dst WritableReg (temp_writable_reg $I64)) + (_ Unit (emit (MInst.ElfTlsGetAddr dst name)))) + dst)) + (decl int_convert_2_float_op (Type bool Type) FpuOPRR) (extern constructor int_convert_2_float_op int_convert_2_float_op) diff --git a/cranelift/codegen/src/isa/riscv64/inst/emit.rs b/cranelift/codegen/src/isa/riscv64/inst/emit.rs index 7effa2b4cb2d..3f8ef674ff3a 100644 --- a/cranelift/codegen/src/isa/riscv64/inst/emit.rs +++ b/cranelift/codegen/src/isa/riscv64/inst/emit.rs @@ -1,7 +1,7 @@ //! Riscv64 ISA: binary code emission. use crate::binemit::StackMap; -use crate::ir::{self, RelSourceLoc, TrapCode}; +use crate::ir::{self, LibCall, RelSourceLoc, TrapCode}; use crate::isa::riscv64::inst::*; use crate::isa::riscv64::lower::isle::generated_code::{CaOp, CrOp}; use crate::machinst::{AllocationConsumer, Reg, Writable}; @@ -343,6 +343,7 @@ impl Inst { | Inst::Jal { .. } | Inst::CondBr { .. } | Inst::LoadExtName { .. } + | Inst::ElfTlsGetAddr { .. } | Inst::LoadAddr { .. } | Inst::VirtualSPOffsetAdj { .. } | Inst::Mov { .. } @@ -978,7 +979,7 @@ impl Inst { ); sink.add_call_site(ir::Opcode::ReturnCall); - sink.add_reloc(Reloc::RiscvCall, &callee, 0); + sink.add_reloc(Reloc::RiscvCall, &**callee, 0); Inst::construct_auipc_and_jalr(None, writable_spilltmp_reg(), 0) .into_iter() .for_each(|i| i.emit_uncompressed(sink, emit_info, state, start_off)); @@ -1958,6 +1959,60 @@ impl Inst { sink.bind_label(label_end, &mut state.ctrl_plane); } + + &Inst::ElfTlsGetAddr { rd, ref name } => { + // RISC-V's TLS GD model is slightly different from other arches. + // + // We have a relocation (R_RISCV_TLS_GD_HI20) that loads the high 20 bits + // of the address relative to the GOT entry. This relocation points to + // the symbol as usual. + // + // However when loading the bottom 12bits of the address, we need to + // use a label that points to the previous AUIPC instruction. + // + // label: + // auipc a0,0 # R_RISCV_TLS_GD_HI20 (symbol) + // addi a0,a0,0 # R_RISCV_PCREL_LO12_I (label) + // + // https://github.com/riscv-non-isa/riscv-elf-psabi-doc/blob/master/riscv-elf.adoc#global-dynamic + + // Create the lable that is going to be published to the final binary object. + let auipc_label = sink.get_label(); + sink.bind_label(auipc_label, &mut state.ctrl_plane); + + // Get the current PC. + sink.add_reloc(Reloc::RiscvTlsGdHi20, &**name, 0); + Inst::Auipc { + rd: rd, + imm: Imm20::from_i32(0), + } + .emit_uncompressed(sink, emit_info, state, start_off); + + // The `addi` here, points to the `auipc` label instead of directly to the symbol. + sink.add_reloc(Reloc::RiscvPCRelLo12I, &auipc_label, 0); + Inst::AluRRImm12 { + alu_op: AluOPRRI::Addi, + rd: rd, + rs: rd.to_reg(), + imm12: Imm12::from_i16(0), + } + .emit_uncompressed(sink, emit_info, state, start_off); + + Inst::Call { + info: Box::new(CallInfo { + dest: ExternalName::LibCall(LibCall::ElfTlsGetAddr), + uses: smallvec![], + defs: smallvec![], + opcode: crate::ir::Opcode::TlsValue, + caller_callconv: CallConv::SystemV, + callee_callconv: CallConv::SystemV, + callee_pop_size: 0, + clobbers: PRegSet::empty(), + }), + } + .emit_uncompressed(sink, emit_info, state, start_off); + } + &Inst::TrapIfC { rs1, rs2, @@ -3273,6 +3328,12 @@ impl Inst { offset, }, + Inst::ElfTlsGetAddr { rd, name } => { + let rd = allocs.next_writable(rd); + debug_assert_eq!(a0(), rd.to_reg()); + Inst::ElfTlsGetAddr { rd, name } + } + Inst::TrapIfC { rs1, rs2, diff --git a/cranelift/codegen/src/isa/riscv64/inst/mod.rs b/cranelift/codegen/src/isa/riscv64/inst/mod.rs index 334af1c53419..5a12644f82e0 100644 --- a/cranelift/codegen/src/isa/riscv64/inst/mod.rs +++ b/cranelift/codegen/src/isa/riscv64/inst/mod.rs @@ -473,6 +473,13 @@ fn riscv64_get_operands VReg>(inst: &Inst, collector: &mut Operan &Inst::LoadExtName { rd, .. } => { collector.reg_def(rd); } + &Inst::ElfTlsGetAddr { rd, .. } => { + // x10 is a0 which is both the first argument and the first return value. + collector.reg_fixed_def(rd, a0()); + let mut clobbers = Riscv64MachineDeps::get_regs_clobbered_by_call(CallConv::SystemV); + clobbers.remove(px_reg(10)); + collector.reg_clobbers(clobbers); + } &Inst::LoadAddr { rd, mem } => { if let Some(r) = mem.get_allocatable_register() { collector.reg_use(r); @@ -1691,6 +1698,10 @@ impl Inst { let rd = format_reg(rd.to_reg(), allocs); format!("load_sym {},{}{:+}", rd, name.display(None), offset) } + &Inst::ElfTlsGetAddr { rd, ref name } => { + let rd = format_reg(rd.to_reg(), allocs); + format!("elf_tls_get_addr {rd},{}", name.display(None)) + } &MInst::LoadAddr { ref rd, ref mem } => { let rs = mem.to_string_with_alloc(allocs); let rd = format_reg(rd.to_reg(), allocs); diff --git a/cranelift/codegen/src/isa/riscv64/lower.isle b/cranelift/codegen/src/isa/riscv64/lower.isle index 26b375f3cdf3..ec264a221094 100644 --- a/cranelift/codegen/src/isa/riscv64/lower.isle +++ b/cranelift/codegen/src/isa/riscv64/lower.isle @@ -1675,6 +1675,11 @@ (lower (symbol_value (symbol_value_data name _ offset))) (load_ext_name name offset)) +;;;;; Rules for `tls_value` ;;;;;;;;;;;;;; + +(rule (lower (has_type (tls_model (TlsModel.ElfGd)) (tls_value (symbol_value_data name _ _)))) + (elf_tls_get_addr name)) + ;;;;; Rules for `bitcast`;;;;;;;;; (rule (lower (has_type out_ty (bitcast _ v @ (value_type in_ty)))) diff --git a/cranelift/codegen/src/lib.rs b/cranelift/codegen/src/lib.rs index 250b36cc897b..c856e1a1b645 100644 --- a/cranelift/codegen/src/lib.rs +++ b/cranelift/codegen/src/lib.rs @@ -56,7 +56,8 @@ pub mod write; pub use crate::entity::packed_option; pub use crate::machinst::buffer::{ - MachCallSite, MachReloc, MachSrcLoc, MachStackMap, MachTextSectionBuilder, MachTrap, + FinalizedMachReloc, FinalizedRelocTarget, MachCallSite, MachSrcLoc, MachStackMap, + MachTextSectionBuilder, MachTrap, }; pub use crate::machinst::{ CompiledCode, Final, MachBuffer, MachBufferFinalized, MachInst, MachInstEmit, diff --git a/cranelift/codegen/src/machinst/buffer.rs b/cranelift/codegen/src/machinst/buffer.rs index 8bf7cfdb4fc3..aa69feef4929 100644 --- a/cranelift/codegen/src/machinst/buffer.rs +++ b/cranelift/codegen/src/machinst/buffer.rs @@ -171,6 +171,7 @@ //! all longer-range fixups to later. use crate::binemit::{Addend, CodeOffset, Reloc, StackMap}; +use crate::ir::function::FunctionParameters; use crate::ir::{ExternalName, Opcode, RelSourceLoc, SourceLoc, TrapCode}; use crate::isa::unwind::UnwindInst; use crate::machinst::{ @@ -345,7 +346,7 @@ pub struct MachBufferFinalized { /// Any relocations referring to this code. Note that only *external* /// relocations are tracked here; references to labels within the buffer are /// resolved before emission. - pub(crate) relocs: SmallVec<[MachReloc; 16]>, + pub(crate) relocs: SmallVec<[FinalizedMachReloc; 16]>, /// Any trap records referring to this code. pub(crate) traps: SmallVec<[MachTrap; 16]>, /// Any call site records referring to this code. @@ -1451,12 +1452,31 @@ impl MachBuffer { let alignment = self.finish_constants(constants); + // Resolve all labels to their offsets. + let finalized_relocs = self + .relocs + .iter() + .map(|reloc| FinalizedMachReloc { + offset: reloc.offset, + kind: reloc.kind, + addend: reloc.addend, + target: match &reloc.target { + RelocTarget::ExternalName(name) => { + FinalizedRelocTarget::ExternalName(name.clone()) + } + RelocTarget::Label(label) => { + FinalizedRelocTarget::Func(self.resolve_label_offset(*label)) + } + }, + }) + .collect(); + let mut srclocs = self.srclocs; srclocs.sort_by_key(|entry| entry.start); MachBufferFinalized { data: self.data, - relocs: self.relocs, + relocs: finalized_relocs, traps: self.traps, call_sites: self.call_sites, srclocs, @@ -1467,8 +1487,13 @@ impl MachBuffer { } /// Add an external relocation at the current offset. - pub fn add_reloc(&mut self, kind: Reloc, name: &ExternalName, addend: Addend) { - let name = name.clone(); + pub fn add_reloc + Clone>( + &mut self, + kind: Reloc, + target: &T, + addend: Addend, + ) { + let target: RelocTarget = target.clone().into(); // FIXME(#3277): This should use `I::LabelUse::from_reloc` to optionally // generate a label-use statement to track whether an island is possibly // needed to escape this function to actually get to the external name. @@ -1505,7 +1530,7 @@ impl MachBuffer { self.relocs.push(MachReloc { offset: self.data.len() as CodeOffset, kind, - name, + target, addend, }); } @@ -1622,7 +1647,7 @@ impl MachBufferFinalized { } /// Get the list of external relocations for this code. - pub fn relocs(&self) -> &[MachReloc] { + pub fn relocs(&self) -> &[FinalizedMachReloc] { &self.relocs[..] } @@ -1717,18 +1742,72 @@ impl Ord for MachLabelFixup { feature = "enable-serde", derive(serde_derive::Serialize, serde_derive::Deserialize) )] -pub struct MachReloc { +pub struct MachRelocBase { /// The offset at which the relocation applies, *relative to the /// containing section*. pub offset: CodeOffset, /// The kind of relocation. pub kind: Reloc, /// The external symbol / name to which this relocation refers. - pub name: ExternalName, + pub target: T, /// The addend to add to the symbol value. pub addend: i64, } +type MachReloc = MachRelocBase; + +/// A relocation resulting from a compilation. +pub type FinalizedMachReloc = MachRelocBase; + +/// A Relocation target +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum RelocTarget { + /// Points to an [ExternalName] outside the current function. + ExternalName(ExternalName), + /// Points to a [MachLabel] inside this function. + /// This is different from [MachLabelFixup] in that both the relocation and the + /// label will be emitted and are only resolved at link time. + /// + /// There is no reason to prefer this over [MachLabelFixup] unless the ABI requires it. + Label(MachLabel), +} + +impl From for RelocTarget { + fn from(name: ExternalName) -> Self { + Self::ExternalName(name) + } +} + +impl From for RelocTarget { + fn from(label: MachLabel) -> Self { + Self::Label(label) + } +} + +/// A Relocation target +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr( + feature = "enable-serde", + derive(serde_derive::Serialize, serde_derive::Deserialize) +)] +pub enum FinalizedRelocTarget { + /// Points to an [ExternalName] outside the current function. + ExternalName(ExternalName), + /// Points to a [CodeOffset] from the start of the current function. + Func(CodeOffset), +} + +impl FinalizedRelocTarget { + /// Returns a display for the current [FinalizedRelocTarget], with extra context to prettify the + /// output. + pub fn display<'a>(&'a self, params: Option<&'a FunctionParameters>) -> String { + match self { + FinalizedRelocTarget::ExternalName(name) => format!("{}", name.display(params)), + FinalizedRelocTarget::Func(offset) => format!("func+{offset}"), + } + } +} + /// A trap record resulting from a compilation. #[derive(Clone, Debug, PartialEq)] #[cfg_attr( diff --git a/cranelift/codegen/src/machinst/mod.rs b/cranelift/codegen/src/machinst/mod.rs index fd08b6be7988..a8d3ce376d12 100644 --- a/cranelift/codegen/src/machinst/mod.rs +++ b/cranelift/codegen/src/machinst/mod.rs @@ -430,7 +430,7 @@ impl CompiledCodeBase { buf, " ; reloc_external {} {} {}", reloc.kind, - reloc.name.display(params), + reloc.target.display(params), reloc.addend, )?; } diff --git a/cranelift/filetests/filetests/isa/riscv64/tls-elf.clif b/cranelift/filetests/filetests/isa/riscv64/tls-elf.clif new file mode 100644 index 000000000000..18e781b2fc0d --- /dev/null +++ b/cranelift/filetests/filetests/isa/riscv64/tls-elf.clif @@ -0,0 +1,58 @@ +test compile precise-output +set tls_model=elf_gd +target riscv64 + +function u0:0(i32) -> i32, i64 { +gv0 = symbol colocated tls u1:0 + +block0(v0: i32): + v1 = global_value.i64 gv0 + return v0, v1 +} + +; VCode: +; add sp,-16 +; sd ra,8(sp) +; sd fp,0(sp) +; mv fp,sp +; sd s1,-8(sp) +; add sp,-16 +; block0: +; mv s1,a0 +; elf_tls_get_addr a0,userextname0 +; mv a1,a0 +; mv a0,s1 +; add sp,+16 +; ld s1,-8(sp) +; ld ra,8(sp) +; ld fp,0(sp) +; add sp,+16 +; ret +; +; Disassembled: +; block0: ; offset 0x0 +; addi sp, sp, -0x10 +; sd ra, 8(sp) +; sd s0, 0(sp) +; mv s0, sp +; sd s1, -8(sp) +; addi sp, sp, -0x10 +; block1: ; offset 0x18 +; mv s1, a0 +; auipc a0, 0 ; reloc_external RiscvTlsGdHi20 u1:0 0 +; mv a0, a0 ; reloc_external RiscvPCRelLo12I func+28 0 +; auipc t5, 0 +; ld t5, 0xc(t5) +; j 0xc +; .byte 0x00, 0x00, 0x00, 0x00 ; reloc_external Abs8 %ElfTlsGetAddr 0 +; .byte 0x00, 0x00, 0x00, 0x00 +; jalr t5 +; mv a1, a0 +; mv a0, s1 +; addi sp, sp, 0x10 +; ld s1, -8(sp) +; ld ra, 8(sp) +; ld s0, 0(sp) +; addi sp, sp, 0x10 +; ret + diff --git a/cranelift/jit/src/backend.rs b/cranelift/jit/src/backend.rs index b564b82b6eb4..438c60d49e27 100644 --- a/cranelift/jit/src/backend.rs +++ b/cranelift/jit/src/backend.rs @@ -4,12 +4,12 @@ use crate::{compiled_blob::CompiledBlob, memory::BranchProtection, memory::Memor use cranelift_codegen::binemit::Reloc; use cranelift_codegen::isa::{OwnedTargetIsa, TargetIsa}; use cranelift_codegen::settings::Configurable; -use cranelift_codegen::{self, ir, settings, MachReloc}; +use cranelift_codegen::{self, ir, settings, FinalizedMachReloc}; use cranelift_control::ControlPlane; use cranelift_entity::SecondaryMap; use cranelift_module::{ DataDescription, DataId, FuncId, Init, Linkage, Module, ModuleDeclarations, ModuleError, - ModuleExtName, ModuleReloc, ModuleResult, + ModuleReloc, ModuleRelocTarget, ModuleResult, }; use log::info; use std::cell::RefCell; @@ -300,9 +300,9 @@ impl JITModule { std::ptr::write(plt_ptr, plt_val); } - fn get_address(&self, name: &ModuleExtName) -> *const u8 { + fn get_address(&self, name: &ModuleRelocTarget) -> *const u8 { match *name { - ModuleExtName::User { .. } => { + ModuleRelocTarget::User { .. } => { let (name, linkage) = if ModuleDeclarations::is_function(name) { if self.hotswap_enabled { return self.get_plt_address(name); @@ -337,7 +337,7 @@ impl JITModule { panic!("can't resolve symbol {}", name); } } - ModuleExtName::LibCall(ref libcall) => { + ModuleRelocTarget::LibCall(ref libcall) => { let sym = (self.libcall_names)(*libcall); self.lookup_symbol(&sym) .unwrap_or_else(|| panic!("can't resolve libcall {}", sym)) @@ -354,9 +354,9 @@ impl JITModule { unsafe { got_entry.as_ref() }.load(Ordering::SeqCst) } - fn get_got_address(&self, name: &ModuleExtName) -> NonNull> { + fn get_got_address(&self, name: &ModuleRelocTarget) -> NonNull> { match *name { - ModuleExtName::User { .. } => { + ModuleRelocTarget::User { .. } => { if ModuleDeclarations::is_function(name) { let func_id = FuncId::from_name(name); self.function_got_entries[func_id].unwrap() @@ -365,7 +365,7 @@ impl JITModule { self.data_object_got_entries[data_id].unwrap() } } - ModuleExtName::LibCall(ref libcall) => *self + ModuleRelocTarget::LibCall(ref libcall) => *self .libcall_got_entries .get(libcall) .unwrap_or_else(|| panic!("can't resolve libcall {}", libcall)), @@ -373,9 +373,9 @@ impl JITModule { } } - fn get_plt_address(&self, name: &ModuleExtName) -> *const u8 { + fn get_plt_address(&self, name: &ModuleRelocTarget) -> *const u8 { match *name { - ModuleExtName::User { .. } => { + ModuleRelocTarget::User { .. } => { if ModuleDeclarations::is_function(name) { let func_id = FuncId::from_name(name); self.function_plt_entries[func_id] @@ -386,7 +386,7 @@ impl JITModule { unreachable!("PLT relocations can only have functions as target"); } } - ModuleExtName::LibCall(ref libcall) => self + ModuleRelocTarget::LibCall(ref libcall) => self .libcall_plt_entries .get(libcall) .unwrap_or_else(|| panic!("can't resolve libcall {}", libcall)) @@ -731,10 +731,10 @@ impl Module for JITModule { .unwrap() .perform_relocations( |name| match *name { - ModuleExtName::User { .. } => { + ModuleRelocTarget::User { .. } => { unreachable!("non GOT or PLT relocation in function {} to {}", id, name) } - ModuleExtName::LibCall(ref libcall) => self + ModuleRelocTarget::LibCall(ref libcall) => self .libcall_plt_entries .get(libcall) .unwrap_or_else(|| panic!("can't resolve libcall {}", libcall)) @@ -758,7 +758,7 @@ impl Module for JITModule { func: &ir::Function, alignment: u64, bytes: &[u8], - relocs: &[MachReloc], + relocs: &[FinalizedMachReloc], ) -> ModuleResult<()> { info!("defining function {} with bytes", id); let decl = self.declarations.get_function_decl(id); diff --git a/cranelift/jit/src/compiled_blob.rs b/cranelift/jit/src/compiled_blob.rs index f8f8bcd70d8f..de01855c6d7b 100644 --- a/cranelift/jit/src/compiled_blob.rs +++ b/cranelift/jit/src/compiled_blob.rs @@ -1,6 +1,6 @@ use cranelift_codegen::binemit::Reloc; -use cranelift_module::ModuleExtName; use cranelift_module::ModuleReloc; +use cranelift_module::ModuleRelocTarget; use std::convert::TryFrom; /// Reads a 32bit instruction at `iptr`, and writes it again after @@ -21,9 +21,9 @@ pub(crate) struct CompiledBlob { impl CompiledBlob { pub(crate) fn perform_relocations( &self, - get_address: impl Fn(&ModuleExtName) -> *const u8, - get_got_entry: impl Fn(&ModuleExtName) -> *const u8, - get_plt_entry: impl Fn(&ModuleExtName) -> *const u8, + get_address: impl Fn(&ModuleRelocTarget) -> *const u8, + get_got_entry: impl Fn(&ModuleRelocTarget) -> *const u8, + get_plt_entry: impl Fn(&ModuleRelocTarget) -> *const u8, ) { use std::ptr::write_unaligned; diff --git a/cranelift/module/src/data_context.rs b/cranelift/module/src/data_context.rs index 2b262eca38b3..b55f96fcd3c5 100644 --- a/cranelift/module/src/data_context.rs +++ b/cranelift/module/src/data_context.rs @@ -9,7 +9,7 @@ use std::string::String; use std::vec::Vec; use crate::module::ModuleReloc; -use crate::ModuleExtName; +use crate::ModuleRelocTarget; /// This specifies how data is to be initialized. #[derive(Clone, PartialEq, Eq, Debug)] @@ -53,9 +53,9 @@ pub struct DataDescription { /// How the data should be initialized. pub init: Init, /// External function declarations. - pub function_decls: PrimaryMap, + pub function_decls: PrimaryMap, /// External data object declarations. - pub data_decls: PrimaryMap, + pub data_decls: PrimaryMap, /// Function addresses to write at specified offsets. pub function_relocs: Vec<(CodeOffset, ir::FuncRef)>, /// Data addresses to write at specified offsets. @@ -122,7 +122,7 @@ impl DataDescription { /// Users of the `Module` API generally should call /// `Module::declare_func_in_data` instead, as it takes care of generating /// the appropriate `ExternalName`. - pub fn import_function(&mut self, name: ModuleExtName) -> ir::FuncRef { + pub fn import_function(&mut self, name: ModuleRelocTarget) -> ir::FuncRef { self.function_decls.push(name) } @@ -133,7 +133,7 @@ impl DataDescription { /// Users of the `Module` API generally should call /// `Module::declare_data_in_data` instead, as it takes care of generating /// the appropriate `ExternalName`. - pub fn import_global_value(&mut self, name: ModuleExtName) -> ir::GlobalValue { + pub fn import_global_value(&mut self, name: ModuleRelocTarget) -> ir::GlobalValue { self.data_decls.push(name) } @@ -176,7 +176,7 @@ impl DataDescription { #[cfg(test)] mod tests { - use crate::ModuleExtName; + use crate::ModuleRelocTarget; use super::{DataDescription, Init}; @@ -191,11 +191,11 @@ mod tests { data.define_zeroinit(256); - let _func_a = data.import_function(ModuleExtName::user(0, 0)); - let func_b = data.import_function(ModuleExtName::user(0, 1)); - let func_c = data.import_function(ModuleExtName::user(0, 2)); - let _data_a = data.import_global_value(ModuleExtName::user(0, 3)); - let data_b = data.import_global_value(ModuleExtName::user(0, 4)); + let _func_a = data.import_function(ModuleRelocTarget::user(0, 0)); + let func_b = data.import_function(ModuleRelocTarget::user(0, 1)); + let func_c = data.import_function(ModuleRelocTarget::user(0, 2)); + let _data_a = data.import_global_value(ModuleRelocTarget::user(0, 3)); + let data_b = data.import_global_value(ModuleRelocTarget::user(0, 4)); data.write_function_addr(8, func_b); data.write_function_addr(16, func_c); diff --git a/cranelift/module/src/lib.rs b/cranelift/module/src/lib.rs index b2bbe088e4af..88313c3a224b 100644 --- a/cranelift/module/src/lib.rs +++ b/cranelift/module/src/lib.rs @@ -29,7 +29,7 @@ mod traps; pub use crate::data_context::{DataDescription, Init}; pub use crate::module::{ DataDeclaration, DataId, FuncId, FuncOrDataId, FunctionDeclaration, Linkage, Module, - ModuleDeclarations, ModuleError, ModuleExtName, ModuleReloc, ModuleResult, + ModuleDeclarations, ModuleError, ModuleReloc, ModuleRelocTarget, ModuleResult, }; pub use crate::traps::TrapSite; diff --git a/cranelift/module/src/module.rs b/cranelift/module/src/module.rs index 917b8ca688c8..cae49ef17e97 100644 --- a/cranelift/module/src/module.rs +++ b/cranelift/module/src/module.rs @@ -11,9 +11,11 @@ use core::fmt::Display; use cranelift_codegen::binemit::{CodeOffset, Reloc}; use cranelift_codegen::entity::{entity_impl, PrimaryMap}; use cranelift_codegen::ir::function::{Function, VersionMarker}; +use cranelift_codegen::ir::{ExternalName, UserFuncName}; use cranelift_codegen::settings::SetError; -use cranelift_codegen::MachReloc; -use cranelift_codegen::{ir, isa, CodegenError, CompileError, Context}; +use cranelift_codegen::{ + ir, isa, CodegenError, CompileError, Context, FinalizedMachReloc, FinalizedRelocTarget, +}; use cranelift_control::ControlPlane; use std::borrow::{Cow, ToOwned}; use std::string::String; @@ -27,22 +29,29 @@ pub struct ModuleReloc { /// The kind of relocation. pub kind: Reloc, /// The external symbol / name to which this relocation refers. - pub name: ModuleExtName, + pub name: ModuleRelocTarget, /// The addend to add to the symbol value. pub addend: i64, } impl ModuleReloc { - /// Converts a `MachReloc` produced from a `Function` into a `ModuleReloc`. - pub fn from_mach_reloc(mach_reloc: &MachReloc, func: &Function) -> Self { - let name = match mach_reloc.name { - ir::ExternalName::User(reff) => { + /// Converts a `FinalizedMachReloc` produced from a `Function` into a `ModuleReloc`. + pub fn from_mach_reloc(mach_reloc: &FinalizedMachReloc, func: &Function) -> Self { + let name = match mach_reloc.target { + FinalizedRelocTarget::ExternalName(ExternalName::User(reff)) => { let name = &func.params.user_named_funcs()[reff]; - ModuleExtName::user(name.namespace, name.index) + ModuleRelocTarget::user(name.namespace, name.index) + } + FinalizedRelocTarget::ExternalName(ExternalName::TestCase(_)) => unimplemented!(), + FinalizedRelocTarget::ExternalName(ExternalName::LibCall(libcall)) => { + ModuleRelocTarget::LibCall(libcall) + } + FinalizedRelocTarget::ExternalName(ExternalName::KnownSymbol(ks)) => { + ModuleRelocTarget::KnownSymbol(ks) + } + FinalizedRelocTarget::Func(offset) => { + ModuleRelocTarget::FunctionOffset(func.name.clone(), offset) } - ir::ExternalName::TestCase(_) => unimplemented!(), - ir::ExternalName::LibCall(libcall) => ModuleExtName::LibCall(libcall), - ir::ExternalName::KnownSymbol(ks) => ModuleExtName::KnownSymbol(ks), }; Self { offset: mach_reloc.offset, @@ -63,7 +72,7 @@ pub struct FuncId(u32); entity_impl!(FuncId, "funcid"); /// Function identifiers are namespace 0 in `ir::ExternalName` -impl From for ModuleExtName { +impl From for ModuleRelocTarget { fn from(id: FuncId) -> Self { Self::User { namespace: 0, @@ -74,8 +83,8 @@ impl From for ModuleExtName { impl FuncId { /// Get the `FuncId` for the function named by `name`. - pub fn from_name(name: &ModuleExtName) -> FuncId { - if let ModuleExtName::User { namespace, index } = name { + pub fn from_name(name: &ModuleRelocTarget) -> FuncId { + if let ModuleRelocTarget::User { namespace, index } = name { debug_assert_eq!(*namespace, 0); FuncId::from_u32(*index) } else { @@ -94,7 +103,7 @@ pub struct DataId(u32); entity_impl!(DataId, "dataid"); /// Data identifiers are namespace 1 in `ir::ExternalName` -impl From for ModuleExtName { +impl From for ModuleRelocTarget { fn from(id: DataId) -> Self { Self::User { namespace: 1, @@ -105,8 +114,8 @@ impl From for ModuleExtName { impl DataId { /// Get the `DataId` for the data object named by `name`. - pub fn from_name(name: &ModuleExtName) -> DataId { - if let ModuleExtName::User { namespace, index } = name { + pub fn from_name(name: &ModuleRelocTarget) -> DataId { + if let ModuleRelocTarget::User { namespace, index } = name { debug_assert_eq!(*namespace, 1); DataId::from_u32(*index) } else { @@ -191,7 +200,7 @@ pub enum FuncOrDataId { } /// Mapping to `ModuleExtName` is trivial based on the `FuncId` and `DataId` mapping. -impl From for ModuleExtName { +impl From for ModuleRelocTarget { fn from(id: FuncOrDataId) -> Self { match id { FuncOrDataId::Func(funcid) => Self::from(funcid), @@ -408,7 +417,7 @@ impl DataDeclaration { feature = "enable-serde", derive(serde_derive::Serialize, serde_derive::Deserialize) )] -pub enum ModuleExtName { +pub enum ModuleRelocTarget { /// User defined function, converted from `ExternalName::User`. User { /// Arbitrary. @@ -420,21 +429,24 @@ pub enum ModuleExtName { LibCall(ir::LibCall), /// Symbols known to the linker. KnownSymbol(ir::KnownSymbol), + /// A offset inside a function + FunctionOffset(UserFuncName, CodeOffset), } -impl ModuleExtName { +impl ModuleRelocTarget { /// Creates a user-defined external name. pub fn user(namespace: u32, index: u32) -> Self { Self::User { namespace, index } } } -impl Display for ModuleExtName { +impl Display for ModuleRelocTarget { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { Self::User { namespace, index } => write!(f, "u{}:{}", namespace, index), Self::LibCall(lc) => write!(f, "%{}", lc), Self::KnownSymbol(ks) => write!(f, "{}", ks), + Self::FunctionOffset(fname, offset) => write!(f, "{fname}+{offset}"), } } } @@ -686,10 +698,12 @@ impl ModuleDeclarations { } /// Return whether `name` names a function, rather than a data object. - pub fn is_function(name: &ModuleExtName) -> bool { + pub fn is_function(name: &ModuleRelocTarget) -> bool { match name { - ModuleExtName::User { namespace, .. } => *namespace == 0, - ModuleExtName::LibCall(_) | ModuleExtName::KnownSymbol(_) => { + ModuleRelocTarget::User { namespace, .. } => *namespace == 0, + ModuleRelocTarget::LibCall(_) + | ModuleRelocTarget::KnownSymbol(_) + | ModuleRelocTarget::FunctionOffset(..) => { panic!("unexpected module ext name") } } @@ -918,12 +932,12 @@ pub trait Module { /// TODO: Same as above. fn declare_func_in_data(&self, func_id: FuncId, data: &mut DataDescription) -> ir::FuncRef { - data.import_function(ModuleExtName::user(0, func_id.as_u32())) + data.import_function(ModuleRelocTarget::user(0, func_id.as_u32())) } /// TODO: Same as above. fn declare_data_in_data(&self, data_id: DataId, data: &mut DataDescription) -> ir::GlobalValue { - data.import_global_value(ModuleExtName::user(1, data_id.as_u32())) + data.import_global_value(ModuleRelocTarget::user(1, data_id.as_u32())) } /// Define a function, producing the function body from the given `Context`. @@ -965,7 +979,7 @@ pub trait Module { func: &ir::Function, alignment: u64, bytes: &[u8], - relocs: &[MachReloc], + relocs: &[FinalizedMachReloc], ) -> ModuleResult<()>; /// Define a data object, producing the data contents from the given `DataContext`. @@ -1067,7 +1081,7 @@ impl Module for &mut M { func: &ir::Function, alignment: u64, bytes: &[u8], - relocs: &[MachReloc], + relocs: &[FinalizedMachReloc], ) -> ModuleResult<()> { (**self).define_function_bytes(func_id, func, alignment, bytes, relocs) } diff --git a/cranelift/object/src/backend.rs b/cranelift/object/src/backend.rs index 453936aeffb3..d083d84ce12f 100644 --- a/cranelift/object/src/backend.rs +++ b/cranelift/object/src/backend.rs @@ -3,12 +3,13 @@ use anyhow::anyhow; use cranelift_codegen::binemit::{Addend, CodeOffset, Reloc}; use cranelift_codegen::entity::SecondaryMap; +use cranelift_codegen::ir::UserFuncName; use cranelift_codegen::isa::{OwnedTargetIsa, TargetIsa}; -use cranelift_codegen::{self, ir, MachReloc}; +use cranelift_codegen::{self, ir, FinalizedMachReloc}; use cranelift_control::ControlPlane; use cranelift_module::{ DataDescription, DataId, FuncId, Init, Linkage, Module, ModuleDeclarations, ModuleError, - ModuleExtName, ModuleReloc, ModuleResult, + ModuleReloc, ModuleRelocTarget, ModuleResult, }; use log::info; use object::write::{ @@ -17,6 +18,7 @@ use object::write::{ use object::{ RelocationEncoding, RelocationKind, SectionKind, SymbolFlags, SymbolKind, SymbolScope, }; +use std::collections::hash_map::Entry; use std::collections::HashMap; use std::mem; use target_lexicon::PointerWidth; @@ -130,6 +132,7 @@ pub struct ObjectModule { libcalls: HashMap, libcall_names: Box String + Send + Sync>, known_symbols: HashMap, + known_labels: HashMap<(UserFuncName, CodeOffset), SymbolId>, per_function_section: bool, } @@ -149,6 +152,7 @@ impl ObjectModule { libcalls: HashMap::new(), libcall_names: builder.libcall_names, known_symbols: HashMap::new(), + known_labels: HashMap::new(), per_function_section: builder.per_function_section, } } @@ -333,21 +337,18 @@ impl Module for ObjectModule { func: &ir::Function, alignment: u64, bytes: &[u8], - relocs: &[MachReloc], + relocs: &[FinalizedMachReloc], ) -> ModuleResult<()> { info!("defining function {} with bytes", func_id); let decl = self.declarations.get_function_decl(func_id); + let decl_name = decl.linkage_name(func_id); if !decl.linkage.is_definable() { - return Err(ModuleError::InvalidImportDefinition( - decl.linkage_name(func_id).into_owned(), - )); + return Err(ModuleError::InvalidImportDefinition(decl_name.into_owned())); } let &mut (symbol, ref mut defined) = self.functions[func_id].as_mut().unwrap(); if *defined { - return Err(ModuleError::DuplicateDefinition( - decl.linkage_name(func_id).into_owned(), - )); + return Err(ModuleError::DuplicateDefinition(decl_name.into_owned())); } *defined = true; @@ -528,9 +529,9 @@ impl ObjectModule { /// This should only be called during finish because it creates /// symbols for missing libcalls. - fn get_symbol(&mut self, name: &ModuleExtName) -> SymbolId { + fn get_symbol(&mut self, name: &ModuleRelocTarget) -> SymbolId { match *name { - ModuleExtName::User { .. } => { + ModuleRelocTarget::User { .. } => { if ModuleDeclarations::is_function(name) { let id = FuncId::from_name(name); self.functions[id].unwrap().0 @@ -539,7 +540,7 @@ impl ObjectModule { self.data_objects[id].unwrap().0 } } - ModuleExtName::LibCall(ref libcall) => { + ModuleRelocTarget::LibCall(ref libcall) => { let name = (self.libcall_names)(*libcall); if let Some(symbol) = self.object.symbol_id(name.as_bytes()) { symbol @@ -562,7 +563,7 @@ impl ObjectModule { } // These are "magic" names well-known to the linker. // They require special treatment. - ModuleExtName::KnownSymbol(ref known_symbol) => { + ModuleRelocTarget::KnownSymbol(ref known_symbol) => { if let Some(symbol) = self.known_symbols.get(known_symbol) { *symbol } else { @@ -592,6 +593,36 @@ impl ObjectModule { symbol } } + + ModuleRelocTarget::FunctionOffset(ref fname, offset) => { + match self.known_labels.entry((fname.clone(), offset)) { + Entry::Occupied(o) => *o.get(), + Entry::Vacant(v) => { + let func_user_name = fname.get_user().unwrap(); + let func_id = FuncId::from_name(&ModuleRelocTarget::user( + func_user_name.namespace, + func_user_name.index, + )); + let func_symbol_id = self.functions[func_id].unwrap().0; + let func_symbol = self.object.symbol(func_symbol_id); + + let name = format!(".L{}_{}", func_id.as_u32(), offset); + let symbol_id = self.object.add_symbol(Symbol { + name: name.as_bytes().to_vec(), + value: func_symbol.value + offset as u64, + size: 0, + kind: SymbolKind::Label, + scope: SymbolScope::Compilation, + weak: false, + section: SymbolSection::Section(func_symbol.section.id().unwrap()), + flags: SymbolFlags::None, + }); + + v.insert(symbol_id); + symbol_id + } + } + } } } @@ -776,6 +807,30 @@ impl ObjectModule { 0, ) } + Reloc::RiscvTlsGdHi20 => { + assert_eq!( + self.object.format(), + object::BinaryFormat::Elf, + "RiscvTlsGdHi20 is not supported for this file format" + ); + ( + RelocationKind::Elf(object::elf::R_RISCV_TLS_GD_HI20), + RelocationEncoding::Generic, + 0, + ) + } + Reloc::RiscvPCRelLo12I => { + assert_eq!( + self.object.format(), + object::BinaryFormat::Elf, + "RiscvPCRelLo12I is not supported for this file format" + ); + ( + RelocationKind::Elf(object::elf::R_RISCV_PCREL_LO12_I), + RelocationEncoding::Generic, + 0, + ) + } // FIXME reloc => unimplemented!("{:?}", reloc), }; @@ -846,7 +901,7 @@ struct SymbolRelocs { #[derive(Clone)] struct ObjectRelocRecord { offset: CodeOffset, - name: ModuleExtName, + name: ModuleRelocTarget, kind: RelocationKind, encoding: RelocationEncoding, size: u8, diff --git a/cranelift/src/disasm.rs b/cranelift/src/disasm.rs index 1739f526b8c5..0b0235b67959 100644 --- a/cranelift/src/disasm.rs +++ b/cranelift/src/disasm.rs @@ -2,15 +2,15 @@ use anyhow::Result; use cfg_if::cfg_if; use cranelift_codegen::ir::function::FunctionParameters; use cranelift_codegen::isa::TargetIsa; -use cranelift_codegen::{MachReloc, MachStackMap, MachTrap}; +use cranelift_codegen::{FinalizedMachReloc, MachStackMap, MachTrap}; use std::fmt::Write; -fn print_relocs(func_params: &FunctionParameters, relocs: &[MachReloc]) -> String { +fn print_relocs(func_params: &FunctionParameters, relocs: &[FinalizedMachReloc]) -> String { let mut text = String::new(); - for &MachReloc { + for &FinalizedMachReloc { kind, offset, - ref name, + ref target, addend, } in relocs { @@ -18,7 +18,7 @@ fn print_relocs(func_params: &FunctionParameters, relocs: &[MachReloc]) -> Strin text, "reloc_external: {} {} {} at {}", kind, - name.display(Some(func_params)), + target.display(Some(func_params)), addend, offset ) @@ -121,7 +121,7 @@ pub fn print_all( mem: &[u8], code_size: u32, print: bool, - relocs: &[MachReloc], + relocs: &[FinalizedMachReloc], traps: &[MachTrap], stack_maps: &[MachStackMap], ) -> Result<()> { diff --git a/crates/cranelift-shared/src/lib.rs b/crates/cranelift-shared/src/lib.rs index add380288703..b508e04a79e6 100644 --- a/crates/cranelift-shared/src/lib.rs +++ b/crates/cranelift-shared/src/lib.rs @@ -1,7 +1,7 @@ use cranelift_codegen::{ binemit, ir::{self, ExternalName, UserExternalNameRef}, - settings, MachReloc, MachTrap, + settings, FinalizedMachReloc, FinalizedRelocTarget, MachTrap, }; use std::collections::BTreeMap; use wasmtime_environ::{FlagValue, FuncIndex, Trap, TrapInformation}; @@ -101,24 +101,26 @@ pub fn mach_trap_to_trap(trap: &MachTrap) -> Option { /// Converts machine relocations to relocation information /// to perform. -fn mach_reloc_to_reloc(reloc: &MachReloc, transform_user_func_ref: F) -> Relocation +fn mach_reloc_to_reloc(reloc: &FinalizedMachReloc, transform_user_func_ref: F) -> Relocation where F: Fn(UserExternalNameRef) -> (u32, u32), { - let &MachReloc { + let &FinalizedMachReloc { offset, kind, - ref name, + ref target, addend, } = reloc; - let reloc_target = if let ExternalName::User(user_func_ref) = *name { - let (namespace, index) = transform_user_func_ref(user_func_ref); - debug_assert_eq!(namespace, 0); - RelocationTarget::UserFunc(FuncIndex::from_u32(index)) - } else if let ExternalName::LibCall(libcall) = *name { - RelocationTarget::LibCall(libcall) - } else { - panic!("unrecognized external name") + let reloc_target = match *target { + FinalizedRelocTarget::ExternalName(ExternalName::User(user_func_ref)) => { + let (namespace, index) = transform_user_func_ref(user_func_ref); + debug_assert_eq!(namespace, 0); + RelocationTarget::UserFunc(FuncIndex::from_u32(index)) + } + FinalizedRelocTarget::ExternalName(ExternalName::LibCall(libcall)) => { + RelocationTarget::LibCall(libcall) + } + _ => panic!("unrecognized external name"), }; Relocation { reloc: kind, From 27345059c7bd8403b9a8bacae8a544383a59578c Mon Sep 17 00:00:00 2001 From: Afonso Bordado Date: Mon, 18 Sep 2023 20:46:16 +0100 Subject: [PATCH 03/14] riscv64: Cleanup trap handling (#7047) * riscv64: Deduplicate Trap Instruction * riscv64: Use `defer_trap` in TrapIf instruction This places the actual trap opcode at the end. * riscv64: Emit islands before `br_table` sequence This fixes a slightly subtle issue with our island emission in BrTable. We used to emit islands right before the jump table targets. This causes issues because if the island is actually emitted, we have the potential to jump right into the middle of it. This happens because we have calculated a fixed offset from the `auipc` instruction assuming no island is emitted. This commit changes the island to be emitted right at the start of the br_table sequence so that this cannot happen. * riscv64: Add trapz and trapnz helpers * riscv64: Emit inline traps on `TrapIf` --- cranelift/codegen/src/isa/riscv64/abi.rs | 2 +- cranelift/codegen/src/isa/riscv64/inst.isle | 51 ++++++----- .../codegen/src/isa/riscv64/inst/emit.rs | 84 ++++++++----------- cranelift/codegen/src/isa/riscv64/inst/mod.rs | 12 +-- cranelift/codegen/src/isa/riscv64/lower.isle | 2 +- .../filetests/isa/riscv64/arithmetic.clif | 70 ++++++++-------- .../filetests/isa/riscv64/stack-limit.clif | 18 ++-- .../isa/riscv64/uadd_overflow_trap.clif | 12 +-- ...o_spectre_i32_access_0xffff0000_offset.wat | 4 +- ...no_spectre_i8_access_0xffff0000_offset.wat | 4 +- ...s_spectre_i32_access_0xffff0000_offset.wat | 4 +- ...es_spectre_i8_access_0xffff0000_offset.wat | 4 +- ...o_spectre_i32_access_0xffff0000_offset.wat | 4 +- ...no_spectre_i8_access_0xffff0000_offset.wat | 4 +- ...s_spectre_i32_access_0xffff0000_offset.wat | 4 +- ...es_spectre_i8_access_0xffff0000_offset.wat | 4 +- 16 files changed, 130 insertions(+), 153 deletions(-) diff --git a/cranelift/codegen/src/isa/riscv64/abi.rs b/cranelift/codegen/src/isa/riscv64/abi.rs index 4ace58b8e386..9abc85451d0e 100644 --- a/cranelift/codegen/src/isa/riscv64/abi.rs +++ b/cranelift/codegen/src/isa/riscv64/abi.rs @@ -290,7 +290,7 @@ impl ABIMachineSpec for Riscv64MachineDeps { fn gen_stack_lower_bound_trap(limit_reg: Reg) -> SmallInstVec { let mut insts = SmallVec::new(); - insts.push(Inst::TrapIfC { + insts.push(Inst::TrapIf { cc: IntCC::UnsignedLessThan, rs1: stack_reg(), rs2: limit_reg, diff --git a/cranelift/codegen/src/isa/riscv64/inst.isle b/cranelift/codegen/src/isa/riscv64/inst.isle index 2fe4fa6f8e02..72ffdcf087c4 100644 --- a/cranelift/codegen/src/isa/riscv64/inst.isle +++ b/cranelift/codegen/src/isa/riscv64/inst.isle @@ -122,12 +122,8 @@ (callee Reg) (info BoxReturnCallInfo)) + ;; Emits a trap with the given trap code if the comparison succeeds (TrapIf - (test Reg) - (trap_code TrapCode)) - - ;; use a simple compare to decide to cause trap or not. - (TrapIfC (rs1 Reg) (rs2 Reg) (cc IntCC) @@ -2877,36 +2873,39 @@ (gen_select_reg (IntCC.SignedGreaterThan) x y x y)) -(decl gen_trapif (XReg TrapCode) InstOutput) -(rule - (gen_trapif test trap_code) - (side_effect (SideEffectNoResult.Inst (MInst.TrapIf test trap_code)))) +;; Builds an instruction sequence that traps if the comparision succeeds. +(decl gen_trapif (IntCC XReg XReg TrapCode) InstOutput) +(rule (gen_trapif cc a b trap_code) + (side_effect (SideEffectNoResult.Inst (MInst.TrapIf a b cc trap_code)))) + +;; Builds an instruction sequence that traps if the input is non-zero. +(decl gen_trapnz (XReg TrapCode) InstOutput) +(rule (gen_trapnz test trap_code) + (gen_trapif (IntCC.NotEqual) test (zero_reg) trap_code)) + +;; Builds an instruction sequence that traps if the input is zero. +(decl gen_trapz (XReg TrapCode) InstOutput) +(rule (gen_trapz test trap_code) + (gen_trapif (IntCC.Equal) test (zero_reg) trap_code)) -(decl gen_trapifc (IntCC XReg XReg TrapCode) InstOutput) -(rule - (gen_trapifc cc a b trap_code) - (side_effect (SideEffectNoResult.Inst (MInst.TrapIfC a b cc trap_code)))) (decl shift_int_to_most_significant (XReg Type) XReg) (extern constructor shift_int_to_most_significant shift_int_to_most_significant) ;;; generate div overflow. (decl gen_div_overflow (XReg XReg Type) InstOutput) -(rule - (gen_div_overflow rs1 rs2 ty) - (let - ((r_const_neg_1 XReg (imm $I64 (i64_as_u64 -1))) - (r_const_min XReg (rv_slli (imm $I64 1) (imm12_const 63))) - (tmp_rs1 XReg (shift_int_to_most_significant rs1 ty)) - (t1 XReg (gen_icmp (IntCC.Equal) r_const_neg_1 rs2 ty)) - (t2 XReg (gen_icmp (IntCC.Equal) r_const_min tmp_rs1 ty)) - (test XReg (rv_and t1 t2))) - (gen_trapif test (TrapCode.IntegerOverflow)))) +(rule (gen_div_overflow rs1 rs2 ty) + (let ((r_const_neg_1 XReg (imm $I64 (i64_as_u64 -1))) + (r_const_min XReg (rv_slli (imm $I64 1) (imm12_const 63))) + (tmp_rs1 XReg (shift_int_to_most_significant rs1 ty)) + (t1 XReg (gen_icmp (IntCC.Equal) r_const_neg_1 rs2 ty)) + (t2 XReg (gen_icmp (IntCC.Equal) r_const_min tmp_rs1 ty)) + (test XReg (rv_and t1 t2))) + (gen_trapnz test (TrapCode.IntegerOverflow)))) (decl gen_div_by_zero (XReg) InstOutput) -(rule - (gen_div_by_zero r) - (gen_trapifc (IntCC.Equal) (zero_reg) r (TrapCode.IntegerDivisionByZero))) +(rule (gen_div_by_zero r) + (gen_trapz r (TrapCode.IntegerDivisionByZero))) ;;;; Helpers for Emitting Calls ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; diff --git a/cranelift/codegen/src/isa/riscv64/inst/emit.rs b/cranelift/codegen/src/isa/riscv64/inst/emit.rs index 3f8ef674ff3a..1f5a9f7684ce 100644 --- a/cranelift/codegen/src/isa/riscv64/inst/emit.rs +++ b/cranelift/codegen/src/isa/riscv64/inst/emit.rs @@ -339,7 +339,6 @@ impl Inst { | Inst::CallInd { .. } | Inst::ReturnCall { .. } | Inst::ReturnCallInd { .. } - | Inst::TrapIf { .. } | Inst::Jal { .. } | Inst::CondBr { .. } | Inst::LoadExtName { .. } @@ -365,7 +364,7 @@ impl Inst { | Inst::AtomicStore { .. } | Inst::AtomicLoad { .. } | Inst::AtomicRmwLoop { .. } - | Inst::TrapIfC { .. } + | Inst::TrapIf { .. } | Inst::Unwind { .. } | Inst::DummyUse { .. } | Inst::FloatRound { .. } @@ -1097,6 +1096,20 @@ impl Inst { let default_target = targets[0]; let targets = &targets[1..]; + // We are going to potentially emit a large amount of instructions, so ensure that we emit an island + // now if we need one. + // + // The worse case PC calculations are 12 instructions. And each entry in the jump table is 2 instructions. + // Check if we need to emit a jump table here to support that jump. + let inst_count = 12 + (targets.len() * 2); + let distance = (inst_count * Inst::UNCOMPRESSED_INSTRUCTION_SIZE as usize) as u32; + if sink.island_needed(distance) { + let jump_around_label = sink.get_label(); + Inst::gen_jump(jump_around_label).emit(&[], sink, emit_info, state); + sink.emit_island(distance + 4, &mut state.ctrl_plane); + sink.bind_label(jump_around_label, &mut state.ctrl_plane); + } + // We emit a bounds check on the index, if the index is larger than the number of // jump table entries, we jump to the default block. Otherwise we compute a jump // offset by multiplying the index by 8 (the size of each entry) and then jump to @@ -1204,17 +1217,9 @@ impl Inst { // Emit the jump table. // - // Each entry is a aupc + jalr to the target block. We also start with a island + // Each entry is a auipc + jalr to the target block. We also start with a island // if necessary. - // Each entry in the jump table is 2 instructions, so 8 bytes. Check if - // we need to emit a jump table here to support that jump. - let distance = - (targets.len() * 2 * Inst::UNCOMPRESSED_INSTRUCTION_SIZE as usize) as u32; - if sink.island_needed(distance) { - sink.emit_island(distance, &mut state.ctrl_plane); - } - // Emit the jumps back to back for target in targets.iter() { sink.use_label_at_offset(sink.cur_offset(), *target, LabelUse::PCRel32); @@ -1824,7 +1829,9 @@ impl Inst { } .emit(&[], sink, emit_info, state); Inst::TrapIf { - test: rd.to_reg(), + cc: IntCC::NotEqual, + rs1: rd.to_reg(), + rs2: zero_reg(), trap_code: TrapCode::IntegerOverflow, } .emit(&[], sink, emit_info, state); @@ -1852,7 +1859,9 @@ impl Inst { .emit(&[], sink, emit_info, state); Inst::TrapIf { - test: rd.to_reg(), + cc: IntCC::NotEqual, + rs1: rd.to_reg(), + rs2: zero_reg(), trap_code: TrapCode::IntegerOverflow, } .emit(&[], sink, emit_info, state); @@ -2013,45 +2022,25 @@ impl Inst { .emit_uncompressed(sink, emit_info, state, start_off); } - &Inst::TrapIfC { + &Inst::TrapIf { rs1, rs2, cc, trap_code, } => { - let label_trap = sink.get_label(); - let label_jump_over = sink.get_label(); + let label_end = sink.get_label(); + let cond = IntegerCompare { kind: cc, rs1, rs2 }; + + // Jump over the trap if we the condition is false. Inst::CondBr { - taken: CondBrTarget::Label(label_trap), - not_taken: CondBrTarget::Label(label_jump_over), - kind: IntegerCompare { kind: cc, rs1, rs2 }, + taken: CondBrTarget::Label(label_end), + not_taken: CondBrTarget::Fallthrough, + kind: cond.inverse(), } .emit(&[], sink, emit_info, state); - // trap - sink.bind_label(label_trap, &mut state.ctrl_plane); Inst::Udf { trap_code }.emit(&[], sink, emit_info, state); - sink.bind_label(label_jump_over, &mut state.ctrl_plane); - } - &Inst::TrapIf { test, trap_code } => { - let label_trap = sink.get_label(); - let label_jump_over = sink.get_label(); - Inst::CondBr { - taken: CondBrTarget::Label(label_trap), - not_taken: CondBrTarget::Label(label_jump_over), - kind: IntegerCompare { - kind: IntCC::NotEqual, - rs1: test, - rs2: zero_reg(), - }, - } - .emit(&[], sink, emit_info, state); - // trap - sink.bind_label(label_trap, &mut state.ctrl_plane); - Inst::Udf { - trap_code: trap_code, - } - .emit(&[], sink, emit_info, state); - sink.bind_label(label_jump_over, &mut state.ctrl_plane); + + sink.bind_label(label_end, &mut state.ctrl_plane); } &Inst::Udf { trap_code } => { sink.add_trap(trap_code); @@ -3334,23 +3323,18 @@ impl Inst { Inst::ElfTlsGetAddr { rd, name } } - Inst::TrapIfC { + Inst::TrapIf { rs1, rs2, cc, trap_code, - } => Inst::TrapIfC { + } => Inst::TrapIf { rs1: allocs.next(rs1), rs2: allocs.next(rs2), cc, trap_code, }, - Inst::TrapIf { test, trap_code } => Inst::TrapIf { - test: allocs.next(test), - trap_code, - }, - Inst::Udf { .. } => self, Inst::AtomicLoad { rd, ty, p } => Inst::AtomicLoad { diff --git a/cranelift/codegen/src/isa/riscv64/inst/mod.rs b/cranelift/codegen/src/isa/riscv64/inst/mod.rs index 5a12644f82e0..77a207e21788 100644 --- a/cranelift/codegen/src/isa/riscv64/inst/mod.rs +++ b/cranelift/codegen/src/isa/riscv64/inst/mod.rs @@ -459,9 +459,6 @@ fn riscv64_get_operands VReg>(inst: &Inst, collector: &mut Operan collector.reg_fixed_use(u.vreg, u.preg); } } - &Inst::TrapIf { test, .. } => { - collector.reg_use(test); - } &Inst::Jal { .. } => { // JAL technically has a rd register, but we currently always // hardcode it to x0. @@ -603,7 +600,7 @@ fn riscv64_get_operands VReg>(inst: &Inst, collector: &mut Operan collector.reg_early_def(t0); collector.reg_early_def(dst); } - &Inst::TrapIfC { rs1, rs2, .. } => { + &Inst::TrapIf { rs1, rs2, .. } => { collector.reg_use(rs1); collector.reg_use(rs2); } @@ -1635,10 +1632,7 @@ impl Inst { } s } - &MInst::TrapIf { test, trap_code } => { - format!("trap_if {},{}", format_reg(test, allocs), trap_code,) - } - &MInst::TrapIfC { + &MInst::TrapIf { rs1, rs2, cc, @@ -1646,7 +1640,7 @@ impl Inst { } => { let rs1 = format_reg(rs1, allocs); let rs2 = format_reg(rs2, allocs); - format!("trap_ifc {}##({} {} {})", trap_code, rs1, cc, rs2) + format!("trap_if {trap_code}##({rs1} {cc} {rs2})") } &MInst::Jal { label } => { format!("j {}", label.to_string()) diff --git a/cranelift/codegen/src/isa/riscv64/lower.isle b/cranelift/codegen/src/isa/riscv64/lower.isle index ec264a221094..d67adcc198e2 100644 --- a/cranelift/codegen/src/isa/riscv64/lower.isle +++ b/cranelift/codegen/src/isa/riscv64/lower.isle @@ -297,7 +297,7 @@ (rule (lower (has_type (fits_in_64 ty) (uadd_overflow_trap x y tc))) (let ((res ValueRegs (lower_uadd_overflow x y ty)) - (_ InstOutput (gen_trapif (value_regs_get res 1) tc))) + (_ InstOutput (gen_trapnz (value_regs_get res 1) tc))) (value_regs_get res 0))) diff --git a/cranelift/filetests/filetests/isa/riscv64/arithmetic.clif b/cranelift/filetests/filetests/isa/riscv64/arithmetic.clif index 723b57d311d5..a04b3b8cfefe 100644 --- a/cranelift/filetests/filetests/isa/riscv64/arithmetic.clif +++ b/cranelift/filetests/filetests/isa/riscv64/arithmetic.clif @@ -96,8 +96,8 @@ block0(v0: i64, v1: i64): ; eq a3,a3,a1##ty=i64 ; eq a5,a2,a0##ty=i64 ; and a2,a3,a5 -; trap_if a2,int_ovf -; trap_ifc int_divz##(zero eq a1) +; trap_if int_ovf##(a2 ne zero) +; trap_if int_divz##(a1 eq zero) ; div a0,a0,a1 ; ret ; @@ -117,7 +117,7 @@ block0(v0: i64, v1: i64): ; and a2, a3, a5 ; beqz a2, 8 ; .byte 0x00, 0x00, 0x00, 0x00 ; trap: int_ovf -; bne zero, a1, 8 +; bnez a1, 8 ; .byte 0x00, 0x00, 0x00, 0x00 ; trap: int_divz ; div a0, a0, a1 ; ret @@ -138,8 +138,8 @@ block0(v0: i64): ; eq a4,a4,a3##ty=i64 ; eq a5,a1,a0##ty=i64 ; and a1,a4,a5 -; trap_if a1,int_ovf -; trap_ifc int_divz##(zero eq a3) +; trap_if int_ovf##(a1 ne zero) +; trap_if int_divz##(a3 eq zero) ; div a0,a0,a3 ; ret ; @@ -160,7 +160,7 @@ block0(v0: i64): ; and a1, a4, a5 ; beqz a1, 8 ; .byte 0x00, 0x00, 0x00, 0x00 ; trap: int_ovf -; bne zero, a3, 8 +; bnez a3, 8 ; .byte 0x00, 0x00, 0x00, 0x00 ; trap: int_divz ; div a0, a0, a3 ; ret @@ -173,13 +173,13 @@ block0(v0: i64, v1: i64): ; VCode: ; block0: -; trap_ifc int_divz##(zero eq a1) +; trap_if int_divz##(a1 eq zero) ; divu a0,a0,a1 ; ret ; ; Disassembled: ; block0: ; offset 0x0 -; bne zero, a1, 8 +; bnez a1, 8 ; .byte 0x00, 0x00, 0x00, 0x00 ; trap: int_divz ; divu a0, a0, a1 ; ret @@ -194,14 +194,14 @@ block0(v0: i64): ; VCode: ; block0: ; li a3,2 -; trap_ifc int_divz##(zero eq a3) +; trap_if int_divz##(a3 eq zero) ; divu a0,a0,a3 ; ret ; ; Disassembled: ; block0: ; offset 0x0 ; addi a3, zero, 2 -; bne zero, a3, 8 +; bnez a3, 8 ; .byte 0x00, 0x00, 0x00, 0x00 ; trap: int_divz ; divu a0, a0, a3 ; ret @@ -214,13 +214,13 @@ block0(v0: i64, v1: i64): ; VCode: ; block0: -; trap_ifc int_divz##(zero eq a1) +; trap_if int_divz##(a1 eq zero) ; rem a0,a0,a1 ; ret ; ; Disassembled: ; block0: ; offset 0x0 -; bne zero, a1, 8 +; bnez a1, 8 ; .byte 0x00, 0x00, 0x00, 0x00 ; trap: int_divz ; rem a0, a0, a1 ; ret @@ -233,13 +233,13 @@ block0(v0: i64, v1: i64): ; VCode: ; block0: -; trap_ifc int_divz##(zero eq a1) +; trap_if int_divz##(a1 eq zero) ; remu a0,a0,a1 ; ret ; ; Disassembled: ; block0: ; offset 0x0 -; bne zero, a1, 8 +; bnez a1, 8 ; .byte 0x00, 0x00, 0x00, 0x00 ; trap: int_divz ; remu a0, a0, a1 ; ret @@ -261,8 +261,8 @@ block0(v0: i32, v1: i32): ; eq a4,a1,a5##ty=i32 ; eq a0,a0,a2##ty=i32 ; and a1,a4,a0 -; trap_if a1,int_ovf -; trap_ifc int_divz##(zero eq a5) +; trap_if int_ovf##(a1 ne zero) +; trap_if int_divz##(a5 eq zero) ; divw a0,a3,a5 ; ret ; @@ -285,7 +285,7 @@ block0(v0: i32, v1: i32): ; and a1, a4, a0 ; beqz a1, 8 ; .byte 0x00, 0x00, 0x00, 0x00 ; trap: int_ovf -; bne zero, a5, 8 +; bnez a5, 8 ; .byte 0x00, 0x00, 0x00, 0x00 ; trap: int_divz ; divw a0, a3, a5 ; ret @@ -310,8 +310,8 @@ block0(v0: i32): ; eq a4,a1,a5##ty=i32 ; eq a0,a0,a2##ty=i32 ; and a1,a4,a0 -; trap_if a1,int_ovf -; trap_ifc int_divz##(zero eq a5) +; trap_if int_ovf##(a1 ne zero) +; trap_if int_divz##(a5 eq zero) ; divw a0,a3,a5 ; ret ; @@ -336,7 +336,7 @@ block0(v0: i32): ; and a1, a4, a0 ; beqz a1, 8 ; .byte 0x00, 0x00, 0x00, 0x00 ; trap: int_ovf -; bne zero, a5, 8 +; bnez a5, 8 ; .byte 0x00, 0x00, 0x00, 0x00 ; trap: int_divz ; divw a0, a3, a5 ; ret @@ -351,7 +351,7 @@ block0(v0: i32, v1: i32): ; block0: ; slli a3,a1,32 ; srli a5,a3,32 -; trap_ifc int_divz##(zero eq a5) +; trap_if int_divz##(a5 eq zero) ; slli a2,a0,32 ; srli a4,a2,32 ; divuw a0,a4,a5 @@ -361,7 +361,7 @@ block0(v0: i32, v1: i32): ; block0: ; offset 0x0 ; slli a3, a1, 0x20 ; srli a5, a3, 0x20 -; bne zero, a5, 8 +; bnez a5, 8 ; .byte 0x00, 0x00, 0x00, 0x00 ; trap: int_divz ; slli a2, a0, 0x20 ; srli a4, a2, 0x20 @@ -380,7 +380,7 @@ block0(v0: i32): ; li a1,2 ; slli a3,a1,32 ; srli a5,a3,32 -; trap_ifc int_divz##(zero eq a5) +; trap_if int_divz##(a5 eq zero) ; slli a2,a0,32 ; srli a4,a2,32 ; divuw a0,a4,a5 @@ -391,7 +391,7 @@ block0(v0: i32): ; addi a1, zero, 2 ; slli a3, a1, 0x20 ; srli a5, a3, 0x20 -; bne zero, a5, 8 +; bnez a5, 8 ; .byte 0x00, 0x00, 0x00, 0x00 ; trap: int_divz ; slli a2, a0, 0x20 ; srli a4, a2, 0x20 @@ -407,14 +407,14 @@ block0(v0: i32, v1: i32): ; VCode: ; block0: ; sext.w a3,a1 -; trap_ifc int_divz##(zero eq a3) +; trap_if int_divz##(a3 eq zero) ; remw a0,a0,a3 ; ret ; ; Disassembled: ; block0: ; offset 0x0 ; sext.w a3, a1 -; bne zero, a3, 8 +; bnez a3, 8 ; .byte 0x00, 0x00, 0x00, 0x00 ; trap: int_divz ; remw a0, a0, a3 ; ret @@ -429,7 +429,7 @@ block0(v0: i32, v1: i32): ; block0: ; slli a3,a1,32 ; srli a5,a3,32 -; trap_ifc int_divz##(zero eq a5) +; trap_if int_divz##(a5 eq zero) ; remuw a0,a0,a5 ; ret ; @@ -437,7 +437,7 @@ block0(v0: i32, v1: i32): ; block0: ; offset 0x0 ; slli a3, a1, 0x20 ; srli a5, a3, 0x20 -; bne zero, a5, 8 +; bnez a5, 8 ; .byte 0x00, 0x00, 0x00, 0x00 ; trap: int_divz ; remuw a0, a0, a5 ; ret @@ -807,14 +807,14 @@ block0(v0: i64): ; VCode: ; block0: ; li a3,2 -; trap_ifc int_divz##(zero eq a3) +; trap_if int_divz##(a3 eq zero) ; rem a0,a0,a3 ; ret ; ; Disassembled: ; block0: ; offset 0x0 ; addi a3, zero, 2 -; bne zero, a3, 8 +; bnez a3, 8 ; .byte 0x00, 0x00, 0x00, 0x00 ; trap: int_divz ; rem a0, a0, a3 ; ret @@ -829,14 +829,14 @@ block0(v0: i64): ; VCode: ; block0: ; li a3,2 -; trap_ifc int_divz##(zero eq a3) +; trap_if int_divz##(a3 eq zero) ; remu a0,a0,a3 ; ret ; ; Disassembled: ; block0: ; offset 0x0 ; addi a3, zero, 2 -; bne zero, a3, 8 +; bnez a3, 8 ; .byte 0x00, 0x00, 0x00, 0x00 ; trap: int_divz ; remu a0, a0, a3 ; ret @@ -857,8 +857,8 @@ block0(v0: i64): ; eq a4,a4,a3##ty=i64 ; eq a5,a1,a0##ty=i64 ; and a1,a4,a5 -; trap_if a1,int_ovf -; trap_ifc int_divz##(zero eq a3) +; trap_if int_ovf##(a1 ne zero) +; trap_if int_divz##(a3 eq zero) ; div a0,a0,a3 ; ret ; @@ -879,7 +879,7 @@ block0(v0: i64): ; and a1, a4, a5 ; beqz a1, 8 ; .byte 0x00, 0x00, 0x00, 0x00 ; trap: int_ovf -; bne zero, a3, 8 +; bnez a3, 8 ; .byte 0x00, 0x00, 0x00, 0x00 ; trap: int_divz ; div a0, a0, a3 ; ret diff --git a/cranelift/filetests/filetests/isa/riscv64/stack-limit.clif b/cranelift/filetests/filetests/isa/riscv64/stack-limit.clif index a3a2b5c1b2fc..d33f83bdd740 100644 --- a/cranelift/filetests/filetests/isa/riscv64/stack-limit.clif +++ b/cranelift/filetests/filetests/isa/riscv64/stack-limit.clif @@ -58,7 +58,7 @@ block0(v0: i64): ; sd ra,8(sp) ; sd fp,0(sp) ; mv fp,sp -; trap_ifc stk_ovf##(sp ult a0) +; trap_if stk_ovf##(sp ult a0) ; block0: ; load_sym a2,%foo+0 ; callind a2 @@ -105,7 +105,7 @@ block0(v0: i64): ; mv fp,sp ; ld t6,0(a0) ; ld t6,4(t6) -; trap_ifc stk_ovf##(sp ult t6) +; trap_if stk_ovf##(sp ult t6) ; block0: ; load_sym a2,%foo+0 ; callind a2 @@ -148,7 +148,7 @@ block0(v0: i64): ; sd fp,0(sp) ; mv fp,sp ; addi t6,a0,176 -; trap_ifc stk_ovf##(sp ult t6) +; trap_if stk_ovf##(sp ult t6) ; add sp,-176 ; block0: ; add sp,+176 @@ -185,11 +185,11 @@ block0(v0: i64): ; sd ra,8(sp) ; sd fp,0(sp) ; mv fp,sp -; trap_ifc stk_ovf##(sp ult a0) +; trap_if stk_ovf##(sp ult a0) ; lui t5,98 ; addi t5,t5,-1408 ; add t6,t5,a0 -; trap_ifc stk_ovf##(sp ult t6) +; trap_if stk_ovf##(sp ult t6) ; lui a0,98 ; addi a0,a0,-1408 ; call %Probestack @@ -252,7 +252,7 @@ block0(v0: i64): ; ld t6,0(a0) ; ld t6,4(t6) ; addi t6,t6,32 -; trap_ifc stk_ovf##(sp ult t6) +; trap_if stk_ovf##(sp ult t6) ; add sp,-32 ; block0: ; add sp,+32 @@ -297,11 +297,11 @@ block0(v0: i64): ; mv fp,sp ; ld t6,0(a0) ; ld t6,4(t6) -; trap_ifc stk_ovf##(sp ult t6) +; trap_if stk_ovf##(sp ult t6) ; lui t5,98 ; addi t5,t5,-1408 ; add t6,t5,t6 -; trap_ifc stk_ovf##(sp ult t6) +; trap_if stk_ovf##(sp ult t6) ; lui a0,98 ; addi a0,a0,-1408 ; call %Probestack @@ -364,7 +364,7 @@ block0(v0: i64): ; mv fp,sp ; ld t6,400000(a0) ; addi t6,t6,32 -; trap_ifc stk_ovf##(sp ult t6) +; trap_if stk_ovf##(sp ult t6) ; add sp,-32 ; block0: ; add sp,+32 diff --git a/cranelift/filetests/filetests/isa/riscv64/uadd_overflow_trap.clif b/cranelift/filetests/filetests/isa/riscv64/uadd_overflow_trap.clif index cae0b880d26b..d0c89e0c4f23 100644 --- a/cranelift/filetests/filetests/isa/riscv64/uadd_overflow_trap.clif +++ b/cranelift/filetests/filetests/isa/riscv64/uadd_overflow_trap.clif @@ -17,7 +17,7 @@ block0(v0: i32): ; srli a3,a1,32 ; add a0,a5,a3 ; srli a1,a0,32 -; trap_if a1,user0 +; trap_if user0##(a1 ne zero) ; ret ; ; Disassembled: @@ -49,7 +49,7 @@ block0(v0: i32): ; srli a3,a1,32 ; add a0,a5,a3 ; srli a1,a0,32 -; trap_if a1,user0 +; trap_if user0##(a1 ne zero) ; ret ; ; Disassembled: @@ -79,7 +79,7 @@ block0(v0: i32, v1: i32): ; srli a3,a1,32 ; add a0,a5,a3 ; srli a1,a0,32 -; trap_if a1,user0 +; trap_if user0##(a1 ne zero) ; ret ; ; Disassembled: @@ -107,7 +107,7 @@ block0(v0: i64): ; li a4,127 ; add a0,a1,a4 ; ult a5,a0,a1##ty=i64 -; trap_if a5,user0 +; trap_if user0##(a5 ne zero) ; ret ; ; Disassembled: @@ -135,7 +135,7 @@ block0(v0: i64): ; li a4,127 ; add a0,a4,a0 ; ult a5,a0,a4##ty=i64 -; trap_if a5,user0 +; trap_if user0##(a5 ne zero) ; ret ; ; Disassembled: @@ -162,7 +162,7 @@ block0(v0: i64, v1: i64): ; mv a1,a3 ; ult a5,a1,a0##ty=i64 ; mv a0,a1 -; trap_if a5,user0 +; trap_if user0##(a5 ne zero) ; ret ; ; Disassembled: diff --git a/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i32_index_0_guard_no_spectre_i32_access_0xffff0000_offset.wat b/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i32_index_0_guard_no_spectre_i32_access_0xffff0000_offset.wat index 468923cbb9c5..f80497f60b79 100644 --- a/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i32_index_0_guard_no_spectre_i32_access_0xffff0000_offset.wat +++ b/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i32_index_0_guard_no_spectre_i32_access_0xffff0000_offset.wat @@ -46,7 +46,7 @@ ;; ld a4,[const(1)] ;; add a4,a3,a4 ;; ult a5,a4,a3##ty=i64 -;; trap_if a5,heap_oob +;; trap_if heap_oob##(a5 ne zero) ;; ld a5,8(a2) ;; ugt a4,a4,a5##ty=i64 ;; bne a4,zero,taken(label3),not_taken(label1) @@ -69,7 +69,7 @@ ;; ld a2,[const(1)] ;; add a2,a3,a2 ;; ult a4,a2,a3##ty=i64 -;; trap_if a4,heap_oob +;; trap_if heap_oob##(a4 ne zero) ;; ld a4,8(a1) ;; ugt a4,a2,a4##ty=i64 ;; bne a4,zero,taken(label3),not_taken(label1) diff --git a/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i32_index_0_guard_no_spectre_i8_access_0xffff0000_offset.wat b/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i32_index_0_guard_no_spectre_i8_access_0xffff0000_offset.wat index 80e09cd8feb4..eebf3e0a9c0f 100644 --- a/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i32_index_0_guard_no_spectre_i8_access_0xffff0000_offset.wat +++ b/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i32_index_0_guard_no_spectre_i8_access_0xffff0000_offset.wat @@ -46,7 +46,7 @@ ;; ld a4,[const(1)] ;; add a4,a3,a4 ;; ult a5,a4,a3##ty=i64 -;; trap_if a5,heap_oob +;; trap_if heap_oob##(a5 ne zero) ;; ld a5,8(a2) ;; ugt a4,a4,a5##ty=i64 ;; bne a4,zero,taken(label3),not_taken(label1) @@ -69,7 +69,7 @@ ;; ld a2,[const(1)] ;; add a2,a3,a2 ;; ult a4,a2,a3##ty=i64 -;; trap_if a4,heap_oob +;; trap_if heap_oob##(a4 ne zero) ;; ld a4,8(a1) ;; ugt a4,a2,a4##ty=i64 ;; bne a4,zero,taken(label3),not_taken(label1) diff --git a/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i32_index_0_guard_yes_spectre_i32_access_0xffff0000_offset.wat b/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i32_index_0_guard_yes_spectre_i32_access_0xffff0000_offset.wat index c1b25aea6d34..22d4a1643b34 100644 --- a/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i32_index_0_guard_yes_spectre_i32_access_0xffff0000_offset.wat +++ b/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i32_index_0_guard_yes_spectre_i32_access_0xffff0000_offset.wat @@ -46,7 +46,7 @@ ;; ld a4,[const(1)] ;; add a3,a5,a4 ;; ult a0,a3,a5##ty=i64 -;; trap_if a0,heap_oob +;; trap_if heap_oob##(a0 ne zero) ;; ld a0,8(a2) ;; ugt a0,a3,a0##ty=i64 ;; ld a2,0(a2) @@ -73,7 +73,7 @@ ;; ld a4,[const(1)] ;; add a3,a5,a4 ;; ult a0,a3,a5##ty=i64 -;; trap_if a0,heap_oob +;; trap_if heap_oob##(a0 ne zero) ;; ld a0,8(a1) ;; ugt a0,a3,a0##ty=i64 ;; ld a1,0(a1) diff --git a/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i32_index_0_guard_yes_spectre_i8_access_0xffff0000_offset.wat b/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i32_index_0_guard_yes_spectre_i8_access_0xffff0000_offset.wat index f16c448937d6..4573556ee391 100644 --- a/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i32_index_0_guard_yes_spectre_i8_access_0xffff0000_offset.wat +++ b/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i32_index_0_guard_yes_spectre_i8_access_0xffff0000_offset.wat @@ -46,7 +46,7 @@ ;; ld a4,[const(1)] ;; add a3,a5,a4 ;; ult a0,a3,a5##ty=i64 -;; trap_if a0,heap_oob +;; trap_if heap_oob##(a0 ne zero) ;; ld a0,8(a2) ;; ugt a0,a3,a0##ty=i64 ;; ld a2,0(a2) @@ -73,7 +73,7 @@ ;; ld a4,[const(1)] ;; add a3,a5,a4 ;; ult a0,a3,a5##ty=i64 -;; trap_if a0,heap_oob +;; trap_if heap_oob##(a0 ne zero) ;; ld a0,8(a1) ;; ugt a0,a3,a0##ty=i64 ;; ld a1,0(a1) diff --git a/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i64_index_0_guard_no_spectre_i32_access_0xffff0000_offset.wat b/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i64_index_0_guard_no_spectre_i32_access_0xffff0000_offset.wat index 600d6c81d28d..e4263aebdc52 100644 --- a/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i64_index_0_guard_no_spectre_i32_access_0xffff0000_offset.wat +++ b/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i64_index_0_guard_no_spectre_i32_access_0xffff0000_offset.wat @@ -44,7 +44,7 @@ ;; ld a3,[const(1)] ;; add a5,a0,a3 ;; ult a3,a5,a0##ty=i64 -;; trap_if a3,heap_oob +;; trap_if heap_oob##(a3 ne zero) ;; ld a3,8(a2) ;; ugt a3,a5,a3##ty=i64 ;; bne a3,zero,taken(label3),not_taken(label1) @@ -65,7 +65,7 @@ ;; ld a2,[const(1)] ;; add a5,a0,a2 ;; ult a2,a5,a0##ty=i64 -;; trap_if a2,heap_oob +;; trap_if heap_oob##(a2 ne zero) ;; ld a2,8(a1) ;; ugt a2,a5,a2##ty=i64 ;; bne a2,zero,taken(label3),not_taken(label1) diff --git a/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i64_index_0_guard_no_spectre_i8_access_0xffff0000_offset.wat b/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i64_index_0_guard_no_spectre_i8_access_0xffff0000_offset.wat index 6c5de64ee055..c08c2c65c62e 100644 --- a/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i64_index_0_guard_no_spectre_i8_access_0xffff0000_offset.wat +++ b/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i64_index_0_guard_no_spectre_i8_access_0xffff0000_offset.wat @@ -44,7 +44,7 @@ ;; ld a3,[const(1)] ;; add a5,a0,a3 ;; ult a3,a5,a0##ty=i64 -;; trap_if a3,heap_oob +;; trap_if heap_oob##(a3 ne zero) ;; ld a3,8(a2) ;; ugt a3,a5,a3##ty=i64 ;; bne a3,zero,taken(label3),not_taken(label1) @@ -65,7 +65,7 @@ ;; ld a2,[const(1)] ;; add a5,a0,a2 ;; ult a2,a5,a0##ty=i64 -;; trap_if a2,heap_oob +;; trap_if heap_oob##(a2 ne zero) ;; ld a2,8(a1) ;; ugt a2,a5,a2##ty=i64 ;; bne a2,zero,taken(label3),not_taken(label1) diff --git a/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i64_index_0_guard_yes_spectre_i32_access_0xffff0000_offset.wat b/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i64_index_0_guard_yes_spectre_i32_access_0xffff0000_offset.wat index 988ad32934ea..d27e29e4b2f9 100644 --- a/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i64_index_0_guard_yes_spectre_i32_access_0xffff0000_offset.wat +++ b/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i64_index_0_guard_yes_spectre_i32_access_0xffff0000_offset.wat @@ -44,7 +44,7 @@ ;; ld a3,[const(1)] ;; add a3,a0,a3 ;; ult a4,a3,a0##ty=i64 -;; trap_if a4,heap_oob +;; trap_if heap_oob##(a4 ne zero) ;; ld a4,8(a2) ;; ugt a4,a3,a4##ty=i64 ;; ld a3,0(a2) @@ -69,7 +69,7 @@ ;; ld a2,[const(1)] ;; add a2,a0,a2 ;; ult a3,a2,a0##ty=i64 -;; trap_if a3,heap_oob +;; trap_if heap_oob##(a3 ne zero) ;; ld a3,8(a1) ;; ugt a4,a2,a3##ty=i64 ;; ld a3,0(a1) diff --git a/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i64_index_0_guard_yes_spectre_i8_access_0xffff0000_offset.wat b/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i64_index_0_guard_yes_spectre_i8_access_0xffff0000_offset.wat index 37cc3b77e3bf..f51b54592618 100644 --- a/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i64_index_0_guard_yes_spectre_i8_access_0xffff0000_offset.wat +++ b/cranelift/filetests/filetests/isa/riscv64/wasm/load_store_dynamic_kind_i64_index_0_guard_yes_spectre_i8_access_0xffff0000_offset.wat @@ -44,7 +44,7 @@ ;; ld a3,[const(1)] ;; add a3,a0,a3 ;; ult a4,a3,a0##ty=i64 -;; trap_if a4,heap_oob +;; trap_if heap_oob##(a4 ne zero) ;; ld a4,8(a2) ;; ugt a4,a3,a4##ty=i64 ;; ld a3,0(a2) @@ -69,7 +69,7 @@ ;; ld a2,[const(1)] ;; add a2,a0,a2 ;; ult a3,a2,a0##ty=i64 -;; trap_if a3,heap_oob +;; trap_if heap_oob##(a3 ne zero) ;; ld a3,8(a1) ;; ugt a4,a2,a3##ty=i64 ;; ld a3,0(a1) From 891cbf0bc8b1cd2d2c22509aa31ae1ed7b92502f Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Mon, 18 Sep 2023 17:18:08 -0500 Subject: [PATCH 04/14] Update wasm-tools family of crates (#7059) * Update wasm-tools family of crates Mostly minor updates, but staying up-to-date. * Update text format syntax * Update cargo vet entries * Update more old text syntax --- Cargo.lock | 52 ++++++------- Cargo.toml | 18 ++--- cranelift/wasm/src/code_translator.rs | 2 +- cranelift/wasm/src/func_translator.rs | 8 +- cranelift/wasm/wasmtests/arith.wat | 6 +- cranelift/wasm/wasmtests/call.wat | 2 +- cranelift/wasm/wasmtests/fac-multi-value.wat | 6 +- cranelift/wasm/wasmtests/fibonacci.wat | 20 ++--- cranelift/wasm/wasmtests/globals.wat | 2 +- cranelift/wasm/wasmtests/icall.wat | 2 +- cranelift/wasm/wasmtests/multi-0.wat | 2 +- crates/wast/src/core.rs | 24 +++--- supply-chain/imports.lock | 77 ++++++++++++++++++++ tests/all/cli_tests/rs2wasm-add-func.wat | 4 +- tests/all/host_funcs.rs | 4 +- tests/misc_testsuite/rs2wasm-add-func.wast | 4 +- 16 files changed, 156 insertions(+), 77 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b79d222c9770..63a5aab2b99c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2416,9 +2416,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.80" +version = "1.0.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f972498cf015f7c0746cac89ebe1d6ef10c293b94175a243a2d9442c163d9944" +checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" dependencies = [ "itoa", "ryu", @@ -3153,18 +3153,18 @@ checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" [[package]] name = "wasm-encoder" -version = "0.32.0" +version = "0.33.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ba64e81215916eaeb48fee292f29401d69235d62d8b8fd92a7b2844ec5ae5f7" +checksum = "b39de0723a53d3c8f54bed106cfbc0d06b3e4d945c5c5022115a61e3b29183ae" dependencies = [ "leb128", ] [[package]] name = "wasm-metadata" -version = "0.10.3" +version = "0.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08dc59d1fa569150851542143ca79438ca56845ccb31696c70225c638e063471" +checksum = "9fab01638cbecc57afec7b53ce0e28620b44d7ae1dea53120c96dd08486c07ce" dependencies = [ "anyhow", "indexmap 2.0.0", @@ -3177,9 +3177,9 @@ dependencies = [ [[package]] name = "wasm-mutate" -version = "0.2.32" +version = "0.2.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a893c0b2bad91ea78171dbff7fe6d764f8b1c9e20617061d284cfbdf902e329" +checksum = "730c1644f14f3dfa52d8c63bb26c6f3fe89ba80430f1ce29b0388e04825ec514" dependencies = [ "egg", "log", @@ -3191,9 +3191,9 @@ dependencies = [ [[package]] name = "wasm-smith" -version = "0.12.15" +version = "0.12.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad8e84473dc3e1ca4759114ba8f3beaf3b7272d67868e351ba5d3fc006139f93" +checksum = "154bb82cb9f17e5c0773e192e800094583767f752077d5116be376de747539eb" dependencies = [ "arbitrary", "flagset", @@ -3243,9 +3243,9 @@ dependencies = [ [[package]] name = "wasmparser" -version = "0.112.0" +version = "0.113.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e986b010f47fcce49cf8ea5d5f9e5d2737832f12b53ae8ae785bbe895d0877bf" +checksum = "a128cea7b8516703ab41b10a0b1aa9ba18d0454cd3792341489947ddeee268db" dependencies = [ "indexmap 2.0.0", "semver", @@ -3262,9 +3262,9 @@ dependencies = [ [[package]] name = "wasmprinter" -version = "0.2.64" +version = "0.2.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34ddf5892036cd4b780d505eff1194a0cbc10ed896097656fdcea3744b5e7c2f" +checksum = "ab2e5e818f88cee5311e9a5df15cba0a8f772978baf3109af97004bce6e8e3c6" dependencies = [ "anyhow", "wasmparser", @@ -3425,7 +3425,7 @@ dependencies = [ "wasmtime-wasi-nn", "wasmtime-wasi-threads", "wasmtime-wast", - "wast 64.0.0", + "wast 65.0.1", "wat", "windows-sys", ] @@ -3803,7 +3803,7 @@ dependencies = [ "anyhow", "log", "wasmtime", - "wast 64.0.0", + "wast 65.0.1", ] [[package]] @@ -3846,9 +3846,9 @@ dependencies = [ [[package]] name = "wast" -version = "64.0.0" +version = "65.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a259b226fd6910225aa7baeba82f9d9933b6d00f2ce1b49b80fa4214328237cc" +checksum = "5fd8c1cbadf94a0b0d1071c581d3cfea1b7ed5192c79808dd15406e508dd0afb" dependencies = [ "leb128", "memchr", @@ -3858,11 +3858,11 @@ dependencies = [ [[package]] name = "wat" -version = "1.0.71" +version = "1.0.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53253d920ab413fca1c7dc2161d601c79b4fdf631d0ba51dd4343bf9b556c3f6" +checksum = "3209e35eeaf483714f4c6be93f4a03e69aad5f304e3fa66afa7cb90fe1c8051f" dependencies = [ - "wast 64.0.0", + "wast 65.0.1", ] [[package]] @@ -4178,9 +4178,9 @@ dependencies = [ [[package]] name = "wit-component" -version = "0.14.0" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66d9f2d16dd55d1a372dcfd4b7a466ea876682a5a3cb97e71ec9eef04affa876" +checksum = "af872ef43ecb73cc49c7bd2dd19ef9117168e183c78cf70000dca0e14b6a5473" dependencies = [ "anyhow", "bitflags 2.3.3", @@ -4196,9 +4196,9 @@ dependencies = [ [[package]] name = "wit-parser" -version = "0.11.0" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e8b849bea13cc2315426b16efe6eb6813466d78f5fde69b0bb150c9c40e0dc" +checksum = "1dcd022610436a1873e60bfdd9b407763f2404adf7d1cb57912c7ae4059e57a5" dependencies = [ "anyhow", "id-arena", @@ -4206,6 +4206,8 @@ dependencies = [ "log", "pulldown-cmark", "semver", + "serde", + "serde_json", "unicode-xid", "url", ] diff --git a/Cargo.toml b/Cargo.toml index b9a6c3ea4df1..5c823d258df9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -209,15 +209,15 @@ is-terminal = "0.4.0" wit-bindgen = { version = "0.11.0", default-features = false } # wasm-tools family: -wasmparser = "0.112.0" -wat = "1.0.71" -wast = "64.0.0" -wasmprinter = "0.2.64" -wasm-encoder = "0.32.0" -wasm-smith = "0.12.15" -wasm-mutate = "0.2.32" -wit-parser = "0.11.0" -wit-component = "0.14.0" +wasmparser = "0.113.1" +wat = "1.0.73" +wast = "65.0.1" +wasmprinter = "0.2.66" +wasm-encoder = "0.33.1" +wasm-smith = "0.12.17" +wasm-mutate = "0.2.34" +wit-parser = "0.11.1" +wit-component = "0.14.2" # Non-Bytecode Alliance maintained dependencies: # -------------------------- diff --git a/cranelift/wasm/src/code_translator.rs b/cranelift/wasm/src/code_translator.rs index d683a420d667..6a0090c10089 100644 --- a/cranelift/wasm/src/code_translator.rs +++ b/cranelift/wasm/src/code_translator.rs @@ -2503,7 +2503,7 @@ pub fn translate_operator( state.push1(r); } - Operator::I31New | Operator::I31GetS | Operator::I31GetU => { + Operator::RefI31 | Operator::I31GetS | Operator::I31GetU => { unimplemented!("GC operators not yet implemented") } }; diff --git a/cranelift/wasm/src/func_translator.rs b/cranelift/wasm/src/func_translator.rs index 2ed2b995c1e5..ed666cdef982 100644 --- a/cranelift/wasm/src/func_translator.rs +++ b/cranelift/wasm/src/func_translator.rs @@ -314,7 +314,7 @@ mod tests { " (module (func $small2 (param i32) (result i32) - (i32.add (get_local 0) (i32.const 1)) + (i32.add (local.get 0) (i32.const 1)) ) ) ", @@ -349,7 +349,7 @@ mod tests { " (module (func $small2 (param i32) (result i32) - (return (i32.add (get_local 0) (i32.const 1))) + (return (i32.add (local.get 0) (i32.const 1))) ) ) ", @@ -386,8 +386,8 @@ mod tests { (func $infloop (result i32) (local i32) (loop (result i32) - (i32.add (get_local 0) (i32.const 1)) - (set_local 0) + (i32.add (local.get 0) (i32.const 1)) + (local.set 0) (br 0) ) ) diff --git a/cranelift/wasm/wasmtests/arith.wat b/cranelift/wasm/wasmtests/arith.wat index fa7115696b61..80005a6742de 100644 --- a/cranelift/wasm/wasmtests/arith.wat +++ b/cranelift/wasm/wasmtests/arith.wat @@ -1,11 +1,11 @@ (module (memory 1) (func $main (local i32) - (set_local 0 (i32.sub (i32.const 4) (i32.const 4))) + (local.set 0 (i32.sub (i32.const 4) (i32.const 4))) (if - (get_local 0) + (local.get 0) (then unreachable) - (else (drop (i32.mul (i32.const 6) (get_local 0)))) + (else (drop (i32.mul (i32.const 6) (local.get 0)))) ) ) (start $main) diff --git a/cranelift/wasm/wasmtests/call.wat b/cranelift/wasm/wasmtests/call.wat index e8640d2342a4..d6d75cb5c9c7 100644 --- a/cranelift/wasm/wasmtests/call.wat +++ b/cranelift/wasm/wasmtests/call.wat @@ -1,6 +1,6 @@ (module (func $main (local i32) - (set_local 0 (i32.const 0)) + (local.set 0 (i32.const 0)) (drop (call $inc)) ) (func $inc (result i32) diff --git a/cranelift/wasm/wasmtests/fac-multi-value.wat b/cranelift/wasm/wasmtests/fac-multi-value.wat index 7a4d5c0fabd3..ad5e3d9a04c7 100644 --- a/cranelift/wasm/wasmtests/fac-multi-value.wat +++ b/cranelift/wasm/wasmtests/fac-multi-value.wat @@ -1,13 +1,13 @@ (module ;; Iterative factorial without locals. (func $pick0 (param i64) (result i64 i64) - (get_local 0) (get_local 0) + (local.get 0) (local.get 0) ) (func $pick1 (param i64 i64) (result i64 i64 i64) - (get_local 0) (get_local 1) (get_local 0) + (local.get 0) (local.get 1) (local.get 0) ) (func (export "fac-ssa") (param i64) (result i64) - (i64.const 1) (get_local 0) + (i64.const 1) (local.get 0) (loop $l (param i64 i64) (result i64) (call $pick1) (call $pick1) (i64.mul) (call $pick1) (i64.const 1) (i64.sub) diff --git a/cranelift/wasm/wasmtests/fibonacci.wat b/cranelift/wasm/wasmtests/fibonacci.wat index 1788a467ca0f..a40b2c91e442 100644 --- a/cranelift/wasm/wasmtests/fibonacci.wat +++ b/cranelift/wasm/wasmtests/fibonacci.wat @@ -1,21 +1,21 @@ (module (memory 1) (func $main (local i32 i32 i32 i32) - (set_local 0 (i32.const 0)) - (set_local 1 (i32.const 1)) - (set_local 2 (i32.const 1)) - (set_local 3 (i32.const 0)) + (local.set 0 (i32.const 0)) + (local.set 1 (i32.const 1)) + (local.set 2 (i32.const 1)) + (local.set 3 (i32.const 0)) (block (loop - (br_if 1 (i32.gt_s (get_local 0) (i32.const 5))) - (set_local 3 (get_local 2)) - (set_local 2 (i32.add (get_local 2) (get_local 1))) - (set_local 1 (get_local 3)) - (set_local 0 (i32.add (get_local 0) (i32.const 1))) + (br_if 1 (i32.gt_s (local.get 0) (i32.const 5))) + (local.set 3 (local.get 2)) + (local.set 2 (i32.add (local.get 2) (local.get 1))) + (local.set 1 (local.get 3)) + (local.set 0 (i32.add (local.get 0) (i32.const 1))) (br 0) ) ) - (i32.store (i32.const 0) (get_local 2)) + (i32.store (i32.const 0) (local.get 2)) ) (start $main) (data (i32.const 0) "0000") diff --git a/cranelift/wasm/wasmtests/globals.wat b/cranelift/wasm/wasmtests/globals.wat index 646e5f0f453d..fb9a591172bc 100644 --- a/cranelift/wasm/wasmtests/globals.wat +++ b/cranelift/wasm/wasmtests/globals.wat @@ -2,7 +2,7 @@ (global $x (mut i32) (i32.const 4)) (memory 1) (func $main (local i32) - (i32.store (i32.const 0) (get_global $x)) + (i32.store (i32.const 0) (global.get $x)) ) (start $main) ) diff --git a/cranelift/wasm/wasmtests/icall.wat b/cranelift/wasm/wasmtests/icall.wat index 76f28f47a92b..f1dde0eafea5 100644 --- a/cranelift/wasm/wasmtests/icall.wat +++ b/cranelift/wasm/wasmtests/icall.wat @@ -1,7 +1,7 @@ (module (type $ft (func (param f32) (result i32))) (func $foo (export "foo") (param i32 f32) (result i32) - (call_indirect (type $ft) (get_local 1) (get_local 0)) + (call_indirect (type $ft) (local.get 1) (local.get 0)) ) (table (;0;) 23 23 anyfunc) ) diff --git a/cranelift/wasm/wasmtests/multi-0.wat b/cranelift/wasm/wasmtests/multi-0.wat index d1cc24c59691..806493918c30 100644 --- a/cranelift/wasm/wasmtests/multi-0.wat +++ b/cranelift/wasm/wasmtests/multi-0.wat @@ -1,3 +1,3 @@ (module (func (export "i64.dup") (param i64) (result i64 i64) - (get_local 0) (get_local 0))) + (local.get 0) (local.get 0))) diff --git a/crates/wast/src/core.rs b/crates/wast/src/core.rs index cbcf6bfeb950..de01b1023c9e 100644 --- a/crates/wast/src/core.rs +++ b/crates/wast/src/core.rs @@ -65,19 +65,19 @@ pub fn match_val(actual: &Val, expected: &WastRetCore) -> Result<()> { Ok(()) } } - (Val::ExternRef(x), WastRetCore::RefExtern(y)) => { - if let Some(x) = x { - let x = x - .data() - .downcast_ref::() - .expect("only u32 externrefs created in wast test suites"); - if x == y { - Ok(()) - } else { - bail!("expected {} found {}", y, x); - } + (Val::ExternRef(_), WastRetCore::RefExtern(None)) => Ok(()), + (Val::ExternRef(None), WastRetCore::RefExtern(Some(_))) => { + bail!("expected non-null externref, found null") + } + (Val::ExternRef(Some(x)), WastRetCore::RefExtern(Some(y))) => { + let x = x + .data() + .downcast_ref::() + .expect("only u32 externrefs created in wast test suites"); + if x == y { + Ok(()) } else { - bail!("expected non-null externref, found null") + bail!("expected {} found {}", y, x); } } (Val::FuncRef(actual), WastRetCore::RefNull(expected)) => match (actual, expected) { diff --git a/supply-chain/imports.lock b/supply-chain/imports.lock index 9f5c949891ed..29747fbcc690 100644 --- a/supply-chain/imports.lock +++ b/supply-chain/imports.lock @@ -999,6 +999,13 @@ user-id = 3618 user-login = "dtolnay" user-name = "David Tolnay" +[[publisher.serde_json]] +version = "1.0.107" +when = "2023-09-13" +user-id = 3618 +user-login = "dtolnay" +user-name = "David Tolnay" + [[publisher.spdx]] version = "0.10.1" when = "2023-04-06" @@ -1188,6 +1195,13 @@ user-id = 1 user-login = "alexcrichton" user-name = "Alex Crichton" +[[publisher.wasm-encoder]] +version = "0.33.1" +when = "2023-09-18" +user-id = 1 +user-login = "alexcrichton" +user-name = "Alex Crichton" + [[publisher.wasm-metadata]] version = "0.9.0" when = "2023-07-11" @@ -1216,6 +1230,13 @@ user-id = 1 user-login = "alexcrichton" user-name = "Alex Crichton" +[[publisher.wasm-metadata]] +version = "0.10.5" +when = "2023-09-18" +user-id = 1 +user-login = "alexcrichton" +user-name = "Alex Crichton" + [[publisher.wasm-mutate]] version = "0.2.30" when = "2023-07-26" @@ -1237,6 +1258,13 @@ user-id = 1 user-login = "alexcrichton" user-name = "Alex Crichton" +[[publisher.wasm-mutate]] +version = "0.2.34" +when = "2023-09-18" +user-id = 1 +user-login = "alexcrichton" +user-name = "Alex Crichton" + [[publisher.wasm-smith]] version = "0.12.13" when = "2023-07-26" @@ -1258,6 +1286,13 @@ user-id = 1 user-login = "alexcrichton" user-name = "Alex Crichton" +[[publisher.wasm-smith]] +version = "0.12.17" +when = "2023-09-18" +user-id = 1 +user-login = "alexcrichton" +user-name = "Alex Crichton" + [[publisher.wasmparser]] version = "0.108.0" when = "2023-07-11" @@ -1286,6 +1321,13 @@ user-id = 1 user-login = "alexcrichton" user-name = "Alex Crichton" +[[publisher.wasmparser]] +version = "0.113.1" +when = "2023-09-18" +user-id = 1 +user-login = "alexcrichton" +user-name = "Alex Crichton" + [[publisher.wasmprinter]] version = "0.2.62" when = "2023-07-26" @@ -1307,6 +1349,13 @@ user-id = 1 user-login = "alexcrichton" user-name = "Alex Crichton" +[[publisher.wasmprinter]] +version = "0.2.66" +when = "2023-09-18" +user-id = 1 +user-login = "alexcrichton" +user-name = "Alex Crichton" + [[publisher.wasmtime]] version = "11.0.1" when = "2023-07-24" @@ -1616,6 +1665,13 @@ user-id = 1 user-login = "alexcrichton" user-name = "Alex Crichton" +[[publisher.wast]] +version = "65.0.1" +when = "2023-09-18" +user-id = 1 +user-login = "alexcrichton" +user-name = "Alex Crichton" + [[publisher.wat]] version = "1.0.69" when = "2023-07-26" @@ -1637,6 +1693,13 @@ user-id = 1 user-login = "alexcrichton" user-name = "Alex Crichton" +[[publisher.wat]] +version = "1.0.73" +when = "2023-09-18" +user-id = 1 +user-login = "alexcrichton" +user-name = "Alex Crichton" + [[publisher.wiggle]] version = "11.0.1" when = "2023-07-24" @@ -1867,6 +1930,13 @@ user-id = 1 user-login = "alexcrichton" user-name = "Alex Crichton" +[[publisher.wit-component]] +version = "0.14.2" +when = "2023-09-18" +user-id = 1 +user-login = "alexcrichton" +user-name = "Alex Crichton" + [[publisher.wit-parser]] version = "0.9.2" when = "2023-07-26" @@ -1888,6 +1958,13 @@ user-id = 1 user-login = "alexcrichton" user-name = "Alex Crichton" +[[publisher.wit-parser]] +version = "0.11.1" +when = "2023-09-18" +user-id = 1 +user-login = "alexcrichton" +user-name = "Alex Crichton" + [[audits.embark-studios.wildcard-audits.spdx]] who = "Jake Shadle " criteria = "safe-to-deploy" diff --git a/tests/all/cli_tests/rs2wasm-add-func.wat b/tests/all/cli_tests/rs2wasm-add-func.wat index 21ff7f5d7a0d..89af3ffc4dca 100644 --- a/tests/all/cli_tests/rs2wasm-add-func.wat +++ b/tests/all/cli_tests/rs2wasm-add-func.wat @@ -2,8 +2,8 @@ (type (;0;) (func)) (type (;1;) (func (param i32 i32) (result i32))) (func $add (type 1) (param i32 i32) (result i32) - get_local 1 - get_local 0 + local.get 1 + local.get 0 i32.add) (func $start (type 0)) (table (;0;) 1 1 anyfunc) diff --git a/tests/all/host_funcs.rs b/tests/all/host_funcs.rs index 62d37d55b7e2..76033acd62ca 100644 --- a/tests/all/host_funcs.rs +++ b/tests/all/host_funcs.rs @@ -380,14 +380,14 @@ fn call_wasm_many_args() -> Result<()> { r#" (func (export "run") (param i32 i32 i32 i32 i32 i32 i32 i32 i32 i32) i32.const 1 - get_local 0 + local.get 0 i32.ne if unreachable end i32.const 10 - get_local 9 + local.get 9 i32.ne if unreachable diff --git a/tests/misc_testsuite/rs2wasm-add-func.wast b/tests/misc_testsuite/rs2wasm-add-func.wast index 21ff7f5d7a0d..89af3ffc4dca 100644 --- a/tests/misc_testsuite/rs2wasm-add-func.wast +++ b/tests/misc_testsuite/rs2wasm-add-func.wast @@ -2,8 +2,8 @@ (type (;0;) (func)) (type (;1;) (func (param i32 i32) (result i32))) (func $add (type 1) (param i32 i32) (result i32) - get_local 1 - get_local 0 + local.get 1 + local.get 0 i32.add) (func $start (type 0)) (table (;0;) 1 1 anyfunc) From 4cc3525cf7f0a3187030529a11b33919c602019c Mon Sep 17 00:00:00 2001 From: Afonso Bordado Date: Tue, 19 Sep 2023 00:30:46 +0100 Subject: [PATCH 05/14] riscv64: Add compressed `addi` (#7057) * riscv64: Add `c.ebreak` instruction * riscv64: Implement `c.unimp` * riscv64: Add `c.addi` * riscv64: Add `c.addiw` * riscv64: Add `c.addi16sp` * riscv64: Add `c.slli` * riscv64: Add `c.addi4spn` * riscv64: Update `c.addiw` comment * riscv64: Centralize Zca Check * riscv64: Avoid double construction in some match arms --- cranelift/codegen/src/isa/riscv64/inst.isle | 16 ++ .../codegen/src/isa/riscv64/inst/args.rs | 44 +++- .../codegen/src/isa/riscv64/inst/emit.rs | 141 +++++++++++-- .../codegen/src/isa/riscv64/inst/encode.rs | 58 +++++- .../codegen/src/isa/riscv64/inst/imms.rs | 32 +++ .../filetests/filetests/isa/riscv64/zca.clif | 196 +++++++++++++++++- 6 files changed, 457 insertions(+), 30 deletions(-) diff --git a/cranelift/codegen/src/isa/riscv64/inst.isle b/cranelift/codegen/src/isa/riscv64/inst.isle index 72ffdcf087c4..dc5aca2d130e 100644 --- a/cranelift/codegen/src/isa/riscv64/inst.isle +++ b/cranelift/codegen/src/isa/riscv64/inst.isle @@ -715,6 +715,9 @@ (CAdd) (CJr) (CJalr) + ;; c.ebreak technically isn't a CR format instruction, but it's encoding + ;; lines up with this format. + (CEbreak) )) ;; Opcodes for the CA compressed instruction format @@ -732,6 +735,19 @@ (CJ) )) +;; Opcodes for the CI compressed instruction format +(type CiOp (enum + (CAddi) + (CAddiw) + (CAddi16sp) + (CSlli) +)) + +;; Opcodes for the CIW compressed instruction format +(type CiwOp (enum + (CAddi4spn) +)) + (type CsrRegOP (enum ;; Atomic Read/Write CSR diff --git a/cranelift/codegen/src/isa/riscv64/inst/args.rs b/cranelift/codegen/src/isa/riscv64/inst/args.rs index 721551f2bc54..469dc531e20c 100644 --- a/cranelift/codegen/src/isa/riscv64/inst/args.rs +++ b/cranelift/codegen/src/isa/riscv64/inst/args.rs @@ -6,7 +6,10 @@ use super::*; use crate::ir::condcodes::CondCode; use crate::isa::riscv64::inst::{reg_name, reg_to_gpr_num}; -use crate::isa::riscv64::lower::isle::generated_code::{COpcodeSpace, CaOp, CjOp, CrOp}; + +use crate::isa::riscv64::lower::isle::generated_code::{ + COpcodeSpace, CaOp, CiOp, CiwOp, CjOp, CrOp, +}; use crate::machinst::isle::WritableReg; use std::fmt::{Display, Formatter, Result}; @@ -1916,14 +1919,14 @@ impl CrOp { match self { // `c.jr` has the same op/funct4 as C.MV, but RS2 is 0, which is illegal for mv. CrOp::CMv | CrOp::CJr => 0b1000, - CrOp::CAdd | CrOp::CJalr => 0b1001, + CrOp::CAdd | CrOp::CJalr | CrOp::CEbreak => 0b1001, } } pub fn op(&self) -> COpcodeSpace { // https://five-embeddev.com/riscv-isa-manual/latest/rvc-opcode-map.html#rvcopcodemap match self { - CrOp::CMv | CrOp::CAdd | CrOp::CJr | CrOp::CJalr => COpcodeSpace::C2, + CrOp::CMv | CrOp::CAdd | CrOp::CJr | CrOp::CJalr | CrOp::CEbreak => COpcodeSpace::C2, } } } @@ -1974,3 +1977,38 @@ impl CjOp { } } } + +impl CiOp { + pub fn funct3(&self) -> u32 { + // https://github.com/michaeljclark/riscv-meta/blob/master/opcodes + match self { + CiOp::CAddi | CiOp::CSlli => 0b000, + CiOp::CAddiw => 0b001, + CiOp::CAddi16sp => 0b011, + } + } + + pub fn op(&self) -> COpcodeSpace { + // https://five-embeddev.com/riscv-isa-manual/latest/rvc-opcode-map.html#rvcopcodemap + match self { + CiOp::CAddi | CiOp::CAddiw | CiOp::CAddi16sp => COpcodeSpace::C1, + CiOp::CSlli => COpcodeSpace::C2, + } + } +} + +impl CiwOp { + pub fn funct3(&self) -> u32 { + // https://github.com/michaeljclark/riscv-meta/blob/master/opcodes + match self { + CiwOp::CAddi4spn => 0b000, + } + } + + pub fn op(&self) -> COpcodeSpace { + // https://five-embeddev.com/riscv-isa-manual/latest/rvc-opcode-map.html#rvcopcodemap + match self { + CiwOp::CAddi4spn => COpcodeSpace::C0, + } + } +} diff --git a/cranelift/codegen/src/isa/riscv64/inst/emit.rs b/cranelift/codegen/src/isa/riscv64/inst/emit.rs index 1f5a9f7684ce..ffc1173d493f 100644 --- a/cranelift/codegen/src/isa/riscv64/inst/emit.rs +++ b/cranelift/codegen/src/isa/riscv64/inst/emit.rs @@ -3,7 +3,7 @@ use crate::binemit::StackMap; use crate::ir::{self, LibCall, RelSourceLoc, TrapCode}; use crate::isa::riscv64::inst::*; -use crate::isa::riscv64::lower::isle::generated_code::{CaOp, CrOp}; +use crate::isa::riscv64::lower::isle::generated_code::{CaOp, CiOp, CiwOp, CrOp}; use crate::machinst::{AllocationConsumer, Reg, Writable}; use crate::trace; use cranelift_control::ControlPlane; @@ -460,11 +460,17 @@ impl Inst { &self, sink: &mut MachBuffer, emit_info: &EmitInfo, - _state: &mut EmitState, + state: &mut EmitState, start_off: &mut u32, ) -> bool { let has_zca = emit_info.isa_flags.has_zca(); + // Currently all compressed extensions (Zcb, Zcd, Zcmp, Zcmt, etc..) require Zca + // to be enabled, so check it early. + if !has_zca { + return false; + } + fn reg_is_compressible(r: Reg) -> bool { r.to_real_reg() .map(|r| r.hw_enc() >= 8 && r.hw_enc() < 16) @@ -478,7 +484,7 @@ impl Inst { rd, rs1, rs2, - } if has_zca && rd.to_reg() == rs1 && rs1 != zero_reg() && rs2 != zero_reg() => { + } if rd.to_reg() == rs1 && rs1 != zero_reg() && rs2 != zero_reg() => { sink.put2(encode_cr_type(CrOp::CAdd, rd, rs2)); } @@ -488,8 +494,7 @@ impl Inst { rd, rs, imm12, - } if has_zca - && rd.to_reg() != rs + } if rd.to_reg() != rs && rd.to_reg() != zero_reg() && rs != zero_reg() && imm12.as_i16() == 0 => @@ -509,11 +514,7 @@ impl Inst { rd, rs1, rs2, - } if has_zca - && rd.to_reg() == rs1 - && reg_is_compressible(rs1) - && reg_is_compressible(rs2) => - { + } if rd.to_reg() == rs1 && reg_is_compressible(rs1) && reg_is_compressible(rs2) => { let op = match alu_op { AluOPRRR::And => CaOp::CAnd, AluOPRRR::Or => CaOp::COr, @@ -530,7 +531,7 @@ impl Inst { // c.j // // We don't have a separate JAL as that is only availabile in RV32C - Inst::Jal { label } if has_zca => { + Inst::Jal { label } => { sink.use_label_at_offset(*start_off, label, LabelUse::RVCJump); sink.add_uncond_branch(*start_off, *start_off + 2, label); sink.put2(encode_cj_type(CjOp::CJ, Imm12::ZERO)); @@ -538,24 +539,121 @@ impl Inst { // c.jr Inst::Jalr { rd, base, offset } - if has_zca - && rd.to_reg() == zero_reg() - && base != zero_reg() - && offset.as_i16() == 0 => + if rd.to_reg() == zero_reg() && base != zero_reg() && offset.as_i16() == 0 => { sink.put2(encode_cr2_type(CrOp::CJr, base)); } // c.jalr Inst::Jalr { rd, base, offset } - if has_zca - && rd.to_reg() == link_reg() - && base != zero_reg() - && offset.as_i16() == 0 => + if rd.to_reg() == link_reg() && base != zero_reg() && offset.as_i16() == 0 => { sink.put2(encode_cr2_type(CrOp::CJalr, base)); } + // c.ebreak + Inst::EBreak => { + sink.put2(encode_cr_type( + CrOp::CEbreak, + writable_zero_reg(), + zero_reg(), + )); + } + + // c.unimp + Inst::Udf { trap_code } => { + sink.add_trap(trap_code); + if let Some(s) = state.take_stack_map() { + sink.add_stack_map(StackMapExtent::UpcomingBytes(2), s); + } + sink.put2(0x0000); + } + + // c.addi16sp + // + // c.addi16sp shares the opcode with c.lui, but has a destination field of x2. + // c.addi16sp adds the non-zero sign-extended 6-bit immediate to the value in the stack pointer (sp=x2), + // where the immediate is scaled to represent multiples of 16 in the range (-512,496). c.addi16sp is used + // to adjust the stack pointer in procedure prologues and epilogues. It expands into addi x2, x2, nzimm. c.addi16sp + // is only valid when nzimm≠0; the code point with nzimm=0 is reserved. + Inst::AluRRImm12 { + alu_op: AluOPRRI::Addi, + rd, + rs, + imm12, + } if rd.to_reg() == rs + && rs == stack_reg() + && imm12.as_i16() != 0 + && (imm12.as_i16() % 16) == 0 + && Imm6::maybe_from_i16(imm12.as_i16() / 16).is_some() => + { + let imm6 = Imm6::maybe_from_i16(imm12.as_i16() / 16).unwrap(); + sink.put2(encode_c_addi16sp(imm6)); + } + + // c.addi4spn + // + // c.addi4spn is a CIW-format instruction that adds a zero-extended non-zero + // immediate, scaled by 4, to the stack pointer, x2, and writes the result to + // rd. This instruction is used to generate pointers to stack-allocated variables + // and expands to addi rd, x2, nzuimm. c.addi4spn is only valid when nzuimm≠0; + // the code points with nzuimm=0 are reserved. + Inst::AluRRImm12 { + alu_op: AluOPRRI::Addi, + rd, + rs, + imm12, + } if reg_is_compressible(rd.to_reg()) + && rs == stack_reg() + && imm12.as_i16() != 0 + && (imm12.as_i16() % 4) == 0 + && u8::try_from(imm12.as_i16() / 4).is_ok() => + { + let imm = u8::try_from(imm12.as_i16() / 4).unwrap(); + sink.put2(encode_ciw_type(CiwOp::CAddi4spn, rd, imm)); + } + + // c.addi + Inst::AluRRImm12 { + alu_op: AluOPRRI::Addi, + rd, + rs, + imm12, + } if rd.to_reg() == rs && rs != zero_reg() && imm12.as_i16() != 0 => { + let imm6 = match Imm6::maybe_from_imm12(imm12) { + Some(imm6) => imm6, + None => return false, + }; + + sink.put2(encode_ci_type(CiOp::CAddi, rd, imm6)); + } + + // c.addiw + Inst::AluRRImm12 { + alu_op: AluOPRRI::Addiw, + rd, + rs, + imm12, + } if rd.to_reg() == rs && rs != zero_reg() => { + let imm6 = match Imm6::maybe_from_imm12(imm12) { + Some(imm6) => imm6, + None => return false, + }; + sink.put2(encode_ci_type(CiOp::CAddiw, rd, imm6)); + } + + // c.slli + Inst::AluRRImm12 { + alu_op: AluOPRRI::Slli, + rd, + rs, + imm12, + } if rd.to_reg() == rs && rs != zero_reg() && imm12.as_i16() != 0 => { + // The shift amount is unsigned, but we encode it as signed. + let shift = imm12.as_i16() & 0x3f; + let imm6 = Imm6::maybe_from_i16(shift << 10 >> 10).unwrap(); + sink.put2(encode_ci_type(CiOp::CSlli, rd, imm6)); + } _ => return false, } @@ -2045,7 +2143,10 @@ impl Inst { &Inst::Udf { trap_code } => { sink.add_trap(trap_code); if let Some(s) = state.take_stack_map() { - sink.add_stack_map(StackMapExtent::UpcomingBytes(4), s); + sink.add_stack_map( + StackMapExtent::UpcomingBytes(Inst::TRAP_OPCODE.len() as u32), + s, + ); } sink.put_data(Inst::TRAP_OPCODE); } diff --git a/cranelift/codegen/src/isa/riscv64/inst/encode.rs b/cranelift/codegen/src/isa/riscv64/inst/encode.rs index aafb686a8883..2a06b1093a5f 100644 --- a/cranelift/codegen/src/isa/riscv64/inst/encode.rs +++ b/cranelift/codegen/src/isa/riscv64/inst/encode.rs @@ -9,8 +9,8 @@ use super::*; use crate::isa::riscv64::inst::reg_to_gpr_num; use crate::isa::riscv64::lower::isle::generated_code::{ - CaOp, CjOp, CrOp, VecAluOpRImm5, VecAluOpRR, VecAluOpRRImm5, VecAluOpRRR, VecAluOpRRRImm5, - VecAluOpRRRR, VecElementWidth, VecOpCategory, VecOpMasking, + CaOp, CiOp, CiwOp, CjOp, CrOp, VecAluOpRImm5, VecAluOpRR, VecAluOpRRImm5, VecAluOpRRR, + VecAluOpRRRImm5, VecAluOpRRRR, VecElementWidth, VecOpCategory, VecOpMasking, }; use crate::machinst::isle::WritableReg; use crate::Reg; @@ -388,3 +388,57 @@ pub fn encode_cj_type(op: CjOp, imm: Imm12) -> u16 { bits |= unsigned_field_width(op.funct3(), 3) << 13; bits.try_into().unwrap() } + +// Encode a CI type instruction. +// +// The imm field is a 6 bit signed immediate. +// +// 0--1-2-------6-7-------11-12-----12-13-----15 +// |op | imm[4:0] | src | imm[5] | funct3 | +pub fn encode_ci_type(op: CiOp, rd: WritableReg, imm: Imm6) -> u16 { + let imm = imm.bits(); + + let mut bits = 0; + bits |= unsigned_field_width(op.op().bits(), 2); + bits |= unsigned_field_width((imm & 0x1f) as u32, 5) << 2; + bits |= reg_to_gpr_num(rd.to_reg()) << 7; + bits |= unsigned_field_width(((imm >> 5) & 1) as u32, 1) << 12; + bits |= unsigned_field_width(op.funct3(), 3) << 13; + bits.try_into().unwrap() +} + +/// c.addi16sp is a regular CI op, but the immediate field is encoded in a weird way +pub fn encode_c_addi16sp(imm: Imm6) -> u16 { + let imm = imm.bits(); + + // [6|1|3|5:4|2] + let mut enc_imm = 0; + enc_imm |= ((imm >> 5) & 1) << 5; + enc_imm |= ((imm >> 0) & 1) << 4; + enc_imm |= ((imm >> 2) & 1) << 3; + enc_imm |= ((imm >> 3) & 3) << 1; + enc_imm |= ((imm >> 1) & 1) << 0; + let enc_imm = Imm6::maybe_from_i16((enc_imm as i16) << 10 >> 10).unwrap(); + + encode_ci_type(CiOp::CAddi16sp, writable_stack_reg(), enc_imm) +} + +// Encode a CIW type instruction. +// +// 0--1-2------4-5------12-13--------15 +// |op | rd | imm | funct3 | +pub fn encode_ciw_type(op: CiwOp, rd: WritableReg, imm: u8) -> u16 { + // [3:2|7:4|0|1] + let mut imm_field = 0; + imm_field |= ((imm >> 1) & 1) << 0; + imm_field |= ((imm >> 0) & 1) << 1; + imm_field |= ((imm >> 4) & 7) << 2; + imm_field |= ((imm >> 2) & 3) << 6; + + let mut bits = 0; + bits |= unsigned_field_width(op.op().bits(), 2); + bits |= reg_to_compressed_gpr_num(rd.to_reg()) << 2; + bits |= unsigned_field_width(imm_field as u32, 8) << 5; + bits |= unsigned_field_width(op.funct3(), 3) << 13; + bits.try_into().unwrap() +} diff --git a/cranelift/codegen/src/isa/riscv64/inst/imms.rs b/cranelift/codegen/src/isa/riscv64/inst/imms.rs index c9ddb6e70a29..6f4d7075db70 100644 --- a/cranelift/codegen/src/isa/riscv64/inst/imms.rs +++ b/cranelift/codegen/src/isa/riscv64/inst/imms.rs @@ -167,6 +167,38 @@ impl Display for Imm5 { } } +/// A Signed 6-bit immediate. +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct Imm6 { + value: i8, +} + +impl Imm6 { + /// Create an signed 6-bit immediate from an i16 + pub fn maybe_from_i16(value: i16) -> Option { + if value >= -32 && value <= 31 { + Some(Self { value: value as i8 }) + } else { + None + } + } + + pub fn maybe_from_imm12(value: Imm12) -> Option { + Imm6::maybe_from_i16(value.as_i16()) + } + + /// Bits for encoding. + pub fn bits(&self) -> u8 { + self.value as u8 & 0x3f + } +} + +impl Display for Imm6 { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!(f, "{}", self.value) + } +} + impl Inst { pub(crate) fn imm_min() -> i64 { let imm20_max: i64 = (1 << 19) << 12; diff --git a/cranelift/filetests/filetests/isa/riscv64/zca.clif b/cranelift/filetests/filetests/isa/riscv64/zca.clif index 5852d450dc48..84ee5c541fb3 100644 --- a/cranelift/filetests/filetests/isa/riscv64/zca.clif +++ b/cranelift/filetests/filetests/isa/riscv64/zca.clif @@ -188,11 +188,11 @@ block0(v0: i8): ; ; Disassembled: ; block0: ; offset 0x0 -; addi sp, sp, -0x10 +; c.addi16sp sp, -0x10 ; sd ra, 8(sp) ; sd s0, 0(sp) ; c.mv s0, sp -; block1: ; offset 0xe +; block1: ; offset 0xc ; auipc a2, 0 ; ld a2, 0xa(a2) ; c.j 0xa @@ -241,14 +241,200 @@ block0(v0: i64, v1: i64): ; ; Disassembled: ; block0: ; offset 0x0 -; addi sp, sp, -0x10 +; c.addi16sp sp, -0x10 ; sd ra, 8(sp) ; sd s0, 0(sp) ; c.mv s0, sp -; block1: ; offset 0xe +; block1: ; offset 0xc ; c.jalr a1 ; ld ra, 8(sp) ; ld s0, 0(sp) -; addi sp, sp, 0x10 +; c.addi16sp sp, 0x10 +; c.jr ra + +function %c_ebreak() { +block0: + debugtrap + return +} + +; VCode: +; block0: +; ebreak +; ret +; +; Disassembled: +; block0: ; offset 0x0 +; c.ebreak +; c.jr ra + +function %c_unimp() { +block0: + trap user0 +} + +; VCode: +; block0: +; udf##trap_code=user0 +; +; Disassembled: +; block0: ; offset 0x0 +; c.unimp ; trap: user0 + + +function %c_addi_max(i64) -> i64 { +block0(v0: i64): + v2 = iadd_imm.i64 v0, 31 + return v2 +} + +; VCode: +; block0: +; addi a0,a0,31 +; ret +; +; Disassembled: +; block0: ; offset 0x0 +; c.addi a0, 0x1f +; c.jr ra + +function %c_addi_min(i64) -> i64 { +block0(v0: i64): + v2 = iadd_imm.i64 v0, -32 + return v2 +} + +; VCode: +; block0: +; addi a0,a0,-32 +; ret +; +; Disassembled: +; block0: ; offset 0x0 +; c.addi a0, -0x20 +; c.jr ra + +function %c_sext_w(i32) -> i64 { +block0(v0: i32): + v1 = sextend.i64 v0 + return v1 +} + +; VCode: +; block0: +; sext.w a0,a0 +; ret +; +; Disassembled: +; block0: ; offset 0x0 +; c.addiw a0, 0 +; c.jr ra + +function %c_addiw(i32) -> i64 { +block0(v0: i32): + v1 = iadd_imm.i32 v0, -32 + v2 = sextend.i64 v1 + return v2 +} + +; VCode: +; block0: +; addiw a0,a0,-32 +; ret +; +; Disassembled: +; block0: ; offset 0x0 +; c.addiw a0, -0x20 +; c.jr ra + +function %c_addi16sp() -> i64 { + ss0 = explicit_slot 8 + +block0: + v0 = stack_addr.i64 ss0 + return v0 +} + +; VCode: +; add sp,-16 +; sd ra,8(sp) +; sd fp,0(sp) +; mv fp,sp +; add sp,-16 +; block0: +; load_addr a0,0(nominal_sp) +; add sp,+16 +; ld ra,8(sp) +; ld fp,0(sp) +; add sp,+16 +; ret +; +; Disassembled: +; block0: ; offset 0x0 +; c.addi16sp sp, -0x10 +; sd ra, 8(sp) +; sd s0, 0(sp) +; c.mv s0, sp +; c.addi16sp sp, -0x10 +; block1: ; offset 0xe +; c.mv a0, sp +; c.addi16sp sp, 0x10 +; ld ra, 8(sp) +; ld s0, 0(sp) +; c.addi16sp sp, 0x10 +; c.jr ra + +function %c_slli(i64) -> i64 { +block0(v0: i64): + v1 = ishl_imm.i64 v0, 63 + return v1 +} + +; VCode: +; block0: +; slli a0,a0,63 +; ret +; +; Disassembled: +; block0: ; offset 0x0 +; c.slli a0, 0x3f +; c.jr ra + + +function %c_addi4spn() -> i64 { + ss0 = explicit_slot 64 + +block0: + v0 = stack_addr.i64 ss0+24 + return v0 +} + +; VCode: +; add sp,-16 +; sd ra,8(sp) +; sd fp,0(sp) +; mv fp,sp +; add sp,-64 +; block0: +; load_addr a0,24(nominal_sp) +; add sp,+64 +; ld ra,8(sp) +; ld fp,0(sp) +; add sp,+16 +; ret +; +; Disassembled: +; block0: ; offset 0x0 +; c.addi16sp sp, -0x10 +; sd ra, 8(sp) +; sd s0, 0(sp) +; c.mv s0, sp +; c.addi16sp sp, -0x40 +; block1: ; offset 0xe +; c.addi4spn a0, sp, 0x18 +; c.addi16sp sp, 0x40 +; ld ra, 8(sp) +; ld s0, 0(sp) +; c.addi16sp sp, 0x10 ; c.jr ra From 5359a8c2c59c3c89a0b0a0584255791f87314e69 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Mon, 18 Sep 2023 22:46:21 -0500 Subject: [PATCH 06/14] Force usage of a worker thread for stdin on all platforms (#7058) This commit is a follow-up to #6833 to remove the `unix` module for handling stdio which sets stdin to nonblocking mode. I've just now discovered that on macOS at least configuring `O_NONBLOCK` for stdin affects the stdout/stderr descriptors too. This program for example will panic: fn main() { unsafe { let r = libc::fcntl( libc::STDIN_FILENO, libc::F_SETFL, libc::fcntl(libc::STDIN_FILENO, libc::F_GETFL) | libc::O_NONBLOCK, ); assert_eq!(r, 0); } loop { println!("hello"); } } It was originally assumed that updating the flags for stdin wouldn't affect anything else except Wasmtime, but because this looks to not be the case this commit removes the logic of registering stdin raw with Tokio and instead unconditionally using the worker thread solution which should work in all situations. --- crates/wasi/src/preview2/stdio.rs | 7 -- crates/wasi/src/preview2/stdio/unix.rs | 160 ------------------------- 2 files changed, 167 deletions(-) delete mode 100644 crates/wasi/src/preview2/stdio/unix.rs diff --git a/crates/wasi/src/preview2/stdio.rs b/crates/wasi/src/preview2/stdio.rs index ca8453a3b999..f385329c1a37 100644 --- a/crates/wasi/src/preview2/stdio.rs +++ b/crates/wasi/src/preview2/stdio.rs @@ -8,14 +8,7 @@ use crate::preview2::{HostOutputStream, OutputStreamError, WasiView}; use bytes::Bytes; use is_terminal::IsTerminal; -#[cfg(unix)] -mod unix; -#[cfg(unix)] -pub use self::unix::{stdin, Stdin}; - -#[allow(dead_code)] mod worker_thread_stdin; -#[cfg(windows)] pub use self::worker_thread_stdin::{stdin, Stdin}; // blocking-write-and-flush must accept 4k. It doesn't seem likely that we need to diff --git a/crates/wasi/src/preview2/stdio/unix.rs b/crates/wasi/src/preview2/stdio/unix.rs deleted file mode 100644 index e43e64b8d6dc..000000000000 --- a/crates/wasi/src/preview2/stdio/unix.rs +++ /dev/null @@ -1,160 +0,0 @@ -use super::worker_thread_stdin; -use crate::preview2::{pipe::AsyncReadStream, HostInputStream, StreamState}; -use anyhow::Error; -use bytes::Bytes; -use futures::ready; -use std::future::Future; -use std::io::{self, Read}; -use std::pin::Pin; -use std::sync::{Arc, Mutex, OnceLock}; -use std::task::{Context, Poll}; -use tokio::io::unix::AsyncFd; -use tokio::io::{AsyncRead, Interest, ReadBuf}; - -// We need a single global instance of the AsyncFd because creating -// this instance registers the process's stdin fd with epoll, which will -// return an error if an fd is registered more than once. -static STDIN: OnceLock = OnceLock::new(); - -#[derive(Clone)] -pub enum Stdin { - // The process's standard input can be successfully registered with `epoll`, - // so it's tracked by a native async stream. - Async(Arc>), - - // The process's stdin can't be registered with epoll, for example it's a - // file on Linux or `/dev/null` on macOS. The fallback implementation of a - // worker thread is used in these situations. - Blocking(worker_thread_stdin::Stdin), -} - -pub fn stdin() -> Stdin { - fn init_stdin() -> anyhow::Result { - use crate::preview2::RUNTIME; - match tokio::runtime::Handle::try_current() { - Ok(_) => Ok(AsyncReadStream::new(InnerStdin::new()?)), - Err(_) => { - let _enter = RUNTIME.enter(); - RUNTIME.block_on(async { Ok(AsyncReadStream::new(InnerStdin::new()?)) }) - } - } - } - - let handle = STDIN - .get_or_init(|| match init_stdin() { - Ok(stream) => Stdin::Async(Arc::new(Mutex::new(stream))), - Err(_) => Stdin::Blocking(worker_thread_stdin::stdin()), - }) - .clone(); - - if let Stdin::Async(stream) = &handle { - let mut guard = stream.lock().unwrap(); - - // The backing task exited. This can happen in two cases: - // - // 1. the task crashed - // 2. the runtime has exited and been restarted in the same process - // - // As we can't tell the difference between these two, we assume the latter and restart the - // task. - if guard.join_handle.is_finished() { - *guard = init_stdin().unwrap(); - } - } - - handle -} - -impl is_terminal::IsTerminal for Stdin { - fn is_terminal(&self) -> bool { - std::io::stdin().is_terminal() - } -} - -#[async_trait::async_trait] -impl crate::preview2::HostInputStream for Stdin { - fn read(&mut self, size: usize) -> Result<(Bytes, StreamState), Error> { - match self { - Stdin::Async(s) => HostInputStream::read(&mut *s.lock().unwrap(), size), - Stdin::Blocking(s) => s.read(size), - } - } - - async fn ready(&mut self) -> Result<(), Error> { - match self { - Stdin::Async(handle) => { - // Custom Future impl takes the std mutex in each invocation of poll. - // Required so we don't have to use a tokio mutex, which we can't take from - // inside a sync context in Self::read. - // - // Taking the lock, creating a fresh ready() future, polling it once, and - // then releasing the lock is acceptable here because the ready() future - // is only ever going to await on a single channel recv, plus some management - // of a state machine (for buffering). - struct Ready<'a> { - handle: &'a Arc>, - } - impl<'a> Future for Ready<'a> { - type Output = Result<(), Error>; - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let mut locked = self.handle.lock().unwrap(); - let fut = locked.ready(); - tokio::pin!(fut); - fut.poll(cx) - } - } - Ready { handle }.await - } - Stdin::Blocking(s) => s.ready().await, - } - } -} - -struct InnerStdin { - inner: AsyncFd, -} - -impl InnerStdin { - pub fn new() -> anyhow::Result { - use rustix::fs::OFlags; - use std::os::fd::AsRawFd; - - let stdin = std::io::stdin(); - - let borrowed_fd = unsafe { rustix::fd::BorrowedFd::borrow_raw(stdin.as_raw_fd()) }; - let flags = rustix::fs::fcntl_getfl(borrowed_fd)?; - if !flags.contains(OFlags::NONBLOCK) { - rustix::fs::fcntl_setfl(borrowed_fd, flags.union(OFlags::NONBLOCK))?; - } - - Ok(Self { - inner: AsyncFd::with_interest(stdin, Interest::READABLE)?, - }) - } -} - -impl AsyncRead for InnerStdin { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - loop { - let mut guard = ready!(self.inner.poll_read_ready_mut(cx))?; - - let unfilled = buf.initialize_unfilled(); - match guard.try_io(|inner| inner.get_mut().read(unfilled)) { - Ok(Ok(len)) => { - buf.advance(len); - return Poll::Ready(Ok(())); - } - Ok(Err(err)) => { - return Poll::Ready(Err(err)); - } - Err(_would_block) => { - continue; - } - } - } - } -} From a4d036ca8e9ea56c23521ade7c847000d624a368 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Tue, 19 Sep 2023 17:49:51 -0500 Subject: [PATCH 07/14] wasi: Fix a few issues around stdin (#7063) * wasi: Fix a few issues around stdin This commit is intended to address #6986 and some other issues related to stdin and reading it, notably: * Previously once EOF was reached the `closed` flag was mistakenly not set. * Previously data would be infinitely buffered regardless of how fast the guest program would consume it. * Previously stdin would be immediately ready by Wasmtime regardless of whether the guest wanted to read stdin or not. * The host-side preview1-to-preview2 adapter didn't perform a blocking read meaning that it never blocked. These issues are addressed by refactoring the code in question. Note that this is similar to the logic of `AsyncReadStream` somewhat but that type is not appropriate in this context due to the singleton nature of stdin meaning that the per-stream helper task and per-stream buffer of `AsyncReadStream` are not appropriate. Closees #6986 * Increase slop size for windows --- crates/wasi/src/preview2/pipe.rs | 5 +- crates/wasi/src/preview2/preview1.rs | 2 +- .../src/preview2/stdio/worker_thread_stdin.rs | 213 ++++++++++-------- tests/all/cli_tests.rs | 86 ++++++- tests/all/cli_tests/count-stdin.wat | 59 +++++ 5 files changed, 267 insertions(+), 98 deletions(-) create mode 100644 tests/all/cli_tests/count-stdin.wat diff --git a/crates/wasi/src/preview2/pipe.rs b/crates/wasi/src/preview2/pipe.rs index d4878b3b4df7..a01215c11323 100644 --- a/crates/wasi/src/preview2/pipe.rs +++ b/crates/wasi/src/preview2/pipe.rs @@ -120,8 +120,7 @@ pub struct AsyncReadStream { state: StreamState, buffer: Option>, receiver: mpsc::Receiver>, - #[allow(unused)] // just used to implement unix stdin - pub(crate) join_handle: crate::preview2::AbortOnDropJoinHandle<()>, + _join_handle: crate::preview2::AbortOnDropJoinHandle<()>, } impl AsyncReadStream { @@ -150,7 +149,7 @@ impl AsyncReadStream { state: StreamState::Open, buffer: None, receiver, - join_handle, + _join_handle: join_handle, } } } diff --git a/crates/wasi/src/preview2/preview1.rs b/crates/wasi/src/preview2/preview1.rs index fb90c6a98d27..0ad722bbed6c 100644 --- a/crates/wasi/src/preview2/preview1.rs +++ b/crates/wasi/src/preview2/preview1.rs @@ -1336,7 +1336,7 @@ impl< return Ok(0); }; let (read, state) = stream_res( - streams::Host::read( + streams::Host::blocking_read( self, input_stream, buf.len().try_into().unwrap_or(u64::MAX), diff --git a/crates/wasi/src/preview2/stdio/worker_thread_stdin.rs b/crates/wasi/src/preview2/stdio/worker_thread_stdin.rs index c5be39b6dfb8..e1505dff844b 100644 --- a/crates/wasi/src/preview2/stdio/worker_thread_stdin.rs +++ b/crates/wasi/src/preview2/stdio/worker_thread_stdin.rs @@ -1,91 +1,103 @@ +//! Handling for standard in using a worker task. +//! +//! Standard input is a global singleton resource for the entire program which +//! needs special care. Currently this implementation adheres to a few +//! constraints which make this nontrivial to implement. +//! +//! * Any number of guest wasm programs can read stdin. While this doesn't make +//! a ton of sense semantically they shouldn't block forever. Instead it's a +//! race to see who actually reads which parts of stdin. +//! +//! * Data from stdin isn't actually read unless requested. This is done to try +//! to be a good neighbor to others running in the process. Under the +//! assumption that most programs have one "thing" which reads stdin the +//! actual consumption of bytes is delayed until the wasm guest is dynamically +//! chosen to be that "thing". Before that data from stdin is not consumed to +//! avoid taking it from other components in the process. +//! +//! * Tokio's documentation indicates that "interactive stdin" is best done with +//! a helper thread to avoid blocking shutdown of the event loop. That's +//! respected here where all stdin reading happens on a blocking helper thread +//! that, at this time, is never shut down. +//! +//! This module is one that's likely to change over time though as new systems +//! are encountered along with preexisting bugs. + use crate::preview2::{HostInputStream, StreamState}; use anyhow::Error; use bytes::{Bytes, BytesMut}; use std::io::Read; -use std::sync::Arc; -use tokio::sync::watch; - -// wasmtime cant use std::sync::OnceLock yet because of a llvm regression in -// 1.70. when 1.71 is released, we can switch to using std here. -use once_cell::sync::OnceCell as OnceLock; - -use std::sync::Mutex; +use std::mem; +use std::sync::{Condvar, Mutex, OnceLock}; +use tokio::sync::Notify; +#[derive(Default)] struct GlobalStdin { - // Worker thread uses this to notify of new events. Ready checks use this - // to create a new Receiver via .subscribe(). The newly created receiver - // will only wait for events created after the call to subscribe(). - tx: Arc>, - // Worker thread and receivers share this state to get bytes read off - // stdin, or the error/closed state. - state: Arc>, + state: Mutex, + read_requested: Condvar, + read_completed: Notify, } -#[derive(Debug)] -struct StdinState { - // Bytes read off stdin. - buffer: BytesMut, - // Error read off stdin, if any. - error: Option, - // If an error has occured in the past, we consider the stream closed. - closed: bool, +#[derive(Default, Debug)] +enum StdinState { + #[default] + ReadNotRequested, + ReadRequested, + Data(BytesMut), + Error(std::io::Error), + Closed, } -static STDIN: OnceLock = OnceLock::new(); +impl GlobalStdin { + fn get() -> &'static GlobalStdin { + static STDIN: OnceLock = OnceLock::new(); + STDIN.get_or_init(|| create()) + } +} fn create() -> GlobalStdin { - let (tx, _rx) = watch::channel(()); - let tx = Arc::new(tx); - - let state = Arc::new(Mutex::new(StdinState { - buffer: BytesMut::new(), - error: None, - closed: false, - })); - - let ret = GlobalStdin { - state: state.clone(), - tx: tx.clone(), - }; - - std::thread::spawn(move || loop { - let mut bytes = BytesMut::zeroed(1024); - match std::io::stdin().lock().read(&mut bytes) { - // Reading `0` indicates that stdin has reached EOF, so we break - // the loop to allow the thread to exit. - Ok(0) => break, - - Ok(nbytes) => { - // Append to the buffer: - bytes.truncate(nbytes); - let mut locked = state.lock().unwrap(); - locked.buffer.extend_from_slice(&bytes); - } - Err(e) => { - // Set the error, and mark the stream as closed: - let mut locked = state.lock().unwrap(); - if locked.error.is_none() { - locked.error = Some(e) + std::thread::spawn(|| { + let state = GlobalStdin::get(); + loop { + // Wait for a read to be requested, but don't hold the lock across + // the blocking read. + let mut lock = state.state.lock().unwrap(); + lock = state + .read_requested + .wait_while(lock, |state| !matches!(state, StdinState::ReadRequested)) + .unwrap(); + drop(lock); + + let mut bytes = BytesMut::zeroed(1024); + let (new_state, done) = match std::io::stdin().read(&mut bytes) { + Ok(0) => (StdinState::Closed, true), + Ok(nbytes) => { + bytes.truncate(nbytes); + (StdinState::Data(bytes), false) } - locked.closed = true; + Err(e) => (StdinState::Error(e), true), + }; + + // After the blocking read completes the state should not have been + // tampered with. + debug_assert!(matches!( + *state.state.lock().unwrap(), + StdinState::ReadRequested + )); + *state.state.lock().unwrap() = new_state; + state.read_completed.notify_waiters(); + if done { + break; } } - // Receivers may or may not exist - fine if they dont, new - // ones will be created with subscribe() - let _ = tx.send(()); }); - ret + + GlobalStdin::default() } /// Only public interface is the [`HostInputStream`] impl. #[derive(Clone)] pub struct Stdin; -impl Stdin { - // Private! Only required internally. - fn get_global() -> &'static GlobalStdin { - STDIN.get_or_init(|| create()) - } -} pub fn stdin() -> Stdin { Stdin @@ -100,40 +112,55 @@ impl is_terminal::IsTerminal for Stdin { #[async_trait::async_trait] impl HostInputStream for Stdin { fn read(&mut self, size: usize) -> Result<(Bytes, StreamState), Error> { - let g = Stdin::get_global(); + let g = GlobalStdin::get(); let mut locked = g.state.lock().unwrap(); - - if let Some(e) = locked.error.take() { - return Err(e.into()); + match mem::replace(&mut *locked, StdinState::ReadRequested) { + StdinState::ReadNotRequested => { + g.read_requested.notify_one(); + Ok((Bytes::new(), StreamState::Open)) + } + StdinState::ReadRequested => Ok((Bytes::new(), StreamState::Open)), + StdinState::Data(mut data) => { + let size = data.len().min(size); + let bytes = data.split_to(size); + *locked = if data.is_empty() { + StdinState::ReadNotRequested + } else { + StdinState::Data(data) + }; + Ok((bytes.freeze(), StreamState::Open)) + } + StdinState::Error(e) => { + *locked = StdinState::Closed; + return Err(e.into()); + } + StdinState::Closed => { + *locked = StdinState::Closed; + Ok((Bytes::new(), StreamState::Closed)) + } } - let size = locked.buffer.len().min(size); - let bytes = locked.buffer.split_to(size); - let state = if locked.buffer.is_empty() && locked.closed { - StreamState::Closed - } else { - StreamState::Open - }; - Ok((bytes.freeze(), state)) } async fn ready(&mut self) -> Result<(), Error> { - let g = Stdin::get_global(); - - // Block makes sure we dont hold the mutex across the await: - let mut rx = { - let locked = g.state.lock().unwrap(); - // read() will only return (empty, open) when the buffer is empty, - // AND there is no error AND the stream is still open: - if !locked.buffer.is_empty() || locked.error.is_some() || locked.closed { - return Ok(()); + let g = GlobalStdin::get(); + + // Scope the synchronous `state.lock()` to this block which does not + // `.await` inside of it. + let notified = { + let mut locked = g.state.lock().unwrap(); + match *locked { + // If a read isn't requested yet + StdinState::ReadNotRequested => { + g.read_requested.notify_one(); + *locked = StdinState::ReadRequested; + g.read_completed.notified() + } + StdinState::ReadRequested => g.read_completed.notified(), + StdinState::Data(_) | StdinState::Closed | StdinState::Error(_) => return Ok(()), } - // Sender will take the mutex before updating the state of - // subscribe, so this ensures we will only await for any stdin - // events that are recorded after we drop the mutex: - g.tx.subscribe() }; - rx.changed().await.expect("impossible for sender to drop"); + notified.await; Ok(()) } diff --git a/tests/all/cli_tests.rs b/tests/all/cli_tests.rs index cf6ce1f751b0..de5f2cb06a03 100644 --- a/tests/all/cli_tests.rs +++ b/tests/all/cli_tests.rs @@ -4,7 +4,7 @@ use anyhow::{bail, Result}; use std::fs::File; use std::io::Write; use std::path::Path; -use std::process::{Command, Output}; +use std::process::{Command, Output, Stdio}; use tempfile::{NamedTempFile, TempDir}; // Run the wasmtime CLI with the provided args and return the `Output`. @@ -932,3 +932,87 @@ fn option_group_boolean_parsing() -> Result<()> { ])?; Ok(()) } + +#[test] +fn preview2_stdin() -> Result<()> { + let test = "tests/all/cli_tests/count-stdin.wat"; + let cmd = || -> Result<_> { + let mut cmd = get_wasmtime_command()?; + cmd.arg("--invoke=count").arg("-Spreview2").arg(test); + Ok(cmd) + }; + + // read empty pipe is ok + let output = cmd()?.output()?; + assert!(output.status.success()); + assert_eq!(String::from_utf8_lossy(&output.stdout), "0\n"); + + // read itself is ok + let file = File::open(test)?; + let size = file.metadata()?.len(); + let output = cmd()?.stdin(File::open(test)?).output()?; + assert!(output.status.success()); + assert_eq!(String::from_utf8_lossy(&output.stdout), format!("{size}\n")); + + // read piped input ok is ok + let mut child = cmd()? + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn()?; + let mut stdin = child.stdin.take().unwrap(); + std::thread::spawn(move || { + stdin.write_all(b"hello").unwrap(); + }); + let output = child.wait_with_output()?; + assert!(output.status.success()); + assert_eq!(String::from_utf8_lossy(&output.stdout), "5\n"); + + let count_up_to = |n: usize| -> Result<_> { + let mut child = get_wasmtime_command()? + .arg("--invoke=count-up-to") + .arg("-Spreview2") + .arg(test) + .arg(n.to_string()) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn()?; + let mut stdin = child.stdin.take().unwrap(); + let t = std::thread::spawn(move || { + let mut written = 0; + let bytes = [0; 64 * 1024]; + loop { + written += match stdin.write(&bytes) { + Ok(n) => n, + Err(_) => break written, + }; + } + }); + let output = child.wait_with_output()?; + assert!(output.status.success()); + let written = t.join().unwrap(); + let read = String::from_utf8_lossy(&output.stdout) + .trim() + .parse::() + .unwrap(); + // The test reads in 1000 byte chunks so make sure that it doesn't read + // more than 1000 bytes than requested. + assert!(read < n + 1000, "test read too much {read}"); + Ok(written) + }; + + // wasmtime shouldn't eat information that the guest never actually tried to + // read. + // + // NB: this may be a bit flaky. Exactly how much we wrote in the above + // helper thread depends on how much the OS buffers for us. For now give + // some some slop and assume that OSes are unlikely to buffer more than + // that. + let slop = 256 * 1024; + for amt in [0, 100, 100_000] { + let written = count_up_to(amt)?; + assert!(written < slop + amt, "wrote too much {written}"); + } + Ok(()) +} diff --git a/tests/all/cli_tests/count-stdin.wat b/tests/all/cli_tests/count-stdin.wat new file mode 100644 index 000000000000..80d1a03890c3 --- /dev/null +++ b/tests/all/cli_tests/count-stdin.wat @@ -0,0 +1,59 @@ +(module + (import "wasi_snapshot_preview1" "fd_read" + (func $read (param i32 i32 i32 i32) (result i32))) + + (memory (export "memory") 1) + + (func (export "count") (result i32) + (call $count-up-to (i32.const -1)) + ) + + (func $count-up-to (export "count-up-to") (param $up-to i32) (result i32) + (local $size i32) + + (i32.eqz (local.get $up-to)) + if + local.get 0 + return + end + loop $the-loop + ;; setup a basic ciovec pointing into memory + (i32.store + (i32.const 100) + (i32.const 200)) + (i32.store + (i32.const 104) + (i32.const 1000)) + + + (call $read + (i32.const 0) ;; stdin fileno + (i32.const 100) ;; ciovec base + (i32.const 1) ;; ciovec len + (i32.const 8) ;; ret val ptr + ) + ;; reading stdin must succeed (e.g. return 0) + if unreachable end + + ;; update with how many bytes were read + (local.set $size + (i32.add + (local.get $size) + (i32.load (i32.const 8)))) + + + ;; if no data was read, exit the loop + ;; if the size read exceeds what we're supposed to read, also exit the + ;; loop + (i32.load (i32.const 8)) + if + (i32.lt_u (local.get $size) (local.get $up-to)) + if + br $the-loop + end + end + end + + local.get $size + ) +) From 2d43a28fd55195664f4a5798248d4be3f3a6773d Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Wed, 20 Sep 2023 09:35:40 -0500 Subject: [PATCH 08/14] c-api: Expose image_range for modules (#7064) We're using this to monitor the amount of executable memory each module needs. Signed-off-by: Tyler Rockwood --- crates/c-api/include/wasmtime/module.h | 15 +++++++++++++++ crates/c-api/src/module.rs | 11 +++++++++++ 2 files changed, 26 insertions(+) diff --git a/crates/c-api/include/wasmtime/module.h b/crates/c-api/include/wasmtime/module.h index deb6bcec1efa..8eca72a4899e 100644 --- a/crates/c-api/include/wasmtime/module.h +++ b/crates/c-api/include/wasmtime/module.h @@ -148,6 +148,21 @@ WASM_API_EXTERN wasmtime_error_t *wasmtime_module_deserialize_file( wasmtime_module_t **ret ); + +/** + * \brief Returns the range of bytes in memory where this module’s compilation image resides. + * + * The compilation image for a module contains executable code, data, debug information, etc. + * This is roughly the same as the wasmtime_module_serialize but not the exact same. + * + * For more details see: https://docs.wasmtime.dev/api/wasmtime/struct.Module.html#method.image_range + */ +WASM_API_EXTERN void wasmtime_module_image_range( + const wasm_module_t *module, + size_t *start, + size_t *end +); + #ifdef __cplusplus } // extern "C" #endif diff --git a/crates/c-api/src/module.rs b/crates/c-api/src/module.rs index a1dce264931d..3a1a6e6664cc 100644 --- a/crates/c-api/src/module.rs +++ b/crates/c-api/src/module.rs @@ -183,6 +183,17 @@ pub extern "C" fn wasmtime_module_serialize( handle_result(module.module.serialize(), |buf| ret.set_buffer(buf)) } +#[no_mangle] +pub extern "C" fn wasmtime_module_image_range( + module: &wasmtime_module_t, + start: &mut usize, + end: &mut usize, +) { + let range = module.module.image_range(); + *start = range.start; + *end = range.end; +} + #[no_mangle] pub unsafe extern "C" fn wasmtime_module_deserialize( engine: &wasm_engine_t, From f2b43d84317654cc3664602cf35f74eca4ecd428 Mon Sep 17 00:00:00 2001 From: wasmtime-publish <59749941+wasmtime-publish@users.noreply.github.com> Date: Wed, 20 Sep 2023 09:35:59 -0500 Subject: [PATCH 09/14] Update release date of Wasmtime 13.0.0 (#7066) Co-authored-by: Wasmtime Publish --- RELEASES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index de377e1243cc..6d1ba437dda8 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -12,7 +12,7 @@ Unreleased. ## 13.0.0 -Unreleased. +Released 2023-09-20 ### Added From 04b09c80ee3909fedc8a1b038ea4e8b85c6219b8 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Wed, 20 Sep 2023 11:06:35 -0500 Subject: [PATCH 10/14] Fix character boundary issues in preview1 host adapter (#7011) * Fix character boundary issues in preview1 host adapter This fixes two separate issues in the preview1-to-preview2 host-side adapter in the `wasmtime-wasi` crate. Both instances were copying a truncated string to the guest program but the truncation was happening in the middle of a unicode character so truncation now happens on bytes instead of a string. * Update `write_bytes` to take `&[u8]` Try to weed out any accidental string-related issues for the future. --- .../wasi-tests/src/bin/fd_readdir.rs | 35 +++++++++++++++++++ .../wasi-tests/src/bin/readlink.rs | 25 +++++++++++++ crates/wasi/src/preview2/preview1.rs | 31 +++++++++------- 3 files changed, 78 insertions(+), 13 deletions(-) diff --git a/crates/test-programs/wasi-tests/src/bin/fd_readdir.rs b/crates/test-programs/wasi-tests/src/bin/fd_readdir.rs index 0238292558df..3045f166b0c6 100644 --- a/crates/test-programs/wasi-tests/src/bin/fd_readdir.rs +++ b/crates/test-programs/wasi-tests/src/bin/fd_readdir.rs @@ -205,6 +205,40 @@ unsafe fn test_fd_readdir_lots(dir_fd: wasi::Fd) { } } +unsafe fn test_fd_readdir_unicode_boundary(dir_fd: wasi::Fd) { + let filename = "Действие"; + let file_fd = wasi::path_open( + dir_fd, + 0, + filename, + wasi::OFLAGS_CREAT, + wasi::RIGHTS_FD_READ | wasi::RIGHTS_FD_WRITE, + 0, + 0, + ) + .expect("failed to create file"); + assert!( + file_fd > libc::STDERR_FILENO as wasi::Fd, + "file descriptor range check", + ); + wasi::fd_close(file_fd).expect("closing a file"); + + let mut buf = Vec::new(); + 'outer: loop { + let len = wasi::fd_readdir(dir_fd, buf.as_mut_ptr(), buf.capacity(), 0).unwrap(); + buf.set_len(len); + + for entry in ReadDir::from_slice(&buf) { + if entry.name == filename { + break 'outer; + } + } + buf = Vec::with_capacity(buf.capacity() + 1); + } + + wasi::path_unlink_file(dir_fd, filename).expect("removing a file"); +} + fn main() { let mut args = env::args(); let prog = args.next().unwrap(); @@ -227,4 +261,5 @@ fn main() { // Run the tests. unsafe { test_fd_readdir(dir_fd) } unsafe { test_fd_readdir_lots(dir_fd) } + unsafe { test_fd_readdir_unicode_boundary(dir_fd) } } diff --git a/crates/test-programs/wasi-tests/src/bin/readlink.rs b/crates/test-programs/wasi-tests/src/bin/readlink.rs index eb8b8aa5f734..f6bcc309587b 100644 --- a/crates/test-programs/wasi-tests/src/bin/readlink.rs +++ b/crates/test-programs/wasi-tests/src/bin/readlink.rs @@ -32,6 +32,30 @@ unsafe fn test_readlink(dir_fd: wasi::Fd) { wasi::path_unlink_file(dir_fd, "symlink").expect("removing a file"); } +unsafe fn test_incremental_readlink(dir_fd: wasi::Fd) { + let filename = "Действие"; + create_file(dir_fd, filename); + + wasi::path_symlink(filename, dir_fd, "symlink").expect("creating a symlink"); + + let mut buf = Vec::new(); + loop { + if buf.capacity() > 2 * filename.len() { + panic!() + } + let bufused = wasi::path_readlink(dir_fd, "symlink", buf.as_mut_ptr(), buf.capacity()) + .expect("readlink should succeed"); + buf.set_len(bufused); + if buf.capacity() > filename.len() { + assert!(buf.starts_with(filename.as_bytes())); + break; + } + buf = Vec::with_capacity(buf.capacity() + 1); + } + wasi::path_unlink_file(dir_fd, filename).expect("removing a file"); + wasi::path_unlink_file(dir_fd, "symlink").expect("removing a file"); +} + fn main() { let mut args = env::args(); let prog = args.next().unwrap(); @@ -53,4 +77,5 @@ fn main() { // Run the tests. unsafe { test_readlink(dir_fd) } + unsafe { test_incremental_readlink(dir_fd) } } diff --git a/crates/wasi/src/preview2/preview1.rs b/crates/wasi/src/preview2/preview1.rs index 0ad722bbed6c..a44330b25b11 100644 --- a/crates/wasi/src/preview2/preview1.rs +++ b/crates/wasi/src/preview2/preview1.rs @@ -779,7 +779,7 @@ type Result = std::result::Result; fn write_bytes<'a>( ptr: impl Borrow>, - buf: impl AsRef<[u8]>, + buf: &[u8], ) -> Result, types::Error> { // NOTE: legacy implementation always returns Inval errno @@ -870,7 +870,7 @@ impl< argv.write(argv_buf)?; let argv = argv.add(1)?; - let argv_buf = write_bytes(argv_buf, arg)?; + let argv_buf = write_bytes(argv_buf, arg.as_bytes())?; let argv_buf = write_byte(argv_buf, 0)?; Ok((argv, argv_buf)) @@ -910,9 +910,9 @@ impl< environ.write(environ_buf)?; let environ = environ.add(1)?; - let environ_buf = write_bytes(environ_buf, k)?; + let environ_buf = write_bytes(environ_buf, k.as_bytes())?; let environ_buf = write_byte(environ_buf, b'=')?; - let environ_buf = write_bytes(environ_buf, v)?; + let environ_buf = write_bytes(environ_buf, v.as_bytes())?; let environ_buf = write_byte(environ_buf, 0)?; Ok((environ, environ_buf)) @@ -1514,7 +1514,7 @@ impl< if p.len() > path_max_len { return Err(types::Errno::Nametoolong.into()); } - write_bytes(path, p)?; + write_bytes(path, p.as_bytes())?; return Ok(()); } Err(types::Errno::Notdir.into()) // NOTE: legacy implementation returns NOTDIR here @@ -1672,7 +1672,8 @@ impl< ); let mut buf = *buf; let mut cap = buf_len; - for (ref entry, mut path) in head.into_iter().chain(dir.into_iter()).skip(cookie) { + for (ref entry, path) in head.into_iter().chain(dir.into_iter()).skip(cookie) { + let mut path = path.into_bytes(); assert_eq!( 1, size_of_val(&entry.d_type), @@ -1692,7 +1693,7 @@ impl< path.truncate(cap); } cap = cap.checked_sub(path.len() as _).unwrap(); - buf = write_bytes(buf, path)?; + buf = write_bytes(buf, &path)?; if cap == 0 { return Ok(buf_len); } @@ -1901,11 +1902,15 @@ impl< ) -> Result { let dirfd = self.get_dir_fd(dirfd)?; let path = read_string(path)?; - let mut path = self.readlink_at(dirfd, path).await.map_err(|e| { - e.try_into() - .context("failed to call `readlink-at`") - .unwrap_or_else(types::Error::trap) - })?; + let mut path = self + .readlink_at(dirfd, path) + .await + .map_err(|e| { + e.try_into() + .context("failed to call `readlink-at`") + .unwrap_or_else(types::Error::trap) + })? + .into_bytes(); if let Ok(buf_len) = buf_len.try_into() { // `path` cannot be longer than `usize`, only truncate if `buf_len` fits in `usize` path.truncate(buf_len); @@ -2031,7 +2036,7 @@ impl< .get_random_bytes(buf_len.into()) .context("failed to call `get-random-bytes`") .map_err(types::Error::trap)?; - write_bytes(buf, rand)?; + write_bytes(buf, &rand)?; Ok(()) } From b1511dceb27a3a446780542c37dd68d85a25f4e3 Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Wed, 20 Sep 2023 16:48:17 -0700 Subject: [PATCH 11/14] Fix documentation typos (#7071) * Fix documentation typos While working on integrating MPK into Wasmtime, I found these typos and line-wrapping issues. This fix is a separate commit to separate it from the already-complex MPK implementation. * fix: remove extra changes --- cranelift/wasm/src/heap.rs | 12 ++++++------ crates/environ/src/tunables.rs | 6 ++++-- crates/jit-icache-coherence/src/libc.rs | 2 +- crates/runtime/src/instance/allocator.rs | 11 +++++++---- crates/runtime/src/lib.rs | 4 ++-- crates/wasmtime/src/config.rs | 4 ++-- 6 files changed, 22 insertions(+), 17 deletions(-) diff --git a/cranelift/wasm/src/heap.rs b/cranelift/wasm/src/heap.rs index e71fae8e3d12..b253c62dcee0 100644 --- a/cranelift/wasm/src/heap.rs +++ b/cranelift/wasm/src/heap.rs @@ -31,8 +31,8 @@ entity_impl!(Heap, "heap"); /// always present. /// /// 2. The *unmapped pages* is a possibly empty range of address space that may -/// be mapped in the future when the heap is grown. They are addressable -/// but not accessible. +/// be mapped in the future when the heap is grown. They are addressable but +/// not accessible. /// /// 3. The *offset-guard pages* is a range of address space that is guaranteed /// to always cause a trap when accessed. It is used to optimize bounds @@ -48,10 +48,10 @@ entity_impl!(Heap, "heap"); /// /// #### Static heaps /// -/// A *static heap* starts out with all the address space it will ever need, so it -/// never moves to a different address. At the base address is a number of mapped -/// pages corresponding to the heap's current size. Then follows a number of -/// unmapped pages where the heap can grow up to its maximum size. After the +/// A *static heap* starts out with all the address space it will ever need, so +/// it never moves to a different address. At the base address is a number of +/// mapped pages corresponding to the heap's current size. Then follows a number +/// of unmapped pages where the heap can grow up to its maximum size. After the /// unmapped pages follow the offset-guard pages which are also guaranteed to /// generate a trap when accessed. /// diff --git a/crates/environ/src/tunables.rs b/crates/environ/src/tunables.rs index 06d6b5ca5207..7ba3e6fdfc9b 100644 --- a/crates/environ/src/tunables.rs +++ b/crates/environ/src/tunables.rs @@ -3,7 +3,8 @@ use serde_derive::{Deserialize, Serialize}; /// Tunable parameters for WebAssembly compilation. #[derive(Clone, Hash, Serialize, Deserialize)] pub struct Tunables { - /// For static heaps, the size in wasm pages of the heap protected by bounds checking. + /// For static heaps, the size in wasm pages of the heap protected by bounds + /// checking. pub static_memory_bound: u64, /// The size in bytes of the offset guard for static heaps. @@ -31,7 +32,8 @@ pub struct Tunables { /// Whether or not we use epoch-based interruption. pub epoch_interruption: bool, - /// Whether or not to treat the static memory bound as the maximum for unbounded heaps. + /// Whether or not to treat the static memory bound as the maximum for + /// unbounded heaps. pub static_memory_bound_is_maximum: bool, /// Whether or not linear memory allocations will have a guard region at the diff --git a/crates/jit-icache-coherence/src/libc.rs b/crates/jit-icache-coherence/src/libc.rs index 557cd06921a6..364658bd1813 100644 --- a/crates/jit-icache-coherence/src/libc.rs +++ b/crates/jit-icache-coherence/src/libc.rs @@ -104,7 +104,7 @@ fn riscv_flush_icache(start: u64, end: u64) -> Result<()> { match unsafe { libc::syscall( { - // The syscall isn't defined in `libc`, so we definfe the syscall number here. + // The syscall isn't defined in `libc`, so we define the syscall number here. // https://github.com/torvalds/linux/search?q=__NR_arch_specific_syscall #[allow(non_upper_case_globals)] const __NR_arch_specific_syscall :i64 = 244; diff --git a/crates/runtime/src/instance/allocator.rs b/crates/runtime/src/instance/allocator.rs index 6a17e14682f9..35c1201495d6 100644 --- a/crates/runtime/src/instance/allocator.rs +++ b/crates/runtime/src/instance/allocator.rs @@ -42,7 +42,8 @@ pub struct InstanceAllocationRequest<'a> { /// A pointer to the "store" for this instance to be allocated. The store /// correlates with the `Store` in wasmtime itself, and lots of contextual - /// information about the execution of wasm can be learned through the store. + /// information about the execution of wasm can be learned through the + /// store. /// /// Note that this is a raw pointer and has a static lifetime, both of which /// are a bit of a lie. This is done purely so a store can learn about @@ -172,7 +173,7 @@ pub unsafe trait InstanceAllocatorImpl { // associated types are not object safe. // // 2. We would want a parameterized `Drop` implementation so that we could - // pass in the `InstaceAllocatorImpl` on drop, but this doesn't exist in + // pass in the `InstanceAllocatorImpl` on drop, but this doesn't exist in // Rust. Therefore, we would be forced to add reference counting and // stuff like that to keep a handle on the instance allocator from this // theoretical type. That's a bummer. @@ -250,11 +251,13 @@ pub unsafe trait InstanceAllocatorImpl { #[cfg(feature = "async")] fn allocate_fiber_stack(&self) -> Result; - /// Deallocates a fiber stack that was previously allocated with `allocate_fiber_stack`. + /// Deallocates a fiber stack that was previously allocated with + /// `allocate_fiber_stack`. /// /// # Safety /// - /// The provided stack is required to have been allocated with `allocate_fiber_stack`. + /// The provided stack is required to have been allocated with + /// `allocate_fiber_stack`. #[cfg(feature = "async")] unsafe fn deallocate_fiber_stack(&self, stack: &wasmtime_fiber::FiberStack); diff --git a/crates/runtime/src/lib.rs b/crates/runtime/src/lib.rs index 340b780c51c1..ed2c4415ded6 100644 --- a/crates/runtime/src/lib.rs +++ b/crates/runtime/src/lib.rs @@ -182,8 +182,8 @@ pub trait ModuleRuntimeInfo: Send + Sync + 'static { /// not callable from outside the Wasm module itself. fn array_to_wasm_trampoline(&self, index: DefinedFuncIndex) -> Option; - /// Return the addres, in memory, of the trampoline that allows Wasm to call - /// a native function of the given signature. + /// Return the address, in memory, of the trampoline that allows Wasm to + /// call a native function of the given signature. fn wasm_to_native_trampoline( &self, signature: VMSharedSignatureIndex, diff --git a/crates/wasmtime/src/config.rs b/crates/wasmtime/src/config.rs index 69b7435cd6c9..2f657e462cd6 100644 --- a/crates/wasmtime/src/config.rs +++ b/crates/wasmtime/src/config.rs @@ -1908,7 +1908,7 @@ impl PoolingAllocationConfig { /// allocator additionally track an "affinity" flag to a particular core /// wasm module. When a module is instantiated into a slot then the slot is /// considered affine to that module, even after the instance has been - /// dealloocated. + /// deallocated. /// /// When a new instance is created then a slot must be chosen, and the /// current algorithm for selecting a slot is: @@ -1931,7 +1931,7 @@ impl PoolingAllocationConfig { /// impact of "unused slots" for a long-running wasm server. /// /// If this setting is set to 0, for example, then affine slots are - /// aggressively resused on a least-recently-used basis. A "cold" slot is + /// aggressively reused on a least-recently-used basis. A "cold" slot is /// only used if there are no affine slots available to allocate from. This /// means that the set of slots used over the lifetime of a program is the /// same as the maximum concurrent number of wasm instances. From e69a7f732e3e32976316b6b2d50fa721adfdb6ff Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Thu, 21 Sep 2023 19:55:43 +0200 Subject: [PATCH 12/14] Do proper type checking for type handles. (#7065) Instead of relying purely on the assumption that type handles can be compared cheaply by pointer equality, fallback to a more expensive walk of the type tree that recursively compares types structurally. This allows different components to call into each other as long as their types are structurally equivalent. Signed-off-by: Ryan Levick --- crates/wasmtime/src/component/types.rs | 269 +++++++++++++++++++++++-- tests/all/component_model/import.rs | 61 ++++++ 2 files changed, 310 insertions(+), 20 deletions(-) diff --git a/crates/wasmtime/src/component/types.rs b/crates/wasmtime/src/component/types.rs index 43b25a13c388..f751079fa037 100644 --- a/crates/wasmtime/src/component/types.rs +++ b/crates/wasmtime/src/component/types.rs @@ -9,8 +9,8 @@ use std::ops::Deref; use std::sync::Arc; use wasmtime_environ::component::{ CanonicalAbiInfo, ComponentTypes, InterfaceType, ResourceIndex, TypeEnumIndex, TypeFlagsIndex, - TypeListIndex, TypeOptionIndex, TypeRecordIndex, TypeResultIndex, TypeTupleIndex, - TypeVariantIndex, + TypeListIndex, TypeOptionIndex, TypeRecordIndex, TypeResourceTableIndex, TypeResultIndex, + TypeTupleIndex, TypeVariantIndex, }; use wasmtime_environ::PrimaryMap; @@ -56,6 +56,29 @@ impl Handle { resources: &self.resources, } } + + fn equivalent<'a>( + &'a self, + other: &'a Self, + type_check: fn(&TypeChecker<'a>, T, T) -> bool, + ) -> bool + where + T: PartialEq + Copy, + { + (self.index == other.index + && Arc::ptr_eq(&self.types, &other.types) + && Arc::ptr_eq(&self.resources, &other.resources)) + || type_check( + &TypeChecker { + a_types: &self.types, + b_types: &other.types, + a_resource: &self.resources, + b_resource: &other.resources, + }, + self.index, + other.index, + ) + } } impl fmt::Debug for Handle { @@ -66,23 +89,173 @@ impl fmt::Debug for Handle { } } -impl PartialEq for Handle { - fn eq(&self, other: &Self) -> bool { - // FIXME: This is an overly-restrictive definition of equality in that it doesn't consider types to be - // equal unless they refer to the same declaration in the same component. It's a good shortcut for the - // common case, but we should also do a recursive structural equality test if the shortcut test fails. - self.index == other.index - && Arc::ptr_eq(&self.types, &other.types) - && Arc::ptr_eq(&self.resources, &other.resources) - } +/// Type checker between two `Handle`s +struct TypeChecker<'a> { + a_types: &'a ComponentTypes, + a_resource: &'a PrimaryMap, + b_types: &'a ComponentTypes, + b_resource: &'a PrimaryMap, } -impl Eq for Handle {} +impl TypeChecker<'_> { + fn interface_types_equal(&self, a: InterfaceType, b: InterfaceType) -> bool { + match (a, b) { + (InterfaceType::Own(o1), InterfaceType::Own(o2)) => self.resources_equal(o1, o2), + (InterfaceType::Own(_), _) => false, + (InterfaceType::Borrow(b1), InterfaceType::Borrow(b2)) => self.resources_equal(b1, b2), + (InterfaceType::Borrow(_), _) => false, + (InterfaceType::List(l1), InterfaceType::List(l2)) => self.lists_equal(l1, l2), + (InterfaceType::List(_), _) => false, + (InterfaceType::Record(r1), InterfaceType::Record(r2)) => self.records_equal(r1, r2), + (InterfaceType::Record(_), _) => false, + (InterfaceType::Variant(v1), InterfaceType::Variant(v2)) => self.variants_equal(v1, v2), + (InterfaceType::Variant(_), _) => false, + (InterfaceType::Result(r1), InterfaceType::Result(r2)) => self.results_equal(r1, r2), + (InterfaceType::Result(_), _) => false, + (InterfaceType::Option(o1), InterfaceType::Option(o2)) => self.options_equal(o1, o2), + (InterfaceType::Option(_), _) => false, + (InterfaceType::Enum(e1), InterfaceType::Enum(e2)) => self.enums_equal(e1, e2), + (InterfaceType::Enum(_), _) => false, + (InterfaceType::Tuple(t1), InterfaceType::Tuple(t2)) => self.tuples_equal(t1, t2), + (InterfaceType::Tuple(_), _) => false, + (InterfaceType::Flags(f1), InterfaceType::Flags(f2)) => self.flags_equal(f1, f2), + (InterfaceType::Flags(_), _) => false, + (InterfaceType::Bool, InterfaceType::Bool) => true, + (InterfaceType::Bool, _) => false, + (InterfaceType::U8, InterfaceType::U8) => true, + (InterfaceType::U8, _) => false, + (InterfaceType::U16, InterfaceType::U16) => true, + (InterfaceType::U16, _) => false, + (InterfaceType::U32, InterfaceType::U32) => true, + (InterfaceType::U32, _) => false, + (InterfaceType::U64, InterfaceType::U64) => true, + (InterfaceType::U64, _) => false, + (InterfaceType::S8, InterfaceType::S8) => true, + (InterfaceType::S8, _) => false, + (InterfaceType::S16, InterfaceType::S16) => true, + (InterfaceType::S16, _) => false, + (InterfaceType::S32, InterfaceType::S32) => true, + (InterfaceType::S32, _) => false, + (InterfaceType::S64, InterfaceType::S64) => true, + (InterfaceType::S64, _) => false, + (InterfaceType::Float32, InterfaceType::Float32) => true, + (InterfaceType::Float32, _) => false, + (InterfaceType::Float64, InterfaceType::Float64) => true, + (InterfaceType::Float64, _) => false, + (InterfaceType::String, InterfaceType::String) => true, + (InterfaceType::String, _) => false, + (InterfaceType::Char, InterfaceType::Char) => true, + (InterfaceType::Char, _) => false, + } + } + + fn lists_equal(&self, l1: TypeListIndex, l2: TypeListIndex) -> bool { + let a = &self.a_types[l1]; + let b = &self.b_types[l2]; + self.interface_types_equal(a.element, b.element) + } + + fn resources_equal(&self, o1: TypeResourceTableIndex, o2: TypeResourceTableIndex) -> bool { + let a = &self.a_types[o1]; + let b = &self.b_types[o2]; + self.a_resource[a.ty] == self.b_resource[b.ty] + } + + fn records_equal(&self, r1: TypeRecordIndex, r2: TypeRecordIndex) -> bool { + let a = &self.a_types[r1]; + let b = &self.b_types[r2]; + if a.fields.len() != b.fields.len() { + return false; + } + a.fields + .iter() + .zip(b.fields.iter()) + .all(|(a_field, b_field)| { + a_field.name == b_field.name && self.interface_types_equal(a_field.ty, b_field.ty) + }) + } + + fn variants_equal(&self, v1: TypeVariantIndex, v2: TypeVariantIndex) -> bool { + let a = &self.a_types[v1]; + let b = &self.b_types[v2]; + if a.cases.len() != b.cases.len() { + return false; + } + a.cases.iter().zip(b.cases.iter()).all(|(a_case, b_case)| { + if a_case.name != b_case.name { + return false; + } + match (a_case.ty, b_case.ty) { + (Some(a_case_ty), Some(b_case_ty)) => { + self.interface_types_equal(a_case_ty, b_case_ty) + } + (None, None) => true, + _ => false, + } + }) + } + + fn results_equal(&self, r1: TypeResultIndex, r2: TypeResultIndex) -> bool { + let a = &self.a_types[r1]; + let b = &self.b_types[r2]; + let oks = match (a.ok, b.ok) { + (Some(ok1), Some(ok2)) => self.interface_types_equal(ok1, ok2), + (None, None) => true, + _ => false, + }; + if !oks { + return false; + } + match (a.err, b.err) { + (Some(err1), Some(err2)) => self.interface_types_equal(err1, err2), + (None, None) => true, + _ => false, + } + } + + fn options_equal(&self, o1: TypeOptionIndex, o2: TypeOptionIndex) -> bool { + let a = &self.a_types[o1]; + let b = &self.b_types[o2]; + self.interface_types_equal(a.ty, b.ty) + } + + fn enums_equal(&self, e1: TypeEnumIndex, e2: TypeEnumIndex) -> bool { + let a = &self.a_types[e1]; + let b = &self.b_types[e2]; + a.names == b.names + } + + fn tuples_equal(&self, t1: TypeTupleIndex, t2: TypeTupleIndex) -> bool { + let a = &self.a_types[t1]; + let b = &self.b_types[t2]; + if a.types.len() != b.types.len() { + return false; + } + a.types + .iter() + .zip(b.types.iter()) + .all(|(&a, &b)| self.interface_types_equal(a, b)) + } + + fn flags_equal(&self, f1: TypeFlagsIndex, f2: TypeFlagsIndex) -> bool { + let a = &self.a_types[f1]; + let b = &self.b_types[f2]; + a.names == b.names + } +} /// A `list` interface type -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Clone, Debug)] pub struct List(Handle); +impl PartialEq for List { + fn eq(&self, other: &Self) -> bool { + self.0.equivalent(&other.0, TypeChecker::lists_equal) + } +} + +impl Eq for List {} + impl List { /// Instantiate this type with the specified `values`. pub fn new_val(&self, values: Box<[Val]>) -> Result { @@ -108,7 +281,7 @@ pub struct Field<'a> { } /// A `record` interface type -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Clone, Debug)] pub struct Record(Handle); impl Record { @@ -130,8 +303,16 @@ impl Record { } } +impl PartialEq for Record { + fn eq(&self, other: &Self) -> bool { + self.0.equivalent(&other.0, TypeChecker::records_equal) + } +} + +impl Eq for Record {} + /// A `tuple` interface type -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Clone, Debug)] pub struct Tuple(Handle); impl Tuple { @@ -153,6 +334,14 @@ impl Tuple { } } +impl PartialEq for Tuple { + fn eq(&self, other: &Self) -> bool { + self.0.equivalent(&other.0, TypeChecker::tuples_equal) + } +} + +impl Eq for Tuple {} + /// A case declaration belonging to a `variant` pub struct Case<'a> { /// The name of the case @@ -162,7 +351,7 @@ pub struct Case<'a> { } /// A `variant` interface type -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Clone, Debug)] pub struct Variant(Handle); impl Variant { @@ -187,8 +376,16 @@ impl Variant { } } +impl PartialEq for Variant { + fn eq(&self, other: &Self) -> bool { + self.0.equivalent(&other.0, TypeChecker::variants_equal) + } +} + +impl Eq for Variant {} + /// An `enum` interface type -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Clone, Debug)] pub struct Enum(Handle); impl Enum { @@ -210,8 +407,16 @@ impl Enum { } } +impl PartialEq for Enum { + fn eq(&self, other: &Self) -> bool { + self.0.equivalent(&other.0, TypeChecker::enums_equal) + } +} + +impl Eq for Enum {} + /// An `option` interface type -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Clone, Debug)] pub struct OptionType(Handle); impl OptionType { @@ -230,8 +435,16 @@ impl OptionType { } } +impl PartialEq for OptionType { + fn eq(&self, other: &Self) -> bool { + self.0.equivalent(&other.0, TypeChecker::options_equal) + } +} + +impl Eq for OptionType {} + /// An `expected` interface type -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Clone, Debug)] pub struct ResultType(Handle); impl ResultType { @@ -261,8 +474,16 @@ impl ResultType { } } +impl PartialEq for ResultType { + fn eq(&self, other: &Self) -> bool { + self.0.equivalent(&other.0, TypeChecker::results_equal) + } +} + +impl Eq for ResultType {} + /// A `flags` interface type -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Clone, Debug)] pub struct Flags(Handle); impl Flags { @@ -288,6 +509,14 @@ impl Flags { } } +impl PartialEq for Flags { + fn eq(&self, other: &Self) -> bool { + self.0.equivalent(&other.0, TypeChecker::flags_equal) + } +} + +impl Eq for Flags {} + /// Represents a component model interface type #[derive(Clone, PartialEq, Eq, Debug)] #[allow(missing_docs)] diff --git a/tests/all/component_model/import.rs b/tests/all/component_model/import.rs index 38e23edd4f95..e77b9a73433b 100644 --- a/tests/all/component_model/import.rs +++ b/tests/all/component_model/import.rs @@ -924,3 +924,64 @@ fn no_actual_wasm_code() -> Result<()> { Ok(()) } + +#[test] +fn use_types_across_component_boundaries() -> Result<()> { + // Create a component that exports a function that returns a record + let engine = super::engine(); + let component = Component::new( + &engine, + r#"(component + (type (;0;) (record (field "a" u8) (field "b" string))) + (import "my-record" (type $my-record (eq 0))) + (core module $m + (memory $memory 17) + (export "memory" (memory $memory)) + (func (export "my-func") (result i32) + i32.const 4 + return)) + (core instance $instance (instantiate $m)) + (type $func-type (func (result $my-record))) + (alias core export $instance "my-func" (core func $my-func)) + (alias core export $instance "memory" (core memory $memory)) + (func $my-func (type $func-type) (canon lift (core func $my-func) (memory $memory) string-encoding=utf8)) + (export $export "my-func" (func $my-func)) + )"#, + )?; + let mut store = Store::new(&engine, 0); + let linker = Linker::new(&engine); + let instance = linker.instantiate(&mut store, &component)?; + let my_func = instance.get_func(&mut store, "my-func").unwrap(); + let mut results = vec![Val::Bool(false)]; + my_func.call(&mut store, &[], &mut results)?; + + // Create another component that exports a function that takes that record as an argument + let component = Component::new( + &engine, + format!( + r#"(component + (type (;0;) (record (field "a" u8) (field "b" string))) + (import "my-record" (type $my-record (eq 0))) + (core module $m + (memory $memory 17) + (export "memory" (memory $memory)) + {REALLOC_AND_FREE} + (func (export "my-func") (param i32 i32 i32))) + (core instance $instance (instantiate $m)) + (type $func-type (func (param "my-record" $my-record))) + (alias core export $instance "my-func" (core func $my-func)) + (alias core export $instance "memory" (core memory $memory)) + (func $my-func (type $func-type) (canon lift (core func $my-func) (memory $memory) string-encoding=utf8 (realloc (func $instance "realloc")))) + (export $export "my-func" (func $my-func)) + )"# + ), + )?; + let mut store = Store::new(&engine, 0); + let linker = Linker::new(&engine); + let instance = linker.instantiate(&mut store, &component)?; + let my_func = instance.get_func(&mut store, "my-func").unwrap(); + // Call the exported function with the return values of the call to the previous component's exported function + my_func.call(&mut store, &results, &mut [])?; + + Ok(()) +} From 30ee0dc8d990cb6126ede5c9ca7bc6d8000b44e5 Mon Sep 17 00:00:00 2001 From: Trevor Elliott Date: Thu, 21 Sep 2023 16:04:52 -0700 Subject: [PATCH 13/14] Use wasi-streams in the wasi-http implementation (#7056) * Start refactoring wasi-http * Checkpoint * Initial implementation of response future handling * Lazily initialize response headers and body * make wasmtime-wasi-http compile * wasi-http wit: make a way to reject outgoing-request in outgoing-handler before waiting for the future to resolve * wasi: sync wit from wasi-http * outgoing handler impl: report errors to userland * test-programs: get wasi-http-components kicking over, delete modules and components-sync tests wasi-http-components-sync will come back once we get done with other stuff, but its superfulous for now. wasi-http-modules will not be returning. * Process headers * Add HostIncomingBody::new * Add trailers functions * Add TODO for body task outline * Rework incoming-response-consume to return a future-trailers value as well * Fix the wit * First cut at the worker loop * wasi-http: change how we represent bodies/trailers, and annotate own/borrow/child throughout * Update types_impl.rs for wit changes * Split body management into its own module * Checkpoint * more work on incoming body and future trailers * Fill out some more functions * Implement future-trailers-{subscribe,get} * Implement drop-future-trailers * Rework fields, but make the borrow checker mad * Fix borrow error * wasi-http-tests: fix build * test-runner: report errors with stdout/stderr properly * fix two trivial wasi-http tests the error type here changed from a types::Error to an outbound_handler::Error * Remove unnecessary drops * Convert a `bail!` to a `todo!` * Remove a TODO that documented the body worker structure * fill in a bunch more of OutputBody * Remove the custom FrameFut future in favor of using http_body_util * Move the outgoing body types to body.rs * Rework the handling of outgoing bodies * Fix the `outgoing request get` test * Avoid deadlocking the post tests * future_incoming_request_get shouldn't delete the resource * Fix the invalid_dnsname test * implement drop-future-incoming-response * Fix invalid_port and invalid_dnsname tests * Fix the post test * Passing a too large string to println! caused the large post test to fail * Format * Plumb through `between_bytes_timeout` * Downgrade hyper * Revert "Downgrade hyper" This reverts commit fa0750e0c42823cb785288fcf6c0507c5ae9fd64. * Restore old https connection setup * Sync the wasi and wasi-http http deps * Fix tests * Remove the module and component-sync tests, as they are currently not supported * Fix the reference to the large_post test in the components test * Fix wasi-http integration * sync implementation of wasi-http * Slightly more robust error checking * Ignore the wasi-http cli test prtest:full * Consistent ignore attributes between sync and async tests * Fix doc errors * code motion: introduce intermediate `HostIncomingBodyBuilder` rather than a tuple * explain design * Turn FieldMap into a type synonym * Tidy up some future state (#7073) Co-authored-by: Pat Hickey * body HostInputStream: report runtime errors with StreamRuntimeError HostInputStream is designed wrong to need that in the first place. We will fix it in a follow-up as soon as resources land. --------- Co-authored-by: Pat Hickey Co-authored-by: Alex Crichton --- .../tests/wasi-http-components-sync.rs | 13 +- .../tests/wasi-http-components.rs | 18 +- .../test-programs/tests/wasi-http-modules.rs | 196 ---- .../bin/outbound_request_invalid_dnsname.rs | 4 +- .../bin/outbound_request_invalid_version.rs | 2 +- .../src/bin/outbound_request_large_post.rs | 2 +- .../src/bin/outbound_request_post.rs | 2 +- .../bin/outbound_request_unknown_method.rs | 2 +- .../outbound_request_unsupported_scheme.rs | 2 +- .../test-programs/wasi-http-tests/src/lib.rs | 83 +- crates/wasi-http/src/body.rs | 498 ++++++++++ crates/wasi-http/src/http_impl.rs | 450 ++++----- crates/wasi-http/src/incoming_handler.rs | 21 +- crates/wasi-http/src/lib.rs | 59 +- crates/wasi-http/src/proxy.rs | 38 +- crates/wasi-http/src/types.rs | 593 +++++------- crates/wasi-http/src/types_impl.rs | 881 +++++++----------- .../wit/deps/http/incoming-handler.wit | 6 +- .../wit/deps/http/outgoing-handler.wit | 13 +- crates/wasi-http/wit/deps/http/types.wit | 148 ++- crates/wasi/src/preview2/mod.rs | 5 +- crates/wasi/src/preview2/pipe.rs | 195 +--- crates/wasi/src/preview2/table.rs | 11 +- crates/wasi/src/preview2/write_stream.rs | 196 ++++ .../wasi/wit/deps/http/incoming-handler.wit | 6 +- .../wasi/wit/deps/http/outgoing-handler.wit | 13 +- crates/wasi/wit/deps/http/types.wit | 148 ++- src/commands/run.rs | 24 +- tests/all/cli_tests.rs | 1 + 29 files changed, 1728 insertions(+), 1902 deletions(-) delete mode 100644 crates/test-programs/tests/wasi-http-modules.rs create mode 100644 crates/wasi-http/src/body.rs create mode 100644 crates/wasi/src/preview2/write_stream.rs diff --git a/crates/test-programs/tests/wasi-http-components-sync.rs b/crates/test-programs/tests/wasi-http-components-sync.rs index f525d5b44724..33c638fc6883 100644 --- a/crates/test-programs/tests/wasi-http-components-sync.rs +++ b/crates/test-programs/tests/wasi-http-components-sync.rs @@ -46,12 +46,13 @@ impl WasiView for Ctx { } impl WasiHttpView for Ctx { - fn http_ctx(&self) -> &WasiHttpCtx { - &self.http - } - fn http_ctx_mut(&mut self) -> &mut WasiHttpCtx { + fn ctx(&mut self) -> &mut WasiHttpCtx { &mut self.http } + + fn table(&mut self) -> &mut Table { + &mut self.table + } } fn instantiate_component( @@ -60,7 +61,7 @@ fn instantiate_component( ) -> Result<(Store, Command), anyhow::Error> { let mut linker = Linker::new(&ENGINE); add_to_linker(&mut linker)?; - wasmtime_wasi_http::proxy::sync::add_to_linker(&mut linker)?; + wasmtime_wasi_http::proxy::add_to_linker(&mut linker)?; let mut store = Store::new(&ENGINE, ctx); @@ -84,7 +85,7 @@ fn run(name: &str) -> anyhow::Result<()> { builder.env(var, val); } let wasi = builder.build(&mut table)?; - let http = WasiHttpCtx::new(); + let http = WasiHttpCtx {}; let (mut store, command) = instantiate_component(component, Ctx { table, wasi, http })?; command diff --git a/crates/test-programs/tests/wasi-http-components.rs b/crates/test-programs/tests/wasi-http-components.rs index 1f05411d0178..8d1fa1f96c52 100644 --- a/crates/test-programs/tests/wasi-http-components.rs +++ b/crates/test-programs/tests/wasi-http-components.rs @@ -47,10 +47,10 @@ impl WasiView for Ctx { } impl WasiHttpView for Ctx { - fn http_ctx(&self) -> &WasiHttpCtx { - &self.http + fn table(&mut self) -> &mut Table { + &mut self.table } - fn http_ctx_mut(&mut self) -> &mut WasiHttpCtx { + fn ctx(&mut self) -> &mut WasiHttpCtx { &mut self.http } } @@ -85,16 +85,11 @@ async fn run(name: &str) -> anyhow::Result<()> { builder.env(var, val); } let wasi = builder.build(&mut table)?; - let http = WasiHttpCtx::new(); + let http = WasiHttpCtx; let (mut store, command) = instantiate_component(component, Ctx { table, wasi, http }).await?; - command - .wasi_cli_run() - .call_run(&mut store) - .await? - .map_err(|()| anyhow::anyhow!("run returned a failure"))?; - Ok(()) + command.wasi_cli_run().call_run(&mut store).await }; r.map_err(move |trap: anyhow::Error| { let stdout = stdout.try_into_inner().expect("single ref to stdout"); @@ -109,7 +104,8 @@ async fn run(name: &str) -> anyhow::Result<()> { "error while testing wasi-tests {} with http-components", name )) - })?; + })? + .map_err(|()| anyhow::anyhow!("run returned an error"))?; Ok(()) } diff --git a/crates/test-programs/tests/wasi-http-modules.rs b/crates/test-programs/tests/wasi-http-modules.rs deleted file mode 100644 index e73251d5a1f8..000000000000 --- a/crates/test-programs/tests/wasi-http-modules.rs +++ /dev/null @@ -1,196 +0,0 @@ -#![cfg(all(feature = "test_programs", not(skip_wasi_http_tests)))] -use wasmtime::{Config, Engine, Func, Linker, Module, Store}; -use wasmtime_wasi::preview2::{ - pipe::MemoryOutputPipe, - preview1::{WasiPreview1Adapter, WasiPreview1View}, - IsATTY, Table, WasiCtx, WasiCtxBuilder, WasiView, -}; -use wasmtime_wasi_http::{WasiHttpCtx, WasiHttpView}; - -use test_programs::http_server::{setup_http1, setup_http2}; - -lazy_static::lazy_static! { - static ref ENGINE: Engine = { - let mut config = Config::new(); - config.wasm_backtrace_details(wasmtime::WasmBacktraceDetails::Enable); - config.wasm_component_model(true); - config.async_support(true); - - let engine = Engine::new(&config).unwrap(); - engine - }; -} -// uses ENGINE, creates a fn get_module(&str) -> Module -include!(concat!(env!("OUT_DIR"), "/wasi_http_tests_modules.rs")); - -struct Ctx { - table: Table, - wasi: WasiCtx, - adapter: WasiPreview1Adapter, - http: WasiHttpCtx, -} - -impl WasiView for Ctx { - fn table(&self) -> &Table { - &self.table - } - fn table_mut(&mut self) -> &mut Table { - &mut self.table - } - fn ctx(&self) -> &WasiCtx { - &self.wasi - } - fn ctx_mut(&mut self) -> &mut WasiCtx { - &mut self.wasi - } -} -impl WasiPreview1View for Ctx { - fn adapter(&self) -> &WasiPreview1Adapter { - &self.adapter - } - fn adapter_mut(&mut self) -> &mut WasiPreview1Adapter { - &mut self.adapter - } -} -impl WasiHttpView for Ctx { - fn http_ctx(&self) -> &WasiHttpCtx { - &self.http - } - fn http_ctx_mut(&mut self) -> &mut WasiHttpCtx { - &mut self.http - } -} - -async fn instantiate_module(module: Module, ctx: Ctx) -> Result<(Store, Func), anyhow::Error> { - let mut linker = Linker::new(&ENGINE); - wasmtime_wasi_http::add_to_linker(&mut linker)?; - wasmtime_wasi::preview2::preview1::add_to_linker_async(&mut linker)?; - - let mut store = Store::new(&ENGINE, ctx); - - let instance = linker.instantiate_async(&mut store, &module).await?; - let command = instance.get_func(&mut store, "_start").unwrap(); - Ok((store, command)) -} - -async fn run(name: &str) -> anyhow::Result<()> { - let stdout = MemoryOutputPipe::new(4096); - let stderr = MemoryOutputPipe::new(4096); - let r = { - let mut table = Table::new(); - let module = get_module(name); - - // Create our wasi context. - let mut builder = WasiCtxBuilder::new(); - builder.stdout(stdout.clone(), IsATTY::No); - builder.stderr(stderr.clone(), IsATTY::No); - builder.arg(name); - for (var, val) in test_programs::wasi_tests_environment() { - builder.env(var, val); - } - let wasi = builder.build(&mut table)?; - let http = WasiHttpCtx::new(); - - let adapter = WasiPreview1Adapter::new(); - - let (mut store, command) = instantiate_module( - module, - Ctx { - table, - wasi, - http, - adapter, - }, - ) - .await?; - command.call_async(&mut store, &[], &mut []).await - }; - r.map_err(move |trap: anyhow::Error| { - let stdout = stdout.try_into_inner().expect("single ref to stdout"); - if !stdout.is_empty() { - println!("[guest] stdout:\n{}\n===", String::from_utf8_lossy(&stdout)); - } - let stderr = stderr.try_into_inner().expect("single ref to stderr"); - if !stderr.is_empty() { - println!("[guest] stderr:\n{}\n===", String::from_utf8_lossy(&stderr)); - } - trap.context(format!( - "error while testing wasi-tests {} with http-modules", - name - )) - })?; - Ok(()) -} - -#[test_log::test(tokio::test(flavor = "multi_thread"))] -#[cfg_attr( - windows, - ignore = "test is currently flaky in ci and needs to be debugged" -)] -async fn outbound_request_get() { - setup_http1(run("outbound_request_get")).await.unwrap(); -} - -#[test_log::test(tokio::test(flavor = "multi_thread"))] -#[cfg_attr( - windows, - ignore = "test is currently flaky in ci and needs to be debugged" -)] -async fn outbound_request_post() { - setup_http1(run("outbound_request_post")).await.unwrap(); -} - -#[test_log::test(tokio::test(flavor = "multi_thread"))] -#[cfg_attr( - windows, - ignore = "test is currently flaky in ci and needs to be debugged" -)] -async fn outbound_request_large_post() { - setup_http1(run("outbound_request_large_post")) - .await - .unwrap(); -} - -#[test_log::test(tokio::test(flavor = "multi_thread"))] -#[cfg_attr( - windows, - ignore = "test is currently flaky in ci and needs to be debugged" -)] -async fn outbound_request_put() { - setup_http1(run("outbound_request_put")).await.unwrap(); -} - -#[test_log::test(tokio::test(flavor = "multi_thread"))] -#[cfg_attr( - windows, - ignore = "test is currently flaky in ci and needs to be debugged" -)] -async fn outbound_request_invalid_version() { - setup_http2(run("outbound_request_invalid_version")) - .await - .unwrap(); -} - -#[test_log::test(tokio::test(flavor = "multi_thread"))] -async fn outbound_request_unknown_method() { - run("outbound_request_unknown_method").await.unwrap(); -} - -#[test_log::test(tokio::test(flavor = "multi_thread"))] -async fn outbound_request_unsupported_scheme() { - run("outbound_request_unsupported_scheme").await.unwrap(); -} - -#[test_log::test(tokio::test(flavor = "multi_thread"))] -async fn outbound_request_invalid_port() { - run("outbound_request_invalid_port").await.unwrap(); -} - -#[test_log::test(tokio::test(flavor = "multi_thread"))] -#[cfg_attr( - windows, - ignore = "test is currently flaky in ci and needs to be debugged" -)] -async fn outbound_request_invalid_dnsname() { - run("outbound_request_invalid_dnsname").await.unwrap(); -} diff --git a/crates/test-programs/wasi-http-tests/src/bin/outbound_request_invalid_dnsname.rs b/crates/test-programs/wasi-http-tests/src/bin/outbound_request_invalid_dnsname.rs index 7c3ba5909b8f..6cc5ce8bc97b 100644 --- a/crates/test-programs/wasi-http-tests/src/bin/outbound_request_invalid_dnsname.rs +++ b/crates/test-programs/wasi-http-tests/src/bin/outbound_request_invalid_dnsname.rs @@ -15,6 +15,6 @@ async fn run() { ) .await; - let error = res.unwrap_err(); - assert_eq!(error.to_string(), "Error::InvalidUrl(\"invalid dnsname\")"); + let error = res.unwrap_err().to_string(); + assert!(error.starts_with("Error::InvalidUrl(\"failed to lookup address information:")); } diff --git a/crates/test-programs/wasi-http-tests/src/bin/outbound_request_invalid_version.rs b/crates/test-programs/wasi-http-tests/src/bin/outbound_request_invalid_version.rs index cadabda4dec1..53c767edec8f 100644 --- a/crates/test-programs/wasi-http-tests/src/bin/outbound_request_invalid_version.rs +++ b/crates/test-programs/wasi-http-tests/src/bin/outbound_request_invalid_version.rs @@ -17,7 +17,7 @@ async fn run() { let error = res.unwrap_err().to_string(); if error.ne("Error::ProtocolError(\"invalid HTTP version parsed\")") - && error.ne("Error::ProtocolError(\"operation was canceled\")") + && !error.starts_with("Error::ProtocolError(\"operation was canceled") { panic!( r#"assertion failed: `(left == right)` diff --git a/crates/test-programs/wasi-http-tests/src/bin/outbound_request_large_post.rs b/crates/test-programs/wasi-http-tests/src/bin/outbound_request_large_post.rs index 9f4bacef8f73..80e0688d47fc 100644 --- a/crates/test-programs/wasi-http-tests/src/bin/outbound_request_large_post.rs +++ b/crates/test-programs/wasi-http-tests/src/bin/outbound_request_large_post.rs @@ -23,7 +23,7 @@ async fn run() { .context("localhost:3000 /post large") .unwrap(); - println!("localhost:3000 /post large: {res:?}"); + println!("localhost:3000 /post large: {}", res.status); assert_eq!(res.status, 200); let method = res.header("x-wasmtime-test-method").unwrap(); assert_eq!(std::str::from_utf8(method).unwrap(), "POST"); diff --git a/crates/test-programs/wasi-http-tests/src/bin/outbound_request_post.rs b/crates/test-programs/wasi-http-tests/src/bin/outbound_request_post.rs index 69a03d754ee9..131356fa91bc 100644 --- a/crates/test-programs/wasi-http-tests/src/bin/outbound_request_post.rs +++ b/crates/test-programs/wasi-http-tests/src/bin/outbound_request_post.rs @@ -22,5 +22,5 @@ async fn run() { assert_eq!(res.status, 200); let method = res.header("x-wasmtime-test-method").unwrap(); assert_eq!(std::str::from_utf8(method).unwrap(), "POST"); - assert_eq!(res.body, b"{\"foo\": \"bar\"}"); + assert_eq!(res.body, b"{\"foo\": \"bar\"}", "invalid body returned"); } diff --git a/crates/test-programs/wasi-http-tests/src/bin/outbound_request_unknown_method.rs b/crates/test-programs/wasi-http-tests/src/bin/outbound_request_unknown_method.rs index a2ab5e48dc02..727294861b00 100644 --- a/crates/test-programs/wasi-http-tests/src/bin/outbound_request_unknown_method.rs +++ b/crates/test-programs/wasi-http-tests/src/bin/outbound_request_unknown_method.rs @@ -18,6 +18,6 @@ async fn run() { let error = res.unwrap_err(); assert_eq!( error.to_string(), - "Error::InvalidUrl(\"unknown method OTHER\")" + "Error::Invalid(\"unknown method OTHER\")" ); } diff --git a/crates/test-programs/wasi-http-tests/src/bin/outbound_request_unsupported_scheme.rs b/crates/test-programs/wasi-http-tests/src/bin/outbound_request_unsupported_scheme.rs index 482550627e8d..c8a66b7da648 100644 --- a/crates/test-programs/wasi-http-tests/src/bin/outbound_request_unsupported_scheme.rs +++ b/crates/test-programs/wasi-http-tests/src/bin/outbound_request_unsupported_scheme.rs @@ -18,6 +18,6 @@ async fn run() { let error = res.unwrap_err(); assert_eq!( error.to_string(), - "Error::InvalidUrl(\"unsupported scheme WS\")" + "Error::Invalid(\"unsupported scheme WS\")" ); } diff --git a/crates/test-programs/wasi-http-tests/src/lib.rs b/crates/test-programs/wasi-http-tests/src/lib.rs index e1d9fb76dd6c..d516af03da16 100644 --- a/crates/test-programs/wasi-http-tests/src/lib.rs +++ b/crates/test-programs/wasi-http-tests/src/lib.rs @@ -42,29 +42,22 @@ impl Response { } } -struct DropPollable { - pollable: poll::Pollable, -} - -impl Drop for DropPollable { - fn drop(&mut self) { - poll::drop_pollable(self.pollable); - } -} - pub async fn request( method: http_types::Method, scheme: http_types::Scheme, authority: &str, path_with_query: &str, body: Option<&[u8]>, - additional_headers: Option<&[(String, String)]>, + additional_headers: Option<&[(String, Vec)]>, ) -> Result { + fn header_val(v: &str) -> Vec { + v.to_string().into_bytes() + } let headers = http_types::new_fields( &[ &[ - ("User-agent".to_string(), "WASI-HTTP/0.0.1".to_string()), - ("Content-type".to_string(), "application/json".to_string()), + ("User-agent".to_string(), header_val("WASI-HTTP/0.0.1")), + ("Content-type".to_string(), header_val("application/json")), ], additional_headers.unwrap_or(&[]), ] @@ -79,15 +72,16 @@ pub async fn request( headers, ); - let request_body = http_types::outgoing_request_write(request) + let outgoing_body = http_types::outgoing_request_write(request) .map_err(|_| anyhow!("outgoing request write failed"))?; if let Some(mut buf) = body { - let sub = DropPollable { - pollable: streams::subscribe_to_output_stream(request_body), - }; + let request_body = http_types::outgoing_body_write(outgoing_body) + .map_err(|_| anyhow!("outgoing request write failed"))?; + + let pollable = streams::subscribe_to_output_stream(request_body); while !buf.is_empty() { - poll::poll_oneoff(&[sub.pollable]); + poll::poll_oneoff(&[pollable]); let permit = match streams::check_write(request_body) { Ok(n) => n, @@ -109,36 +103,39 @@ pub async fn request( _ => {} } - poll::poll_oneoff(&[sub.pollable]); + poll::poll_oneoff(&[pollable]); + poll::drop_pollable(pollable); match streams::check_write(request_body) { Ok(_) => {} Err(_) => anyhow::bail!("output stream error"), }; + + streams::drop_output_stream(request_body); } - let future_response = outgoing_handler::handle(request, None); + let future_response = outgoing_handler::handle(request, None)?; + + // TODO: The current implementation requires this drop after the request is sent. + // The ownership semantics are unclear in wasi-http we should clarify exactly what is + // supposed to happen here. + http_types::drop_outgoing_body(outgoing_body); let incoming_response = match http_types::future_incoming_response_get(future_response) { - Some(result) => result, + Some(result) => result.map_err(|_| anyhow!("incoming response errored"))?, None => { let pollable = http_types::listen_to_future_incoming_response(future_response); let _ = poll::poll_oneoff(&[pollable]); + poll::drop_pollable(pollable); http_types::future_incoming_response_get(future_response) .expect("incoming response available") + .map_err(|_| anyhow!("incoming response errored"))? } } // TODO: maybe anything that appears in the Result<_, E> position should impl // Error? anyway, just use its Debug here: .map_err(|e| anyhow!("{e:?}"))?; - // TODO: The current implementation requires this drop after the request is sent. - // The ownership semantics are unclear in wasi-http we should clarify exactly what is - // supposed to happen here. - streams::drop_output_stream(request_body); - - http_types::drop_outgoing_request(request); - http_types::drop_future_incoming_response(future_response); let status = http_types::incoming_response_status(incoming_response); @@ -147,26 +144,32 @@ pub async fn request( let headers = http_types::fields_entries(headers_handle); http_types::drop_fields(headers_handle); - let body_stream = http_types::incoming_response_consume(incoming_response) + let incoming_body = http_types::incoming_response_consume(incoming_response) .map_err(|()| anyhow!("incoming response has no body stream"))?; - let input_stream_pollable = streams::subscribe_to_input_stream(body_stream); + + http_types::drop_incoming_response(incoming_response); + + let input_stream = http_types::incoming_body_stream(incoming_body).unwrap(); + let input_stream_pollable = streams::subscribe_to_input_stream(input_stream); let mut body = Vec::new(); let mut eof = streams::StreamStatus::Open; while eof != streams::StreamStatus::Ended { - let (mut body_chunk, stream_status) = - streams::read(body_stream, u64::MAX).map_err(|_| anyhow!("body_stream read failed"))?; - eof = if body_chunk.is_empty() { - streams::StreamStatus::Ended - } else { - stream_status - }; - body.append(&mut body_chunk); + poll::poll_oneoff(&[input_stream_pollable]); + + let (mut body_chunk, stream_status) = streams::read(input_stream, 1024 * 1024) + .map_err(|_| anyhow!("input_stream read failed"))?; + + eof = stream_status; + + if !body_chunk.is_empty() { + body.append(&mut body_chunk); + } } poll::drop_pollable(input_stream_pollable); - streams::drop_input_stream(body_stream); - http_types::drop_incoming_response(incoming_response); + streams::drop_input_stream(input_stream); + http_types::drop_incoming_body(incoming_body); Ok(Response { status, diff --git a/crates/wasi-http/src/body.rs b/crates/wasi-http/src/body.rs new file mode 100644 index 000000000000..298fe7b92bfe --- /dev/null +++ b/crates/wasi-http/src/body.rs @@ -0,0 +1,498 @@ +use crate::{bindings::http::types, types::FieldMap}; +use anyhow::anyhow; +use bytes::Bytes; +use http_body_util::combinators::BoxBody; +use std::future::Future; +use std::{ + convert::Infallible, + pin::Pin, + sync::{Arc, Mutex}, + time::Duration, +}; +use tokio::sync::{mpsc, oneshot}; +use wasmtime_wasi::preview2::{ + self, AbortOnDropJoinHandle, HostInputStream, HostOutputStream, OutputStreamError, + StreamRuntimeError, StreamState, +}; + +/// Holds onto the things needed to construct a [`HostIncomingBody`] until we are ready to build +/// one. The HostIncomingBody spawns a task that starts consuming the incoming body, and we don't +/// want to do that unless the user asks to consume the body. +pub struct HostIncomingBodyBuilder { + pub body: hyper::body::Incoming, + pub between_bytes_timeout: Duration, +} + +impl HostIncomingBodyBuilder { + /// Consume the state held in the [`HostIncomingBodyBuilder`] to spawn a task that will drive the + /// streaming body to completion. Data segments will be communicated out over the + /// [`HostIncomingBodyStream`], and a [`HostFutureTrailers`] gives a way to block on/retrieve + /// the trailers. + pub fn build(mut self) -> HostIncomingBody { + let (body_writer, body_receiver) = mpsc::channel(1); + let (trailer_writer, trailers) = oneshot::channel(); + + let worker = preview2::spawn(async move { + loop { + let frame = match tokio::time::timeout( + self.between_bytes_timeout, + http_body_util::BodyExt::frame(&mut self.body), + ) + .await + { + Ok(None) => break, + + Ok(Some(Ok(frame))) => frame, + + Ok(Some(Err(e))) => { + match body_writer.send(Err(anyhow::anyhow!(e))).await { + Ok(_) => {} + // If the body read end has dropped, then we report this error with the + // trailers. unwrap and rewrap Err because the Ok side of these two Results + // are different. + Err(e) => { + let _ = trailer_writer.send(Err(e.0.unwrap_err())); + } + } + break; + } + + Err(_) => { + match body_writer + .send(Err(types::Error::TimeoutError( + "data frame timed out".to_string(), + ) + .into())) + .await + { + Ok(_) => {} + Err(e) => { + let _ = trailer_writer.send(Err(e.0.unwrap_err())); + } + } + break; + } + }; + + if frame.is_trailers() { + // We know we're not going to write any more data frames at this point, so we + // explicitly drop the body_writer so that anything waiting on the read end returns + // immediately. + drop(body_writer); + + let trailers = frame.into_trailers().unwrap(); + + // TODO: this will fail in two cases: + // 1. we've already used the channel once, which should be imposible, + // 2. the read end is closed. + // I'm not sure how to differentiate between these two cases, or really + // if we need to do anything to handle either. + let _ = trailer_writer.send(Ok(trailers)); + + break; + } + + assert!(frame.is_data(), "frame wasn't data"); + + let data = frame.into_data().unwrap(); + + // If the receiver no longer exists, thats ok - in that case we want to keep the + // loop running to relieve backpressure, so we get to the trailers. + let _ = body_writer.send(Ok(data)).await; + } + }); + + HostIncomingBody { + worker, + stream: Some(HostIncomingBodyStream::new(body_receiver)), + trailers, + } + } +} + +pub struct HostIncomingBody { + pub worker: AbortOnDropJoinHandle<()>, + pub stream: Option, + pub trailers: oneshot::Receiver>, +} + +impl HostIncomingBody { + pub fn into_future_trailers(self) -> HostFutureTrailers { + HostFutureTrailers { + _worker: self.worker, + state: HostFutureTrailersState::Waiting(self.trailers), + } + } +} + +pub struct HostIncomingBodyStream { + pub open: bool, + pub receiver: mpsc::Receiver>, + pub buffer: Bytes, + pub error: Option, +} + +impl HostIncomingBodyStream { + fn new(receiver: mpsc::Receiver>) -> Self { + Self { + open: true, + receiver, + buffer: Bytes::new(), + error: None, + } + } +} + +#[async_trait::async_trait] +impl HostInputStream for HostIncomingBodyStream { + fn read(&mut self, size: usize) -> anyhow::Result<(Bytes, StreamState)> { + use mpsc::error::TryRecvError; + + if !self.buffer.is_empty() { + let len = size.min(self.buffer.len()); + let chunk = self.buffer.split_to(len); + return Ok((chunk, StreamState::Open)); + } + + if let Some(e) = self.error.take() { + return Err(StreamRuntimeError::from(e).into()); + } + + if !self.open { + return Ok((Bytes::new(), StreamState::Closed)); + } + + match self.receiver.try_recv() { + Ok(Ok(mut bytes)) => { + let len = bytes.len().min(size); + let chunk = bytes.split_to(len); + if !bytes.is_empty() { + self.buffer = bytes; + } + + return Ok((chunk, StreamState::Open)); + } + + Ok(Err(e)) => { + self.open = false; + return Err(StreamRuntimeError::from(e).into()); + } + + Err(TryRecvError::Empty) => { + return Ok((Bytes::new(), StreamState::Open)); + } + + Err(TryRecvError::Disconnected) => { + self.open = false; + return Ok((Bytes::new(), StreamState::Closed)); + } + } + } + + async fn ready(&mut self) -> anyhow::Result<()> { + if !self.buffer.is_empty() { + return Ok(()); + } + + if !self.open { + return Ok(()); + } + + match self.receiver.recv().await { + Some(Ok(bytes)) => self.buffer = bytes, + + Some(Err(e)) => { + self.error = Some(e); + self.open = false; + } + + None => self.open = false, + } + + Ok(()) + } +} + +pub struct HostFutureTrailers { + _worker: AbortOnDropJoinHandle<()>, + pub state: HostFutureTrailersState, +} + +pub enum HostFutureTrailersState { + Waiting(oneshot::Receiver>), + Done(Result), +} + +impl HostFutureTrailers { + pub async fn ready(&mut self) -> anyhow::Result<()> { + if let HostFutureTrailersState::Waiting(rx) = &mut self.state { + let result = match rx.await { + Ok(Ok(headers)) => Ok(FieldMap::from(headers)), + Ok(Err(e)) => Err(types::Error::ProtocolError(format!("hyper error: {e:?}"))), + Err(_) => Err(types::Error::ProtocolError( + "stream hung up before trailers were received".to_string(), + )), + }; + self.state = HostFutureTrailersState::Done(result); + } + Ok(()) + } +} + +pub type HyperBody = BoxBody; + +pub struct HostOutgoingBody { + pub body_output_stream: Option>, + pub trailers_sender: Option>, +} + +impl HostOutgoingBody { + pub fn new() -> (Self, HyperBody) { + use http_body_util::BodyExt; + use hyper::{ + body::{Body, Frame}, + HeaderMap, + }; + use std::task::{Context, Poll}; + use tokio::sync::oneshot::error::RecvError; + struct BodyImpl { + body_receiver: mpsc::Receiver, + trailers_receiver: Option>, + } + impl Body for BodyImpl { + type Data = Bytes; + type Error = Infallible; + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + match self.as_mut().body_receiver.poll_recv(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Some(frame)) => Poll::Ready(Some(Ok(Frame::data(frame)))), + + // This means that the `body_sender` end of the channel has been dropped. + Poll::Ready(None) => { + if let Some(mut trailers_receiver) = self.as_mut().trailers_receiver.take() + { + match Pin::new(&mut trailers_receiver).poll(cx) { + Poll::Pending => { + self.as_mut().trailers_receiver = Some(trailers_receiver); + Poll::Pending + } + Poll::Ready(Ok(trailers)) => { + Poll::Ready(Some(Ok(Frame::trailers(trailers)))) + } + Poll::Ready(Err(RecvError { .. })) => Poll::Ready(None), + } + } else { + Poll::Ready(None) + } + } + } + } + } + + let (body_sender, body_receiver) = mpsc::channel(1); + let (trailers_sender, trailers_receiver) = oneshot::channel(); + let body_impl = BodyImpl { + body_receiver, + trailers_receiver: Some(trailers_receiver), + } + .boxed(); + ( + Self { + // TODO: this capacity constant is arbitrary, and should be configurable + body_output_stream: Some(Box::new(BodyWriteStream::new(1024 * 1024, body_sender))), + trailers_sender: Some(trailers_sender), + }, + body_impl, + ) + } +} + +// copied in from preview2::write_stream + +#[derive(Debug)] +struct WorkerState { + alive: bool, + items: std::collections::VecDeque, + write_budget: usize, + flush_pending: bool, + error: Option, +} + +impl WorkerState { + fn check_error(&mut self) -> Result<(), OutputStreamError> { + if let Some(e) = self.error.take() { + return Err(OutputStreamError::LastOperationFailed(e)); + } + if !self.alive { + return Err(OutputStreamError::Closed); + } + Ok(()) + } +} + +struct Worker { + state: Mutex, + new_work: tokio::sync::Notify, + write_ready_changed: tokio::sync::Notify, +} + +enum Job { + Flush, + Write(Bytes), +} + +enum WriteStatus<'a> { + Done(Result), + Pending(tokio::sync::futures::Notified<'a>), +} + +impl Worker { + fn new(write_budget: usize) -> Self { + Self { + state: Mutex::new(WorkerState { + alive: true, + items: std::collections::VecDeque::new(), + write_budget, + flush_pending: false, + error: None, + }), + new_work: tokio::sync::Notify::new(), + write_ready_changed: tokio::sync::Notify::new(), + } + } + fn check_write(&self) -> WriteStatus<'_> { + let mut state = self.state(); + if let Err(e) = state.check_error() { + return WriteStatus::Done(Err(e)); + } + + if state.flush_pending || state.write_budget == 0 { + return WriteStatus::Pending(self.write_ready_changed.notified()); + } + + WriteStatus::Done(Ok(state.write_budget)) + } + fn state(&self) -> std::sync::MutexGuard { + self.state.lock().unwrap() + } + fn pop(&self) -> Option { + let mut state = self.state(); + if state.items.is_empty() { + if state.flush_pending { + return Some(Job::Flush); + } + } else if let Some(bytes) = state.items.pop_front() { + return Some(Job::Write(bytes)); + } + + None + } + fn report_error(&self, e: std::io::Error) { + { + let mut state = self.state(); + state.alive = false; + state.error = Some(e.into()); + state.flush_pending = false; + } + self.write_ready_changed.notify_waiters(); + } + + async fn work(&self, writer: mpsc::Sender) { + loop { + let notified = self.new_work.notified(); + while let Some(job) = self.pop() { + match job { + Job::Flush => { + self.state().flush_pending = false; + } + + Job::Write(bytes) => { + tracing::debug!("worker writing: {bytes:?}"); + let len = bytes.len(); + match writer.send(bytes).await { + Err(_) => { + self.report_error(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "Outgoing stream body reader has dropped", + )); + return; + } + Ok(_) => { + self.state().write_budget += len; + } + } + } + } + + self.write_ready_changed.notify_waiters(); + } + + notified.await; + } + } +} + +/// Provides a [`HostOutputStream`] impl from a [`tokio::sync::mpsc::Sender`]. +pub struct BodyWriteStream { + worker: Arc, + _join_handle: preview2::AbortOnDropJoinHandle<()>, +} + +impl BodyWriteStream { + /// Create a [`BodyWriteStream`]. + pub fn new(write_budget: usize, writer: mpsc::Sender) -> Self { + let worker = Arc::new(Worker::new(write_budget)); + + let w = Arc::clone(&worker); + let join_handle = preview2::spawn(async move { w.work(writer).await }); + + BodyWriteStream { + worker, + _join_handle: join_handle, + } + } +} + +#[async_trait::async_trait] +impl HostOutputStream for BodyWriteStream { + fn write(&mut self, bytes: Bytes) -> Result<(), OutputStreamError> { + let mut state = self.worker.state(); + state.check_error()?; + if state.flush_pending { + return Err(OutputStreamError::Trap(anyhow!( + "write not permitted while flush pending" + ))); + } + match state.write_budget.checked_sub(bytes.len()) { + Some(remaining_budget) => { + state.write_budget = remaining_budget; + state.items.push_back(bytes); + } + None => return Err(OutputStreamError::Trap(anyhow!("write exceeded budget"))), + } + drop(state); + self.worker.new_work.notify_waiters(); + Ok(()) + } + fn flush(&mut self) -> Result<(), OutputStreamError> { + let mut state = self.worker.state(); + state.check_error()?; + + state.flush_pending = true; + self.worker.new_work.notify_waiters(); + + Ok(()) + } + + async fn write_ready(&mut self) -> Result { + loop { + match self.worker.check_write() { + WriteStatus::Done(r) => return r, + WriteStatus::Pending(notifier) => notifier.await, + } + } + } +} diff --git a/crates/wasi-http/src/http_impl.rs b/crates/wasi-http/src/http_impl.rs index e181ef143e0d..570d229cf7b2 100644 --- a/crates/wasi-http/src/http_impl.rs +++ b/crates/wasi-http/src/http_impl.rs @@ -1,122 +1,45 @@ -use crate::bindings::http::types::{ - FutureIncomingResponse, OutgoingRequest, RequestOptions, Scheme, +use crate::bindings::http::{ + outgoing_handler, + types::{FutureIncomingResponse, OutgoingRequest, RequestOptions, Scheme}, }; -use crate::types::{ActiveFields, ActiveFuture, ActiveResponse, HttpResponse, TableHttpExt}; +use crate::types::{HostFutureIncomingResponse, IncomingResponseInternal, TableHttpExt}; use crate::WasiHttpView; use anyhow::Context; -use bytes::{Bytes, BytesMut}; -use http_body_util::{BodyExt, Empty, Full}; -use hyper::{Method, Request}; -#[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))] -use std::sync::Arc; +use bytes::Bytes; +use http_body_util::{BodyExt, Empty}; +use hyper::Method; use std::time::Duration; use tokio::net::TcpStream; use tokio::time::timeout; -#[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))] -use tokio_rustls::rustls::{self, OwnedTrustAnchor}; -use wasmtime_wasi::preview2::{StreamState, TableStreamExt}; +use wasmtime_wasi::preview2; -#[async_trait::async_trait] -impl crate::bindings::http::outgoing_handler::Host for T { - async fn handle( +impl outgoing_handler::Host for T { + fn handle( &mut self, request_id: OutgoingRequest, options: Option, - ) -> wasmtime::Result { - let future = ActiveFuture::new(request_id, options); - let future_id = self - .table_mut() - .push_future(Box::new(future)) - .context("[handle] pushing future")?; - Ok(future_id) - } -} - -#[cfg(feature = "sync")] -pub mod sync { - use crate::bindings::http::outgoing_handler::{ - Host as AsyncHost, RequestOptions as AsyncRequestOptions, - }; - use crate::bindings::sync::http::types::{ - FutureIncomingResponse, OutgoingRequest, RequestOptions, - }; - use crate::WasiHttpView; - use wasmtime_wasi::preview2::in_tokio; - - // same boilerplate everywhere, converting between two identical types with different - // definition sites. one day wasmtime-wit-bindgen will make all this unnecessary - impl From for AsyncRequestOptions { - fn from(other: RequestOptions) -> Self { - Self { - connect_timeout_ms: other.connect_timeout_ms, - first_byte_timeout_ms: other.first_byte_timeout_ms, - between_bytes_timeout_ms: other.between_bytes_timeout_ms, - } - } - } - - impl crate::bindings::sync::http::outgoing_handler::Host for T { - fn handle( - &mut self, - request_id: OutgoingRequest, - options: Option, - ) -> wasmtime::Result { - in_tokio(async { AsyncHost::handle(self, request_id, options.map(|v| v.into())).await }) - } - } -} - -fn port_for_scheme(scheme: &Option) -> &str { - match scheme { - Some(s) => match s { - Scheme::Http => ":80", - Scheme::Https => ":443", - // This should never happen. - _ => panic!("unsupported scheme!"), - }, - None => ":443", - } -} + ) -> wasmtime::Result> { + let connect_timeout = Duration::from_millis( + options + .and_then(|opts| opts.connect_timeout_ms) + .unwrap_or(600 * 1000) as u64, + ); -#[async_trait::async_trait] -pub trait WasiHttpViewExt { - async fn handle_async( - &mut self, - request_id: OutgoingRequest, - options: Option, - ) -> wasmtime::Result; -} + let first_byte_timeout = Duration::from_millis( + options + .and_then(|opts| opts.first_byte_timeout_ms) + .unwrap_or(600 * 1000) as u64, + ); -#[async_trait::async_trait] -impl WasiHttpViewExt for T { - async fn handle_async( - &mut self, - request_id: OutgoingRequest, - options: Option, - ) -> wasmtime::Result { - tracing::debug!("preparing outgoing request"); - let opts = options.unwrap_or( - // TODO: Configurable defaults here? - RequestOptions { - connect_timeout_ms: Some(600 * 1000), - first_byte_timeout_ms: Some(600 * 1000), - between_bytes_timeout_ms: Some(600 * 1000), - }, + let between_bytes_timeout = Duration::from_millis( + options + .and_then(|opts| opts.between_bytes_timeout_ms) + .unwrap_or(600 * 1000) as u64, ); - let connect_timeout = - Duration::from_millis(opts.connect_timeout_ms.unwrap_or(600 * 1000).into()); - let first_bytes_timeout = - Duration::from_millis(opts.first_byte_timeout_ms.unwrap_or(600 * 1000).into()); - let between_bytes_timeout = - Duration::from_millis(opts.between_bytes_timeout_ms.unwrap_or(600 * 1000).into()); - let request = self - .table() - .get_request(request_id) - .context("[handle_async] getting request")?; - tracing::debug!("http request retrieved from table"); + let req = self.table().delete_outgoing_request(request_id)?; - let method = match request.method() { + let method = match req.method { crate::bindings::http::types::Method::Get => Method::GET, crate::bindings::http::types::Method::Head => Method::HEAD, crate::bindings::http::types::Method::Post => Method::POST, @@ -126,214 +49,155 @@ impl WasiHttpViewExt for T { crate::bindings::http::types::Method::Options => Method::OPTIONS, crate::bindings::http::types::Method::Trace => Method::TRACE, crate::bindings::http::types::Method::Patch => Method::PATCH, - crate::bindings::http::types::Method::Other(s) => { - return Err(crate::bindings::http::types::Error::InvalidUrl(format!( - "unknown method {}", - s - )) - .into()); + crate::bindings::http::types::Method::Other(method) => { + return Ok(Err(outgoing_handler::Error::Invalid(format!( + "unknown method {method}" + )))); } }; - let scheme = match request.scheme().as_ref().unwrap_or(&Scheme::Https) { - Scheme::Http => "http://", - Scheme::Https => "https://", - Scheme::Other(s) => { - return Err(crate::bindings::http::types::Error::InvalidUrl(format!( - "unsupported scheme {}", - s - )) - .into()); + let (use_tls, scheme, port) = match req.scheme.unwrap_or(Scheme::Https) { + Scheme::Http => (false, "http://", 80), + Scheme::Https => (true, "https://", 443), + Scheme::Other(scheme) => { + return Ok(Err(outgoing_handler::Error::Invalid(format!( + "unsupported scheme {scheme}" + )))) } }; - // Largely adapted from https://hyper.rs/guides/1/client/basic/ - let authority = match request.authority().find(":") { - Some(_) => request.authority().to_owned(), - None => request.authority().to_owned() + port_for_scheme(request.scheme()), + let authority = if req.authority.find(':').is_some() { + req.authority.clone() + } else { + format!("{}:{port}", req.authority) }; - let tcp_stream = TcpStream::connect(authority.clone()).await?; - let mut sender = if scheme == "https://" { - tracing::debug!("initiating client connection client with TLS"); - #[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))] - { - //TODO: uncomment this code and make the tls implementation a feature decision. - //let connector = tokio_native_tls::native_tls::TlsConnector::builder().build()?; - //let connector = tokio_native_tls::TlsConnector::from(connector); - //let host = authority.split(":").next().unwrap_or(&authority); - //let stream = connector.connect(&host, stream).await?; - // derived from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs - let mut root_cert_store = rustls::RootCertStore::empty(); - root_cert_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map( - |ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - }, - )); - let config = rustls::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_cert_store) - .with_no_client_auth(); - let connector = tokio_rustls::TlsConnector::from(Arc::new(config)); - let mut parts = authority.split(":"); - let host = parts.next().unwrap_or(&authority); - let domain = rustls::ServerName::try_from(host)?; - let stream = connector.connect(domain, tcp_stream).await.map_err(|e| { - crate::bindings::http::types::Error::ProtocolError(e.to_string()) - })?; + let mut builder = hyper::Request::builder() + .method(method) + .uri(format!("{scheme}{authority}{}", req.path_with_query)) + .header(hyper::header::HOST, &authority); - let t = timeout( - connect_timeout, - hyper::client::conn::http1::handshake(stream), - ) - .await?; - let (s, conn) = t?; - tokio::task::spawn(async move { - if let Err(err) = conn.await { - tracing::debug!("[host/client] Connection failed: {:?}", err); - } - }); - s - } - #[cfg(any(target_arch = "riscv64", target_arch = "s390x"))] - return Err(crate::bindings::http::types::Error::UnexpectedError( - "unsupported architecture for SSL".to_string(), - )); - } else { - tracing::debug!("initiating client connection without TLS"); - let t = timeout( - connect_timeout, - hyper::client::conn::http1::handshake(tcp_stream), - ) - .await?; - let (s, conn) = t?; - tokio::task::spawn(async move { - if let Err(err) = conn.await { - tracing::debug!("[host/client] Connection failed: {:?}", err); - } - }); - s - }; + for (k, v) in req.headers.iter() { + builder = builder.header(k, v); + } - let url = scheme.to_owned() + &request.authority() + &request.path_with_query(); + let body = req.body.unwrap_or_else(|| Empty::::new().boxed()); - tracing::debug!("request to url {:?}", &url); - let mut call = Request::builder() - .method(method) - .uri(url) - .header(hyper::header::HOST, request.authority()); + let request = builder.body(body).map_err(http_protocol_error)?; - if let Some(headers) = request.headers() { - for (key, val) in self - .table() - .get_fields(headers) - .context("[handle_async] getting request headers")? - .iter() - { - for item in val { - call = call.header(key, item.clone()); + let handle = preview2::spawn(async move { + let tcp_stream = TcpStream::connect(authority.clone()) + .await + .map_err(invalid_url)?; + + let (mut sender, worker) = if use_tls { + #[cfg(any(target_arch = "riscv64", target_arch = "s390x"))] + { + anyhow::bail!(crate::bindings::http::types::Error::UnexpectedError( + "unsupported architecture for SSL".to_string(), + )); } - } - } - let mut response = ActiveResponse::new(); - let body = match request.body() { - Some(id) => { - let table = self.table_mut(); - let stream = table - .get_stream(id) - .context("[handle_async] getting stream")?; - let input_stream = table - .get_input_stream_mut(stream.incoming()) - .context("[handle_async] getting mutable input stream")?; - let mut bytes = BytesMut::new(); - let mut eof = StreamState::Open; - while eof != StreamState::Closed { - let (chunk, state) = input_stream.read(4096)?; - eof = if chunk.is_empty() { - StreamState::Closed - } else { - state - }; - bytes.extend_from_slice(&chunk[..]); + #[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))] + { + use tokio_rustls::rustls::OwnedTrustAnchor; + + // derived from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs + let mut root_cert_store = rustls::RootCertStore::empty(); + root_cert_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map( + |ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }, + )); + let config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); + let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config)); + let mut parts = authority.split(":"); + let host = parts.next().unwrap_or(&authority); + let domain = rustls::ServerName::try_from(host)?; + let stream = connector.connect(domain, tcp_stream).await.map_err(|e| { + crate::bindings::http::types::Error::ProtocolError(e.to_string()) + })?; + + let (sender, conn) = timeout( + connect_timeout, + hyper::client::conn::http1::handshake(stream), + ) + .await + .map_err(|_| timeout_error("connection"))??; + + let worker = preview2::spawn(async move { + conn.await.context("hyper connection failed")?; + Ok::<_, anyhow::Error>(()) + }); + + (sender, worker) } - Full::::new(bytes.freeze()).boxed() - } - None => Empty::::new().boxed(), - }; - let request = call.body(body)?; - tracing::trace!("hyper request {:?}", request); - let t = timeout(first_bytes_timeout, sender.send_request(request)).await?; - let mut res = t?; - tracing::trace!("hyper response {:?}", res); - response.status = res.status().as_u16(); + } else { + let (sender, conn) = timeout( + connect_timeout, + // TODO: we should plumb the builder through the http context, and use it here + hyper::client::conn::http1::handshake(tcp_stream), + ) + .await + .map_err(|_| timeout_error("connection"))??; - let mut map = ActiveFields::new(); - for (key, value) in res.headers().iter() { - let mut vec = Vec::new(); - vec.push(value.as_bytes().to_vec()); - map.insert(key.as_str().to_string(), vec); - } - let headers = self - .table_mut() - .push_fields(Box::new(map)) - .context("[handle_async] pushing response headers")?; - response.set_headers(headers); + let worker = preview2::spawn(async move { + conn.await.context("hyper connection failed")?; + Ok::<_, anyhow::Error>(()) + }); - let mut buf: Vec = Vec::new(); - while let Some(next) = timeout(between_bytes_timeout, res.frame()).await? { - let frame = next?; - tracing::debug!("response body next frame"); - if let Some(chunk) = frame.data_ref() { - tracing::trace!("response body chunk size {:?}", chunk.len()); - buf.extend_from_slice(chunk); - } - if let Some(trailers) = frame.trailers_ref() { - tracing::debug!("response trailers present"); - let mut map = ActiveFields::new(); - for (name, value) in trailers.iter() { - let key = name.to_string(); - match map.get_mut(&key) { - Some(vec) => vec.push(value.as_bytes().to_vec()), - None => { - let mut vec = Vec::new(); - vec.push(value.as_bytes().to_vec()); - map.insert(key, vec); - } - }; - } - let trailers = self - .table_mut() - .push_fields(Box::new(map)) - .context("[handle_async] pushing response trailers")?; - response.set_trailers(trailers); - tracing::debug!("http trailers saved to table"); - } - } + (sender, worker) + }; - let response_id = self - .table_mut() - .push_response(Box::new(response)) - .context("[handle_async] pushing response")?; - tracing::trace!("response body {:?}", std::str::from_utf8(&buf[..]).unwrap()); - let (stream_id, stream) = self - .table_mut() - .push_stream(Bytes::from(buf), response_id) - .await - .context("[handle_async] pushing stream")?; - let response = self - .table_mut() - .get_response_mut(response_id) - .context("[handle_async] getting mutable response")?; - response.set_body(stream_id); - tracing::debug!("http response saved to table with id {:?}", response_id); + let resp = timeout(first_byte_timeout, sender.send_request(request)) + .await + .map_err(|_| timeout_error("first byte"))? + .map_err(hyper_protocol_error)?; - self.http_ctx_mut().streams.insert(stream_id, stream); + Ok(IncomingResponseInternal { + resp, + worker, + between_bytes_timeout, + }) + }); - Ok(response_id) + let fut = self + .table() + .push_future_incoming_response(HostFutureIncomingResponse::new(handle))?; + + Ok(Ok(fut)) } } + +fn timeout_error(kind: &str) -> anyhow::Error { + anyhow::anyhow!(crate::bindings::http::types::Error::TimeoutError(format!( + "{kind} timed out" + ))) +} + +fn http_protocol_error(e: http::Error) -> anyhow::Error { + anyhow::anyhow!(crate::bindings::http::types::Error::ProtocolError( + e.to_string() + )) +} + +fn hyper_protocol_error(e: hyper::Error) -> anyhow::Error { + anyhow::anyhow!(crate::bindings::http::types::Error::ProtocolError( + e.to_string() + )) +} + +fn invalid_url(e: std::io::Error) -> anyhow::Error { + // TODO: DNS errors show up as a Custom io error, what subset of errors should we consider for + // InvalidUrl here? + anyhow::anyhow!(crate::bindings::http::types::Error::InvalidUrl( + e.to_string() + )) +} diff --git a/crates/wasi-http/src/incoming_handler.rs b/crates/wasi-http/src/incoming_handler.rs index e65a88e27d53..3fd26bf7d0af 100644 --- a/crates/wasi-http/src/incoming_handler.rs +++ b/crates/wasi-http/src/incoming_handler.rs @@ -1,9 +1,8 @@ use crate::bindings::http::types::{IncomingRequest, ResponseOutparam}; use crate::WasiHttpView; -#[async_trait::async_trait] impl crate::bindings::http::incoming_handler::Host for T { - async fn handle( + fn handle( &mut self, _request: IncomingRequest, _response_out: ResponseOutparam, @@ -11,21 +10,3 @@ impl crate::bindings::http::incoming_handler::Host for T { anyhow::bail!("unimplemented: [incoming_handler] handle") } } - -#[cfg(feature = "sync")] -pub mod sync { - use crate::bindings::http::incoming_handler::Host as AsyncHost; - use crate::bindings::sync::http::types::{IncomingRequest, ResponseOutparam}; - use crate::WasiHttpView; - use wasmtime_wasi::preview2::in_tokio; - - impl crate::bindings::sync::http::incoming_handler::Host for T { - fn handle( - &mut self, - request: IncomingRequest, - response_out: ResponseOutparam, - ) -> wasmtime::Result<()> { - in_tokio(async { AsyncHost::handle(self, request, response_out).await }) - } - } -} diff --git a/crates/wasi-http/src/lib.rs b/crates/wasi-http/src/lib.rs index 7a99a0b79a8f..c6b072b1e60f 100644 --- a/crates/wasi-http/src/lib.rs +++ b/crates/wasi-http/src/lib.rs @@ -1,9 +1,8 @@ -pub use crate::http_impl::WasiHttpViewExt; pub use crate::types::{WasiHttpCtx, WasiHttpView}; use core::fmt::Formatter; use std::fmt::{self, Display}; -pub mod component_impl; +pub mod body; pub mod http_impl; pub mod incoming_handler; pub mod proxy; @@ -11,56 +10,22 @@ pub mod types; pub mod types_impl; pub mod bindings { - #[cfg(feature = "sync")] - pub mod sync { - pub(crate) mod _internal { - wasmtime::component::bindgen!({ - path: "wit", - interfaces: " - import wasi:http/incoming-handler - import wasi:http/outgoing-handler - import wasi:http/types - ", - tracing: true, - with: { - "wasi:io/streams": wasmtime_wasi::preview2::bindings::sync_io::io::streams, - "wasi:poll/poll": wasmtime_wasi::preview2::bindings::sync_io::poll::poll, - } - }); - } - pub use self::_internal::wasi::http; - } - - pub(crate) mod _internal_rest { - wasmtime::component::bindgen!({ - path: "wit", - interfaces: " + wasmtime::component::bindgen!({ + path: "wit", + interfaces: " import wasi:http/incoming-handler import wasi:http/outgoing-handler import wasi:http/types ", - tracing: true, - async: true, - with: { - "wasi:io/streams": wasmtime_wasi::preview2::bindings::io::streams, - "wasi:poll/poll": wasmtime_wasi::preview2::bindings::poll::poll, - } - }); - } - - pub use self::_internal_rest::wasi::http; -} - -pub fn add_to_linker(linker: &mut wasmtime::Linker) -> anyhow::Result<()> { - crate::component_impl::add_component_to_linker::(linker, |t| t) -} - -pub mod sync { - use crate::types::WasiHttpView; + tracing: true, + async: false, + with: { + "wasi:io/streams": wasmtime_wasi::preview2::bindings::io::streams, + "wasi:poll/poll": wasmtime_wasi::preview2::bindings::poll::poll, + } + }); - pub fn add_to_linker(linker: &mut wasmtime::Linker) -> anyhow::Result<()> { - crate::component_impl::sync::add_component_to_linker::(linker, |t| t) - } + pub use wasi::http; } impl std::error::Error for crate::bindings::http::types::Error {} diff --git a/crates/wasi-http/src/proxy.rs b/crates/wasi-http/src/proxy.rs index 8cd78d700265..9f156fd09a14 100644 --- a/crates/wasi-http/src/proxy.rs +++ b/crates/wasi-http/src/proxy.rs @@ -4,7 +4,7 @@ use wasmtime_wasi::preview2; wasmtime::component::bindgen!({ world: "wasi:http/proxy", tracing: true, - async: true, + async: false, with: { "wasi:cli/stderr": preview2::bindings::cli::stderr, "wasi:cli/stdin": preview2::bindings::cli::stdin, @@ -30,39 +30,3 @@ where bindings::http::types::add_to_linker(l, |t| t)?; Ok(()) } - -#[cfg(feature = "sync")] -pub mod sync { - use crate::{bindings, WasiHttpView}; - use wasmtime_wasi::preview2; - - wasmtime::component::bindgen!({ - world: "wasi:http/proxy", - tracing: true, - async: false, - with: { - "wasi:cli/stderr": preview2::bindings::cli::stderr, - "wasi:cli/stdin": preview2::bindings::cli::stdin, - "wasi:cli/stdout": preview2::bindings::cli::stdout, - "wasi:clocks/monotonic-clock": preview2::bindings::clocks::monotonic_clock, - "wasi:clocks/timezone": preview2::bindings::clocks::timezone, - "wasi:clocks/wall-clock": preview2::bindings::clocks::wall_clock, - "wasi:http/incoming-handler": bindings::sync::http::incoming_handler, - "wasi:http/outgoing-handler": bindings::sync::http::outgoing_handler, - "wasi:http/types": bindings::sync::http::types, - "wasi:io/streams": preview2::bindings::sync_io::io::streams, - "wasi:poll/poll": preview2::bindings::sync_io::poll::poll, - "wasi:random/random": preview2::bindings::random::random, - }, - }); - - pub fn add_to_linker(l: &mut wasmtime::component::Linker) -> anyhow::Result<()> - where - T: WasiHttpView + bindings::sync::http::types::Host, - { - bindings::sync::http::incoming_handler::add_to_linker(l, |t| t)?; - bindings::sync::http::outgoing_handler::add_to_linker(l, |t| t)?; - bindings::sync::http::types::add_to_linker(l, |t| t)?; - Ok(()) - } -} diff --git a/crates/wasi-http/src/types.rs b/crates/wasi-http/src/types.rs index 050a83470947..3b73f7f3fd49 100644 --- a/crates/wasi-http/src/types.rs +++ b/crates/wasi-http/src/types.rs @@ -1,449 +1,302 @@ //! Implements the base structure (i.e. [WasiHttpCtx]) that will provide the //! implementation of the wasi-http API. -use crate::bindings::http::types::{ - IncomingStream, Method, OutgoingRequest, OutgoingStream, RequestOptions, Scheme, +use crate::{ + bindings::http::types::{FutureTrailers, IncomingBody, Method, OutgoingBody, Scheme}, + body::{ + HostFutureTrailers, HostIncomingBody, HostIncomingBodyBuilder, HostOutgoingBody, HyperBody, + }, }; -use bytes::Bytes; use std::any::Any; -use std::collections::HashMap; -use std::ops::{Deref, DerefMut}; -use wasmtime_wasi::preview2::{ - pipe::{AsyncReadStream, AsyncWriteStream}, - HostInputStream, HostOutputStream, Table, TableError, TableStreamExt, WasiView, -}; - -const MAX_BUF_SIZE: usize = 65_536; +use std::pin::Pin; +use std::task; +use wasmtime_wasi::preview2::{AbortOnDropJoinHandle, Table, TableError}; /// Capture the state necessary for use in the wasi-http API implementation. -pub struct WasiHttpCtx { - pub streams: HashMap, -} - -impl WasiHttpCtx { - /// Make a new context from the default state. - pub fn new() -> Self { - Self { - streams: HashMap::new(), - } - } -} +pub struct WasiHttpCtx; -pub trait WasiHttpView: WasiView { - fn http_ctx(&self) -> &WasiHttpCtx; - fn http_ctx_mut(&mut self) -> &mut WasiHttpCtx; +pub trait WasiHttpView: Send { + fn ctx(&mut self) -> &mut WasiHttpCtx; + fn table(&mut self) -> &mut Table; } -pub type FieldsMap = HashMap>>; - -#[derive(Clone, Debug)] -pub struct ActiveRequest { - pub active: bool, +pub struct HostOutgoingRequest { pub method: Method, pub scheme: Option, pub path_with_query: String, pub authority: String, - pub headers: Option, - pub body: Option, + pub headers: FieldMap, + pub body: Option, } -pub trait HttpRequest: Send + Sync { - fn new() -> Self - where - Self: Sized; - - fn as_any(&self) -> &dyn Any; - - fn method(&self) -> &Method; - fn scheme(&self) -> &Option; - fn path_with_query(&self) -> &str; - fn authority(&self) -> &str; - fn headers(&self) -> Option; - fn set_headers(&mut self, headers: u32); - fn body(&self) -> Option; - fn set_body(&mut self, body: u32); +pub struct HostIncomingResponse { + pub status: u16, + pub headers: FieldMap, + pub body: Option, + pub worker: AbortOnDropJoinHandle>, } -impl HttpRequest for ActiveRequest { - fn new() -> Self { - Self { - active: false, - method: Method::Get, - scheme: Some(Scheme::Http), - path_with_query: "".to_string(), - authority: "".to_string(), - headers: None, - body: None, - } - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn method(&self) -> &Method { - &self.method - } +pub type FieldMap = hyper::HeaderMap; - fn scheme(&self) -> &Option { - &self.scheme - } - - fn path_with_query(&self) -> &str { - &self.path_with_query - } - - fn authority(&self) -> &str { - &self.authority - } - - fn headers(&self) -> Option { - self.headers - } - - fn set_headers(&mut self, headers: u32) { - self.headers = Some(headers); - } - - fn body(&self) -> Option { - self.body - } +pub enum HostFields { + Ref { + parent: u32, - fn set_body(&mut self, body: u32) { - self.body = Some(body); - } + // NOTE: there's not failure in the result here because we assume that HostFields will + // always be registered as a child of the entry with the `parent` id. This ensures that the + // entry will always exist while this `HostFields::Ref` entry exists in the table, thus we + // don't need to account for failure when fetching the fields ref from the parent. + get_fields: for<'a> fn(elem: &'a mut (dyn Any + 'static)) -> &'a mut FieldMap, + }, + Owned { + fields: FieldMap, + }, } -#[derive(Clone, Debug)] -pub struct ActiveResponse { - pub active: bool, - pub status: u16, - pub headers: Option, - pub body: Option, - pub trailers: Option, +pub struct IncomingResponseInternal { + pub resp: hyper::Response, + pub worker: AbortOnDropJoinHandle>, + pub between_bytes_timeout: std::time::Duration, } -pub trait HttpResponse: Send + Sync { - fn new() -> Self - where - Self: Sized; +type FutureIncomingResponseHandle = AbortOnDropJoinHandle>; - fn as_any(&self) -> &dyn Any; - - fn status(&self) -> u16; - fn headers(&self) -> Option; - fn set_headers(&mut self, headers: u32); - fn body(&self) -> Option; - fn set_body(&mut self, body: u32); - fn trailers(&self) -> Option; - fn set_trailers(&mut self, trailers: u32); +pub enum HostFutureIncomingResponse { + Pending(FutureIncomingResponseHandle), + Ready(anyhow::Result), + Consumed, } -impl HttpResponse for ActiveResponse { - fn new() -> Self { - Self { - active: false, - status: 0, - headers: None, - body: None, - trailers: None, - } +impl HostFutureIncomingResponse { + pub fn new(handle: FutureIncomingResponseHandle) -> Self { + Self::Pending(handle) } - fn as_any(&self) -> &dyn Any { - self + pub fn is_ready(&self) -> bool { + matches!(self, Self::Ready(_)) } - fn status(&self) -> u16 { - self.status + pub fn unwrap_ready(self) -> anyhow::Result { + match self { + Self::Ready(res) => res, + Self::Pending(_) | Self::Consumed => { + panic!("unwrap_ready called on a pending HostFutureIncomingResponse") + } + } } +} - fn headers(&self) -> Option { - self.headers +impl std::future::Future for HostFutureIncomingResponse { + type Output = anyhow::Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll { + let s = self.get_mut(); + match s { + Self::Pending(ref mut handle) => match Pin::new(handle).poll(cx) { + task::Poll::Pending => task::Poll::Pending, + task::Poll::Ready(r) => { + *s = Self::Ready(r); + task::Poll::Ready(Ok(())) + } + }, + + Self::Consumed | Self::Ready(_) => task::Poll::Ready(Ok(())), + } } +} - fn set_headers(&mut self, headers: u32) { - self.headers = Some(headers); - } +#[async_trait::async_trait] +pub trait TableHttpExt { + fn push_outgoing_response(&mut self, request: HostOutgoingRequest) -> Result; + fn get_outgoing_request(&self, id: u32) -> Result<&HostOutgoingRequest, TableError>; + fn get_outgoing_request_mut(&mut self, id: u32) + -> Result<&mut HostOutgoingRequest, TableError>; + fn delete_outgoing_request(&mut self, id: u32) -> Result; + + fn push_incoming_response(&mut self, response: HostIncomingResponse) + -> Result; + fn get_incoming_response(&self, id: u32) -> Result<&HostIncomingResponse, TableError>; + fn get_incoming_response_mut( + &mut self, + id: u32, + ) -> Result<&mut HostIncomingResponse, TableError>; + fn delete_incoming_response(&mut self, id: u32) -> Result; - fn body(&self) -> Option { - self.body - } + fn push_fields(&mut self, fields: HostFields) -> Result; + fn get_fields(&mut self, id: u32) -> Result<&mut FieldMap, TableError>; + fn delete_fields(&mut self, id: u32) -> Result; - fn set_body(&mut self, body: u32) { - self.body = Some(body); - } + fn push_future_incoming_response( + &mut self, + response: HostFutureIncomingResponse, + ) -> Result; + fn get_future_incoming_response( + &self, + id: u32, + ) -> Result<&HostFutureIncomingResponse, TableError>; + fn get_future_incoming_response_mut( + &mut self, + id: u32, + ) -> Result<&mut HostFutureIncomingResponse, TableError>; + fn delete_future_incoming_response( + &mut self, + id: u32, + ) -> Result; - fn trailers(&self) -> Option { - self.trailers - } + fn push_incoming_body(&mut self, body: HostIncomingBody) -> Result; + fn get_incoming_body(&mut self, id: IncomingBody) -> Result<&mut HostIncomingBody, TableError>; + fn delete_incoming_body(&mut self, id: IncomingBody) -> Result; - fn set_trailers(&mut self, trailers: u32) { - self.trailers = Some(trailers); - } -} + fn push_outgoing_body(&mut self, body: HostOutgoingBody) -> Result; + fn get_outgoing_body(&mut self, id: OutgoingBody) -> Result<&mut HostOutgoingBody, TableError>; + fn delete_outgoing_body(&mut self, id: OutgoingBody) -> Result; -#[derive(Clone, Debug)] -pub struct ActiveFuture { - request_id: OutgoingRequest, - options: Option, - response_id: Option, - pollable_id: Option, + fn push_future_trailers( + &mut self, + trailers: HostFutureTrailers, + ) -> Result; + fn get_future_trailers( + &mut self, + id: FutureTrailers, + ) -> Result<&mut HostFutureTrailers, TableError>; + fn delete_future_trailers( + &mut self, + id: FutureTrailers, + ) -> Result; } -impl ActiveFuture { - pub fn new(request_id: OutgoingRequest, options: Option) -> Self { - Self { - request_id, - options, - response_id: None, - pollable_id: None, - } - } - - pub fn request_id(&self) -> u32 { - self.request_id - } - - pub fn options(&self) -> Option { - self.options - } - - pub fn response_id(&self) -> Option { - self.response_id +#[async_trait::async_trait] +impl TableHttpExt for Table { + fn push_outgoing_response(&mut self, request: HostOutgoingRequest) -> Result { + self.push(Box::new(request)) } - - pub fn set_response_id(&mut self, response_id: u32) { - self.response_id = Some(response_id); + fn get_outgoing_request(&self, id: u32) -> Result<&HostOutgoingRequest, TableError> { + self.get::(id) } - - pub fn pollable_id(&self) -> Option { - self.pollable_id + fn get_outgoing_request_mut( + &mut self, + id: u32, + ) -> Result<&mut HostOutgoingRequest, TableError> { + self.get_mut::(id) } - - pub fn set_pollable_id(&mut self, pollable_id: u32) { - self.pollable_id = Some(pollable_id); + fn delete_outgoing_request(&mut self, id: u32) -> Result { + let req = self.delete::(id)?; + Ok(req) } -} - -#[derive(Clone, Debug)] -pub struct ActiveFields(HashMap>>); -impl ActiveFields { - pub fn new() -> Self { - Self(FieldsMap::new()) + fn push_incoming_response( + &mut self, + response: HostIncomingResponse, + ) -> Result { + self.push(Box::new(response)) } -} - -pub trait HttpFields: Send + Sync { - fn as_any(&self) -> &dyn Any; -} - -impl HttpFields for ActiveFields { - fn as_any(&self) -> &dyn Any { - self + fn get_incoming_response(&self, id: u32) -> Result<&HostIncomingResponse, TableError> { + self.get::(id) } -} - -impl Deref for ActiveFields { - type Target = FieldsMap; - fn deref(&self) -> &FieldsMap { - &self.0 + fn get_incoming_response_mut( + &mut self, + id: u32, + ) -> Result<&mut HostIncomingResponse, TableError> { + self.get_mut::(id) } -} - -impl DerefMut for ActiveFields { - fn deref_mut(&mut self) -> &mut FieldsMap { - &mut self.0 + fn delete_incoming_response(&mut self, id: u32) -> Result { + let resp = self.delete::(id)?; + Ok(resp) } -} - -#[derive(Clone, Debug)] -pub struct Stream { - input_id: u32, - output_id: u32, - parent_id: u32, -} -impl Stream { - pub fn new(input_id: u32, output_id: u32, parent_id: u32) -> Self { - Self { - input_id, - output_id, - parent_id, + fn push_fields(&mut self, fields: HostFields) -> Result { + match fields { + HostFields::Ref { parent, .. } => self.push_child(Box::new(fields), parent), + HostFields::Owned { .. } => self.push(Box::new(fields)), } } + fn get_fields(&mut self, id: u32) -> Result<&mut FieldMap, TableError> { + let fields = self.get_mut::(id)?; + if let HostFields::Ref { parent, get_fields } = *fields { + let entry = self.get_any_mut(parent)?; + return Ok(get_fields(entry)); + } - pub fn incoming(&self) -> IncomingStream { - self.input_id + match self.get_mut::(id)? { + HostFields::Owned { fields } => Ok(fields), + // NB: ideally the `if let` above would go here instead. That makes + // the borrow-checker unhappy. Unclear why. If you, dear reader, can + // refactor this to remove the `unreachable!` please do. + HostFields::Ref { .. } => unreachable!(), + } } - - pub fn outgoing(&self) -> OutgoingStream { - self.output_id + fn delete_fields(&mut self, id: u32) -> Result { + let fields = self.delete::(id)?; + Ok(fields) } - pub fn parent_id(&self) -> u32 { - self.parent_id - } -} - -#[async_trait::async_trait] -pub trait TableHttpExt { - fn push_request(&mut self, request: Box) -> Result; - fn get_request(&self, id: u32) -> Result<&(dyn HttpRequest), TableError>; - fn get_request_mut(&mut self, id: u32) -> Result<&mut Box, TableError>; - fn delete_request(&mut self, id: u32) -> Result<(), TableError>; - - fn push_response(&mut self, response: Box) -> Result; - fn get_response(&self, id: u32) -> Result<&dyn HttpResponse, TableError>; - fn get_response_mut(&mut self, id: u32) -> Result<&mut Box, TableError>; - fn delete_response(&mut self, id: u32) -> Result<(), TableError>; - - fn push_future(&mut self, future: Box) -> Result; - fn get_future(&self, id: u32) -> Result<&ActiveFuture, TableError>; - fn get_future_mut(&mut self, id: u32) -> Result<&mut Box, TableError>; - fn delete_future(&mut self, id: u32) -> Result<(), TableError>; - - fn push_fields(&mut self, fields: Box) -> Result; - fn get_fields(&self, id: u32) -> Result<&ActiveFields, TableError>; - fn get_fields_mut(&mut self, id: u32) -> Result<&mut Box, TableError>; - fn delete_fields(&mut self, id: u32) -> Result<(), TableError>; - - async fn push_stream( + fn push_future_incoming_response( &mut self, - content: Bytes, - parent: u32, - ) -> Result<(u32, Stream), TableError>; - fn get_stream(&self, id: u32) -> Result<&Stream, TableError>; - fn get_stream_mut(&mut self, id: u32) -> Result<&mut Box, TableError>; - fn delete_stream(&mut self, id: u32) -> Result<(), TableError>; -} - -#[async_trait::async_trait] -impl TableHttpExt for Table { - fn push_request(&mut self, request: Box) -> Result { - self.push(Box::new(request)) - } - fn get_request(&self, id: u32) -> Result<&dyn HttpRequest, TableError> { - self.get::>(id).map(|f| f.as_ref()) - } - fn get_request_mut(&mut self, id: u32) -> Result<&mut Box, TableError> { - self.get_mut::>(id) - } - fn delete_request(&mut self, id: u32) -> Result<(), TableError> { - self.delete::>(id).map(|_old| ()) - } - - fn push_response(&mut self, response: Box) -> Result { + response: HostFutureIncomingResponse, + ) -> Result { self.push(Box::new(response)) } - fn get_response(&self, id: u32) -> Result<&dyn HttpResponse, TableError> { - self.get::>(id).map(|f| f.as_ref()) + fn get_future_incoming_response( + &self, + id: u32, + ) -> Result<&HostFutureIncomingResponse, TableError> { + self.get::(id) } - fn get_response_mut(&mut self, id: u32) -> Result<&mut Box, TableError> { - self.get_mut::>(id) + fn get_future_incoming_response_mut( + &mut self, + id: u32, + ) -> Result<&mut HostFutureIncomingResponse, TableError> { + self.get_mut::(id) } - fn delete_response(&mut self, id: u32) -> Result<(), TableError> { - self.delete::>(id).map(|_old| ()) + fn delete_future_incoming_response( + &mut self, + id: u32, + ) -> Result { + self.delete(id) } - fn push_future(&mut self, future: Box) -> Result { - self.push(Box::new(future)) - } - fn get_future(&self, id: u32) -> Result<&ActiveFuture, TableError> { - self.get::>(id).map(|f| f.as_ref()) + fn push_incoming_body(&mut self, body: HostIncomingBody) -> Result { + self.push(Box::new(body)) } - fn get_future_mut(&mut self, id: u32) -> Result<&mut Box, TableError> { - self.get_mut::>(id) - } - fn delete_future(&mut self, id: u32) -> Result<(), TableError> { - self.delete::>(id).map(|_old| ()) + + fn get_incoming_body(&mut self, id: IncomingBody) -> Result<&mut HostIncomingBody, TableError> { + self.get_mut(id) } - fn push_fields(&mut self, fields: Box) -> Result { - self.push(Box::new(fields)) + fn delete_incoming_body(&mut self, id: IncomingBody) -> Result { + self.delete(id) } - fn get_fields(&self, id: u32) -> Result<&ActiveFields, TableError> { - self.get::>(id).map(|f| f.as_ref()) + + fn push_outgoing_body(&mut self, body: HostOutgoingBody) -> Result { + self.push(Box::new(body)) } - fn get_fields_mut(&mut self, id: u32) -> Result<&mut Box, TableError> { - self.get_mut::>(id) + + fn get_outgoing_body(&mut self, id: OutgoingBody) -> Result<&mut HostOutgoingBody, TableError> { + self.get_mut(id) } - fn delete_fields(&mut self, id: u32) -> Result<(), TableError> { - self.delete::>(id).map(|_old| ()) + + fn delete_outgoing_body(&mut self, id: OutgoingBody) -> Result { + self.delete(id) } - async fn push_stream( + fn push_future_trailers( &mut self, - mut content: Bytes, - parent: u32, - ) -> Result<(u32, Stream), TableError> { - tracing::debug!("preparing http body stream"); - let (a, b) = tokio::io::duplex(MAX_BUF_SIZE); - let (_, write_stream) = tokio::io::split(a); - let (read_stream, _) = tokio::io::split(b); - let input_stream = AsyncReadStream::new(read_stream); - // TODO: more informed budget here - let mut output_stream = AsyncWriteStream::new(4096, write_stream); - - while !content.is_empty() { - let permit = output_stream - .write_ready() - .await - .map_err(|_| TableError::NotPresent)?; - - let len = content.len().min(permit); - let chunk = content.split_to(len); - - output_stream - .write(chunk) - .map_err(|_| TableError::NotPresent)?; - } - output_stream.flush().map_err(|_| TableError::NotPresent)?; - let _readiness = tokio::time::timeout( - std::time::Duration::from_millis(10), - output_stream.write_ready(), - ) - .await; - - let input_stream = Box::new(input_stream); - let output_id = self.push_output_stream(Box::new(output_stream))?; - let input_id = self.push_input_stream(input_stream)?; - let stream = Stream::new(input_id, output_id, parent); - let cloned_stream = stream.clone(); - let stream_id = self.push(Box::new(Box::new(stream)))?; - tracing::trace!( - "http body stream details ( id: {:?}, input: {:?}, output: {:?} )", - stream_id, - input_id, - output_id - ); - Ok((stream_id, cloned_stream)) + trailers: HostFutureTrailers, + ) -> Result { + self.push(Box::new(trailers)) } - fn get_stream(&self, id: u32) -> Result<&Stream, TableError> { - self.get::>(id).map(|f| f.as_ref()) - } - fn get_stream_mut(&mut self, id: u32) -> Result<&mut Box, TableError> { - self.get_mut::>(id) - } - fn delete_stream(&mut self, id: u32) -> Result<(), TableError> { - let stream = self.get_stream_mut(id)?; - let input_stream = stream.incoming(); - let output_stream = stream.outgoing(); - self.delete::>(id).map(|_old| ())?; - self.delete::>(input_stream) - .map(|_old| ())?; - self.delete::>(output_stream) - .map(|_old| ()) - } -} -#[cfg(test)] -mod test { - use super::*; + fn get_future_trailers( + &mut self, + id: FutureTrailers, + ) -> Result<&mut HostFutureTrailers, TableError> { + self.get_mut(id) + } - #[test] - fn instantiate() { - WasiHttpCtx::new(); + fn delete_future_trailers( + &mut self, + id: FutureTrailers, + ) -> Result { + self.delete(id) } } diff --git a/crates/wasi-http/src/types_impl.rs b/crates/wasi-http/src/types_impl.rs index c044cd1bdde9..1b2326f3553b 100644 --- a/crates/wasi-http/src/types_impl.rs +++ b/crates/wasi-http/src/types_impl.rs @@ -1,191 +1,156 @@ use crate::bindings::http::types::{ - Error, Fields, FutureIncomingResponse, Headers, IncomingRequest, IncomingResponse, - IncomingStream, Method, OutgoingRequest, OutgoingResponse, OutgoingStream, ResponseOutparam, + Error, Fields, FutureIncomingResponse, FutureTrailers, Headers, IncomingBody, IncomingRequest, + IncomingResponse, Method, OutgoingBody, OutgoingRequest, OutgoingResponse, ResponseOutparam, Scheme, StatusCode, Trailers, }; -use crate::http_impl::WasiHttpViewExt; -use crate::types::{ActiveFields, ActiveRequest, HttpRequest, TableHttpExt}; +use crate::body::{HostFutureTrailers, HostFutureTrailersState}; +use crate::types::FieldMap; use crate::WasiHttpView; -use anyhow::{anyhow, bail, Context}; -use bytes::Bytes; -use wasmtime_wasi::preview2::{bindings::poll::poll::Pollable, HostPollable, TablePollableExt}; - -#[async_trait::async_trait] -impl crate::bindings::http::types::Host for T { - async fn drop_fields(&mut self, fields: Fields) -> wasmtime::Result<()> { - self.table_mut() +use crate::{ + body::{HostIncomingBodyBuilder, HostOutgoingBody}, + types::{ + HostFields, HostFutureIncomingResponse, HostIncomingResponse, HostOutgoingRequest, + TableHttpExt, + }, +}; +use anyhow::{anyhow, Context}; +use std::any::Any; +use wasmtime_wasi::preview2::{ + bindings::io::streams::{InputStream, OutputStream}, + bindings::poll::poll::Pollable, + HostPollable, PollableFuture, TablePollableExt, TableStreamExt, +}; + +impl crate::bindings::http::types::Host for T { + fn drop_fields(&mut self, fields: Fields) -> wasmtime::Result<()> { + self.table() .delete_fields(fields) .context("[drop_fields] deleting fields")?; Ok(()) } - async fn new_fields(&mut self, entries: Vec<(String, String)>) -> wasmtime::Result { - let mut map = ActiveFields::new(); - for (key, value) in entries { - map.insert(key, vec![value.clone().into_bytes()]); + fn new_fields(&mut self, entries: Vec<(String, Vec)>) -> wasmtime::Result { + let mut map = hyper::HeaderMap::new(); + + for (header, value) in entries { + let header = hyper::header::HeaderName::from_bytes(header.as_bytes())?; + let value = hyper::header::HeaderValue::from_bytes(&value)?; + map.append(header, value); } let id = self - .table_mut() - .push_fields(Box::new(map)) + .table() + .push_fields(HostFields::Owned { fields: map }) .context("[new_fields] pushing fields")?; Ok(id) } - async fn fields_get(&mut self, fields: Fields, name: String) -> wasmtime::Result>> { + fn fields_get(&mut self, fields: Fields, name: String) -> wasmtime::Result>> { let res = self - .table_mut() + .table() .get_fields(fields) .context("[fields_get] getting fields")? - .get(&name) - .ok_or_else(|| anyhow!("key not found: {name}"))? - .clone(); + .get_all(hyper::header::HeaderName::from_bytes(name.as_bytes())?) + .into_iter() + .map(|val| val.as_bytes().to_owned()) + .collect(); Ok(res) } - async fn fields_set( + fn fields_set( &mut self, fields: Fields, name: String, - value: Vec>, + values: Vec>, ) -> wasmtime::Result<()> { - match self.table_mut().get_fields_mut(fields) { - Ok(m) => { - m.insert(name, value.clone()); - Ok(()) - } - Err(_) => bail!("fields not found"), + let m = self.table().get_fields(fields)?; + + let header = hyper::header::HeaderName::from_bytes(name.as_bytes())?; + + m.remove(&header); + for value in values { + let value = hyper::header::HeaderValue::from_bytes(&value)?; + m.append(&header, value); } + + Ok(()) } - async fn fields_delete(&mut self, fields: Fields, name: String) -> wasmtime::Result<()> { - match self.table_mut().get_fields_mut(fields) { - Ok(m) => m.remove(&name), - Err(_) => None, - }; + fn fields_delete(&mut self, fields: Fields, name: String) -> wasmtime::Result<()> { + let m = self.table().get_fields(fields)?; + let header = hyper::header::HeaderName::from_bytes(name.as_bytes())?; + m.remove(header); Ok(()) } - async fn fields_append( + fn fields_append( &mut self, fields: Fields, name: String, value: Vec, ) -> wasmtime::Result<()> { let m = self - .table_mut() - .get_fields_mut(fields) + .table() + .get_fields(fields) .context("[fields_append] getting mutable fields")?; - match m.get_mut(&name) { - Some(v) => v.push(value), - None => { - let mut vec = std::vec::Vec::new(); - vec.push(value); - m.insert(name, vec); - } - }; + let header = hyper::header::HeaderName::from_bytes(name.as_bytes())?; + let value = hyper::header::HeaderValue::from_bytes(&value)?; + m.append(header, value); Ok(()) } - async fn fields_entries(&mut self, fields: Fields) -> wasmtime::Result)>> { - let field_map = match self.table().get_fields(fields) { - Ok(m) => m.iter(), - Err(_) => bail!("fields not found."), - }; - let mut result = Vec::new(); - for (name, value) in field_map { - result.push((name.clone(), value[0].clone())); - } + fn fields_entries(&mut self, fields: Fields) -> wasmtime::Result)>> { + let fields = self.table().get_fields(fields)?; + let result = fields + .iter() + .map(|(name, value)| (name.as_str().to_owned(), value.as_bytes().to_owned())) + .collect(); Ok(result) } - async fn fields_clone(&mut self, fields: Fields) -> wasmtime::Result { - let table = self.table_mut(); - let m = table + fn fields_clone(&mut self, fields: Fields) -> wasmtime::Result { + let fields = self + .table() .get_fields(fields) - .context("[fields_clone] getting fields")?; - let id = table - .push_fields(Box::new(m.clone())) + .context("[fields_clone] getting fields")? + .clone(); + let id = self + .table() + .push_fields(HostFields::Owned { fields }) .context("[fields_clone] pushing fields")?; Ok(id) } - async fn finish_incoming_stream( - &mut self, - stream_id: IncomingStream, - ) -> wasmtime::Result> { - for (_, stream) in self.http_ctx().streams.iter() { - if stream_id == stream.incoming() { - let response = self - .table() - .get_response(stream.parent_id()) - .context("[finish_incoming_stream] get trailers from response")?; - return Ok(response.trailers()); - } - } - bail!("unknown stream!") + fn drop_incoming_request(&mut self, _request: IncomingRequest) -> wasmtime::Result<()> { + todo!("we haven't implemented the server side of wasi-http yet") } - async fn finish_outgoing_stream( - &mut self, - _s: OutgoingStream, - _trailers: Option, - ) -> wasmtime::Result<()> { - bail!("unimplemented: finish_outgoing_stream") - } - async fn drop_incoming_request(&mut self, _request: IncomingRequest) -> wasmtime::Result<()> { - bail!("unimplemented: drop_incoming_request") - } - async fn drop_outgoing_request(&mut self, request: OutgoingRequest) -> wasmtime::Result<()> { - let r = self - .table_mut() - .get_request(request) - .context("[drop_outgoing_request] getting fields")?; - - // Cleanup dependent resources - let body = r.body(); - let headers = r.headers(); - if let Some(b) = body { - self.table_mut().delete_stream(b).ok(); - } - if let Some(h) = headers { - self.table_mut().delete_fields(h).ok(); - } - - self.table_mut() - .delete_request(request) - .context("[drop_outgoing_request] deleting request")?; - + fn drop_outgoing_request(&mut self, request: OutgoingRequest) -> wasmtime::Result<()> { + self.table().delete_outgoing_request(request)?; Ok(()) } - async fn incoming_request_method( - &mut self, - _request: IncomingRequest, - ) -> wasmtime::Result { - bail!("unimplemented: incoming_request_method") + fn incoming_request_method(&mut self, _request: IncomingRequest) -> wasmtime::Result { + todo!("we haven't implemented the server side of wasi-http yet") } - async fn incoming_request_path_with_query( + fn incoming_request_path_with_query( &mut self, _request: IncomingRequest, ) -> wasmtime::Result> { - bail!("unimplemented: incoming_request_path") + todo!("we haven't implemented the server side of wasi-http yet") } - async fn incoming_request_scheme( + fn incoming_request_scheme( &mut self, _request: IncomingRequest, ) -> wasmtime::Result> { - bail!("unimplemented: incoming_request_scheme") + todo!("we haven't implemented the server side of wasi-http yet") } - async fn incoming_request_authority( + fn incoming_request_authority( &mut self, _request: IncomingRequest, ) -> wasmtime::Result> { - bail!("unimplemented: incoming_request_authority") + todo!("we haven't implemented the server side of wasi-http yet") } - async fn incoming_request_headers( - &mut self, - _request: IncomingRequest, - ) -> wasmtime::Result { - bail!("unimplemented: incoming_request_headers") + fn incoming_request_headers(&mut self, _request: IncomingRequest) -> wasmtime::Result { + todo!("we haven't implemented the server side of wasi-http yet") } - async fn incoming_request_consume( + fn incoming_request_consume( &mut self, _request: IncomingRequest, - ) -> wasmtime::Result> { - bail!("unimplemented: incoming_request_consume") + ) -> wasmtime::Result> { + todo!("we haven't implemented the server side of wasi-http yet") } - async fn new_outgoing_request( + fn new_outgoing_request( &mut self, method: Method, path_with_query: Option, @@ -193,507 +158,305 @@ impl crate::bindings::http::types::Host for T authority: Option, headers: Headers, ) -> wasmtime::Result { - let mut req = ActiveRequest::new(); - req.path_with_query = path_with_query.unwrap_or("".to_string()); - req.authority = authority.unwrap_or("".to_string()); - req.method = method; - req.headers = Some(headers); - req.scheme = scheme; + let headers = self.table().get_fields(headers)?.clone(); + + let req = HostOutgoingRequest { + path_with_query: path_with_query.unwrap_or("".to_string()), + authority: authority.unwrap_or("".to_string()), + method, + headers, + scheme, + body: None, + }; let id = self - .table_mut() - .push_request(Box::new(req)) + .table() + .push_outgoing_response(req) .context("[new_outgoing_request] pushing request")?; Ok(id) } - async fn outgoing_request_write( + fn outgoing_request_write( &mut self, request: OutgoingRequest, - ) -> wasmtime::Result> { + ) -> wasmtime::Result> { let req = self .table() - .get_request(request) + .get_outgoing_request_mut(request) .context("[outgoing_request_write] getting request")?; - let stream_id = if let Some(stream_id) = req.body() { - stream_id - } else { - let (new, stream) = self - .table_mut() - .push_stream(Bytes::new(), request) - .await - .expect("[outgoing_request_write] valid output stream"); - self.http_ctx_mut().streams.insert(new, stream); - let req = self - .table_mut() - .get_request_mut(request) - .expect("[outgoing_request_write] request to be found"); - req.set_body(new); - new - }; - let stream = self - .table() - .get_stream(stream_id) - .context("[outgoing_request_write] getting stream")?; - Ok(Ok(stream.outgoing())) + + if req.body.is_some() { + return Ok(Err(())); + } + + let (host_body, hyper_body) = HostOutgoingBody::new(); + + req.body = Some(hyper_body); + + // The output stream will necessarily outlive the request, because we could be still + // writing to the stream after `outgoing-handler.handle` is called. + let outgoing_body = self.table().push_outgoing_body(host_body)?; + + Ok(Ok(outgoing_body)) } - async fn drop_response_outparam( - &mut self, - _response: ResponseOutparam, - ) -> wasmtime::Result<()> { - bail!("unimplemented: drop_response_outparam") + fn drop_response_outparam(&mut self, _response: ResponseOutparam) -> wasmtime::Result<()> { + todo!("we haven't implemented the server side of wasi-http yet") } - async fn set_response_outparam( + fn set_response_outparam( &mut self, _outparam: ResponseOutparam, _response: Result, - ) -> wasmtime::Result> { - bail!("unimplemented: set_response_outparam") + ) -> wasmtime::Result<()> { + todo!("we haven't implemented the server side of wasi-http yet") } - async fn drop_incoming_response(&mut self, response: IncomingResponse) -> wasmtime::Result<()> { - let r = self - .table() - .get_response(response) - .context("[drop_incoming_response] getting response")?; - - // Cleanup dependent resources - let body = r.body(); - let headers = r.headers(); - if let Some(id) = body { - let stream = self - .table() - .get_stream(id) - .context("[drop_incoming_response] getting stream")?; - let incoming_id = stream.incoming(); - if let Some(trailers) = self.finish_incoming_stream(incoming_id).await? { - self.table_mut() - .delete_fields(trailers) - .context("[drop_incoming_response] deleting trailers") - .unwrap_or_else(|_| ()); - } - self.table_mut().delete_stream(id).ok(); - } - if let Some(h) = headers { - self.table_mut().delete_fields(h).ok(); - } - - self.table_mut() - .delete_response(response) + fn drop_incoming_response(&mut self, response: IncomingResponse) -> wasmtime::Result<()> { + self.table() + .delete_incoming_response(response) .context("[drop_incoming_response] deleting response")?; Ok(()) } - async fn drop_outgoing_response( - &mut self, - _response: OutgoingResponse, - ) -> wasmtime::Result<()> { - bail!("unimplemented: drop_outgoing_response") + fn drop_outgoing_response(&mut self, _response: OutgoingResponse) -> wasmtime::Result<()> { + todo!("we haven't implemented the server side of wasi-http yet") } - async fn incoming_response_status( + fn incoming_response_status( &mut self, response: IncomingResponse, ) -> wasmtime::Result { let r = self .table() - .get_response(response) + .get_incoming_response(response) .context("[incoming_response_status] getting response")?; - Ok(r.status()) + Ok(r.status) } - async fn incoming_response_headers( + fn incoming_response_headers( &mut self, response: IncomingResponse, ) -> wasmtime::Result { - let r = self + let _ = self .table() - .get_response(response) + .get_incoming_response_mut(response) .context("[incoming_response_headers] getting response")?; - Ok(r.headers().unwrap_or(0 as Headers)) + + fn get_fields(elem: &mut dyn Any) -> &mut FieldMap { + &mut elem.downcast_mut::().unwrap().headers + } + + let id = self.table().push_fields(HostFields::Ref { + parent: response, + get_fields, + })?; + + Ok(id) } - async fn incoming_response_consume( + fn incoming_response_consume( &mut self, response: IncomingResponse, - ) -> wasmtime::Result> { - let table = self.table_mut(); + ) -> wasmtime::Result> { + let table = self.table(); let r = table - .get_response(response) + .get_incoming_response_mut(response) .context("[incoming_response_consume] getting response")?; - Ok(Ok(r - .body() - .map(|id| { - table - .get_stream(id) - .map(|stream| stream.incoming()) - .expect("[incoming_response_consume] response body stream") - }) - .unwrap_or(0 as IncomingStream))) - } - async fn new_outgoing_response( + + match r.body.take() { + Some(builder) => { + let id = self.table().push_incoming_body(builder.build())?; + Ok(Ok(id)) + } + + None => Ok(Err(())), + } + } + fn drop_future_trailers(&mut self, id: FutureTrailers) -> wasmtime::Result<()> { + self.table() + .delete_future_trailers(id) + .context("[drop future-trailers] deleting future-trailers")?; + Ok(()) + } + + fn future_trailers_subscribe(&mut self, index: FutureTrailers) -> wasmtime::Result { + // Eagerly force errors about the validity of the index. + let _ = self.table().get_future_trailers(index)?; + + fn make_future(elem: &mut dyn Any) -> PollableFuture { + Box::pin(elem.downcast_mut::().unwrap().ready()) + } + + let id = self + .table() + .push_host_pollable(HostPollable::TableEntry { index, make_future })?; + + Ok(id) + } + + fn future_trailers_get( + &mut self, + id: FutureTrailers, + ) -> wasmtime::Result>> { + let trailers = self.table().get_future_trailers(id)?; + match &trailers.state { + HostFutureTrailersState::Waiting(_) => return Ok(None), + HostFutureTrailersState::Done(Err(e)) => return Ok(Some(Err(e.clone()))), + HostFutureTrailersState::Done(Ok(_)) => {} + } + + fn get_fields(elem: &mut dyn Any) -> &mut FieldMap { + let trailers = elem.downcast_mut::().unwrap(); + match &mut trailers.state { + HostFutureTrailersState::Done(Ok(e)) => e, + _ => unreachable!(), + } + } + + let hdrs = self.table().push_fields(HostFields::Ref { + parent: id, + get_fields, + })?; + + Ok(Some(Ok(hdrs))) + } + + fn new_outgoing_response( &mut self, _status_code: StatusCode, _headers: Headers, ) -> wasmtime::Result { - bail!("unimplemented: new_outgoing_response") + todo!("we haven't implemented the server side of wasi-http yet") } - async fn outgoing_response_write( + fn outgoing_response_write( &mut self, _response: OutgoingResponse, - ) -> wasmtime::Result> { - bail!("unimplemented: outgoing_response_write") + ) -> wasmtime::Result> { + todo!("we haven't implemented the server side of wasi-http yet") } - async fn drop_future_incoming_response( + fn drop_future_incoming_response( &mut self, - future: FutureIncomingResponse, + id: FutureIncomingResponse, ) -> wasmtime::Result<()> { - self.table_mut() - .delete_future(future) - .context("[drop_future_incoming_response] deleting future")?; + let _ = self.table().delete_future_incoming_response(id)?; Ok(()) } - async fn future_incoming_response_get( + fn future_incoming_response_get( &mut self, - future: FutureIncomingResponse, - ) -> wasmtime::Result>> { - let f = self - .table() - .get_future(future) - .context("[future_incoming_response_get] getting future")?; - Ok(match f.pollable_id() { - Some(_) => { - let result = match f.response_id() { - Some(id) => Ok(id), - None => { - let response = self.handle_async(f.request_id(), f.options()).await; - match response { - Ok(id) => { - tracing::debug!( - "including response id to future incoming response" - ); - let future_mut = self.table_mut().get_future_mut(future)?; - future_mut.set_response_id(id); - tracing::trace!( - "future incoming response details {:?}", - *future_mut - ); - } - _ => {} - } - response - } - }; - Some(result) - } - None => None, - }) + id: FutureIncomingResponse, + ) -> wasmtime::Result, ()>>> { + let resp = self.table().get_future_incoming_response_mut(id)?; + + match resp { + HostFutureIncomingResponse::Pending(_) => return Ok(None), + HostFutureIncomingResponse::Consumed => return Ok(Some(Err(()))), + HostFutureIncomingResponse::Ready(_) => {} + } + + let resp = + match std::mem::replace(resp, HostFutureIncomingResponse::Consumed).unwrap_ready() { + Err(e) => { + // Trapping if it's not possible to downcast to an wasi-http error + let e = e.downcast::()?; + return Ok(Some(Ok(Err(e)))); + } + + Ok(resp) => resp, + }; + + let (parts, body) = resp.resp.into_parts(); + + let resp = self.table().push_incoming_response(HostIncomingResponse { + status: parts.status.as_u16(), + headers: FieldMap::from(parts.headers), + body: Some(HostIncomingBodyBuilder { + body, + between_bytes_timeout: resp.between_bytes_timeout, + }), + worker: resp.worker, + })?; + + Ok(Some(Ok(Ok(resp)))) } - async fn listen_to_future_incoming_response( + fn listen_to_future_incoming_response( &mut self, - future: FutureIncomingResponse, + id: FutureIncomingResponse, ) -> wasmtime::Result { - let f = self - .table() - .get_future(future) - .context("[listen_to_future_incoming_response] getting future")?; - Ok(match f.pollable_id() { - Some(pollable_id) => pollable_id, - None => { - tracing::debug!("including pollable id to future incoming response"); - let pollable = - HostPollable::Closure(Box::new(|| Box::pin(futures::future::ready(Ok(()))))); - let pollable_id = self - .table_mut() - .push_host_pollable(pollable) - .context("[listen_to_future_incoming_response] pushing host pollable")?; - let f = self - .table_mut() - .get_future_mut(future) - .context("[listen_to_future_incoming_response] getting future")?; - f.set_pollable_id(pollable_id); - tracing::trace!("future incoming response details {:?}", *f); - pollable_id - } - }) - } -} + let _ = self.table().get_future_incoming_response(id)?; -#[cfg(feature = "sync")] -pub mod sync { - use crate::bindings::http::types::{ - Error as AsyncError, Host as AsyncHost, Method as AsyncMethod, Scheme as AsyncScheme, - }; - use crate::bindings::sync::http::types::{ - Error, Fields, FutureIncomingResponse, Headers, IncomingRequest, IncomingResponse, - IncomingStream, Method, OutgoingRequest, OutgoingResponse, OutgoingStream, - ResponseOutparam, Scheme, StatusCode, Trailers, - }; - use crate::http_impl::WasiHttpViewExt; - use crate::WasiHttpView; - use wasmtime_wasi::preview2::{bindings::poll::poll::Pollable, in_tokio}; - - // same boilerplate everywhere, converting between two identical types with different - // definition sites. one day wasmtime-wit-bindgen will make all this unnecessary - impl From for Error { - fn from(other: AsyncError) -> Self { - match other { - AsyncError::InvalidUrl(v) => Self::InvalidUrl(v), - AsyncError::ProtocolError(v) => Self::ProtocolError(v), - AsyncError::TimeoutError(v) => Self::TimeoutError(v), - AsyncError::UnexpectedError(v) => Self::UnexpectedError(v), - } + fn make_future<'a>(elem: &'a mut dyn Any) -> PollableFuture<'a> { + Box::pin( + elem.downcast_mut::() + .expect("parent resource is HostFutureIncomingResponse"), + ) } + + let pollable = self.table().push_host_pollable(HostPollable::TableEntry { + index: id, + make_future, + })?; + + Ok(pollable) } - impl From for AsyncError { - fn from(other: Error) -> Self { - match other { - Error::InvalidUrl(v) => Self::InvalidUrl(v), - Error::ProtocolError(v) => Self::ProtocolError(v), - Error::TimeoutError(v) => Self::TimeoutError(v), - Error::UnexpectedError(v) => Self::UnexpectedError(v), - } + fn incoming_body_stream( + &mut self, + id: IncomingBody, + ) -> wasmtime::Result> { + let body = self.table().get_incoming_body(id)?; + + if let Some(stream) = body.stream.take() { + let stream = self.table().push_input_stream_child(Box::new(stream), id)?; + return Ok(Ok(stream)); } + + Ok(Err(())) } - impl From for Method { - fn from(other: AsyncMethod) -> Self { - match other { - AsyncMethod::Connect => Self::Connect, - AsyncMethod::Delete => Self::Delete, - AsyncMethod::Get => Self::Get, - AsyncMethod::Head => Self::Head, - AsyncMethod::Options => Self::Options, - AsyncMethod::Patch => Self::Patch, - AsyncMethod::Post => Self::Post, - AsyncMethod::Put => Self::Put, - AsyncMethod::Trace => Self::Trace, - AsyncMethod::Other(v) => Self::Other(v), - } - } + fn incoming_body_finish(&mut self, id: IncomingBody) -> wasmtime::Result { + let body = self.table().delete_incoming_body(id)?; + let trailers = self + .table() + .push_future_trailers(body.into_future_trailers())?; + Ok(trailers) } - impl From for AsyncMethod { - fn from(other: Method) -> Self { - match other { - Method::Connect => Self::Connect, - Method::Delete => Self::Delete, - Method::Get => Self::Get, - Method::Head => Self::Head, - Method::Options => Self::Options, - Method::Patch => Self::Patch, - Method::Post => Self::Post, - Method::Put => Self::Put, - Method::Trace => Self::Trace, - Method::Other(v) => Self::Other(v), - } - } + fn drop_incoming_body(&mut self, id: IncomingBody) -> wasmtime::Result<()> { + let _ = self.table().delete_incoming_body(id)?; + Ok(()) } - impl From for Scheme { - fn from(other: AsyncScheme) -> Self { - match other { - AsyncScheme::Http => Self::Http, - AsyncScheme::Https => Self::Https, - AsyncScheme::Other(v) => Self::Other(v), - } + fn outgoing_body_write( + &mut self, + id: OutgoingBody, + ) -> wasmtime::Result> { + let body = self.table().get_outgoing_body(id)?; + if let Some(stream) = body.body_output_stream.take() { + let id = self.table().push_output_stream_child(stream, id)?; + Ok(Ok(id)) + } else { + Ok(Err(())) } } - impl From for AsyncScheme { - fn from(other: Scheme) -> Self { - match other { - Scheme::Http => Self::Http, - Scheme::Https => Self::Https, - Scheme::Other(v) => Self::Other(v), - } + fn outgoing_body_write_trailers( + &mut self, + id: OutgoingBody, + ts: Trailers, + ) -> wasmtime::Result<()> { + let mut body = self.table().delete_outgoing_body(id)?; + let trailers = self.table().get_fields(ts)?.clone(); + + match body + .trailers_sender + .take() + // Should be unreachable - this is the only place we take the trailers sender, + // at the end of the HostOutgoingBody's lifetime + .ok_or_else(|| anyhow!("trailers_sender missing"))? + .send(trailers.into()) + { + Ok(()) => {} + Err(_) => {} // Ignoring failure: receiver died sending body, but we can't report that + // here. } + + Ok(()) } - impl crate::bindings::sync::http::types::Host for T { - fn drop_fields(&mut self, fields: Fields) -> wasmtime::Result<()> { - in_tokio(async { AsyncHost::drop_fields(self, fields).await }) - } - fn new_fields(&mut self, entries: Vec<(String, String)>) -> wasmtime::Result { - in_tokio(async { AsyncHost::new_fields(self, entries).await }) - } - fn fields_get(&mut self, fields: Fields, name: String) -> wasmtime::Result>> { - in_tokio(async { AsyncHost::fields_get(self, fields, name).await }) - } - fn fields_set( - &mut self, - fields: Fields, - name: String, - value: Vec>, - ) -> wasmtime::Result<()> { - in_tokio(async { AsyncHost::fields_set(self, fields, name, value).await }) - } - fn fields_delete(&mut self, fields: Fields, name: String) -> wasmtime::Result<()> { - in_tokio(async { AsyncHost::fields_delete(self, fields, name).await }) - } - fn fields_append( - &mut self, - fields: Fields, - name: String, - value: Vec, - ) -> wasmtime::Result<()> { - in_tokio(async { AsyncHost::fields_append(self, fields, name, value).await }) - } - fn fields_entries(&mut self, fields: Fields) -> wasmtime::Result)>> { - in_tokio(async { AsyncHost::fields_entries(self, fields).await }) - } - fn fields_clone(&mut self, fields: Fields) -> wasmtime::Result { - in_tokio(async { AsyncHost::fields_clone(self, fields).await }) - } - fn finish_incoming_stream( - &mut self, - stream_id: IncomingStream, - ) -> wasmtime::Result> { - in_tokio(async { AsyncHost::finish_incoming_stream(self, stream_id).await }) - } - fn finish_outgoing_stream( - &mut self, - stream: OutgoingStream, - trailers: Option, - ) -> wasmtime::Result<()> { - in_tokio(async { AsyncHost::finish_outgoing_stream(self, stream, trailers).await }) - } - fn drop_incoming_request(&mut self, request: IncomingRequest) -> wasmtime::Result<()> { - in_tokio(async { AsyncHost::drop_incoming_request(self, request).await }) - } - fn drop_outgoing_request(&mut self, request: OutgoingRequest) -> wasmtime::Result<()> { - in_tokio(async { AsyncHost::drop_outgoing_request(self, request).await }) - } - fn incoming_request_method( - &mut self, - request: IncomingRequest, - ) -> wasmtime::Result { - in_tokio(async { AsyncHost::incoming_request_method(self, request).await }) - .map(Method::from) - } - fn incoming_request_path_with_query( - &mut self, - request: IncomingRequest, - ) -> wasmtime::Result> { - in_tokio(async { AsyncHost::incoming_request_path_with_query(self, request).await }) - } - fn incoming_request_scheme( - &mut self, - request: IncomingRequest, - ) -> wasmtime::Result> { - Ok( - in_tokio(async { AsyncHost::incoming_request_scheme(self, request).await })? - .map(Scheme::from), - ) - } - fn incoming_request_authority( - &mut self, - request: IncomingRequest, - ) -> wasmtime::Result> { - in_tokio(async { AsyncHost::incoming_request_authority(self, request).await }) - } - fn incoming_request_headers( - &mut self, - request: IncomingRequest, - ) -> wasmtime::Result { - in_tokio(async { AsyncHost::incoming_request_headers(self, request).await }) - } - fn incoming_request_consume( - &mut self, - request: IncomingRequest, - ) -> wasmtime::Result> { - in_tokio(async { AsyncHost::incoming_request_consume(self, request).await }) - } - fn new_outgoing_request( - &mut self, - method: Method, - path_with_query: Option, - scheme: Option, - authority: Option, - headers: Headers, - ) -> wasmtime::Result { - in_tokio(async { - AsyncHost::new_outgoing_request( - self, - method.into(), - path_with_query, - scheme.map(AsyncScheme::from), - authority, - headers, - ) - .await - }) - } - fn outgoing_request_write( - &mut self, - request: OutgoingRequest, - ) -> wasmtime::Result> { - in_tokio(async { AsyncHost::outgoing_request_write(self, request).await }) - } - fn drop_response_outparam(&mut self, response: ResponseOutparam) -> wasmtime::Result<()> { - in_tokio(async { AsyncHost::drop_response_outparam(self, response).await }) - } - fn set_response_outparam( - &mut self, - outparam: ResponseOutparam, - response: Result, - ) -> wasmtime::Result> { - in_tokio(async { - AsyncHost::set_response_outparam(self, outparam, response.map_err(AsyncError::from)) - .await - }) - } - fn drop_incoming_response(&mut self, response: IncomingResponse) -> wasmtime::Result<()> { - in_tokio(async { AsyncHost::drop_incoming_response(self, response).await }) - } - fn drop_outgoing_response(&mut self, response: OutgoingResponse) -> wasmtime::Result<()> { - in_tokio(async { AsyncHost::drop_outgoing_response(self, response).await }) - } - fn incoming_response_status( - &mut self, - response: IncomingResponse, - ) -> wasmtime::Result { - in_tokio(async { AsyncHost::incoming_response_status(self, response).await }) - } - fn incoming_response_headers( - &mut self, - response: IncomingResponse, - ) -> wasmtime::Result { - in_tokio(async { AsyncHost::incoming_response_headers(self, response).await }) - } - fn incoming_response_consume( - &mut self, - response: IncomingResponse, - ) -> wasmtime::Result> { - in_tokio(async { AsyncHost::incoming_response_consume(self, response).await }) - } - fn new_outgoing_response( - &mut self, - status_code: StatusCode, - headers: Headers, - ) -> wasmtime::Result { - in_tokio(async { AsyncHost::new_outgoing_response(self, status_code, headers).await }) - } - fn outgoing_response_write( - &mut self, - response: OutgoingResponse, - ) -> wasmtime::Result> { - in_tokio(async { AsyncHost::outgoing_response_write(self, response).await }) - } - fn drop_future_incoming_response( - &mut self, - future: FutureIncomingResponse, - ) -> wasmtime::Result<()> { - in_tokio(async { AsyncHost::drop_future_incoming_response(self, future).await }) - } - fn future_incoming_response_get( - &mut self, - future: FutureIncomingResponse, - ) -> wasmtime::Result>> { - Ok( - in_tokio(async { AsyncHost::future_incoming_response_get(self, future).await })? - .map(|v| v.map_err(Error::from)), - ) - } - fn listen_to_future_incoming_response( - &mut self, - future: FutureIncomingResponse, - ) -> wasmtime::Result { - in_tokio(async { AsyncHost::listen_to_future_incoming_response(self, future).await }) - } + fn drop_outgoing_body(&mut self, id: OutgoingBody) -> wasmtime::Result<()> { + let _ = self.table().delete_outgoing_body(id)?; + Ok(()) } } diff --git a/crates/wasi-http/wit/deps/http/incoming-handler.wit b/crates/wasi-http/wit/deps/http/incoming-handler.wit index d0e270465593..ad8a43f8ccf0 100644 --- a/crates/wasi-http/wit/deps/http/incoming-handler.wit +++ b/crates/wasi-http/wit/deps/http/incoming-handler.wit @@ -12,13 +12,13 @@ interface incoming-handler { // The `handle` function takes an outparam instead of returning its response // so that the component may stream its response while streaming any other // request or response bodies. The callee MUST write a response to the - // `response-out` and then finish the response before returning. The `handle` + // `response-outparam` and then finish the response before returning. The `handle` // function is allowed to continue execution after finishing the response's // output stream. While this post-response execution is taken off the // critical path, since there is no return value, there is no way to report // its success or failure. handle: func( - request: incoming-request, - response-out: response-outparam + request: /* own */ incoming-request, + response-out: /* own */ response-outparam ) } diff --git a/crates/wasi-http/wit/deps/http/outgoing-handler.wit b/crates/wasi-http/wit/deps/http/outgoing-handler.wit index 06c8e469f95b..3e03327d742b 100644 --- a/crates/wasi-http/wit/deps/http/outgoing-handler.wit +++ b/crates/wasi-http/wit/deps/http/outgoing-handler.wit @@ -8,11 +8,20 @@ interface outgoing-handler { use types.{outgoing-request, request-options, future-incoming-response} + // FIXME: we would want to use the types.error here but there is a + // wasmtime-wit-bindgen bug that prevents us from using the same error in + // the two different interfaces, right now... + variant error { + invalid(string) + } + // The parameter and result types of the `handle` function allow the caller // to concurrently stream the bodies of the outgoing request and the incoming // response. + // Consumes the outgoing-request. Gives an error if the outgoing-request + // is invalid or cannot be satisfied by this handler. handle: func( - request: outgoing-request, + request: /* own */ outgoing-request, options: option - ) -> future-incoming-response + ) -> result } diff --git a/crates/wasi-http/wit/deps/http/types.wit b/crates/wasi-http/wit/deps/http/types.wit index 7b7b015529c0..821e15f96213 100644 --- a/crates/wasi-http/wit/deps/http/types.wit +++ b/crates/wasi-http/wit/deps/http/types.wit @@ -41,30 +41,28 @@ interface types { // fields = u32` type alias can be replaced by a proper `resource fields` // definition containing all the functions using the method syntactic sugar. type fields = u32 - drop-fields: func(fields: fields) - new-fields: func(entries: list>) -> fields - fields-get: func(fields: fields, name: string) -> list> - fields-set: func(fields: fields, name: string, value: list>) - fields-delete: func(fields: fields, name: string) - fields-append: func(fields: fields, name: string, value: list) - fields-entries: func(fields: fields) -> list>> - fields-clone: func(fields: fields) -> fields + drop-fields: func(fields: /* own */ fields) + // Multiple values for a header are multiple entries in the list with the + // same key. + new-fields: func(entries: list>>) -> fields + // Values off wire are not necessarily well formed, so they are given by + // list instead of string. + fields-get: func(fields: /* borrow */ fields, name: string) -> list> + // Values off wire are not necessarily well formed, so they are given by + // list instead of string. + fields-set: func(fields: /* borrow */ fields, name: string, value: list>) + fields-delete: func(fields: /* borrow */ fields, name: string) + fields-append: func(fields: /* borrow */ fields, name: string, value: list) + + // Values off wire are not necessarily well formed, so they are given by + // list instead of string. + fields-entries: func(fields: /* borrow */ fields) -> list>> + // Deep copy of all contents in a fields. + fields-clone: func(fields: /* borrow */ fields) -> fields type headers = fields type trailers = fields - // The following block defines stream types which corresponds to the HTTP - // standard Contents and Trailers. With Preview3, all of these fields can be - // replaced by a stream>. In the interim, we need to - // build on separate resource types defined by `wasi:io/streams`. The - // `finish-` functions emulate the stream's result value and MUST be called - // exactly once after the final read/write from/to the stream before dropping - // the stream. - type incoming-stream = input-stream - type outgoing-stream = output-stream - finish-incoming-stream: func(s: incoming-stream) -> option - finish-outgoing-stream: func(s: outgoing-stream, trailers: option) - // The following block defines the `incoming-request` and `outgoing-request` // resource types that correspond to HTTP standard Requests. Soon, when // resource types are added, the `u32` type aliases can be replaced by @@ -74,23 +72,30 @@ interface types { // above). The `consume` and `write` methods may only be called once (and // return failure thereafter). type incoming-request = u32 + drop-incoming-request: func(request: /* own */ incoming-request) + incoming-request-method: func(request: /* borrow */ incoming-request) -> method + incoming-request-path-with-query: func(request: /* borrow */ incoming-request) -> option + incoming-request-scheme: func(request: /* borrow */ incoming-request) -> option + incoming-request-authority: func(request: /* borrow */ incoming-request) -> option + + incoming-request-headers: func(request: /* borrow */ incoming-request) -> /* child */ headers + // Will return the input-stream child at most once. If called more than + // once, subsequent calls will return error. + incoming-request-consume: func(request: /* borrow */ incoming-request) -> result< /* child */ input-stream> + type outgoing-request = u32 - drop-incoming-request: func(request: incoming-request) - drop-outgoing-request: func(request: outgoing-request) - incoming-request-method: func(request: incoming-request) -> method - incoming-request-path-with-query: func(request: incoming-request) -> option - incoming-request-scheme: func(request: incoming-request) -> option - incoming-request-authority: func(request: incoming-request) -> option - incoming-request-headers: func(request: incoming-request) -> headers - incoming-request-consume: func(request: incoming-request) -> result + drop-outgoing-request: func(request: /* own */ outgoing-request) new-outgoing-request: func( method: method, path-with-query: option, scheme: option, authority: option, - headers: headers + headers: /* borrow */ headers ) -> outgoing-request - outgoing-request-write: func(request: outgoing-request) -> result + + // Will return the outgoing-body child at most once. If called more than + // once, subsequent calls will return error. + outgoing-request-write: func(request: /* borrow */ outgoing-request) -> result< /* child */ outgoing-body> // Additional optional parameters that can be set when making a request. record request-options { @@ -115,8 +120,8 @@ interface types { // (the `wasi:http/handler` interface used for both incoming and outgoing can // simply return a `stream`). type response-outparam = u32 - drop-response-outparam: func(response: response-outparam) - set-response-outparam: func(param: response-outparam, response: result) -> result + drop-response-outparam: func(response: /* own */ response-outparam) + set-response-outparam: func(param: /* own */ response-outparam, response: result< /* own */ outgoing-response, error>) // This type corresponds to the HTTP standard Status Code. type status-code = u16 @@ -129,27 +134,72 @@ interface types { // type (that uses the single `stream` type mentioned above). The `consume` and // `write` methods may only be called once (and return failure thereafter). type incoming-response = u32 + drop-incoming-response: func(response: /* own */ incoming-response) + incoming-response-status: func(response: /* borrow */ incoming-response) -> status-code + incoming-response-headers: func(response: /* borrow */ incoming-response) -> /* child */ headers + // May be called at most once. returns error if called additional times. + // TODO: make incoming-request-consume work the same way, giving a child + // incoming-body. + incoming-response-consume: func(response: /* borrow */ incoming-response) -> result + + type incoming-body = u32 + drop-incoming-body: func(this: /* own */ incoming-body) + + // returned input-stream is a child - the implementation may trap if + // incoming-body is dropped (or consumed by call to + // incoming-body-finish) before the input-stream is dropped. + // May be called at most once. returns error if called additional times. + incoming-body-stream: func(this: /* borrow */ incoming-body) -> + result + // takes ownership of incoming-body. this will trap if the + // incoming-body-stream child is still alive! + incoming-body-finish: func(this: /* own */ incoming-body) -> + /* transitive child of the incoming-response of incoming-body */ future-trailers + + type future-trailers = u32 + drop-future-trailers: func(this: /* own */ future-trailers) + /// Pollable that resolves when the body has been fully read, and the trailers + /// are ready to be consumed. + future-trailers-subscribe: func(this: /* borrow */ future-trailers) -> /* child */ pollable + + /// Retrieve reference to trailers, if they are ready. + future-trailers-get: func(response: /* borrow */ future-trailers) -> option> + type outgoing-response = u32 - drop-incoming-response: func(response: incoming-response) - drop-outgoing-response: func(response: outgoing-response) - incoming-response-status: func(response: incoming-response) -> status-code - incoming-response-headers: func(response: incoming-response) -> headers - incoming-response-consume: func(response: incoming-response) -> result + drop-outgoing-response: func(response: /* own */ outgoing-response) new-outgoing-response: func( status-code: status-code, - headers: headers + headers: /* borrow */ headers ) -> outgoing-response - outgoing-response-write: func(response: outgoing-response) -> result - // The following block defines a special resource type used by the - // `wasi:http/outgoing-handler` interface to emulate - // `future>` in advance of Preview3. Given a - // `future-incoming-response`, the client can call the non-blocking `get` - // method to get the result if it is available. If the result is not available, - // the client can call `listen` to get a `pollable` that can be passed to - // `io.poll.poll-oneoff`. + /// Will give the child outgoing-response at most once. subsequent calls will + /// return an error. + outgoing-response-write: func(this: /* borrow */ outgoing-response) -> result + + type outgoing-body = u32 + drop-outgoing-body: func(this: /* own */ outgoing-body) + /// Will give the child output-stream at most once. subsequent calls will + /// return an error. + outgoing-body-write: func(this: /* borrow */ outgoing-body) -> result + /// Write trailers as the way to finish an outgoing-body. To finish an + /// outgoing-body without writing trailers, use drop-outgoing-body. + outgoing-body-write-trailers: func(this: /* own */ outgoing-body, trailers: /* own */ trailers) + + /// The following block defines a special resource type used by the + /// `wasi:http/outgoing-handler` interface to emulate + /// `future>` in advance of Preview3. Given a + /// `future-incoming-response`, the client can call the non-blocking `get` + /// method to get the result if it is available. If the result is not available, + /// the client can call `listen` to get a `pollable` that can be passed to + /// `io.poll.poll-oneoff`. type future-incoming-response = u32 - drop-future-incoming-response: func(f: future-incoming-response) - future-incoming-response-get: func(f: future-incoming-response) -> option> - listen-to-future-incoming-response: func(f: future-incoming-response) -> pollable + drop-future-incoming-response: func(f: /* own */ future-incoming-response) + /// option indicates readiness. + /// outer result indicates you are allowed to get the + /// incoming-response-or-error at most once. subsequent calls after ready + /// will return an error here. + /// inner result indicates whether the incoming-response was available, or an + /// error occured. + future-incoming-response-get: func(f: /* borrow */ future-incoming-response) -> option>> + listen-to-future-incoming-response: func(f: /* borrow */ future-incoming-response) -> /* child */ pollable } diff --git a/crates/wasi/src/preview2/mod.rs b/crates/wasi/src/preview2/mod.rs index 2e0a3e95e114..6d9ad182dcec 100644 --- a/crates/wasi/src/preview2/mod.rs +++ b/crates/wasi/src/preview2/mod.rs @@ -31,6 +31,7 @@ mod stdio; mod stream; mod table; mod tcp; +mod write_stream; pub use self::clocks::{HostMonotonicClock, HostWallClock}; pub use self::ctx::{WasiCtx, WasiCtxBuilder, WasiView}; @@ -151,7 +152,7 @@ pub(crate) static RUNTIME: once_cell::sync::Lazy = .unwrap() }); -pub(crate) struct AbortOnDropJoinHandle(tokio::task::JoinHandle); +pub struct AbortOnDropJoinHandle(tokio::task::JoinHandle); impl Drop for AbortOnDropJoinHandle { fn drop(&mut self) { self.0.abort() @@ -188,7 +189,7 @@ impl std::future::Future for AbortOnDropJoinHandle { } } -pub(crate) fn spawn(f: F) -> AbortOnDropJoinHandle +pub fn spawn(f: F) -> AbortOnDropJoinHandle where F: std::future::Future + Send + 'static, G: Send + 'static, diff --git a/crates/wasi/src/preview2/pipe.rs b/crates/wasi/src/preview2/pipe.rs index a01215c11323..69749db9d43c 100644 --- a/crates/wasi/src/preview2/pipe.rs +++ b/crates/wasi/src/preview2/pipe.rs @@ -10,9 +10,10 @@ use crate::preview2::{HostInputStream, HostOutputStream, OutputStreamError, StreamState}; use anyhow::{anyhow, Error}; use bytes::Bytes; -use std::sync::{Arc, Mutex}; use tokio::sync::mpsc; +pub use crate::preview2::write_stream::AsyncWriteStream; + #[derive(Debug)] pub struct MemoryInputPipe { buffer: std::io::Cursor, @@ -221,198 +222,6 @@ impl HostInputStream for AsyncReadStream { } } -#[derive(Debug)] -struct WorkerState { - alive: bool, - items: std::collections::VecDeque, - write_budget: usize, - flush_pending: bool, - error: Option, -} - -impl WorkerState { - fn check_error(&mut self) -> Result<(), OutputStreamError> { - if let Some(e) = self.error.take() { - return Err(OutputStreamError::LastOperationFailed(e)); - } - if !self.alive { - return Err(OutputStreamError::Closed); - } - Ok(()) - } -} - -struct Worker { - state: Mutex, - new_work: tokio::sync::Notify, - write_ready_changed: tokio::sync::Notify, -} - -enum Job { - Flush, - Write(Bytes), -} - -enum WriteStatus<'a> { - Done(Result), - Pending(tokio::sync::futures::Notified<'a>), -} - -impl Worker { - fn new(write_budget: usize) -> Self { - Self { - state: Mutex::new(WorkerState { - alive: true, - items: std::collections::VecDeque::new(), - write_budget, - flush_pending: false, - error: None, - }), - new_work: tokio::sync::Notify::new(), - write_ready_changed: tokio::sync::Notify::new(), - } - } - fn check_write(&self) -> WriteStatus<'_> { - let mut state = self.state(); - if let Err(e) = state.check_error() { - return WriteStatus::Done(Err(e)); - } - - if state.flush_pending || state.write_budget == 0 { - return WriteStatus::Pending(self.write_ready_changed.notified()); - } - - WriteStatus::Done(Ok(state.write_budget)) - } - fn state(&self) -> std::sync::MutexGuard { - self.state.lock().unwrap() - } - fn pop(&self) -> Option { - let mut state = self.state(); - if state.items.is_empty() { - if state.flush_pending { - return Some(Job::Flush); - } - } else if let Some(bytes) = state.items.pop_front() { - return Some(Job::Write(bytes)); - } - - None - } - fn report_error(&self, e: std::io::Error) { - { - let mut state = self.state(); - state.alive = false; - state.error = Some(e.into()); - state.flush_pending = false; - } - self.write_ready_changed.notify_waiters(); - } - async fn work(&self, mut writer: T) { - use tokio::io::AsyncWriteExt; - loop { - let notified = self.new_work.notified(); - while let Some(job) = self.pop() { - match job { - Job::Flush => { - if let Err(e) = writer.flush().await { - self.report_error(e); - return; - } - - tracing::debug!("worker marking flush complete"); - self.state().flush_pending = false; - } - - Job::Write(mut bytes) => { - tracing::debug!("worker writing: {bytes:?}"); - let len = bytes.len(); - match writer.write_all_buf(&mut bytes).await { - Err(e) => { - self.report_error(e); - return; - } - Ok(_) => { - self.state().write_budget += len; - } - } - } - } - - self.write_ready_changed.notify_waiters(); - } - - notified.await; - } - } -} - -/// Provides a [`HostOutputStream`] impl from a [`tokio::io::AsyncWrite`] impl -pub struct AsyncWriteStream { - worker: Arc, - _join_handle: crate::preview2::AbortOnDropJoinHandle<()>, -} - -impl AsyncWriteStream { - /// Create a [`AsyncWriteStream`]. In order to use the [`HostOutputStream`] impl - /// provided by this struct, the argument must impl [`tokio::io::AsyncWrite`]. - pub fn new( - write_budget: usize, - writer: T, - ) -> Self { - let worker = Arc::new(Worker::new(write_budget)); - - let w = Arc::clone(&worker); - let join_handle = crate::preview2::spawn(async move { w.work(writer).await }); - - AsyncWriteStream { - worker, - _join_handle: join_handle, - } - } -} - -#[async_trait::async_trait] -impl HostOutputStream for AsyncWriteStream { - fn write(&mut self, bytes: Bytes) -> Result<(), OutputStreamError> { - let mut state = self.worker.state(); - state.check_error()?; - if state.flush_pending { - return Err(OutputStreamError::Trap(anyhow!( - "write not permitted while flush pending" - ))); - } - match state.write_budget.checked_sub(bytes.len()) { - Some(remaining_budget) => { - state.write_budget = remaining_budget; - state.items.push_back(bytes); - } - None => return Err(OutputStreamError::Trap(anyhow!("write exceeded budget"))), - } - drop(state); - self.worker.new_work.notify_waiters(); - Ok(()) - } - fn flush(&mut self) -> Result<(), OutputStreamError> { - let mut state = self.worker.state(); - state.check_error()?; - - state.flush_pending = true; - self.worker.new_work.notify_waiters(); - - Ok(()) - } - - async fn write_ready(&mut self) -> Result { - loop { - match self.worker.check_write() { - WriteStatus::Done(r) => return r, - WriteStatus::Pending(notifier) => notifier.await, - } - } - } -} - /// An output stream that consumes all input written to it, and is always ready. pub struct SinkOutputStream; diff --git a/crates/wasi/src/preview2/table.rs b/crates/wasi/src/preview2/table.rs index f4efd8005ed1..de9b50a6f772 100644 --- a/crates/wasi/src/preview2/table.rs +++ b/crates/wasi/src/preview2/table.rs @@ -174,6 +174,15 @@ impl Table { } } + /// Get a mutable reference to the underlying untyped cell for an entry in the table. + pub fn get_any_mut(&mut self, key: u32) -> Result<&mut dyn Any, TableError> { + if let Some(r) = self.map.get_mut(&key) { + Ok(&mut *r.entry) + } else { + Err(TableError::NotPresent) + } + } + /// Get an immutable reference to a resource of a given type at a given index. Multiple /// immutable references can be borrowed at any given time. Borrow failure /// results in a trapping error. @@ -203,7 +212,7 @@ impl Table { /// remove or replace the entry based on its contents. The methods available are a subset of /// [`std::collections::hash_map::OccupiedEntry`] - it does not give access to the key, it /// restricts replacing the entry to items of the same type, and it does not allow for deletion. - pub fn entry(&mut self, index: u32) -> Result { + pub fn entry<'a>(&'a mut self, index: u32) -> Result, TableError> { if self.map.contains_key(&index) { Ok(OccupiedEntry { table: self, index }) } else { diff --git a/crates/wasi/src/preview2/write_stream.rs b/crates/wasi/src/preview2/write_stream.rs new file mode 100644 index 000000000000..bf154aec0720 --- /dev/null +++ b/crates/wasi/src/preview2/write_stream.rs @@ -0,0 +1,196 @@ +use crate::preview2::{HostOutputStream, OutputStreamError}; +use anyhow::anyhow; +use bytes::Bytes; +use std::sync::{Arc, Mutex}; + +#[derive(Debug)] +struct WorkerState { + alive: bool, + items: std::collections::VecDeque, + write_budget: usize, + flush_pending: bool, + error: Option, +} + +impl WorkerState { + fn check_error(&mut self) -> Result<(), OutputStreamError> { + if let Some(e) = self.error.take() { + return Err(OutputStreamError::LastOperationFailed(e)); + } + if !self.alive { + return Err(OutputStreamError::Closed); + } + Ok(()) + } +} + +struct Worker { + state: Mutex, + new_work: tokio::sync::Notify, + write_ready_changed: tokio::sync::Notify, +} + +enum Job { + Flush, + Write(Bytes), +} + +enum WriteStatus<'a> { + Done(Result), + Pending(tokio::sync::futures::Notified<'a>), +} + +impl Worker { + fn new(write_budget: usize) -> Self { + Self { + state: Mutex::new(WorkerState { + alive: true, + items: std::collections::VecDeque::new(), + write_budget, + flush_pending: false, + error: None, + }), + new_work: tokio::sync::Notify::new(), + write_ready_changed: tokio::sync::Notify::new(), + } + } + fn check_write(&self) -> WriteStatus<'_> { + let mut state = self.state(); + if let Err(e) = state.check_error() { + return WriteStatus::Done(Err(e)); + } + + if state.flush_pending || state.write_budget == 0 { + return WriteStatus::Pending(self.write_ready_changed.notified()); + } + + WriteStatus::Done(Ok(state.write_budget)) + } + fn state(&self) -> std::sync::MutexGuard { + self.state.lock().unwrap() + } + fn pop(&self) -> Option { + let mut state = self.state(); + if state.items.is_empty() { + if state.flush_pending { + return Some(Job::Flush); + } + } else if let Some(bytes) = state.items.pop_front() { + return Some(Job::Write(bytes)); + } + + None + } + fn report_error(&self, e: std::io::Error) { + { + let mut state = self.state(); + state.alive = false; + state.error = Some(e.into()); + state.flush_pending = false; + } + self.write_ready_changed.notify_waiters(); + } + async fn work(&self, mut writer: T) { + use tokio::io::AsyncWriteExt; + loop { + let notified = self.new_work.notified(); + while let Some(job) = self.pop() { + match job { + Job::Flush => { + if let Err(e) = writer.flush().await { + self.report_error(e); + return; + } + + tracing::debug!("worker marking flush complete"); + self.state().flush_pending = false; + } + + Job::Write(mut bytes) => { + tracing::debug!("worker writing: {bytes:?}"); + let len = bytes.len(); + match writer.write_all_buf(&mut bytes).await { + Err(e) => { + self.report_error(e); + return; + } + Ok(_) => { + self.state().write_budget += len; + } + } + } + } + + self.write_ready_changed.notify_waiters(); + } + + notified.await; + } + } +} + +/// Provides a [`HostOutputStream`] impl from a [`tokio::io::AsyncWrite`] impl +pub struct AsyncWriteStream { + worker: Arc, + _join_handle: crate::preview2::AbortOnDropJoinHandle<()>, +} + +impl AsyncWriteStream { + /// Create a [`AsyncWriteStream`]. In order to use the [`HostOutputStream`] impl + /// provided by this struct, the argument must impl [`tokio::io::AsyncWrite`]. + pub fn new( + write_budget: usize, + writer: T, + ) -> Self { + let worker = Arc::new(Worker::new(write_budget)); + + let w = Arc::clone(&worker); + let join_handle = crate::preview2::spawn(async move { w.work(writer).await }); + + AsyncWriteStream { + worker, + _join_handle: join_handle, + } + } +} + +#[async_trait::async_trait] +impl HostOutputStream for AsyncWriteStream { + fn write(&mut self, bytes: Bytes) -> Result<(), OutputStreamError> { + let mut state = self.worker.state(); + state.check_error()?; + if state.flush_pending { + return Err(OutputStreamError::Trap(anyhow!( + "write not permitted while flush pending" + ))); + } + match state.write_budget.checked_sub(bytes.len()) { + Some(remaining_budget) => { + state.write_budget = remaining_budget; + state.items.push_back(bytes); + } + None => return Err(OutputStreamError::Trap(anyhow!("write exceeded budget"))), + } + drop(state); + self.worker.new_work.notify_waiters(); + Ok(()) + } + fn flush(&mut self) -> Result<(), OutputStreamError> { + let mut state = self.worker.state(); + state.check_error()?; + + state.flush_pending = true; + self.worker.new_work.notify_waiters(); + + Ok(()) + } + + async fn write_ready(&mut self) -> Result { + loop { + match self.worker.check_write() { + WriteStatus::Done(r) => return r, + WriteStatus::Pending(notifier) => notifier.await, + } + } + } +} diff --git a/crates/wasi/wit/deps/http/incoming-handler.wit b/crates/wasi/wit/deps/http/incoming-handler.wit index d0e270465593..ad8a43f8ccf0 100644 --- a/crates/wasi/wit/deps/http/incoming-handler.wit +++ b/crates/wasi/wit/deps/http/incoming-handler.wit @@ -12,13 +12,13 @@ interface incoming-handler { // The `handle` function takes an outparam instead of returning its response // so that the component may stream its response while streaming any other // request or response bodies. The callee MUST write a response to the - // `response-out` and then finish the response before returning. The `handle` + // `response-outparam` and then finish the response before returning. The `handle` // function is allowed to continue execution after finishing the response's // output stream. While this post-response execution is taken off the // critical path, since there is no return value, there is no way to report // its success or failure. handle: func( - request: incoming-request, - response-out: response-outparam + request: /* own */ incoming-request, + response-out: /* own */ response-outparam ) } diff --git a/crates/wasi/wit/deps/http/outgoing-handler.wit b/crates/wasi/wit/deps/http/outgoing-handler.wit index 06c8e469f95b..3e03327d742b 100644 --- a/crates/wasi/wit/deps/http/outgoing-handler.wit +++ b/crates/wasi/wit/deps/http/outgoing-handler.wit @@ -8,11 +8,20 @@ interface outgoing-handler { use types.{outgoing-request, request-options, future-incoming-response} + // FIXME: we would want to use the types.error here but there is a + // wasmtime-wit-bindgen bug that prevents us from using the same error in + // the two different interfaces, right now... + variant error { + invalid(string) + } + // The parameter and result types of the `handle` function allow the caller // to concurrently stream the bodies of the outgoing request and the incoming // response. + // Consumes the outgoing-request. Gives an error if the outgoing-request + // is invalid or cannot be satisfied by this handler. handle: func( - request: outgoing-request, + request: /* own */ outgoing-request, options: option - ) -> future-incoming-response + ) -> result } diff --git a/crates/wasi/wit/deps/http/types.wit b/crates/wasi/wit/deps/http/types.wit index 7b7b015529c0..821e15f96213 100644 --- a/crates/wasi/wit/deps/http/types.wit +++ b/crates/wasi/wit/deps/http/types.wit @@ -41,30 +41,28 @@ interface types { // fields = u32` type alias can be replaced by a proper `resource fields` // definition containing all the functions using the method syntactic sugar. type fields = u32 - drop-fields: func(fields: fields) - new-fields: func(entries: list>) -> fields - fields-get: func(fields: fields, name: string) -> list> - fields-set: func(fields: fields, name: string, value: list>) - fields-delete: func(fields: fields, name: string) - fields-append: func(fields: fields, name: string, value: list) - fields-entries: func(fields: fields) -> list>> - fields-clone: func(fields: fields) -> fields + drop-fields: func(fields: /* own */ fields) + // Multiple values for a header are multiple entries in the list with the + // same key. + new-fields: func(entries: list>>) -> fields + // Values off wire are not necessarily well formed, so they are given by + // list instead of string. + fields-get: func(fields: /* borrow */ fields, name: string) -> list> + // Values off wire are not necessarily well formed, so they are given by + // list instead of string. + fields-set: func(fields: /* borrow */ fields, name: string, value: list>) + fields-delete: func(fields: /* borrow */ fields, name: string) + fields-append: func(fields: /* borrow */ fields, name: string, value: list) + + // Values off wire are not necessarily well formed, so they are given by + // list instead of string. + fields-entries: func(fields: /* borrow */ fields) -> list>> + // Deep copy of all contents in a fields. + fields-clone: func(fields: /* borrow */ fields) -> fields type headers = fields type trailers = fields - // The following block defines stream types which corresponds to the HTTP - // standard Contents and Trailers. With Preview3, all of these fields can be - // replaced by a stream>. In the interim, we need to - // build on separate resource types defined by `wasi:io/streams`. The - // `finish-` functions emulate the stream's result value and MUST be called - // exactly once after the final read/write from/to the stream before dropping - // the stream. - type incoming-stream = input-stream - type outgoing-stream = output-stream - finish-incoming-stream: func(s: incoming-stream) -> option - finish-outgoing-stream: func(s: outgoing-stream, trailers: option) - // The following block defines the `incoming-request` and `outgoing-request` // resource types that correspond to HTTP standard Requests. Soon, when // resource types are added, the `u32` type aliases can be replaced by @@ -74,23 +72,30 @@ interface types { // above). The `consume` and `write` methods may only be called once (and // return failure thereafter). type incoming-request = u32 + drop-incoming-request: func(request: /* own */ incoming-request) + incoming-request-method: func(request: /* borrow */ incoming-request) -> method + incoming-request-path-with-query: func(request: /* borrow */ incoming-request) -> option + incoming-request-scheme: func(request: /* borrow */ incoming-request) -> option + incoming-request-authority: func(request: /* borrow */ incoming-request) -> option + + incoming-request-headers: func(request: /* borrow */ incoming-request) -> /* child */ headers + // Will return the input-stream child at most once. If called more than + // once, subsequent calls will return error. + incoming-request-consume: func(request: /* borrow */ incoming-request) -> result< /* child */ input-stream> + type outgoing-request = u32 - drop-incoming-request: func(request: incoming-request) - drop-outgoing-request: func(request: outgoing-request) - incoming-request-method: func(request: incoming-request) -> method - incoming-request-path-with-query: func(request: incoming-request) -> option - incoming-request-scheme: func(request: incoming-request) -> option - incoming-request-authority: func(request: incoming-request) -> option - incoming-request-headers: func(request: incoming-request) -> headers - incoming-request-consume: func(request: incoming-request) -> result + drop-outgoing-request: func(request: /* own */ outgoing-request) new-outgoing-request: func( method: method, path-with-query: option, scheme: option, authority: option, - headers: headers + headers: /* borrow */ headers ) -> outgoing-request - outgoing-request-write: func(request: outgoing-request) -> result + + // Will return the outgoing-body child at most once. If called more than + // once, subsequent calls will return error. + outgoing-request-write: func(request: /* borrow */ outgoing-request) -> result< /* child */ outgoing-body> // Additional optional parameters that can be set when making a request. record request-options { @@ -115,8 +120,8 @@ interface types { // (the `wasi:http/handler` interface used for both incoming and outgoing can // simply return a `stream`). type response-outparam = u32 - drop-response-outparam: func(response: response-outparam) - set-response-outparam: func(param: response-outparam, response: result) -> result + drop-response-outparam: func(response: /* own */ response-outparam) + set-response-outparam: func(param: /* own */ response-outparam, response: result< /* own */ outgoing-response, error>) // This type corresponds to the HTTP standard Status Code. type status-code = u16 @@ -129,27 +134,72 @@ interface types { // type (that uses the single `stream` type mentioned above). The `consume` and // `write` methods may only be called once (and return failure thereafter). type incoming-response = u32 + drop-incoming-response: func(response: /* own */ incoming-response) + incoming-response-status: func(response: /* borrow */ incoming-response) -> status-code + incoming-response-headers: func(response: /* borrow */ incoming-response) -> /* child */ headers + // May be called at most once. returns error if called additional times. + // TODO: make incoming-request-consume work the same way, giving a child + // incoming-body. + incoming-response-consume: func(response: /* borrow */ incoming-response) -> result + + type incoming-body = u32 + drop-incoming-body: func(this: /* own */ incoming-body) + + // returned input-stream is a child - the implementation may trap if + // incoming-body is dropped (or consumed by call to + // incoming-body-finish) before the input-stream is dropped. + // May be called at most once. returns error if called additional times. + incoming-body-stream: func(this: /* borrow */ incoming-body) -> + result + // takes ownership of incoming-body. this will trap if the + // incoming-body-stream child is still alive! + incoming-body-finish: func(this: /* own */ incoming-body) -> + /* transitive child of the incoming-response of incoming-body */ future-trailers + + type future-trailers = u32 + drop-future-trailers: func(this: /* own */ future-trailers) + /// Pollable that resolves when the body has been fully read, and the trailers + /// are ready to be consumed. + future-trailers-subscribe: func(this: /* borrow */ future-trailers) -> /* child */ pollable + + /// Retrieve reference to trailers, if they are ready. + future-trailers-get: func(response: /* borrow */ future-trailers) -> option> + type outgoing-response = u32 - drop-incoming-response: func(response: incoming-response) - drop-outgoing-response: func(response: outgoing-response) - incoming-response-status: func(response: incoming-response) -> status-code - incoming-response-headers: func(response: incoming-response) -> headers - incoming-response-consume: func(response: incoming-response) -> result + drop-outgoing-response: func(response: /* own */ outgoing-response) new-outgoing-response: func( status-code: status-code, - headers: headers + headers: /* borrow */ headers ) -> outgoing-response - outgoing-response-write: func(response: outgoing-response) -> result - // The following block defines a special resource type used by the - // `wasi:http/outgoing-handler` interface to emulate - // `future>` in advance of Preview3. Given a - // `future-incoming-response`, the client can call the non-blocking `get` - // method to get the result if it is available. If the result is not available, - // the client can call `listen` to get a `pollable` that can be passed to - // `io.poll.poll-oneoff`. + /// Will give the child outgoing-response at most once. subsequent calls will + /// return an error. + outgoing-response-write: func(this: /* borrow */ outgoing-response) -> result + + type outgoing-body = u32 + drop-outgoing-body: func(this: /* own */ outgoing-body) + /// Will give the child output-stream at most once. subsequent calls will + /// return an error. + outgoing-body-write: func(this: /* borrow */ outgoing-body) -> result + /// Write trailers as the way to finish an outgoing-body. To finish an + /// outgoing-body without writing trailers, use drop-outgoing-body. + outgoing-body-write-trailers: func(this: /* own */ outgoing-body, trailers: /* own */ trailers) + + /// The following block defines a special resource type used by the + /// `wasi:http/outgoing-handler` interface to emulate + /// `future>` in advance of Preview3. Given a + /// `future-incoming-response`, the client can call the non-blocking `get` + /// method to get the result if it is available. If the result is not available, + /// the client can call `listen` to get a `pollable` that can be passed to + /// `io.poll.poll-oneoff`. type future-incoming-response = u32 - drop-future-incoming-response: func(f: future-incoming-response) - future-incoming-response-get: func(f: future-incoming-response) -> option> - listen-to-future-incoming-response: func(f: future-incoming-response) -> pollable + drop-future-incoming-response: func(f: /* own */ future-incoming-response) + /// option indicates readiness. + /// outer result indicates you are allowed to get the + /// incoming-response-or-error at most once. subsequent calls after ready + /// will return an error here. + /// inner result indicates whether the incoming-response was available, or an + /// error occured. + future-incoming-response-get: func(f: /* borrow */ future-incoming-response) -> option>> + listen-to-future-incoming-response: func(f: /* borrow */ future-incoming-response) -> /* child */ pollable } diff --git a/src/commands/run.rs b/src/commands/run.rs index 1c1bfe281799..01b04809b10b 100644 --- a/src/commands/run.rs +++ b/src/commands/run.rs @@ -848,22 +848,22 @@ impl RunCommand { } if self.common.wasi.http == Some(true) { - #[cfg(not(feature = "wasi-http"))] + #[cfg(not(all(feature = "wasi-http", feature = "component-model")))] { bail!("Cannot enable wasi-http when the binary is not compiled with this feature."); } - #[cfg(feature = "wasi-http")] + #[cfg(all(feature = "wasi-http", feature = "component-model"))] { match linker { - CliLinker::Core(linker) => { - wasmtime_wasi_http::sync::add_to_linker(linker)?; + CliLinker::Core(_) => { + bail!("Cannot enable wasi-http for core wasm modules"); } - #[cfg(feature = "component-model")] CliLinker::Component(linker) => { - wasmtime_wasi_http::proxy::sync::add_to_linker(linker)?; + wasmtime_wasi_http::proxy::add_to_linker(linker)?; } } - store.data_mut().wasi_http = Some(Arc::new(WasiHttpCtx::new())); + + store.data_mut().wasi_http = Some(Arc::new(WasiHttpCtx {})); } } @@ -998,13 +998,13 @@ impl preview2::preview1::WasiPreview1View for Host { #[cfg(feature = "wasi-http")] impl wasmtime_wasi_http::types::WasiHttpView for Host { - fn http_ctx(&self) -> &WasiHttpCtx { - self.wasi_http.as_ref().unwrap() + fn ctx(&mut self) -> &mut WasiHttpCtx { + let ctx = self.wasi_http.as_mut().unwrap(); + Arc::get_mut(ctx).expect("preview2 is not compatible with threads") } - fn http_ctx_mut(&mut self) -> &mut WasiHttpCtx { - let ctx = self.wasi_http.as_mut().unwrap(); - Arc::get_mut(ctx).expect("wasi-http is not compatible with threads") + fn table(&mut self) -> &mut preview2::Table { + Arc::get_mut(&mut self.preview2_table).expect("preview2 is not compatible with threads") } } diff --git a/tests/all/cli_tests.rs b/tests/all/cli_tests.rs index de5f2cb06a03..4f7dd4968dc0 100644 --- a/tests/all/cli_tests.rs +++ b/tests/all/cli_tests.rs @@ -791,6 +791,7 @@ fn run_basic_component() -> Result<()> { } #[cfg(feature = "wasi-http")] +#[ignore = "needs to be ported to components"] #[test] fn run_wasi_http_module() -> Result<()> { let output = run_wasmtime_for_output( From 6aca67c15dc33beb89e929de598fad13b723950b Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Thu, 21 Sep 2023 18:05:30 -0500 Subject: [PATCH 14/14] Support resource maps in `component::bindgen!` (#7069) * Support resource maps in `component::bindgen!` This commit adds support to `component::bindgen!` to specify resource types using the `with` key of the macro. This can be used to configure the `T` of `Resource` to use a preexisting type rather than unconditionally generating a new empty enum to have a fresh type. * Reenable tests --- crates/component-macro/tests/codegen.rs | 82 +++++++++++++++++++++++++ crates/wasmtime/src/component/mod.rs | 18 +++--- crates/wit-bindgen/src/lib.rs | 59 ++++++++++++------ 3 files changed, 132 insertions(+), 27 deletions(-) diff --git a/crates/component-macro/tests/codegen.rs b/crates/component-macro/tests/codegen.rs index 216d03e52062..90d58bd890f2 100644 --- a/crates/component-macro/tests/codegen.rs +++ b/crates/component-macro/tests/codegen.rs @@ -24,3 +24,85 @@ macro_rules! gentest { } component_macro_test_helpers::foreach!(gentest); + +mod with_key_and_resources { + use anyhow::Result; + use wasmtime::component::Resource; + + wasmtime::component::bindgen!({ + inline: " + package demo:pkg + + interface bar { + resource a + resource b + } + + world foo { + resource a + resource b + + import foo: interface { + resource a + resource b + } + + import bar + } + ", + with: { + "a": MyA, + "b": MyA, + "foo/a": MyA, + "foo/b": MyA, + "demo:pkg/bar/a": MyA, + "demo:pkg/bar/b": MyA, + }, + }); + + pub struct MyA; + + struct MyComponent; + + impl FooImports for MyComponent {} + + impl HostA for MyComponent { + fn drop(&mut self, _: Resource) -> Result<()> { + loop {} + } + } + + impl HostB for MyComponent { + fn drop(&mut self, _: Resource) -> Result<()> { + loop {} + } + } + + impl foo::Host for MyComponent {} + + impl foo::HostA for MyComponent { + fn drop(&mut self, _: Resource) -> Result<()> { + loop {} + } + } + + impl foo::HostB for MyComponent { + fn drop(&mut self, _: Resource) -> Result<()> { + loop {} + } + } + + impl demo::pkg::bar::Host for MyComponent {} + + impl demo::pkg::bar::HostA for MyComponent { + fn drop(&mut self, _: Resource) -> Result<()> { + loop {} + } + } + + impl demo::pkg::bar::HostB for MyComponent { + fn drop(&mut self, _: Resource) -> Result<()> { + loop {} + } + } +} diff --git a/crates/wasmtime/src/component/mod.rs b/crates/wasmtime/src/component/mod.rs index 618f829b2540..73aaad8e1cfc 100644 --- a/crates/wasmtime/src/component/mod.rs +++ b/crates/wasmtime/src/component/mod.rs @@ -326,16 +326,20 @@ pub(crate) use self::store::ComponentStoreData; /// // Restrict the code generated to what's needed for the interface /// // imports in the inlined WIT document fragment. /// interfaces: " -/// import package.foo +/// import wasi:cli/command /// ", /// -/// // Remap interface names to module names, imported from elsewhere. -/// // Using this option will prevent any code from being generated -/// // for the names mentioned in the mapping, assuming instead that the -/// // names mentioned come from a previous use of the `bindgen!` macro -/// // with `only_interfaces: true`. +/// // Remap imported interfaces or resources to types defined in Rust +/// // elsewhere. Using this option will prevent any code from being +/// // generated for interfaces mentioned here. Resources named here will +/// // not have a type generated to represent the resource. +/// // +/// // Interfaces mapped with this option should be previously generated +/// // with an invocation of this macro. Resources need to be mapped to a +/// // Rust type name. /// with: { -/// "a": somewhere::else::a, +/// "wasi:random/random": some::other::wasi::random::random, +/// "wasi:filesystem/types/descriptor": MyDescriptorType, /// }, /// }); /// ``` diff --git a/crates/wit-bindgen/src/lib.rs b/crates/wit-bindgen/src/lib.rs index 2a0d6d68b8aa..45fe6aedf3b1 100644 --- a/crates/wit-bindgen/src/lib.rs +++ b/crates/wit-bindgen/src/lib.rs @@ -912,16 +912,28 @@ impl<'a> InterfaceGenerator<'a> { self.assert_type(id, &name); } - fn type_resource(&mut self, id: TypeId, _name: &str, resource: &TypeDef, docs: &Docs) { - let camel = resource - .name - .as_ref() - .expect("resources are required to be named") - .to_upper_camel_case(); + fn type_resource(&mut self, id: TypeId, name: &str, resource: &TypeDef, docs: &Docs) { + let camel = name.to_upper_camel_case(); if self.types_imported() { self.rustdoc(docs); - uwriteln!(self.src, "pub enum {camel} {{}}"); + + let with_key = match self.current_interface { + Some((_, key, _)) => format!("{}/{name}", self.resolve.name_world_key(key)), + None => name.to_string(), + }; + match self.gen.opts.with.get(&with_key) { + Some(path) => { + uwriteln!( + self.src, + "pub use {}{path} as {camel};", + self.path_to_root() + ); + } + None => { + uwriteln!(self.src, "pub enum {camel} {{}}"); + } + } if self.gen.opts.async_.maybe_async() { uwriteln!(self.src, "#[wasmtime::component::__internal::async_trait]") @@ -1913,6 +1925,24 @@ impl<'a> InterfaceGenerator<'a> { self.push_str("\n"); } } + + fn path_to_root(&self) -> String { + let mut path_to_root = String::new(); + if let Some((_, key, is_export)) = self.current_interface { + match key { + WorldKey::Name(_) => { + path_to_root.push_str("super::"); + } + WorldKey::Interface(_) => { + path_to_root.push_str("super::super::super::"); + } + } + if is_export { + path_to_root.push_str("super::"); + } + } + path_to_root + } } impl<'a> RustGenerator<'a> for InterfaceGenerator<'a> { @@ -1925,23 +1955,12 @@ impl<'a> RustGenerator<'a> for InterfaceGenerator<'a> { } fn path_to_interface(&self, interface: InterfaceId) -> Option { - let mut path_to_root = String::new(); - if let Some((cur, key, is_export)) = self.current_interface { + if let Some((cur, _, _)) = self.current_interface { if cur == interface { return None; } - match key { - WorldKey::Name(_) => { - path_to_root.push_str("super::"); - } - WorldKey::Interface(_) => { - path_to_root.push_str("super::super::super::"); - } - } - if is_export { - path_to_root.push_str("super::"); - } } + let mut path_to_root = self.path_to_root(); let InterfaceName { path, .. } = &self.gen.interface_names[&interface]; path_to_root.push_str(path); Some(path_to_root)