diff --git a/src/crystal/system/win32/iocp.cr b/src/crystal/system/win32/iocp.cr index ba0f11eb2af5..add5a29c2814 100644 --- a/src/crystal/system/win32/iocp.cr +++ b/src/crystal/system/win32/iocp.cr @@ -65,15 +65,11 @@ module Crystal::IOCP enum State STARTED DONE - CANCELLED end @overlapped = LibC::OVERLAPPED.new @fiber = Fiber.current @state : State = :started - property next : OverlappedOperation? - property previous : OverlappedOperation? - @@canceled = Thread::LinkedList(OverlappedOperation).new def initialize(@handle : LibC::HANDLE) end @@ -83,12 +79,9 @@ module Crystal::IOCP end def self.run(handle, &) - operation = OverlappedOperation.new(handle) - begin - yield operation - ensure - operation.done - end + operation_storage = uninitialized ReferenceStorage(OverlappedOperation) + operation = OverlappedOperation.unsafe_construct(pointerof(operation_storage), handle) + yield operation end def self.unbox(overlapped : LibC::OVERLAPPED*) @@ -103,8 +96,6 @@ module Crystal::IOCP def wait_for_result(timeout, &) wait_for_completion(timeout) - raise Exception.new("Invalid state #{@state}") unless @state.done? || @state.started? - result = LibC.GetOverlappedResult(@handle, self, out bytes, 0) if result.zero? error = WinError.value @@ -118,11 +109,7 @@ module Crystal::IOCP def wait_for_wsa_result(timeout, &) wait_for_completion(timeout) - wsa_result { |error| yield error } - end - def wsa_result(&) - raise Exception.new("Invalid state #{@state}") unless @state.done? || @state.started? flags = 0_u32 result = LibC.WSAGetOverlappedResult(LibC::SOCKET.new(@handle.address), self, out bytes, false, pointerof(flags)) if result.zero? @@ -136,49 +123,48 @@ module Crystal::IOCP end protected def schedule(&) - case @state - when .started? - yield @fiber - done! - when .cancelled? - @@canceled.delete(self) - else - raise Exception.new("Invalid state #{@state}") - end - end - - protected def done - case @state - when .started? - # https://learn.microsoft.com/en-us/windows/win32/api/ioapiset/nf-ioapiset-cancelioex - # > The application must not free or reuse the OVERLAPPED structure - # associated with the canceled I/O operations until they have completed - if LibC.CancelIoEx(@handle, self) != 0 - @state = :cancelled - @@canceled.push(self) # to increase lifetime - end - end + done! + yield @fiber end def done! + @fiber.cancel_timeout @state = :done end + def try_cancel : Bool + # Microsoft documentation: + # The application must not free or reuse the OVERLAPPED structure + # associated with the canceled I/O operations until they have completed + # (this does not apply to asynchronous operations that finished + # synchronously, as nothing would be queued to the IOCP) + ret = LibC.CancelIoEx(@handle, self) + if ret.zero? + case error = WinError.value + when .error_not_found? + # Operation has already completed, do nothing + return false + else + raise RuntimeError.from_os_error("CancelIOEx", os_error: error) + end + end + true + end + def wait_for_completion(timeout) if timeout - timeout_event = Crystal::IOCP::Event.new(Fiber.current) - timeout_event.add(timeout) + sleep timeout else - timeout_event = Crystal::IOCP::Event.new(Fiber.current, Time::Span::MAX) + Fiber.suspend end - # memoize event loop to make sure that we still target the same instance - # after wakeup (guaranteed by current MT model but let's be future proof) - event_loop = Crystal::EventLoop.current - event_loop.enqueue(timeout_event) - - Fiber.suspend - event_loop.dequeue(timeout_event) + unless @state.done? + if try_cancel + # Wait for cancellation to complete. We must not free the operation + # until it's completed. + Fiber.suspend + end + end end end @@ -200,13 +186,12 @@ module Crystal::IOCP raise IO::Error.from_os_error(method, error, target: target) end else - operation.done! return value end operation.wait_for_result(timeout) do |error| case error - when .error_io_incomplete? + when .error_io_incomplete?, .error_operation_aborted? raise IO::TimeoutError.new("#{method} timed out") when .error_handle_eof? return 0_u32 @@ -230,13 +215,12 @@ module Crystal::IOCP raise IO::Error.from_os_error(method, error, target: target) end else - operation.done! return value end operation.wait_for_wsa_result(timeout) do |error| case error - when .wsa_io_incomplete? + when .wsa_io_incomplete?, .error_operation_aborted? raise IO::TimeoutError.new("#{method} timed out") when .wsaeconnreset? return 0_u32 unless connreset_is_error diff --git a/src/crystal/system/win32/socket.cr b/src/crystal/system/win32/socket.cr index 2a540f4df88d..9bf1fd6ac853 100644 --- a/src/crystal/system/win32/socket.cr +++ b/src/crystal/system/win32/socket.cr @@ -142,7 +142,6 @@ module Crystal::System::Socket return ::Socket::Error.from_os_error("ConnectEx", error) end else - operation.done! return nil end @@ -204,18 +203,15 @@ module Crystal::System::Socket return false end else - operation.done! return true end - unless operation.wait_for_completion(read_timeout) - raise IO::TimeoutError.new("#{method} timed out") - end - - operation.wsa_result do |error| + operation.wait_for_wsa_result(read_timeout) do |error| case error when .wsa_io_incomplete?, .wsaenotsock? return false + when .error_operation_aborted? + raise IO::TimeoutError.new("#{method} timed out") end end