Skip to content
29 changes: 19 additions & 10 deletions src/AutoRest.CSharp/Mgmt/Models/MgmtRestClientBuilder.cs
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO.MemoryMappedFiles;
using System.Linq;
using AutoRest.CSharp.Common.Input;
using AutoRest.CSharp.Generation.Types;
using AutoRest.CSharp.Generation.Writers;
using AutoRest.CSharp.Input;
using AutoRest.CSharp.Mgmt.AutoRest;
using AutoRest.CSharp.Output.Builders;
using AutoRest.CSharp.Output.Models;
using AutoRest.CSharp.Output.Models.Shared;
using AutoRest.CSharp.Output.Models.Types;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace AutoRest.CSharp.Mgmt.Models
{
internal class MgmtRestClientBuilder : CmcRestClientBuilder
{
private static HashSet<string> AllowedRequestParameterOrigins = new HashSet<string>(StringComparer.OrdinalIgnoreCase) { "modelerfour:synthesized/host", "modelerfour:synthesized/api-version" };
private class ParameterCompareer : IEqualityComparer<RequestParameter>
{
public bool Equals([AllowNull] RequestParameter x, [AllowNull] RequestParameter y)
Expand All @@ -44,11 +41,23 @@ public MgmtRestClientBuilder(OperationGroup operationGroup)
{
}

public static IEnumerable<RequestParameter> GetMgmtParametersFromOperations(ICollection<Operation> operations) =>
operations
.SelectMany(op => op.Parameters.Concat(op.Requests.SelectMany(r => r.Parameters)))
.Where(p => p.Implementation == ImplementationLocation.Client)
.Distinct(new ParameterCompareer());
private static IReadOnlyList<RequestParameter> GetMgmtParametersFromOperations(ICollection<Operation> operations)
{
var parameters = new HashSet<RequestParameter>(new ParameterCompareer());
foreach (var operation in operations)
{
var clientParameters = operation.Parameters.Where(p => p.Implementation == ImplementationLocation.Client);
foreach (var parameter in clientParameters)
{
if (!AllowedRequestParameterOrigins.Contains(parameter.Origin ?? string.Empty))
{
throw new InvalidOperationException($"'{parameter.Language.Default.Name}' with origin '{parameter.Origin}' should be method parameter for operation '{operation.OperationId}'");
}
parameters.Add(parameter);
}
}
return parameters.ToList();
}

public override Parameter BuildConstructorParameter(RequestParameter requestParameter)
{
Expand Down