Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,17 @@ internal static partial bool IsAlgorithmSupportedImpl(CompositeMLDsaAlgorithm al
return CompositeMLDsaManaged.IsAlgorithmSupportedImpl(algorithm);
}

internal static partial CompositeMLDsa GenerateKeyImpl(CompositeMLDsaAlgorithm algorithm) =>
throw new PlatformNotSupportedException();
internal static partial CompositeMLDsa GenerateKeyImpl(CompositeMLDsaAlgorithm algorithm)
{
#if !NETFRAMEWORK
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
throw new PlatformNotSupportedException();
}
#endif

return CompositeMLDsaManaged.GenerateKeyImpl(algorithm);
}

internal static partial CompositeMLDsa ImportCompositeMLDsaPublicKeyImpl(CompositeMLDsaAlgorithm algorithm, ReadOnlySpan<byte> source)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,28 @@ private RsaComponent(RSA rsa, HashAlgorithmName hashAlgorithmName, RSASignatureP
#if NETFRAMEWORK
// RSA-PSS requires RSACng on .NET Framework
private static RSACng CreateRSA() => new RSACng();
private static RSACng CreateRSA(int keySizeInBits) => new RSACng(keySizeInBits);
#elif NETSTANDARD2_0
private static RSA CreateRSA() => RSA.Create();

private static RSA CreateRSA(int keySizeInBits)
{
RSA rsa = RSA.Create();

try
{
rsa.KeySize = keySizeInBits;
return rsa;
}
catch
{
rsa.Dispose();
throw;
}
}
#else
private static RSA CreateRSA() => RSA.Create();
private static RSA CreateRSA(int keySizeInBits) => RSA.Create(keySizeInBits);
#endif

internal override int SignData(
Expand Down Expand Up @@ -80,8 +100,25 @@ internal override bool VerifyData(
#endif
}

public static RsaComponent GenerateKey(RsaAlgorithm algorithm) =>
throw new NotImplementedException();
public static RsaComponent GenerateKey(RsaAlgorithm algorithm)
{
RSA? rsa = null;

try
{
rsa = CreateRSA(algorithm.KeySizeInBits);

// RSA key generation is lazy, so we need to force it to happen eagerly.
_ = rsa.ExportParameters(includePrivateParameters: false);

return new RsaComponent(rsa, algorithm.HashAlgorithmName, algorithm.Padding);
}
catch (CryptographicException)
{
rsa?.Dispose();
throw;
}
}

public static RsaComponent ImportPrivateKey(RsaAlgorithm algorithm, ReadOnlySpan<byte> source)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,82 @@ internal static bool IsAlgorithmSupportedImpl(CompositeMLDsaAlgorithm algorithm)
});
}

internal static CompositeMLDsa GenerateKeyImpl(CompositeMLDsaAlgorithm algorithm) =>
throw new PlatformNotSupportedException();
internal static CompositeMLDsa GenerateKeyImpl(CompositeMLDsaAlgorithm algorithm)
{
Debug.Assert(IsAlgorithmSupportedImpl(algorithm));

AlgorithmMetadata metadata = s_algorithmMetadata[algorithm];

// draft-ietf-lamps-pq-composite-sigs-latest (July 7, 2025), 4.1
// 1. Generate component keys
//
// mldsaSeed = Random(32)
// (mldsaPK, _) = ML-DSA.KeyGen(mldsaSeed)
// (tradPK, tradSK) = Trad.KeyGen()

MLDsa? mldsaKey = null;
ComponentAlgorithm? tradKey = null;

try
{
mldsaKey = MLDsaImplementation.GenerateKey(metadata.MLDsaAlgorithm);
}
catch (CryptographicException)
{
}

try
{
tradKey = metadata.TraditionalAlgorithm switch
{
RsaAlgorithm rsaAlgorithm => RsaComponent.GenerateKey(rsaAlgorithm),
ECDsaAlgorithm ecdsaAlgorithm => ECDsaComponent.GenerateKey(ecdsaAlgorithm),
_ => FailAndGetNull(),
};

static ComponentAlgorithm? FailAndGetNull()
{
Debug.Fail("Only supported algorithms should reach here.");
return null;
}
}
catch (CryptographicException)
{
}

// 2. Check for component key gen failure
//
// if NOT (mldsaPK, mldsaSK) or NOT (tradPK, tradSK):
// output "Key generation error"

[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.NoOptimization)]
static bool KeyGenFailed([NotNullWhen(false)] MLDsa? mldsaKey, [NotNullWhen(false)] ComponentAlgorithm? tradKey) =>
(mldsaKey is null) | (tradKey is null);

if (KeyGenFailed(mldsaKey, tradKey))
{
try
{
Debug.Assert(mldsaKey is null || tradKey is null);

mldsaKey?.Dispose();
tradKey?.Dispose();
}
catch (CryptographicException)
{
}

throw new CryptographicException();
}

// 3. Output the composite public and private keys
//
// pk = SerializePublicKey(mldsaPK, tradPK)
// sk = SerializePrivateKey(mldsaSeed, tradSK)
// return (pk, sk)

return new CompositeMLDsaManaged(algorithm, mldsaKey, tradKey);
}

internal static CompositeMLDsa ImportCompositeMLDsaPublicKeyImpl(CompositeMLDsaAlgorithm algorithm, ReadOnlySpan<byte> source)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public static class CompositeMLDsaFactoryTests
[Fact]
public static void NullArgumentValidation()
{
AssertExtensions.Throws<ArgumentNullException>("algorithm", static () => CompositeMLDsa.GenerateKey(null));
AssertExtensions.Throws<ArgumentNullException>("algorithm", static () => CompositeMLDsa.IsAlgorithmSupported(null));
AssertExtensions.Throws<ArgumentNullException>("algorithm", static () => CompositeMLDsa.ImportCompositeMLDsaPrivateKey(null, Array.Empty<byte>()));
AssertExtensions.Throws<ArgumentNullException>("algorithm", static () => CompositeMLDsa.ImportCompositeMLDsaPrivateKey(null, ReadOnlySpan<byte>.Empty));
Expand Down Expand Up @@ -169,6 +170,17 @@ private static void AssertImportBadPublicKey(CompositeMLDsaAlgorithm algorithm,
key);
}

[Theory]
[MemberData(nameof(CompositeMLDsaTestData.SupportedAlgorithmsTestData), MemberType = typeof(CompositeMLDsaTestData))]
public static void AlgorithmMatches_GenerateKey(CompositeMLDsaAlgorithm algorithm)
{
AssertThrowIfNotSupported(() =>
{
using CompositeMLDsa dsa = CompositeMLDsa.GenerateKey(algorithm);
Assert.Equal(algorithm, dsa.Algorithm);
});
}

[Theory]
[MemberData(nameof(CompositeMLDsaTestData.SupportedAlgorithmIetfVectorsTestData), MemberType = typeof(CompositeMLDsaTestData))]
public static void AlgorithmMatches_Import(CompositeMLDsaTestData.CompositeMLDsaTestVector vector)
Expand All @@ -186,7 +198,7 @@ public static void AlgorithmMatches_Import(CompositeMLDsaTestData.CompositeMLDsa
public static void IsSupported_AgreesWithPlatform()
{
// Composites are supported everywhere MLDsa is supported
Assert.Equal(MLDsa.IsSupported && !PlatformDetection.IsLinux, CompositeMLDsa.IsSupported);
Assert.Equal(MLDsa.IsSupported, CompositeMLDsa.IsSupported);
}

[Theory]
Expand All @@ -195,7 +207,7 @@ public static void IsAlgorithmSupported_AgreesWithPlatform(CompositeMLDsaAlgorit
{
bool supported = CompositeMLDsaTestHelpers.ExecuteComponentFunc(
algorithm,
_ => MLDsa.IsSupported && !PlatformDetection.IsLinux,
_ => MLDsa.IsSupported,
_ => false,
_ => false);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ namespace System.Security.Cryptography.Tests
[ConditionalClass(typeof(CompositeMLDsa), nameof(CompositeMLDsa.IsSupported))]
public sealed class CompositeMLDsaImplementationTests : CompositeMLDsaTestsBase
{
[Theory]
[MemberData(nameof(CompositeMLDsaTestData.SupportedAlgorithmsTestData), MemberType = typeof(CompositeMLDsaTestData))]
public static void CompositeMLDsaIsOnlyPublicAncestor_GenerateKey(CompositeMLDsaAlgorithm algorithm)
{
AssertCompositeMLDsaIsOnlyPublicAncestor(() => CompositeMLDsa.GenerateKey(algorithm));
}

[Theory]
[MemberData(nameof(CompositeMLDsaTestData.SupportedAlgorithmIetfVectorsTestData), MemberType = typeof(CompositeMLDsaTestData))]
public static void CompositeMLDsaIsOnlyPublicAncestor_Import(CompositeMLDsaTestData.CompositeMLDsaTestVector info)
Expand All @@ -32,6 +39,66 @@ private static void AssertCompositeMLDsaIsOnlyPublicAncestor(Func<CompositeMLDsa
Assert.Equal(typeof(CompositeMLDsa), keyType);
}

#region Roundtrip by exporting then importing

[Theory]
[MemberData(nameof(CompositeMLDsaTestData.SupportedAlgorithmsTestData), MemberType = typeof(CompositeMLDsaTestData))]
public void RoundTrip_Export_Import_PublicKey(CompositeMLDsaAlgorithm algorithm)
{
// Generate new key
using CompositeMLDsa dsa = GenerateKey(algorithm);

CompositeMLDsaTestHelpers.AssertExportPublicKey(
export =>
{
// Roundtrip using public key. First export it.
byte[] exportedPublicKey = export(dsa);
CompositeMLDsaTestHelpers.AssertImportPublicKey(
import =>
{
// Then import it.
using CompositeMLDsa roundTrippedDsa = import();

// Verify the roundtripped object has the same key
Assert.Equal(algorithm, roundTrippedDsa.Algorithm);
AssertExtensions.SequenceEqual(dsa.ExportCompositeMLDsaPublicKey(), roundTrippedDsa.ExportCompositeMLDsaPublicKey());
Assert.Throws<CryptographicException>(() => roundTrippedDsa.ExportCompositeMLDsaPrivateKey());
},
algorithm,
exportedPublicKey);
});
}

[Theory]
[MemberData(nameof(CompositeMLDsaTestData.SupportedAlgorithmsTestData), MemberType = typeof(CompositeMLDsaTestData))]
public void RoundTrip_Export_Import_PrivateKey(CompositeMLDsaAlgorithm algorithm)
{
// Generate new key
using CompositeMLDsa dsa = GenerateKey(algorithm);

CompositeMLDsaTestHelpers.AssertExportPrivateKey(
export =>
{
// Roundtrip using secret key. First export it.
byte[] exportedSecretKey = export(dsa);
CompositeMLDsaTestHelpers.AssertImportPrivateKey(
import =>
{
// Then import it.
using CompositeMLDsa roundTrippedDsa = import();

// Verify the roundtripped object has the same key
Assert.Equal(algorithm, roundTrippedDsa.Algorithm);
AssertExtensions.SequenceEqual(dsa.ExportCompositeMLDsaPrivateKey(), roundTrippedDsa.ExportCompositeMLDsaPrivateKey());
AssertExtensions.SequenceEqual(dsa.ExportCompositeMLDsaPublicKey(), roundTrippedDsa.ExportCompositeMLDsaPublicKey());
},
algorithm,
exportedSecretKey);
});
}

#endregion Roundtrip by exporting then importing

#region Roundtrip by importing then exporting

[Theory]
Expand Down Expand Up @@ -60,6 +127,9 @@ public void RoundTrip_Import_Export_PrivateKey(CompositeMLDsaTestData.CompositeM

#endregion Roundtrip by importing then exporting

protected override CompositeMLDsa GenerateKey(CompositeMLDsaAlgorithm algorithm) =>
CompositeMLDsa.GenerateKey(algorithm);

protected override CompositeMLDsa ImportPublicKey(CompositeMLDsaAlgorithm algorithm, ReadOnlySpan<byte> source) =>
CompositeMLDsa.ImportCompositeMLDsaPublicKey(algorithm, source);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ internal CompositeMLDsaTestVector(string tcId, CompositeMLDsaAlgorithm algo, str
public static IEnumerable<object[]> AllAlgorithmsTestData =>
AllAlgorithms.Select(v => new object[] { v });

public static IEnumerable<object[]> SupportedAlgorithmsTestData =>
AllAlgorithms.Where(CompositeMLDsa.IsAlgorithmSupported).Select(v => new object[] { v });

internal static MLDsaKeyInfo GetMLDsaIetfTestVector(CompositeMLDsaAlgorithm algorithm)
{
MLDsaAlgorithm mldsaAlgorithm = CompositeMLDsaTestHelpers.MLDsaAlgorithms[algorithm];
Expand Down
Loading
Loading