diff --git a/RELEASENOTES.md b/RELEASENOTES.md
index d4b6d8aea..e82544afd 100644
--- a/RELEASENOTES.md
+++ b/RELEASENOTES.md
@@ -6,7 +6,8 @@ Releases, starting with 9/2/2021, are listed with the most recent release at the
__Bug Fixes__:
-#1383 `torch.linalg.vector_norm`: Make `ord`-argument optional, as specified in docs
+#1383 `torch.linalg.vector_norm`: Make `ord`-argument optional, as specified in docs
+#1385 PackedSequence now participates in the DisposeScope system at the same level as Tensor objects.
# NuGet Version 0.103.0
diff --git a/build/BranchInfo.props b/build/BranchInfo.props
index 2912bb1d1..494ead70f 100644
--- a/build/BranchInfo.props
+++ b/build/BranchInfo.props
@@ -2,8 +2,7 @@
0
103
- 0
- 0.102.8
+ 1
+ 0.103.0
-
-
+
\ No newline at end of file
diff --git a/src/TorchSharp/DisposeScope.cs b/src/TorchSharp/DisposeScope.cs
index 3e3411e51..13f5c15ce 100644
--- a/src/TorchSharp/DisposeScope.cs
+++ b/src/TorchSharp/DisposeScope.cs
@@ -213,6 +213,9 @@ public void Detach(IEnumerable disposables)
if (disposable is torch.Tensor tensor) {
tensor.OwningDisposeScope = null;
}
+ else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
+ sequence.OwningDisposeScope = null;
+ }
}
}
}
@@ -239,9 +242,16 @@ public IReadOnlyList Attach(IEnumerable disposables)
_disposeScopeManager.StatisticsInstance.DetachedFromScopeCount--;
}
}
+ else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
+ if (sequence.OwningDisposeScope == null && !sequence.IsInvalid) {
+ _disposeScopeManager.StatisticsInstance.DetachedFromScopeCount--;
+ }
+ }
+
AddToOther(this, disposable);
result.Add(disposable);
}
+
return result;
}
@@ -274,6 +284,12 @@ public void DisposeEverythingBut(IEnumerable inKeep)
if (!tensor.IsInvalid) {
_disposeScopeManager.StatisticsInstance.DisposedInScopeCount++;
}
+ } else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
+ // No need to have the disposable call back to the scope
+ sequence.OwningDisposeScope = null;
+ if (!sequence.IsInvalid) {
+ _disposeScopeManager.StatisticsInstance.DisposedInScopeCount++;
+ }
} else {
_disposeScopeManager.StatisticsInstance.DisposedInScopeCount++;
}
@@ -358,6 +374,9 @@ public void MarkAsDisposed(IDisposable disposable)
if (disposable is torch.Tensor tensor) {
tensor.OwningDisposeScope = null;
}
+ else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
+ sequence.OwningDisposeScope = null;
+ }
}
///
@@ -380,6 +399,9 @@ private void AddToOther(DisposeScope? scope, IDisposable disposable)
if (disposable is torch.Tensor tensor) {
tensor.OwningDisposeScope = scope;
}
+ else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
+ sequence.OwningDisposeScope = scope;
+ }
}
internal HashSet DetachAllAndDispose()
@@ -390,6 +412,9 @@ internal HashSet DetachAllAndDispose()
if (disposable is torch.Tensor tensor) {
tensor.OwningDisposeScope = null;
}
+ else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
+ sequence.OwningDisposeScope = null;
+ }
}
this.Disposables = new();
diff --git a/src/TorchSharp/DisposeScopeManager.cs b/src/TorchSharp/DisposeScopeManager.cs
index e1230779e..c6f665ea9 100644
--- a/src/TorchSharp/DisposeScopeManager.cs
+++ b/src/TorchSharp/DisposeScopeManager.cs
@@ -18,7 +18,7 @@ public class DisposeScopeManager
internal ThreadDisposeScopeStatistics StatisticsInstance { get; } = new ThreadDisposeScopeStatistics();
internal DisposeScope? CurrentDisposeScope { get; private set; } = null;
- internal DisposeScope? RegisterOnCurrentDisposeScope(torch.Tensor tensor)
+ internal DisposeScope? RegisterOnCurrentDisposeScope(IDisposable tensor)
{
if (this.CurrentDisposeScope is null) {
StatisticsInstance.CreatedOutsideScopeCount++;
diff --git a/src/TorchSharp/NN/Utils/PackedSequence.cs b/src/TorchSharp/NN/Utils/PackedSequence.cs
index 78b53e7ff..642273809 100644
--- a/src/TorchSharp/NN/Utils/PackedSequence.cs
+++ b/src/TorchSharp/NN/Utils/PackedSequence.cs
@@ -1,5 +1,7 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
+
using System;
+using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using static TorchSharp.PInvoke.NativeMethods;
@@ -18,6 +20,17 @@ public static partial class rnn
///
public sealed class PackedSequence : IDisposable
{
+ internal DisposeScope OwningDisposeScope {
+ get => mOwningDisposeScope;
+ set {
+ mOwningDisposeScope = value;
+ this.batch_sizes.OwningDisposeScope = value;
+ this.data.OwningDisposeScope = value;
+ this.sorted_indices.OwningDisposeScope = value;
+ this.unsorted_indices.OwningDisposeScope = value;
+ }
+ }
+
///
/// Class wrapping PyTorch's packedsequence object reference.
///
@@ -39,6 +52,7 @@ internal HType() : base(IntPtr.Zero, true)
protected override bool ReleaseHandle()
{
THSNN_PackedSequence_dispose(handle);
+ handle = IntPtr.Zero;
return true;
}
}
@@ -62,15 +76,21 @@ protected override bool ReleaseHandle()
/// The original indices
///
public readonly Tensor unsorted_indices;
+ ///
+ /// Is true if the PackedSequence has been disposed, false otherwise.
+ ///
+ internal bool IsInvalid => handle.IsInvalid;
private HType handle;
+ private DisposeScope mOwningDisposeScope;
internal PackedSequence(HType handle)
{
this.handle = handle;
- this.data = new Tensor(THSNN_PackedSequence_data(handle));
- this.batch_sizes = new Tensor(THSNN_PackedSequence_batch_sizes(handle));
- this.sorted_indices = new Tensor(THSNN_PackedSequence_sorted_indices(handle));
- this.unsorted_indices = new Tensor(THSNN_PackedSequence_unsorted_indices(handle));
+ this.data = new Tensor(THSNN_PackedSequence_data(handle)).DetachFromDisposeScope();
+ this.batch_sizes = new Tensor(THSNN_PackedSequence_batch_sizes(handle)).DetachFromDisposeScope();
+ this.sorted_indices = new Tensor(THSNN_PackedSequence_sorted_indices(handle)).DetachFromDisposeScope();
+ this.unsorted_indices = new Tensor(THSNN_PackedSequence_unsorted_indices(handle)).DetachFromDisposeScope();
+ OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this);
}
internal HType Handle => handle;
@@ -84,15 +104,53 @@ public void Dispose()
this.batch_sizes.Dispose();
this.sorted_indices.Dispose();
this.unsorted_indices.Dispose();
+ OwningDisposeScope?.MarkAsDisposed(this);
if (handle != null && !handle.IsInvalid) {
handle.Dispose();
handle.SetHandleAsInvalid();
+
+ }
+ }
+ ///
+ /// Moves PackedSequence to the outer DisposeScope. If there is no outer DisposeScope, it's detached from the
+ /// DisposeScope system.
+ ///
+ /// The same PackedSequence that the method was called on
+ public PackedSequence MoveToOuterDisposeScope()
+ {
+ OwningDisposeScope?.MoveToOuter(this);
+ return this;
+ }
+
+ ///
+ /// Detaches the PackedSequence completely from the DisposeScope system.
+ ///
+ /// The same PackedSequence that the method was called on
+ public PackedSequence DetachFromDisposeScope()
+ {
+ OwningDisposeScope?.Detach(this);
+ return this;
+ }
+
+ public PackedSequence MoveToOtherDisposeScope(PackedSequence other)
+ {
+ return MoveToOtherDisposeScope(other.OwningDisposeScope);
+ }
+
+ public PackedSequence MoveToOtherDisposeScope(DisposeScope other)
+ {
+ if (OwningDisposeScope == null && other != null) {
+ other.Attach(this);
+ }
+ else {
+ OwningDisposeScope?.MoveToOther(other, this);
}
+ return this;
}
-}
+ }
}
}
}
}
-}
+}
\ No newline at end of file
diff --git a/test/TorchSharpTest/TestDisposeScopesPackedSequence.cs b/test/TorchSharpTest/TestDisposeScopesPackedSequence.cs
new file mode 100644
index 000000000..8d4333cc1
--- /dev/null
+++ b/test/TorchSharpTest/TestDisposeScopesPackedSequence.cs
@@ -0,0 +1,121 @@
+using System.Reflection;
+using TorchSharp;
+using Xunit;
+
+namespace TorchSharpTest;
+
+public class TestDisposeScopesPackedSequence
+{
+ [Fact]
+ public void MoveDisposeScope()
+ {
+ var sequences = CreateTestSequences();
+ torch.nn.utils.rnn.PackedSequence packed_sequence;
+ var otherScope = torch.NewDisposeScope();
+ using (torch.NewDisposeScope())
+ {
+ using (torch.NewDisposeScope())
+ {
+ packed_sequence = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: false);
+ AssertPackedSequenceValid(packed_sequence);
+
+ packed_sequence.MoveToOuterDisposeScope();
+ }
+ AssertPackedSequenceValid(packed_sequence);
+
+ packed_sequence.MoveToOtherDisposeScope(otherScope);
+ }
+
+ AssertPackedSequenceValid(packed_sequence);
+ otherScope.Dispose();
+
+ Assert.True(GetPackedSequenceIsInvalid(packed_sequence));
+ Assert.True(packed_sequence.data.IsInvalid);
+ }
+
+ [Fact]
+ public void DisposablesValidityWhenNotSorted()
+ {
+ var sequences = CreateTestSequences();
+ using var scope = torch.NewDisposeScope();
+ var packed = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: false);
+ Assert.Equal(1, scope.DisposablesCount);
+ AssertPackedSequenceValid(packed);
+ }
+
+ [Fact]
+ public void DisposablesValidityWhenSorted()
+ {
+ var sequences = CreateTestSequences();
+ using var scope = torch.NewDisposeScope();
+ var packed = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: true);
+ Assert.Equal(1, scope.DisposablesCount);
+ Assert.False(GetPackedSequenceIsInvalid(packed));
+ Assert.False(packed.batch_sizes.IsInvalid);
+ Assert.False(packed.data.IsInvalid);
+ Assert.True(packed.sorted_indices.IsInvalid);
+ Assert.True(packed.unsorted_indices.IsInvalid);
+ }
+
+ [Fact]
+ public void DisposeScopeStatistics()
+ {
+ DisposeScopeManager.Statistics.Reset();
+ AssertStatCounts(0, 0, 0, 0, 0);
+ var sequences = CreateTestSequences();
+ AssertStatCounts(0, 2, 0, 0, 0);
+ var outOfScope = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: true);
+ AssertStatCounts(0, 7, 0, 0, 0);
+ using var scope = torch.NewDisposeScope();
+ AssertStatCounts(0, 7, 0, 0, 0);
+
+ var inScope = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: true);
+ AssertStatCounts(5, 7, 4, 0, 1);
+
+ scope.Attach(outOfScope);
+ //Possible subtle bug. When attaching an object that isn't owned by any scope, the count subtracts.
+ AssertStatCounts( 5, 7, 3, 0, 2);
+
+ scope.Detach(inScope);
+ AssertStatCounts( 5, 7, 4, 0, 1);
+
+ outOfScope.Dispose();
+ AssertStatCounts( 5, 7, 4, 5, -4);
+
+ }
+
+ private static void AssertStatCounts(long createdInScope, long createdOutsideScope, long detachedFrom, long disposedIn, long threadTotalLive)
+ {
+ var stats = DisposeScopeManager.Statistics;
+ Assert.Equal(createdInScope, stats.CreatedInScopeCount);
+ Assert.Equal(createdOutsideScope, stats.CreatedOutsideScopeCount);
+ Assert.Equal(detachedFrom, stats.DetachedFromScopeCount);
+ Assert.Equal(disposedIn, stats.DisposedInScopeCount);
+ Assert.Equal(threadTotalLive, stats.ThreadTotalLiveCount);
+ }
+
+ private static torch.Tensor[] CreateTestSequences()
+ {
+ return new[]
+ {
+ torch.tensor(new long[] { 1, 2, 3, 4 }),
+ torch.tensor(new long[] { 5, 6 }),
+ };
+ }
+
+ private static void AssertPackedSequenceValid(torch.nn.utils.rnn.PackedSequence packed_sequence)
+ {
+ Assert.False(GetPackedSequenceIsInvalid(packed_sequence));
+ Assert.False(packed_sequence.batch_sizes.IsInvalid);
+ Assert.False(packed_sequence.data.IsInvalid);
+ Assert.False(packed_sequence.sorted_indices.IsInvalid);
+ Assert.False(packed_sequence.unsorted_indices.IsInvalid);
+ }
+
+ private static bool GetPackedSequenceIsInvalid(torch.nn.utils.rnn.PackedSequence packed_sequence)
+ {
+ //HACK: reflection to avoid exposing internal method IsInvalid in API.
+ var getter = typeof(torch.nn.utils.rnn.PackedSequence).GetProperty("IsInvalid", BindingFlags.Instance | BindingFlags.NonPublic)!;
+ return (bool)getter.GetValue(packed_sequence)!;
+ }
+}
\ No newline at end of file