diff --git a/RELEASENOTES.md b/RELEASENOTES.md index d4b6d8aea..e82544afd 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -6,7 +6,8 @@ Releases, starting with 9/2/2021, are listed with the most recent release at the __Bug Fixes__: -#1383 `torch.linalg.vector_norm`: Make `ord`-argument optional, as specified in docs +#1383 `torch.linalg.vector_norm`: Make `ord`-argument optional, as specified in docs
+#1385 PackedSequence now participates in the DisposeScope system at the same level as Tensor objects.
# NuGet Version 0.103.0 diff --git a/build/BranchInfo.props b/build/BranchInfo.props index 2912bb1d1..494ead70f 100644 --- a/build/BranchInfo.props +++ b/build/BranchInfo.props @@ -2,8 +2,7 @@ 0 103 - 0 - 0.102.8 + 1 + 0.103.0 - - + \ No newline at end of file diff --git a/src/TorchSharp/DisposeScope.cs b/src/TorchSharp/DisposeScope.cs index 3e3411e51..13f5c15ce 100644 --- a/src/TorchSharp/DisposeScope.cs +++ b/src/TorchSharp/DisposeScope.cs @@ -213,6 +213,9 @@ public void Detach(IEnumerable disposables) if (disposable is torch.Tensor tensor) { tensor.OwningDisposeScope = null; } + else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) { + sequence.OwningDisposeScope = null; + } } } } @@ -239,9 +242,16 @@ public IReadOnlyList Attach(IEnumerable disposables) _disposeScopeManager.StatisticsInstance.DetachedFromScopeCount--; } } + else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) { + if (sequence.OwningDisposeScope == null && !sequence.IsInvalid) { + _disposeScopeManager.StatisticsInstance.DetachedFromScopeCount--; + } + } + AddToOther(this, disposable); result.Add(disposable); } + return result; } @@ -274,6 +284,12 @@ public void DisposeEverythingBut(IEnumerable inKeep) if (!tensor.IsInvalid) { _disposeScopeManager.StatisticsInstance.DisposedInScopeCount++; } + } else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) { + // No need to have the disposable call back to the scope + sequence.OwningDisposeScope = null; + if (!sequence.IsInvalid) { + _disposeScopeManager.StatisticsInstance.DisposedInScopeCount++; + } } else { _disposeScopeManager.StatisticsInstance.DisposedInScopeCount++; } @@ -358,6 +374,9 @@ public void MarkAsDisposed(IDisposable disposable) if (disposable is torch.Tensor tensor) { tensor.OwningDisposeScope = null; } + else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) { + sequence.OwningDisposeScope = null; + } } /// @@ -380,6 +399,9 @@ private void AddToOther(DisposeScope? scope, IDisposable disposable) if (disposable is torch.Tensor tensor) { tensor.OwningDisposeScope = scope; } + else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) { + sequence.OwningDisposeScope = scope; + } } internal HashSet DetachAllAndDispose() @@ -390,6 +412,9 @@ internal HashSet DetachAllAndDispose() if (disposable is torch.Tensor tensor) { tensor.OwningDisposeScope = null; } + else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) { + sequence.OwningDisposeScope = null; + } } this.Disposables = new(); diff --git a/src/TorchSharp/DisposeScopeManager.cs b/src/TorchSharp/DisposeScopeManager.cs index e1230779e..c6f665ea9 100644 --- a/src/TorchSharp/DisposeScopeManager.cs +++ b/src/TorchSharp/DisposeScopeManager.cs @@ -18,7 +18,7 @@ public class DisposeScopeManager internal ThreadDisposeScopeStatistics StatisticsInstance { get; } = new ThreadDisposeScopeStatistics(); internal DisposeScope? CurrentDisposeScope { get; private set; } = null; - internal DisposeScope? RegisterOnCurrentDisposeScope(torch.Tensor tensor) + internal DisposeScope? RegisterOnCurrentDisposeScope(IDisposable tensor) { if (this.CurrentDisposeScope is null) { StatisticsInstance.CreatedOutsideScopeCount++; diff --git a/src/TorchSharp/NN/Utils/PackedSequence.cs b/src/TorchSharp/NN/Utils/PackedSequence.cs index 78b53e7ff..642273809 100644 --- a/src/TorchSharp/NN/Utils/PackedSequence.cs +++ b/src/TorchSharp/NN/Utils/PackedSequence.cs @@ -1,5 +1,7 @@ // Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. + using System; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using static TorchSharp.PInvoke.NativeMethods; @@ -18,6 +20,17 @@ public static partial class rnn /// public sealed class PackedSequence : IDisposable { + internal DisposeScope OwningDisposeScope { + get => mOwningDisposeScope; + set { + mOwningDisposeScope = value; + this.batch_sizes.OwningDisposeScope = value; + this.data.OwningDisposeScope = value; + this.sorted_indices.OwningDisposeScope = value; + this.unsorted_indices.OwningDisposeScope = value; + } + } + /// /// Class wrapping PyTorch's packedsequence object reference. /// @@ -39,6 +52,7 @@ internal HType() : base(IntPtr.Zero, true) protected override bool ReleaseHandle() { THSNN_PackedSequence_dispose(handle); + handle = IntPtr.Zero; return true; } } @@ -62,15 +76,21 @@ protected override bool ReleaseHandle() /// The original indices /// public readonly Tensor unsorted_indices; + /// + /// Is true if the PackedSequence has been disposed, false otherwise. + /// + internal bool IsInvalid => handle.IsInvalid; private HType handle; + private DisposeScope mOwningDisposeScope; internal PackedSequence(HType handle) { this.handle = handle; - this.data = new Tensor(THSNN_PackedSequence_data(handle)); - this.batch_sizes = new Tensor(THSNN_PackedSequence_batch_sizes(handle)); - this.sorted_indices = new Tensor(THSNN_PackedSequence_sorted_indices(handle)); - this.unsorted_indices = new Tensor(THSNN_PackedSequence_unsorted_indices(handle)); + this.data = new Tensor(THSNN_PackedSequence_data(handle)).DetachFromDisposeScope(); + this.batch_sizes = new Tensor(THSNN_PackedSequence_batch_sizes(handle)).DetachFromDisposeScope(); + this.sorted_indices = new Tensor(THSNN_PackedSequence_sorted_indices(handle)).DetachFromDisposeScope(); + this.unsorted_indices = new Tensor(THSNN_PackedSequence_unsorted_indices(handle)).DetachFromDisposeScope(); + OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this); } internal HType Handle => handle; @@ -84,15 +104,53 @@ public void Dispose() this.batch_sizes.Dispose(); this.sorted_indices.Dispose(); this.unsorted_indices.Dispose(); + OwningDisposeScope?.MarkAsDisposed(this); if (handle != null && !handle.IsInvalid) { handle.Dispose(); handle.SetHandleAsInvalid(); + + } + } + /// + /// Moves PackedSequence to the outer DisposeScope. If there is no outer DisposeScope, it's detached from the + /// DisposeScope system. + /// + /// The same PackedSequence that the method was called on + public PackedSequence MoveToOuterDisposeScope() + { + OwningDisposeScope?.MoveToOuter(this); + return this; + } + + /// + /// Detaches the PackedSequence completely from the DisposeScope system. + /// + /// The same PackedSequence that the method was called on + public PackedSequence DetachFromDisposeScope() + { + OwningDisposeScope?.Detach(this); + return this; + } + + public PackedSequence MoveToOtherDisposeScope(PackedSequence other) + { + return MoveToOtherDisposeScope(other.OwningDisposeScope); + } + + public PackedSequence MoveToOtherDisposeScope(DisposeScope other) + { + if (OwningDisposeScope == null && other != null) { + other.Attach(this); + } + else { + OwningDisposeScope?.MoveToOther(other, this); } + return this; } -} + } } } } } -} +} \ No newline at end of file diff --git a/test/TorchSharpTest/TestDisposeScopesPackedSequence.cs b/test/TorchSharpTest/TestDisposeScopesPackedSequence.cs new file mode 100644 index 000000000..8d4333cc1 --- /dev/null +++ b/test/TorchSharpTest/TestDisposeScopesPackedSequence.cs @@ -0,0 +1,121 @@ +using System.Reflection; +using TorchSharp; +using Xunit; + +namespace TorchSharpTest; + +public class TestDisposeScopesPackedSequence +{ + [Fact] + public void MoveDisposeScope() + { + var sequences = CreateTestSequences(); + torch.nn.utils.rnn.PackedSequence packed_sequence; + var otherScope = torch.NewDisposeScope(); + using (torch.NewDisposeScope()) + { + using (torch.NewDisposeScope()) + { + packed_sequence = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: false); + AssertPackedSequenceValid(packed_sequence); + + packed_sequence.MoveToOuterDisposeScope(); + } + AssertPackedSequenceValid(packed_sequence); + + packed_sequence.MoveToOtherDisposeScope(otherScope); + } + + AssertPackedSequenceValid(packed_sequence); + otherScope.Dispose(); + + Assert.True(GetPackedSequenceIsInvalid(packed_sequence)); + Assert.True(packed_sequence.data.IsInvalid); + } + + [Fact] + public void DisposablesValidityWhenNotSorted() + { + var sequences = CreateTestSequences(); + using var scope = torch.NewDisposeScope(); + var packed = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: false); + Assert.Equal(1, scope.DisposablesCount); + AssertPackedSequenceValid(packed); + } + + [Fact] + public void DisposablesValidityWhenSorted() + { + var sequences = CreateTestSequences(); + using var scope = torch.NewDisposeScope(); + var packed = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: true); + Assert.Equal(1, scope.DisposablesCount); + Assert.False(GetPackedSequenceIsInvalid(packed)); + Assert.False(packed.batch_sizes.IsInvalid); + Assert.False(packed.data.IsInvalid); + Assert.True(packed.sorted_indices.IsInvalid); + Assert.True(packed.unsorted_indices.IsInvalid); + } + + [Fact] + public void DisposeScopeStatistics() + { + DisposeScopeManager.Statistics.Reset(); + AssertStatCounts(0, 0, 0, 0, 0); + var sequences = CreateTestSequences(); + AssertStatCounts(0, 2, 0, 0, 0); + var outOfScope = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: true); + AssertStatCounts(0, 7, 0, 0, 0); + using var scope = torch.NewDisposeScope(); + AssertStatCounts(0, 7, 0, 0, 0); + + var inScope = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: true); + AssertStatCounts(5, 7, 4, 0, 1); + + scope.Attach(outOfScope); + //Possible subtle bug. When attaching an object that isn't owned by any scope, the count subtracts. + AssertStatCounts( 5, 7, 3, 0, 2); + + scope.Detach(inScope); + AssertStatCounts( 5, 7, 4, 0, 1); + + outOfScope.Dispose(); + AssertStatCounts( 5, 7, 4, 5, -4); + + } + + private static void AssertStatCounts(long createdInScope, long createdOutsideScope, long detachedFrom, long disposedIn, long threadTotalLive) + { + var stats = DisposeScopeManager.Statistics; + Assert.Equal(createdInScope, stats.CreatedInScopeCount); + Assert.Equal(createdOutsideScope, stats.CreatedOutsideScopeCount); + Assert.Equal(detachedFrom, stats.DetachedFromScopeCount); + Assert.Equal(disposedIn, stats.DisposedInScopeCount); + Assert.Equal(threadTotalLive, stats.ThreadTotalLiveCount); + } + + private static torch.Tensor[] CreateTestSequences() + { + return new[] + { + torch.tensor(new long[] { 1, 2, 3, 4 }), + torch.tensor(new long[] { 5, 6 }), + }; + } + + private static void AssertPackedSequenceValid(torch.nn.utils.rnn.PackedSequence packed_sequence) + { + Assert.False(GetPackedSequenceIsInvalid(packed_sequence)); + Assert.False(packed_sequence.batch_sizes.IsInvalid); + Assert.False(packed_sequence.data.IsInvalid); + Assert.False(packed_sequence.sorted_indices.IsInvalid); + Assert.False(packed_sequence.unsorted_indices.IsInvalid); + } + + private static bool GetPackedSequenceIsInvalid(torch.nn.utils.rnn.PackedSequence packed_sequence) + { + //HACK: reflection to avoid exposing internal method IsInvalid in API. + var getter = typeof(torch.nn.utils.rnn.PackedSequence).GetProperty("IsInvalid", BindingFlags.Instance | BindingFlags.NonPublic)!; + return (bool)getter.GetValue(packed_sequence)!; + } +} \ No newline at end of file