Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions eng/Versions.props
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
<SystemCodeDomVersion>4.5.0</SystemCodeDomVersion>
<SystemCollectionsImmutableVersion>1.5.0</SystemCollectionsImmutableVersion>
<SystemIOFileSystemAccessControl>4.5.0</SystemIOFileSystemAccessControl>
<SystemMemoryVersion>4.5.3</SystemMemoryVersion>
<SystemMemoryVersion>4.5.5</SystemMemoryVersion>
<SystemReflectionEmitLightweightVersion>4.3.0</SystemReflectionEmitLightweightVersion>
<SystemReflectionEmitVersion>4.3.0</SystemReflectionEmitVersion>
<SystemRuntimeCompilerServicesUnsafeVersion>6.0.0</SystemRuntimeCompilerServicesUnsafeVersion>
Expand Down Expand Up @@ -56,12 +56,12 @@
<NewtonsoftJsonVersion>13.0.1</NewtonsoftJsonVersion>
<ParquetDotNetVersion>2.1.3</ParquetDotNetVersion>
<PlotlyNETCSharpVersion>0.0.1</PlotlyNETCSharpVersion>
<SharpZipLibVersion>1.3.3</SharpZipLibVersion>
<SharpZipLibVersion>1.4.0</SharpZipLibVersion>
<TensorflowDotNETVersion>0.20.1</TensorflowDotNETVersion>
<TensorFlowMajorVersion>2</TensorFlowMajorVersion>
<TensorFlowVersion>2.3.1</TensorFlowVersion>
<TorchSharpVersion>0.98.3</TorchSharpVersion>
<LibTorchVersion>1.11.0.1</LibTorchVersion>
<TorchSharpVersion>0.99.5</TorchSharpVersion>
<LibTorchVersion>1.13.0.1</LibTorchVersion>
<!-- Build/infrastructure Dependencies -->
<CodecovVersion>1.12.4</CodecovVersion>
<CoverletCollectorVersion>3.1.2</CoverletCollectorVersion>
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.TorchSharp/Microsoft.ML.TorchSharp.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
<PackageReference Include="TorchSharp" Version="$(TorchSharpVersion)" />
<PackageReference Include="libtorch-cpu-win-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Windows'))" PrivateAssets="all" />
<PackageReference Include="libtorch-cpu-linux-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Linux'))" PrivateAssets="all" />
Expand Down
2 changes: 0 additions & 2 deletions src/Microsoft.ML.TorchSharp/NasBert/Models/BaseModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ internal abstract class BaseModel : torch.nn.Module<torch.Tensor, torch.Tensor,

#pragma warning disable CA1024 // Use properties where appropriate: Modules should be fields in TorchSharp
public abstract TransformerEncoder GetEncoder();

public abstract BaseHead GetHead();
#pragma warning restore CA1024 // Use properties where appropriate

protected BaseModel(NasBertTrainer.NasBertOptions options)
Expand Down
1 change: 0 additions & 1 deletion src/Microsoft.ML.TorchSharp/NasBert/Models/NasBertModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ internal class NasBertModel : BaseModel
{
private readonly PredictionHead _predictionHead;

public override BaseHead GetHead() => _predictionHead;
public override TransformerEncoder GetEncoder() => Encoder;

protected readonly TransformerEncoder Encoder;
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.TorchSharp/NasBert/Models/PredictionHead.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

namespace Microsoft.ML.TorchSharp.NasBert.Models
{
internal sealed class PredictionHead : BaseHead, torch.nn.IModule<torch.Tensor, torch.Tensor>
internal sealed class PredictionHead : torch.nn.Module<torch.Tensor, torch.Tensor>
{
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:Private field name not in: _camelCase format", Justification = "Has to match TorchSharp model.")]
private readonly Sequential Classifier;
Expand All @@ -34,7 +34,7 @@ public PredictionHead(int inputDim, int numClasses, double dropoutRate)
}

[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp")]
public torch.Tensor forward(torch.Tensor features)
public override torch.Tensor forward(torch.Tensor features)
{
// TODO: try whitening-like techniques
// take <s> token (equiv. to [CLS])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

namespace Microsoft.ML.TorchSharp.NasBert.Models
{
internal sealed class TransformerEncoder : torch.nn.Module, torch.nn.IModule<torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor>
internal sealed class TransformerEncoder : torch.nn.Module<torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor>
{
#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format Have to match TorchSharp model

Expand Down Expand Up @@ -159,7 +159,7 @@ public TransformerEncoder(
}

[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
public torch.Tensor forward(
public override torch.Tensor forward(
torch.Tensor tokens,
torch.Tensor segmentLabels = null,
torch.Tensor positions = null)
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.TorchSharp/Utils/ImageUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public static void Postprocess(Tensor imgBatch, Tensor classification, Tensor re

for (int i = 0; i < classification.shape[2]; ++i)
{
var scores1 = torch.squeeze(classification[.., .., i]);
var scores1 = torch.squeeze(classification[.., .., i], null);
var scoresOverThresh = scores1 > 0.05;
if (scoresOverThresh.sum().ToSingle() == 0)
{
Expand All @@ -59,7 +59,7 @@ public static void Postprocess(Tensor imgBatch, Tensor classification, Tensor re
}

var scores = scores1[scoresOverThresh];
var anchorBoxes1 = torch.squeeze(transformedAnchors);
var anchorBoxes1 = torch.squeeze(transformedAnchors, null);
var anchorBoxes = anchorBoxes1[scoresOverThresh];
var anchorsNmsIdx = Nms(anchorBoxes, scores, overlapThreshold);
var finalAnchorBoxesIndexesValue = torch.ones(anchorsNmsIdx.shape[0], dtype: ScalarType.Int64, device: imgBatch.device).multiply(i);
Expand Down
2 changes: 1 addition & 1 deletion test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@

<ItemGroup Condition="'$(TargetArchitecture)' == 'x64'">
<PackageReference Include="libtorch-cpu-win-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Windows')) AND '$(TargetArchitecture)' == 'x64'" />
<!--<PackageReference Include="TorchSharp-cuda-windows" Version="0.96.8" Condition="$([MSBuild]::IsOSPlatform('Windows'))" />-->
<!--<PackageReference Include="TorchSharp-cuda-windows" Version="0.99.5" Condition="$([MSBuild]::IsOSPlatform('Windows'))" />-->
<PackageReference Include="libtorch-cpu-linux-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('Linux')) AND '$(TargetArchitecture)' == 'x64'" />
<PackageReference Include="libtorch-cpu-osx-x64" Version="$(LibTorchVersion)" Condition="$([MSBuild]::IsOSPlatform('OSX')) AND '$(TargetArchitecture)' == 'x64'" />
</ItemGroup>
Expand Down