diff --git a/InterfaceStubGenerator.Shared/AttributeDataAccessor.cs b/InterfaceStubGenerator.Shared/AttributeDataAccessor.cs new file mode 100644 index 000000000..6a4463cb0 --- /dev/null +++ b/InterfaceStubGenerator.Shared/AttributeDataAccessor.cs @@ -0,0 +1,220 @@ +// using System.Reflection; +// +// using Microsoft.CodeAnalysis; +// using Microsoft.CodeAnalysis.CSharp.Syntax; +// using Microsoft.CodeAnalysis.Operations; +// +// namespace Refit.Generator +// { +// public class AttributeDataAccessor(WellKnownTypes symbolAccessor) +// { +// private const string NameOfOperatorName = "nameof"; +// private const char FullNameOfPrefix = '@'; +// +// public TAttribute AccessSingle(ISymbol symbol) +// where TAttribute : Attribute => AccessSingle(symbol); +// +// public TData AccessSingle(ISymbol symbol) +// where TAttribute : Attribute +// where TData : notnull => Access(symbol).Single(); +// +// public TAttribute? AccessFirstOrDefault(ISymbol symbol) +// where TAttribute : Attribute => Access(symbol).FirstOrDefault(); +// +// public TData? AccessFirstOrDefault(ISymbol symbol) +// where TAttribute : Attribute +// where TData : notnull => Access(symbol).FirstOrDefault(); +// +// // public bool HasAttribute(ISymbol symbol) +// // where TAttribute : Attribute => symbolAccessor.HasAttribute(symbol); +// +// public IEnumerable Access(ISymbol symbol) +// where TAttribute : Attribute => Access(symbol); +// +// // public IEnumerable TryAccess(IEnumerable data) +// // where TAttribute : Attribute => TryAccess(data); +// +// public IEnumerable Access(ISymbol symbol) +// where TAttribute : Attribute +// where TData : notnull +// { +// var attrDatas = symbolAccessor.GetAttributes(symbol); +// return Access(attrDatas); +// } +// +// // public IEnumerable TryAccess(IEnumerable attributes) +// // where TAttribute : Attribute +// // where TData : notnull +// // { +// // var attrDatas = symbolAccessor.TryGetAttributes(attributes); +// // return attrDatas.Select(a => Access(a)); +// // } +// +// /// +// /// Reads the attribute data and sets it on a newly created instance of . +// /// If has n type parameters, +// /// needs to have an accessible ctor with the parameters 0 to n-1 to be of type . +// /// needs to have exactly the same constructors as with additional type arguments. +// /// +// /// The attributes data. +// /// The type of the attribute. +// /// The type of the data class. If no type parameters are involved, this is usually the same as . +// /// The attribute data. +// /// If a property or ctor argument of could not be read on the attribute. +// public IEnumerable Access(IEnumerable attributes) +// where TAttribute : Attribute +// where TData : notnull +// { +// foreach (var attrData in symbolAccessor.GetAttributes(attributes)) +// { +// yield return Access(attrData, symbolAccessor); +// } +// } +// +// internal static TData Access(AttributeData attrData, WellKnownTypes? symbolAccessor = null) +// where TAttribute : Attribute +// where TData : notnull +// { +// var attrType = typeof(TAttribute); +// var dataType = typeof(TData); +// +// var syntax = (AttributeSyntax?)attrData.ApplicationSyntaxReference?.GetSyntax(); +// var syntaxArguments = +// (IReadOnlyList?)syntax?.ArgumentList?.Arguments +// ?? new AttributeArgumentSyntax[attrData.ConstructorArguments.Length + attrData.NamedArguments.Length]; +// var typeArguments = (IReadOnlyCollection?)attrData.AttributeClass?.TypeArguments ?? []; +// var attr = Create(typeArguments, attrData.ConstructorArguments, syntaxArguments, symbolAccessor); +// +// var syntaxIndex = attrData.ConstructorArguments.Length; +// var propertiesByName = +// dataType.GetProperties().GroupBy(x => x.Name).ToDictionary(x => x.Key, x => x.First()); +// foreach (var namedArgument in attrData.NamedArguments) +// { +// if (!propertiesByName.TryGetValue(namedArgument.Key, out var prop)) +// throw new InvalidOperationException( +// $"Could not get property {namedArgument.Key} of attribute {attrType.FullName}"); +// +// var value = BuildArgumentValue(namedArgument.Value, prop.PropertyType, syntaxArguments[syntaxIndex], +// symbolAccessor); +// prop.SetValue(attr, value); +// syntaxIndex++; +// } +// +// // if (attr is HasSyntaxReference symbolRefHolder) +// // { +// // symbolRefHolder.SyntaxReference = attrData.ApplicationSyntaxReference?.GetSyntax(); +// // } +// +// return attr; +// } +// +// private static TData Create( +// IReadOnlyCollection typeArguments, +// IReadOnlyCollection constructorArguments, +// IReadOnlyList argumentSyntax, +// WellKnownTypes? symbolAccessor +// ) +// where TData : notnull +// { +// // The data class should have a constructor +// // with generic type parameters of the attribute class +// // as ITypeSymbol parameters followed by all other parameters +// // of the attribute constructor. +// // Multiple attribute class constructors/generic data classes are not yet supported. +// var argCount = typeArguments.Count + constructorArguments.Count; +// foreach (var constructor in typeof(TData).GetConstructors()) +// { +// var parameters = constructor.GetParameters(); +// if (parameters.Length != argCount) +// continue; +// +// var constructorArgumentValues = constructorArguments.Select( +// (arg, i) => BuildArgumentValue(arg, parameters[i + typeArguments.Count].ParameterType, +// argumentSyntax[i], symbolAccessor) +// ); +// var constructorTypeAndValueArguments = typeArguments.Concat(constructorArgumentValues).ToArray(); +// if (!ValidateParameterTypes(constructorTypeAndValueArguments, parameters)) +// continue; +// +// return (TData?)Activator.CreateInstance(typeof(TData), constructorTypeAndValueArguments) +// ?? throw new InvalidOperationException($"Could not create instance of {typeof(TData)}"); +// } +// +// throw new InvalidOperationException( +// $"{typeof(TData)} does not have a constructor with {argCount} parameters and matchable arguments" +// ); +// } +// +// private static object? BuildArgumentValue( +// TypedConstant arg, +// Type targetType, +// AttributeArgumentSyntax? syntax, +// WellKnownTypes? symbolAccessor +// ) +// { +// return arg.Kind switch +// { +// // _ when (targetType == typeof(AttributeValue?) || targetType == typeof(AttributeValue)) && syntax != null => new AttributeValue( +// // arg, +// // syntax.Expression +// // ), +// _ when arg.IsNull => null, +// // _ when targetType == typeof(IMemberPathConfiguration) => CreateMemberPath(arg, syntax, symbolAccessor), +// TypedConstantKind.Enum => GetEnumValue(arg, targetType), +// TypedConstantKind.Array => BuildArrayValue(arg, targetType, symbolAccessor), +// TypedConstantKind.Primitive => arg.Value, +// TypedConstantKind.Type when targetType == typeof(ITypeSymbol) => arg.Value, +// _ => throw new ArgumentOutOfRangeException( +// $"{nameof(AttributeDataAccessor)} does not support constructor arguments of kind {arg.Kind.ToString()} or cannot convert it to {targetType}" +// ), +// }; +// } +// +// private static object?[] BuildArrayValue(TypedConstant arg, Type targetType, WellKnownTypes? symbolAccessor) +// { +// if (!targetType.IsGenericType || targetType.GetGenericTypeDefinition() != typeof(IReadOnlyCollection<>)) +// throw new InvalidOperationException( +// $"{nameof(IReadOnlyCollection)} is the only supported array type"); +// +// var elementTargetType = targetType.GetGenericArguments()[0]; +// return arg.Values.Select(x => BuildArgumentValue(x, elementTargetType, null, symbolAccessor)).ToArray(); +// } +// +// private static object? GetEnumValue(TypedConstant arg, Type targetType) +// { +// if (arg.Value == null) +// return null; +// +// var enumRoslynType = arg.Type ?? throw new InvalidOperationException("Type is null"); +// if (targetType == typeof(IFieldSymbol)) +// return enumRoslynType.GetMembers().Where(x => x.Kind == SymbolKind.Field) +// .First(f => Equals(f.ConstantValue, arg.Value)); +// +// if (targetType.IsConstructedGenericType && targetType.GetGenericTypeDefinition() == typeof(Nullable<>)) +// { +// targetType = Nullable.GetUnderlyingType(targetType)!; +// } +// +// return Enum.ToObject(targetType, arg.Value); +// } +// +// private static bool ValidateParameterTypes(object?[] arguments, ParameterInfo[] parameters) +// { +// if (arguments.Length != parameters.Length) +// return false; +// +// for (var argIdx = 0; argIdx < arguments.Length; argIdx++) +// { +// var value = arguments[argIdx]; +// var param = parameters[argIdx]; +// if (value == null && param.ParameterType.IsValueType) +// return false; +// +// // if (value?.GetType().IsAssignableTo(param.ParameterType) == false) +// // return false; +// } +// +// return true; +// } +// } +// } diff --git a/InterfaceStubGenerator.Shared/Configuration/AliasAsConfiguration.cs b/InterfaceStubGenerator.Shared/Configuration/AliasAsConfiguration.cs new file mode 100644 index 000000000..7d88d26b0 --- /dev/null +++ b/InterfaceStubGenerator.Shared/Configuration/AliasAsConfiguration.cs @@ -0,0 +1,9 @@ +namespace Refit.Generator.Configuration; + +// TODO: I hate how I have to duplicate the attributes in this file. +// See if I can remove this. Iirc Mapperly didn't need to do this initially. Might not need all this +// Arguably cleaner doing the current system tho +public class AliasAsConfiguration(string name) +{ + public string Name { get; protected set; } = name; +} diff --git a/InterfaceStubGenerator.Shared/Configuration/AttachmentNameConfiguration.cs b/InterfaceStubGenerator.Shared/Configuration/AttachmentNameConfiguration.cs new file mode 100644 index 000000000..08e392dda --- /dev/null +++ b/InterfaceStubGenerator.Shared/Configuration/AttachmentNameConfiguration.cs @@ -0,0 +1,12 @@ +namespace Refit.Generator.Configuration; + +public class AttachmentNameConfiguration(string name) +{ + /// + /// Gets or sets the name. + /// + /// + /// The name. + /// + public string Name { get; protected set; } = name; +} diff --git a/InterfaceStubGenerator.Shared/Configuration/AuthorizeConfiguration.cs b/InterfaceStubGenerator.Shared/Configuration/AuthorizeConfiguration.cs new file mode 100644 index 000000000..4e92a3346 --- /dev/null +++ b/InterfaceStubGenerator.Shared/Configuration/AuthorizeConfiguration.cs @@ -0,0 +1,12 @@ +namespace Refit.Generator.Configuration; + +public class AuthorizeConfiguration(string scheme = "Bearer") : Attribute +{ + /// + /// Gets the scheme. + /// + /// + /// The scheme. + /// + public string Scheme { get; } = scheme; +} diff --git a/InterfaceStubGenerator.Shared/Configuration/BodyConfiguration.cs b/InterfaceStubGenerator.Shared/Configuration/BodyConfiguration.cs new file mode 100644 index 000000000..1c310c8c3 --- /dev/null +++ b/InterfaceStubGenerator.Shared/Configuration/BodyConfiguration.cs @@ -0,0 +1,53 @@ +namespace Refit.Generator.Configuration; + +public class BodyConfiguration +{ + /// + /// Initializes a new instance of the class. + /// + public BodyConfiguration() { } + + /// + /// Initializes a new instance of the class. + /// + /// if set to true [buffered]. + public BodyConfiguration(bool buffered) => Buffered = buffered; + + /// + /// Initializes a new instance of the class. + /// + /// The serialization method. + /// if set to true [buffered]. + public BodyConfiguration(BodySerializationMethod serializationMethod, bool buffered) + { + SerializationMethod = serializationMethod; + Buffered = buffered; + } + + /// + /// Initializes a new instance of the class. + /// + /// The serialization method. + public BodyConfiguration( + BodySerializationMethod serializationMethod = BodySerializationMethod.Default + ) + { + SerializationMethod = serializationMethod; + } + + /// + /// Gets or sets the buffered. + /// + /// + /// The buffered. + /// + public bool? Buffered { get; } + + /// + /// Gets or sets the serialization method. + /// + /// + /// The serialization method. + /// + public BodySerializationMethod SerializationMethod { get; } = BodySerializationMethod.Default; +} diff --git a/InterfaceStubGenerator.Shared/Configuration/BodySerializationMethod.cs b/InterfaceStubGenerator.Shared/Configuration/BodySerializationMethod.cs new file mode 100644 index 000000000..0ca7e6554 --- /dev/null +++ b/InterfaceStubGenerator.Shared/Configuration/BodySerializationMethod.cs @@ -0,0 +1,29 @@ +#nullable enable +namespace Refit.Generator.Configuration; + +/// +/// Defines methods to serialize HTTP requests' bodies. +/// +public enum BodySerializationMethod +{ + /// + /// Encodes everything using the ContentSerializer in RefitSettings except for strings. Strings are set as-is + /// + Default = 0, + + /// + /// Json encodes everything, including strings + /// + [Obsolete("Use BodySerializationMethod.Serialized instead", false)] + Json = 1, + + /// + /// Form-UrlEncode's the values + /// + UrlEncoded = 2, + + /// + /// Encodes everything using the ContentSerializer in RefitSettings + /// + Serialized = 3, +} diff --git a/InterfaceStubGenerator.Shared/Configuration/CollectionFormat.cs b/InterfaceStubGenerator.Shared/Configuration/CollectionFormat.cs new file mode 100644 index 000000000..b5fc11b7d --- /dev/null +++ b/InterfaceStubGenerator.Shared/Configuration/CollectionFormat.cs @@ -0,0 +1,39 @@ +#nullable enable +namespace Refit.Generator.Configuration; + +/// +/// Collection format defined in https://swagger.io/docs/specification/2-0/describing-parameters/ +/// +public enum CollectionFormat +{ + /// + /// Values formatted with or + /// . + /// + RefitParameterFormatter, + + /// + /// Comma-separated values + /// + Csv, + + /// + /// Space-separated values + /// + Ssv, + + /// + /// Tab-separated values + /// + Tsv, + + /// + /// Pipe-separated values + /// + Pipes, + + /// + /// Multiple parameter instances + /// + Multi, +} diff --git a/InterfaceStubGenerator.Shared/Configuration/HeaderCollectionConfiguration.cs b/InterfaceStubGenerator.Shared/Configuration/HeaderCollectionConfiguration.cs new file mode 100644 index 000000000..6a0cfba66 --- /dev/null +++ b/InterfaceStubGenerator.Shared/Configuration/HeaderCollectionConfiguration.cs @@ -0,0 +1,3 @@ +namespace Refit.Generator.Configuration; + +public class HeaderCollectionConfiguration { } diff --git a/InterfaceStubGenerator.Shared/Configuration/HeaderConfiguration.cs b/InterfaceStubGenerator.Shared/Configuration/HeaderConfiguration.cs new file mode 100644 index 000000000..8df4801f6 --- /dev/null +++ b/InterfaceStubGenerator.Shared/Configuration/HeaderConfiguration.cs @@ -0,0 +1,6 @@ +namespace Refit.Generator.Configuration; + +public class HeaderConfiguration(string header) : Attribute +{ + public string Header { get; } = header; +} diff --git a/InterfaceStubGenerator.Shared/Configuration/HeadersConfiguration.cs b/InterfaceStubGenerator.Shared/Configuration/HeadersConfiguration.cs new file mode 100644 index 000000000..dce167528 --- /dev/null +++ b/InterfaceStubGenerator.Shared/Configuration/HeadersConfiguration.cs @@ -0,0 +1,13 @@ +#nullable enable +namespace Refit.Generator.Configuration; + +public class HeadersConfiguration(params string[] headers) : Attribute +{ + /// + /// Gets the headers. + /// + /// + /// The headers. + /// + public string[] Headers { get; } = headers ?? []; +} diff --git a/InterfaceStubGenerator.Shared/Configuration/HttpMethodConfiguration.cs b/InterfaceStubGenerator.Shared/Configuration/HttpMethodConfiguration.cs new file mode 100644 index 000000000..9eed1511a --- /dev/null +++ b/InterfaceStubGenerator.Shared/Configuration/HttpMethodConfiguration.cs @@ -0,0 +1,20 @@ +namespace Refit.Generator.Configuration; + +public class HttpMethodConfiguration(string path) +{ + /// + /// Gets the method. + /// + /// + /// The method. + /// + public HttpMethod Method { get; set; } + + /// + /// Gets or sets the path. + /// + /// + /// The path. + /// + public virtual string Path { get; protected set; } = path; +} diff --git a/InterfaceStubGenerator.Shared/Configuration/MulitpartConfiguration.cs b/InterfaceStubGenerator.Shared/Configuration/MulitpartConfiguration.cs new file mode 100644 index 000000000..d03d17e5c --- /dev/null +++ b/InterfaceStubGenerator.Shared/Configuration/MulitpartConfiguration.cs @@ -0,0 +1,6 @@ +namespace Refit.Generator.Configuration; + +public class MulitpartConfiguration(string boundaryText = "----MyGreatBoundary") +{ + public string BoundaryText { get; private set; } = boundaryText; +} diff --git a/InterfaceStubGenerator.Shared/Configuration/PropertyConfiguration.cs b/InterfaceStubGenerator.Shared/Configuration/PropertyConfiguration.cs new file mode 100644 index 000000000..727e4a455 --- /dev/null +++ b/InterfaceStubGenerator.Shared/Configuration/PropertyConfiguration.cs @@ -0,0 +1,13 @@ +namespace DefaultNamespace; + +public class PropertyConfiguration +{ + public PropertyConfiguration() { } + + public PropertyConfiguration(string key) + { + Key = key; + } + + public string? Key { get; } +} diff --git a/InterfaceStubGenerator.Shared/Configuration/QueryConfiguration.cs b/InterfaceStubGenerator.Shared/Configuration/QueryConfiguration.cs new file mode 100644 index 000000000..02d1cc505 --- /dev/null +++ b/InterfaceStubGenerator.Shared/Configuration/QueryConfiguration.cs @@ -0,0 +1,115 @@ +namespace Refit.Generator.Configuration; + +public class QueryConfiguration +{ + CollectionFormat? collectionFormat; + + /// + /// Initializes a new instance of the class. + /// + public QueryConfiguration() { } + + /// + /// Initializes a new instance of the class. + /// + /// The delimiter. + public QueryConfiguration(string delimiter) + { + Delimiter = delimiter; + } + + /// + /// Initializes a new instance of the class. + /// + /// The delimiter. + /// The prefix. + public QueryConfiguration(string delimiter, string prefix) + { + Delimiter = delimiter; + Prefix = prefix; + } + + /// + /// Initializes a new instance of the class. + /// + /// The delimiter. + /// The prefix. + /// The format. + public QueryConfiguration(string delimiter, string prefix, string format) + { + Delimiter = delimiter; + Prefix = prefix; + Format = format; + } + + /// + /// Initializes a new instance of the class. + /// + /// The collection format. + public QueryConfiguration(CollectionFormat collectionFormat) + { + CollectionFormat = collectionFormat; + } + + /// + /// Used to customize the name of either the query parameter pair or of the form field when form encoding. + /// + /// + public string Delimiter { get; protected set; } = "."; + + /// + /// Used to customize the name of the encoded value. + /// + /// + /// Gets combined with in the format var name = $"{Prefix}{Delimiter}{originalFieldName}" + /// where originalFieldName is the name of the object property or method parameter. + /// + /// + /// + /// class Form + /// { + /// [Query("-", "dontlog")] + /// public string password { get; } + /// } + /// + /// will result in the encoded form having a field named dontlog-password. + /// + public string? Prefix { get; protected set; } + +#pragma warning disable CA1019 // Define accessors for attribute arguments + + /// + /// Used to customize the formatting of the encoded value. + /// + /// + /// + /// interface IServerApi + /// { + /// [Get("/expenses")] + /// Task addExpense([Query(Format="0.00")] double expense); + /// } + /// + /// Calling serverApi.addExpense(5) will result in a URI of {baseUri}/expenses?expense=5.00. + /// + public string? Format { get; set; } + + /// + /// Specifies how the collection should be encoded. + /// + public CollectionFormat CollectionFormat + { + // Cannot make property nullable due to Attribute restrictions + get => collectionFormat.GetValueOrDefault(); + set => collectionFormat = value; + } + +#pragma warning restore CA1019 // Define accessors for attribute arguments + + /// + /// Gets a value indicating whether this instance is collection format specified. + /// + /// + /// true if this instance is collection format specified; otherwise, false. + /// + public bool IsCollectionFormatSpecified => collectionFormat.HasValue; +} diff --git a/InterfaceStubGenerator.Shared/Configuration/QueryUriFormatConfiguration.cs b/InterfaceStubGenerator.Shared/Configuration/QueryUriFormatConfiguration.cs new file mode 100644 index 000000000..7c3d5746f --- /dev/null +++ b/InterfaceStubGenerator.Shared/Configuration/QueryUriFormatConfiguration.cs @@ -0,0 +1,9 @@ +namespace Refit.Generator.Configuration; + +public class QueryUriFormatConfiguration(UriFormat uriFormat) +{ + /// + /// Specifies how the Query Params should be encoded. + /// + public UriFormat UriFormat { get; } = uriFormat; +} diff --git a/InterfaceStubGenerator.Shared/EmitRefitBody.cs b/InterfaceStubGenerator.Shared/EmitRefitBody.cs new file mode 100644 index 000000000..3963afaff --- /dev/null +++ b/InterfaceStubGenerator.Shared/EmitRefitBody.cs @@ -0,0 +1,529 @@ +using System.Diagnostics; +using System.Linq; +using System.Reflection; +using System.Text; +using Microsoft.CodeAnalysis.Text; +using Refit.Generator.Configuration; + +namespace Refit.Generator; + +static class EmitRefitBody +{ + // TODO: replace default with CancellationToken.None + // TODO: should I make this an instance and store information as properties + // Alternatively I could use a context object passed to each method :thinking: + const string InnerMethodName = "Inner"; + const string RequestName = "request"; + const string SettingsExpression = "this.requestBuilder.Settings"; + + // TODO: use UniqueNameBuilder + public static void WriteRefitBody( + StringBuilder source, + MethodModel methodModel, + UniqueNameBuilder uniqueNames + ) + { + var methodScope = uniqueNames.NewScope(); + var innerMethodName = methodScope.New(InnerMethodName); + + WriteCreateRequestMethod(source, methodModel.RefitBody!, methodScope, innerMethodName); + + var requestName = methodScope.New(RequestName); + + source.AppendLine( + $""" + + var {requestName} = {innerMethodName}(); + """ + ); + WriteReturn(source, methodModel, uniqueNames, requestName); + } + + static void WriteCreateRequestMethod( + StringBuilder source, + RefitBodyModel model, + UniqueNameBuilder uniqueNames, + string innerMethodName + ) + { + uniqueNames = uniqueNames.NewScope(); + + var requestName = uniqueNames.New(RequestName); + source.Append( + $$""" + + global::System.Net.Http.HttpRequestMessage {{innerMethodName}}() + { + var {{requestName}} = new global::System.Net.Http.HttpRequestMessage() { Method = {{HttpMethodToEnumString( + model.HttpMethod + )}} }; + + """ + ); + + // TryWriteMultiPartInit(source, model, requestName); + TryWriteBody(source, model, requestName); + // need to run multi part attachment here. + TryWriteHeaders(source, model, requestName); + + WriteProperties(source, model, requestName); + WriteVersion(source, model, requestName); + + WriteBuildUrl(source, model, requestName, uniqueNames); + source.AppendLine( + $$""" + + return {{requestName}}; + } + """ + ); + } + + static string HttpMethodToEnumString(HttpMethod method) + { + if (method == HttpMethod.Get) + { + return "global::System.Net.Http.HttpMethod.Get"; + } + else if (method == HttpMethod.Post) + { + return "global::System.Net.Http.HttpMethod.Post"; + } + else if (method == HttpMethod.Put) + { + return "global::System.Net.Http.HttpMethod.Put"; + } + else if (method == HttpMethod.Delete) + { + return "global::System.Net.Http.HttpMethod.Delete"; + } + else if (method == new HttpMethod("PATCH")) + { + return "global::Refit.RefitHelper.Patch"; + } + else if (method == HttpMethod.Options) + { + return "global::System.Net.Http.HttpMethod.Options"; + } + else if (method == HttpMethod.Head) + { + return "global::System.Net.Http.HttpMethod.Head"; + } + + throw new NotImplementedException(); + } + + // TODO: make into scope names + static string? WriteParameterInfoArray( + StringBuilder source, + RefitBodyModel model, + UniqueNameBuilder uniqueNames + ) + { + throw new NotImplementedException(); + // if no usage of ParameterInfo then exit early + if ( + model.QueryParameters.Count == 0 + && model.UrlFragments.OfType().Count() + == model.UrlFragments.Count + ) + { + return null; + } + + // TODO: implement + var fieldName = uniqueNames.New("__parameterInfo"); + // source.Append($"global::System.Reflection.ParameterInfo[] {fieldName} = {model.}") + return null; + } + + static void TryWriteMultiPartInit( + StringBuilder source, + RefitBodyModel model, + string requestName + ) + { + if (model.MultipartBoundary is null) + return; + + source.Append( + $$""" + throw new NotImplementedException("MultiPart"); + {{requestName}}.Content = new global::System.Net.Http.MultipartFormDataContent({{model.MultipartBoundary}}); + """ + ); + } + + static void TryWriteBody(StringBuilder source, RefitBodyModel model, string requestName) + { + if (model.BodyParameter is null) + return; + var isBuffered = WriteBool(model.BodyParameter.Buffered); + var serializationMethod = model.BodyParameter.SerializationMethod switch + { + BodySerializationMethod.Default => "global::Refit.BodySerializationMethod.Default", + BodySerializationMethod.Json => "global::Refit.BodySerializationMethod.Json", + BodySerializationMethod.UrlEncoded => + "global::Refit.BodySerializationMethod.UrlEncoded", + BodySerializationMethod.Serialized => + "global::Refit.BodySerializationMethod.Serialized", + }; + + // TODO: use full alias for type + source.Append( + $$""" + + global::Refit.RefitHelper.AddBody({{requestName}}, {{SettingsExpression}}, {{model + .BodyParameter + .Parameter}}, {{isBuffered}}, {{serializationMethod}}); + """ + ); + } + + static void TryWriteHeaders(StringBuilder source, RefitBodyModel model, string requestName) + { + if (model.HeaderPs.Count == 0) + { + return; + } + + source.AppendLine( + $$""" + + {{requestName}}.Content = new global::System.Net.Http.ByteArrayContent([]); + """ + ); + + foreach (var headerPs in model.HeaderPs) + { + if (headerPs.Type == HeaderType.Static) + { + source.AppendLine( + $$""" + global::Refit.RefitHelper.AddHeader({{requestName}}, {{headerPs + .Static + .Value + .Key}}, {{headerPs.Static.Value.Value}}.ToString()); + """ + ); + } + else if (headerPs.Type == HeaderType.Collection) + { + source.AppendLine( + $$""" + global::Refit.RefitHelper.AddHeaderCollection({{requestName}}, {{model.HeaderCollectionParam}}); + """ + ); + } + else if (headerPs.Type == HeaderType.Authorise) + { + source.AppendLine( + $$""" + global::Refit.RefitHelper.AddHeader({{requestName}}, "Authorization", $"{{headerPs + .Authorise + .Value + .Scheme}} {{{headerPs.Authorise.Value.Parameter}}.ToString()}"); + """ + ); + } + } + // // TODO: implement + // // if no method headers, parameter headers or header collections don't emit + // if (model.Headers.Count == 0 && model.HeaderParameters.Count == 0) + // { + // if(model.HeaderCollectionParam is null) + // return; + // + // // TODO: ensure that AddHeaderCollection adds content + // source.AppendLine( + // $$""" + // global::Refit.RefitHelper.AddHeaderCollection({{requestName}}, {{model.HeaderCollectionParam}}); + // """); + // return; + // } + // + // // TODO: only emit if http method can have a body + // source.AppendLine( + // $$""" + // global::Refit.RefitHelper.SetContentForHeaders({{requestName}}, ); + // """); + // + // foreach (var methodHeader in model.Headers) + // { + // source.AppendLine( + // $$""" + // global::Refit.RefitHelper.AddHeader({{requestName}}, {{methodHeader.Key}}, {{methodHeader.Value}}); + // """); + // } + // + // foreach (var parameterHeader in model.HeaderParameters) + // { + // source.AppendLine( + // $$""" + // global::Refit.RefitHelper.AddHeader({{requestName}}, {{parameterHeader.HeaderKey}}, {{parameterHeader.Parameter}}); + // """); + // } + } + + static void WriteProperties(StringBuilder source, RefitBodyModel refitModel, string requestName) + { + // add refit settings properties + source.AppendLine( + $""" + + global::Refit.RefitHelper.WriteRefitSettingsProperties({requestName}, {SettingsExpression}); + """ + ); + + // add each property + foreach (var property in refitModel.Properties) + { + source.AppendLine( + $""" global::Refit.RefitHelper.WriteProperty({requestName}, "{property.Key}", {property.Parameter});""" + ); + } + + // TODO: implement add top level types + // TODO: what is a top level type?????? What was I talking about? + // TODO: need to pass down interface type name and create a proprety for the method info :( + // I could prolly create a static instance for the latter + source.AppendLine( + $""" + global::Refit.RefitHelper.AddTopLevelTypes({requestName}, null, null); + """ + ); + } + + static void WriteVersion(StringBuilder source, RefitBodyModel model, string requestName) + { + source.AppendLine( + $""" + global::Refit.RefitHelper.AddVersionToRequest({requestName}, {SettingsExpression}); + """ + ); + } + + static void WriteBuildUrl( + StringBuilder source, + RefitBodyModel model, + string requestName, + UniqueNameBuilder uniqueName + ) + { + // TODO: why is this assertion here + // Debug.Assert(model.UrlFragments.Count > 1); + if (model.UrlFragments.Count == 1 && model.QueryParameters.Count == 0) + { + Debug.Assert(model.UrlFragments[0] is ConstantFragmentModel); + var constant = model.UrlFragments[0] as ConstantFragmentModel; + + // TODO: emit static reusable uri for constant uris + // TODO: uri stripping logic could be improved + // TODO: consider base addresses with path and queries, does it break this + // TODO: do urls containing " break this? + // TODO: get queryUriFormat + source.AppendLine( + $$""" + var basePath = Client.BaseAddress.AbsolutePath == "/" ? string.Empty : Client.BaseAddress.AbsolutePath; + var uri = new UriBuilder(new Uri(global::Refit.RefitHelper.BaseUri, $"{basePath}{{constant!.Value}}")); + + {{requestName}}.RequestUri = new Uri( + uri.Uri.GetComponents(global::System.UriComponents.PathAndQuery, global::System.UriFormat.{{model.UriFormat.ToString()}}), + UriKind.Relative + ); + """ + ); + return; + } + + // TODO: uniqueName for vsb & RefitSettings + // add version to request + source.AppendLine( + $""" + + var vsb = new ValueStringBuilder(stackalloc char[256]); + vsb.Append(Client.BaseAddress.AbsolutePath == "/" ? string.Empty : Client.BaseAddress.AbsolutePath); + """ + ); + // TODO: add initial section to url + // TODO: add get static info + + foreach (var fragment in model.UrlFragments) + { + if (fragment is ConstantFragmentModel constant) + { + source.AppendLine( + $$""" + vsb.Append("{{constant.Value}}"); + """ + ); + continue; + } + + if (fragment is DynamicFragmentModel dynamic) + { + source.AppendLine( + $$""" + global::Refit.RefitHelper.AddUrlFragment(ref vsb, {{dynamic.Access}}, {{SettingsExpression}}, typeof({{dynamic.TypeDeclaration}})); + """ + ); + continue; + } + + if (fragment is DynamicRoundTripFragmentModel roundTrip) + { + source.AppendLine( + $$""" + global::Refit.RefitHelper.AddRoundTripUrlFragment(ref vsb, {{SettingsExpression}}, typeof({{roundTrip.TypeDeclaration}})); + """ + ); + continue; + } + + if (fragment is DynamicPropertyFragmentModel dynamicProperty) + { + source.AppendLine( + $$""" + global::Refit.RefitHelper.AddPropertyFragment(ref vsb, {{SettingsExpression}}, typeof({{dynamicProperty.TypeDeclaration}})); + """ + ); + continue; + } + } + + if (model.QueryParameters.Count > 0) + { + source.AppendLine(""" vsb.Append("?");"""); + } + + for (int i = 0; i < model.QueryParameters.Count; i++) + { + var query = model.QueryParameters[i]; + + if (i > 0) + { + source.AppendLine(""" vsb.Append("&");"""); + } + + // TODO: create overload for a default or non existent QueryAttribute? + // TODO: escape params + source.AppendLine( + $$""" + global::Refit.RefitHelper.AddQueryObject(ref vsb, {{SettingsExpression}}, "{{query.Parameter}}", {{query.Parameter}}); + """ + ); + } + + source.AppendLine( + $$""" + + var uri = new UriBuilder(new Uri(global::Refit.RefitHelper.BaseUri, vsb.ToString())); + + {{requestName}}.RequestUri = new Uri( + uri.Uri.GetComponents(global::System.UriComponents.PathAndQuery, global::System.UriFormat.{{model.UriFormat.ToString()}}), + UriKind.Relative + ); + """ + ); + } + + static void WriteReturn( + StringBuilder source, + MethodModel model, + UniqueNameBuilder uniqueNames, + string requestExpression + ) + { + var refitModel = model.RefitBody!; + + // TODO: return type needs to support the inner type of Task + var responseExpression = uniqueNames.New("response"); + var cacellationTokenExpression = refitModel.CancellationTokenParam ?? "default"; + + if (model.ReturnTypeMetadata == ReturnTypeInfo.AsyncVoid) + { + source.AppendLine( + $""" + await global::Refit.RefitHelper.SendVoidTaskAsync({requestExpression}, Client, {SettingsExpression}, {cacellationTokenExpression}); + """ + ); + } + else if ( + model.ReturnTypeMetadata == ReturnTypeInfo.AsyncResult + && !refitModel.IsApiResponse + ) + { + source.AppendLine( + $""" + return await global::Refit.RefitHelper.SendTaskResultAsync<{refitModel.GenericInnerReturnType}>({requestExpression}, Client, {SettingsExpression}, {WriteBool( + refitModel.BodyParameter?.Buffered + )}, {cacellationTokenExpression}); + """ + ); + } + else if (model.ReturnTypeMetadata == ReturnTypeInfo.AsyncResult && refitModel.IsApiResponse) + { + source.AppendLine( + $""" + return await global::Refit.RefitHelper.SendTaskIApiResultAsync<{refitModel.GenericInnerReturnType}, {refitModel.DeserializedResultType}>({requestExpression}, Client, {SettingsExpression}, {WriteBool( + refitModel.BodyParameter?.Buffered + )}, {cacellationTokenExpression}); + """ + ); + } + else + { + // TODO: this should be an extracted + // TODO: is ReturnTypeMetadata broken? + // TODO: if return insert throw? should this be done in refitmodel and just emit fail + // TODO: use uniqueNameBuilder on all identifiers + // TODO: emit TaskToObservab + source.AppendLine( + $$""" + return new global::Refit.RequestBuilderImplementation.TaskToObservable<{{refitModel.GenericInnerReturnType}}>(ct => + { + """ + ); + + var ctToken = refitModel.CancellationTokenParam is null ? "ct" : "cts"; + if (refitModel.CancellationTokenParam is not null) + { + source.AppendLine( + $$""" + var cts = global::System.Threading.CancellationTokenSource.CreateLinkedTokenSource(methodCt, {{refitModel.CancellationTokenParam}}); + """ + ); + } + + if (refitModel.IsApiResponse) + { + source.AppendLine( + $""" + return global::Refit.RefitHelper.SendTaskIApiResultAsync<{refitModel.GenericInnerReturnType}, {refitModel.DeserializedResultType}>({requestExpression}, Client, {SettingsExpression}, {WriteBool( + refitModel.BodyParameter?.Buffered + )}, {ctToken}); + """ + ); + } + else + { + source.AppendLine( + $""" + return global::Refit.RefitHelper.SendTaskResultAsync<{refitModel.GenericInnerReturnType}>({requestExpression}, Client, {SettingsExpression}, {WriteBool( + refitModel.BodyParameter?.Buffered + )}, {ctToken}); + """ + ); + } + + source.AppendLine(""" });"""); + } + } + + static string WriteBool(bool? value) + { + return value is null ? "false" + : value.Value ? "true" + : "false"; + } +} diff --git a/InterfaceStubGenerator.Shared/Emitter.cs b/InterfaceStubGenerator.Shared/Emitter.cs index a1305fb48..28c568e70 100644 --- a/InterfaceStubGenerator.Shared/Emitter.cs +++ b/InterfaceStubGenerator.Shared/Emitter.cs @@ -18,55 +18,106 @@ Action addSource var attributeText = $$""" - #pragma warning disable - namespace {{model.RefitInternalNamespace}} - { - [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] - [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] - [global::System.AttributeUsage (global::System.AttributeTargets.Class | global::System.AttributeTargets.Struct | global::System.AttributeTargets.Enum | global::System.AttributeTargets.Constructor | global::System.AttributeTargets.Method | global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Event | global::System.AttributeTargets.Interface | global::System.AttributeTargets.Delegate)] - sealed class PreserveAttribute : global::System.Attribute - { - // - // Fields - // - public bool AllMembers; - - public bool Conditional; - } - } - #pragma warning restore - - """; + #pragma warning disable + namespace {{model.RefitInternalNamespace}} + { + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + [global::System.AttributeUsage (global::System.AttributeTargets.Class | global::System.AttributeTargets.Struct | global::System.AttributeTargets.Enum | global::System.AttributeTargets.Constructor | global::System.AttributeTargets.Method | global::System.AttributeTargets.Property | global::System.AttributeTargets.Field | global::System.AttributeTargets.Event | global::System.AttributeTargets.Interface | global::System.AttributeTargets.Delegate)] + sealed class PreserveAttribute : global::System.Attribute + { + // + // Fields + // + public bool AllMembers; + + public bool Conditional; + } + } + #pragma warning restore + + """; // add the attribute text addSource("PreserveAttribute.g.cs", SourceText.From(attributeText, Encoding.UTF8)); var generatedClassText = $$""" - #pragma warning disable - namespace Refit.Implementation - { - - /// - [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] - [global::System.Diagnostics.DebuggerNonUserCode] - [{{model.PreserveAttributeDisplayName}}] - [global::System.Reflection.Obfuscation(Exclude=true)] - [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] - internal static partial class Generated - { - #if NET5_0_OR_GREATER - [System.Runtime.CompilerServices.ModuleInitializer] - [System.Diagnostics.CodeAnalysis.DynamicDependency(System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.All, typeof(global::Refit.Implementation.Generated))] - public static void Initialize() - { - } - #endif - } - } - #pragma warning restore - - """; + #pragma warning disable + namespace Refit.Implementation + { + + /// + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.Diagnostics.DebuggerNonUserCode] + [{{model.PreserveAttributeDisplayName}}] + [global::System.Reflection.Obfuscation(Exclude=true)] + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + internal static partial class Generated + { + #if NET5_0_OR_GREATER + [System.Runtime.CompilerServices.ModuleInitializer] + [System.Diagnostics.CodeAnalysis.DynamicDependency(System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.All, typeof(global::Refit.Implementation.Generated))] + public static void Initialize() + { + } + #endif + } + } + #pragma warning restore + + """; addSource("Generated.g.cs", SourceText.From(generatedClassText, Encoding.UTF8)); + + // TODO: are the attributes correct, should this be an actual file or class we copy? + // TODO: this should eventually emit the helper logic, + // until then I'm going to cheat and write the code in the main library, + // this is a bad idea but makes my life easier + // TODO: emit ValueStringBuilder + var generatedHelpers = $$""" + + #pragma warning disable + namespace Refit.Implementation + { + + /// + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + [global::System.Diagnostics.DebuggerNonUserCode] + [{{model.PreserveAttributeDisplayName}}] + [global::System.Reflection.Obfuscation(Exclude=true)] + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + internal static partial class Generated + { + #if NET5_0_OR_GREATER + [System.Runtime.CompilerServices.ModuleInitializer] + [System.Diagnostics.CodeAnalysis.DynamicDependency(System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.All, typeof(global::Refit.Implementation.Generated))] + public static void Initialize() + { + } + #endif + internal static class ____GeneratedHelpers + { + public static void WriteRefitSettingsProperties(global::System.Net.Http.HttpRequestMessage request, global::Refit.RefitSettings settings) + { + // Add RefitSetting.HttpRequestMessageOptions to the HttpRequestMessage + if (settings.HttpRequestMessageOptions != null) + { + foreach (var p in settings.HttpRequestMessageOptions) + { + #if NET6_0_OR_GREATER + ret.Options.Set(new HttpRequestOptionsKey(p.Key), p.Value); + #else + ret.Properties.Add(p); + #endif + } + } + } + } + } + } + #pragma warning restore + + """; + // addSource("GeneratedHelper.g.cs", SourceText.From(generatedHelpers, Encoding.UTF8)); } public static string EmitInterface(InterfaceModel model) @@ -102,12 +153,14 @@ partial class {model.Ns}{model.ClassDeclaration} /// public global::System.Net.Http.HttpClient Client {{ get; }} readonly global::Refit.IRequestBuilder requestBuilder; + readonly global::Refit.RefitSettings settings; /// public {model.Ns}{model.ClassSuffix}(global::System.Net.Http.HttpClient client, global::Refit.IRequestBuilder requestBuilder) {{ Client = client; this.requestBuilder = requestBuilder; + this.settings = requestBuilder.Settings; }} " ); @@ -176,46 +229,74 @@ UniqueNameBuilder uniqueNames ReturnTypeInfo.AsyncVoid => (true, "await (", ").ConfigureAwait(false)"), ReturnTypeInfo.AsyncResult => (true, "return await (", ").ConfigureAwait(false)"), ReturnTypeInfo.Return => (false, "return ", ""), - _ - => throw new ArgumentOutOfRangeException( - nameof(methodModel.ReturnTypeMetadata), - methodModel.ReturnTypeMetadata, - "Unsupported value." - ) + _ => throw new ArgumentOutOfRangeException( + nameof(methodModel.ReturnTypeMetadata), + methodModel.ReturnTypeMetadata, + "Unsupported value." + ), }; WriteMethodOpening(source, methodModel, !isTopLevel, isAsync); - // Build the list of args for the array - var argArray = methodModel - .Parameters.AsArray() - .Select(static param => $"@{param.MetadataName}") - .ToArray(); - - // List of generic arguments - var genericArray = methodModel - .Constraints.AsArray() - .Select(static typeParam => $"typeof({typeParam.DeclaredName})") - .ToArray(); - - var argumentsArrayString = - argArray.Length == 0 - ? "global::System.Array.Empty()" - : $"new object[] {{ {string.Join(", ", argArray)} }}"; + // TODO: unique name builder + foreach (var param in methodModel.Parameters) + { + uniqueNames.Reserve(param.MetadataName); + } - var genericString = - genericArray.Length > 0 - ? $", new global::System.Type[] {{ {string.Join(", ", genericArray)} }}" - : string.Empty; + if (methodModel.RefitBody is null) + { + // Build the list of args for the array + var argArray = methodModel + .Parameters.AsArray() + .Select(static param => $"@{param.MetadataName}") + .ToArray(); + + // List of generic arguments + var genericArray = methodModel + .Constraints.AsArray() + .Select(static typeParam => $"typeof({typeParam.DeclaredName})") + .ToArray(); + + var argumentsArrayString = + argArray.Length == 0 + ? "global::System.Array.Empty()" + : $"new object[] {{ {string.Join(", ", argArray)} }}"; + + var genericString = + genericArray.Length > 0 + ? $", new global::System.Type[] {{ {string.Join(", ", genericArray)} }}" + : string.Empty; + + if (methodModel.Error is not null) + { + source.AppendLine( + @$" + // {methodModel.Error.Replace("\r", "").Replace("\n", "")} + " + ); + } - source.Append( - @$" + source.Append( + @$" var ______arguments = {argumentsArrayString}; var ______func = requestBuilder.BuildRestResultFuncForMethod(""{methodModel.Name}"", {parameterTypesExpression}{genericString} ); {@return}({returnType})______func(this.Client, ______arguments){configureAwait}; " - ); + ); + } + else + { + try + { + EmitRefitBody.WriteRefitBody(source, methodModel, uniqueNames); + } + catch (Exception e) + { + source.AppendLine($"// {e.ToString().Replace("\r", "").Replace("\n", "")}"); + } + } WriteMethodClosing(source); } @@ -242,12 +323,12 @@ private static void WriteDisposableMethod(StringBuilder source) """ - /// - void global::System.IDisposable.Dispose() - { - Client?.Dispose(); - } - """ + /// + void global::System.IDisposable.Dispose() + { + Client?.Dispose(); + } + """ ); } @@ -275,8 +356,8 @@ UniqueNameBuilder uniqueNames $$""" - private static readonly global::System.Type[] {{typeParameterFieldName}} = new global::System.Type[] {{{types}} }; - """ + private static readonly global::System.Type[] {{typeParameterFieldName}} = new global::System.Type[] {{{types}} }; + """ ); return typeParameterFieldName; diff --git a/InterfaceStubGenerator.Shared/ITypeSymbolExtensions.cs b/InterfaceStubGenerator.Shared/ITypeSymbolExtensions.cs index 8ef02f60b..2b1d9890e 100644 --- a/InterfaceStubGenerator.Shared/ITypeSymbolExtensions.cs +++ b/InterfaceStubGenerator.Shared/ITypeSymbolExtensions.cs @@ -3,11 +3,15 @@ using System.Linq; using System.Text; using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; namespace Refit.Generator { static class ITypeSymbolExtensions { + internal static IEnumerable GetFields(this ITypeSymbol symbol) => + symbol.GetMembers().OfType(); + public static IEnumerable GetBaseTypesAndThis(this ITypeSymbol? type) { var current = type; @@ -43,5 +47,214 @@ public static bool InheritsFromOrEquals(this ITypeSymbol type, ITypeSymbol baseT return type.GetBaseTypesAndThis() .Any(t => t.Equals(baseType, SymbolEqualityComparer.Default)); } + + public static IEnumerable GetAttributesFor( + this ISymbol type, + ITypeSymbol attributeType + ) + { + return type.GetAttributes() + .Where(t => t.AttributeClass!.InheritsFromOrEquals(attributeType)); + } + + // TODO: most of this stuff isnt needed, I tried using this to supoort HttpMethodAttribute inheritance + // TODO: pretty sure custom HttpMethodAttributes will break the generator, don't think I can prevent his + public static T MapToType(this AttributeData attributeData, WellKnownTypes knownTypes) + { + T attribute; + var dataType = typeof(T); + + var syntax = (AttributeSyntax?)attributeData.ApplicationSyntaxReference?.GetSyntax(); + var syntaxArguments = + (IReadOnlyList?)syntax?.ArgumentList?.Arguments + ?? new AttributeArgumentSyntax[ + attributeData.ConstructorArguments.Length + attributeData.NamedArguments.Length + ]; + + if ( + attributeData.AttributeConstructor != null + && attributeData.ConstructorArguments.Length > 0 + ) + { + attribute = (T) + Activator.CreateInstance( + typeof(T), + attributeData.GetActualConstructorParams().ToArray() + ); + } + else + { + attribute = (T)Activator.CreateInstance(typeof(T)); + } + foreach (var p in attributeData.NamedArguments) + { + typeof(T).GetField(p.Key).SetValue(attribute, p.Value.Value); + } + + var syntaxIndex = attributeData.ConstructorArguments.Length; + + var propertiesByName = dataType + .GetProperties() + .GroupBy(x => x.Name) + .ToDictionary(x => x.Key, x => x.First()); + foreach (var namedArgument in attributeData.NamedArguments) + { + if (!propertiesByName.TryGetValue(namedArgument.Key, out var prop)) + throw new InvalidOperationException( + $"Could not get property {namedArgument.Key} of attribute " + ); + + var value = BuildArgumentValue( + namedArgument.Value, + prop.PropertyType, + syntaxArguments[syntaxIndex], + knownTypes + ); + prop.SetValue(attribute, value); + syntaxIndex++; + } + + return attribute; + } + + private static object? BuildArgumentValue( + TypedConstant arg, + Type targetType, + AttributeArgumentSyntax? syntax, + WellKnownTypes? knownTypes + ) + { + return arg.Kind switch + { + _ + when ( + targetType == typeof(AttributeValue?) + || targetType == typeof(AttributeValue) + ) + && syntax != null => new AttributeValue(arg, syntax.Expression), + _ when arg.IsNull => null, + TypedConstantKind.Enum => GetEnumValue(arg, targetType), + TypedConstantKind.Array => BuildArrayValue(arg, targetType, knownTypes), + TypedConstantKind.Primitive => arg.Value, + TypedConstantKind.Type when targetType == typeof(ITypeSymbol) => arg.Value, + _ => throw new ArgumentOutOfRangeException( + $"{nameof(WellKnownTypes)} does not support constructor arguments of kind {arg.Kind.ToString()} or cannot convert it to {targetType}" + ), + }; + } + + public readonly record struct AttributeValue( + TypedConstant ConstantValue, + ExpressionSyntax Expression + ); + + private static object?[] BuildArrayValue( + TypedConstant arg, + Type targetType, + WellKnownTypes? symbolAccessor + ) + { + if ( + !targetType.IsGenericType + || targetType.GetGenericTypeDefinition() != typeof(IReadOnlyCollection<>) + ) + throw new InvalidOperationException( + $"{nameof(IReadOnlyCollection)} is the only supported array type" + ); + + var elementTargetType = targetType.GetGenericArguments()[0]; + return arg + .Values.Select(x => BuildArgumentValue(x, elementTargetType, null, symbolAccessor)) + .ToArray(); + } + + private static object? GetEnumValue(TypedConstant arg, Type targetType) + { + if (arg.Value == null) + return null; + + var enumRoslynType = arg.Type ?? throw new InvalidOperationException("Type is null"); + if (targetType == typeof(IFieldSymbol)) + return enumRoslynType.GetFields().First(f => Equals(f.ConstantValue, arg.Value)); + + if ( + targetType.IsConstructedGenericType + && targetType.GetGenericTypeDefinition() == typeof(Nullable<>) + ) + { + targetType = Nullable.GetUnderlyingType(targetType)!; + } + + return Enum.ToObject(targetType, arg.Value); + } + + public static IEnumerable GetActualConstructorParams( + this AttributeData attributeData + ) + { + foreach (var arg in attributeData.ConstructorArguments) + { + if (arg.Kind == TypedConstantKind.Array) + { + // Assume they are strings, but the array that we get from this + // should actually be of type of the objects within it, be it strings or ints + // This is definitely possible with reflection, I just don't know how exactly. + yield return arg.Values.Select(a => a.Value).OfType().ToArray(); + } + else + { + yield return arg.Value; + } + } + } + + public static TResult? AccessFirstOrDefault( + this ISymbol symbol, + WellKnownTypes knownTypes + ) + where TAttribute : Attribute + where TResult : class + { + var attributeSymbol = knownTypes.Get(); + var attribute = symbol.GetAttributesFor(attributeSymbol).FirstOrDefault(); + return attribute?.MapToType(knownTypes); + } + + public static TResult? AccessFirstOrDefault( + this ISymbol symbol, + INamedTypeSymbol attributeSymbol, + WellKnownTypes knownTypes + ) + where TResult : class + { + var attribute = symbol.GetAttributesFor(attributeSymbol).FirstOrDefault(); + return attribute?.MapToType(knownTypes); + } + + // public static IEnumerable Access(this ISymbol symbol, WellKnownTypes knownTypes) + // where TAttribute : Attribute + // { + // var attributeSymbol = knownTypes.Get(); + // var attributes = symbol.GetAttributesFor(attributeSymbol); + // + // foreach (var attribute in attributes) + // { + // yield return attribute.MapToType(knownTypes); + // } + // } + + public static IEnumerable Access( + this ISymbol symbol, + INamedTypeSymbol attributeSymbol, + WellKnownTypes knownTypes + ) + { + var attributes = symbol.GetAttributesFor(attributeSymbol); + + foreach (var attribute in attributes) + { + yield return attribute.MapToType(knownTypes); + } + } } } diff --git a/InterfaceStubGenerator.Shared/InterfaceStubGenerator.Shared.projitems b/InterfaceStubGenerator.Shared/InterfaceStubGenerator.Shared.projitems index 8eab43eb8..9b3df81d0 100644 --- a/InterfaceStubGenerator.Shared/InterfaceStubGenerator.Shared.projitems +++ b/InterfaceStubGenerator.Shared/InterfaceStubGenerator.Shared.projitems @@ -9,7 +9,23 @@ InterfaceStubGenerator.Shared + + + + + + + + + + + + + + + + @@ -23,5 +39,7 @@ + + \ No newline at end of file diff --git a/InterfaceStubGenerator.Shared/Models/InterfaceModel.cs b/InterfaceStubGenerator.Shared/Models/InterfaceModel.cs index aa1a875c5..a100d0d51 100644 --- a/InterfaceStubGenerator.Shared/Models/InterfaceModel.cs +++ b/InterfaceStubGenerator.Shared/Models/InterfaceModel.cs @@ -21,5 +21,5 @@ internal enum Nullability : byte { Enabled, Disabled, - None + None, } diff --git a/InterfaceStubGenerator.Shared/Models/MethodModel.cs b/InterfaceStubGenerator.Shared/Models/MethodModel.cs index 513a8b93b..7b2c9f82d 100644 --- a/InterfaceStubGenerator.Shared/Models/MethodModel.cs +++ b/InterfaceStubGenerator.Shared/Models/MethodModel.cs @@ -7,12 +7,105 @@ internal sealed record MethodModel( string DeclaredMethod, ReturnTypeInfo ReturnTypeMetadata, ImmutableEquatableArray Parameters, - ImmutableEquatableArray Constraints + ImmutableEquatableArray Constraints, + RefitBodyModel? RefitBody, + string? Error ); +// TODO: maybe add RXFunc? +// TODO: Add inner return type aka T in Task internal enum ReturnTypeInfo : byte { Return, AsyncVoid, - AsyncResult + AsyncResult, } + +internal record ThrowError(string errorExpression); + +// TODO: rename generic inner +internal sealed record RefitBodyModel( + HttpMethod HttpMethod, + string? GenericInnerReturnType, + string DeserializedResultType, + bool IsApiResponse, + string? CancellationTokenParam, + string? MultipartBoundary, + ImmutableEquatableArray UrlFragments, + ImmutableEquatableArray HeaderPs, + ImmutableEquatableArray Headers, + ImmutableEquatableArray HeaderParameters, + string? HeaderCollectionParam, + ImmutableEquatableArray AuthoriseParameters, + ImmutableEquatableArray Properties, + ImmutableEquatableArray QueryParameters, + BodyModel? BodyParameter, + UriFormat UriFormat +); + +internal record struct HeaderModel(string Key, string Value); + +internal record struct HeaderParameterModel(string Parameter, string HeaderKey); + +internal record struct PropertyModel(string Parameter, string Key); + +internal record struct AuthoriseModel(string Parameter, string Scheme); + +internal record ConstantFragmentModel(string Value) : ParameterFragment; + +internal record DynamicFragmentModel(string Access, int ParameterIndex, string TypeDeclaration) + : ParameterFragment; + +internal record DynamicRoundTripFragmentModel( + string Access, + int ParameterIndex, + string TypeDeclaration +) : ParameterFragment; + +internal record DynamicPropertyFragmentModel( + string Access, + string PropertyName, + string ContainingType, + string TypeDeclaration +) : ParameterFragment; + +internal record QueryModel( + string Parameter, + int ParameterIndex, + Refit.Generator.Configuration.CollectionFormat CollectionFormat, + string Delimiter, + string? Prefix, + string? Format +); + +internal record BodyModel( + string Parameter, + bool Buffered, + Refit.Generator.Configuration.BodySerializationMethod SerializationMethod +); + +// TODO: decide how to handle enum types +internal record HeaderPsModel( + HeaderType Type, + HeaderModel? Static, + HeaderParameterModel? Dynamic, + string? Collection, + AuthoriseModel? Authorise +); + +internal enum BodyParameterType +{ + Content, + Stream, + String, +} + +internal enum HeaderType +{ + Static, + Dynamic, + Collection, + Authorise, +} + +internal record ParameterFragment { } diff --git a/InterfaceStubGenerator.Shared/Models/TypeConstraint.cs b/InterfaceStubGenerator.Shared/Models/TypeConstraint.cs index da386ec23..8d87e21cd 100644 --- a/InterfaceStubGenerator.Shared/Models/TypeConstraint.cs +++ b/InterfaceStubGenerator.Shared/Models/TypeConstraint.cs @@ -15,5 +15,5 @@ internal enum KnownTypeConstraint : byte Unmanaged = 1 << 1, Struct = 1 << 2, NotNull = 1 << 3, - New = 1 << 4 + New = 1 << 4, } diff --git a/InterfaceStubGenerator.Shared/Parser.cs b/InterfaceStubGenerator.Shared/Parser.cs index 2d1dd0515..5407a5b51 100644 --- a/InterfaceStubGenerator.Shared/Parser.cs +++ b/InterfaceStubGenerator.Shared/Parser.cs @@ -186,6 +186,8 @@ sealed class PreserveAttribute : global::System.Attribute ) ); + var wellKnownTypes = new WellKnownTypes(compilation); + // get the newly bound attribute var preserveAttributeSymbol = compilation.GetTypeByMetadataName( $"{refitInternalNamespace}.PreserveAttribute" @@ -222,7 +224,8 @@ sealed class PreserveAttribute : global::System.Attribute disposableInterfaceSymbol, httpMethodBaseAttributeSymbol, supportsNullable, - interfaceToNullableEnabledMap[group.Key] + interfaceToNullableEnabledMap[group.Key], + wellKnownTypes ); interfaceModels.Add(interfaceModel); @@ -245,7 +248,8 @@ static InterfaceModel ProcessInterface( ISymbol disposableInterfaceSymbol, INamedTypeSymbol httpMethodBaseAttributeSymbol, bool supportsNullable, - bool nullableEnabled + bool nullableEnabled, + WellKnownTypes knownTypes ) { // Get the class name with the type parameters, then remove the namespace @@ -287,10 +291,9 @@ bool nullableEnabled .ToList(); // Look for disposable - var disposeMethod = derivedMethods.Find( - m => - m.ContainingType?.Equals(disposableInterfaceSymbol, SymbolEqualityComparer.Default) - == true + var disposeMethod = derivedMethods.Find(m => + m.ContainingType?.Equals(disposableInterfaceSymbol, SymbolEqualityComparer.Default) + == true ); if (disposeMethod != null) { @@ -315,11 +318,11 @@ bool nullableEnabled // Handle Refit Methods var refitMethodsArray = refitMethods - .Select(m => ParseMethod(m, true)) + .Select(m => ParseMethod(m, true, true, knownTypes)) .ToImmutableEquatableArray(); var derivedRefitMethodsArray = refitMethods .Concat(derivedRefitMethods) - .Select(m => ParseMethod(m, false)) + .Select(m => ParseMethod(m, false, true, knownTypes)) .ToImmutableEquatableArray(); // Handle non-refit Methods that aren't static or properties or have a method body @@ -334,7 +337,7 @@ bool nullableEnabled ) // If an interface method has a body, it won't be abstract continue; - nonRefitMethodModelList.Add(ParseNonRefitMethod(method, diagnostics)); + nonRefitMethodModelList.Add(ParseNonRefitMethod(method, diagnostics, knownTypes)); } var nonRefitMethodModels = nonRefitMethodModelList.ToImmutableEquatableArray(); @@ -367,7 +370,8 @@ bool nullableEnabled private static MethodModel ParseNonRefitMethod( IMethodSymbol methodSymbol, - List diagnostics + List diagnostics, + WellKnownTypes knownTypes ) { // report invalid error diagnostic @@ -382,7 +386,7 @@ List diagnostics diagnostics.Add(diagnostic); } - return ParseMethod(methodSymbol, false); + return ParseMethod(methodSymbol, false, false, knownTypes); } private static bool IsRefitMethod( @@ -403,12 +407,8 @@ bool isOverrideOrExplicitImplementation { // Need to loop over the constraints and create them return typeParameters - .Select( - typeParameter => - ParseConstraintsForTypeParameter( - typeParameter, - isOverrideOrExplicitImplementation - ) + .Select(typeParameter => + ParseConstraintsForTypeParameter(typeParameter, isOverrideOrExplicitImplementation) ) .ToImmutableEquatableArray(); } @@ -444,9 +444,8 @@ bool isOverrideOrExplicitImplementation if (!isOverrideOrExplicitImplementation) { constraints = typeParameter - .ConstraintTypes.Select( - typeConstraint => - typeConstraint.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + .ConstraintTypes.Select(typeConstraint => + typeConstraint.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) ) .ToImmutableEquatableArray(); } @@ -489,7 +488,12 @@ private static bool ContainsTypeParameter(ITypeSymbol symbol) return false; } - private static MethodModel ParseMethod(IMethodSymbol methodSymbol, bool isImplicitInterface) + private static MethodModel ParseMethod( + IMethodSymbol methodSymbol, + bool isImplicitInterface, + bool isRefit, + WellKnownTypes knownTypes + ) { var returnType = methodSymbol.ReturnType.ToDisplayString( SymbolDisplayFormat.FullyQualifiedFormat @@ -510,6 +514,26 @@ private static MethodModel ParseMethod(IMethodSymbol methodSymbol, bool isImplic var constraints = GenerateConstraints(methodSymbol.TypeParameters, !isImplicitInterface); + RefitBodyModel refitMethodModel = null; + string? error = null; + if (isRefit) + { + try + { + var restMethodSymbolInternal = new RestMethodSymbolInternal( + methodSymbol, + knownTypes + ); + refitMethodModel = restMethodSymbolInternal.ToRefitBodyModel(); + } + catch (Exception e) + { + // TODO: remove debug stuff + error = e.ToString(); + Console.WriteLine(error); + } + } + return new MethodModel( methodSymbol.Name, returnType, @@ -517,7 +541,9 @@ private static MethodModel ParseMethod(IMethodSymbol methodSymbol, bool isImplic declaredMethod, returnTypeInfo, parameters, - constraints + constraints, + refitMethodModel, + error ); } } diff --git a/InterfaceStubGenerator.Shared/RestMethodSymbolInternal.cs b/InterfaceStubGenerator.Shared/RestMethodSymbolInternal.cs new file mode 100644 index 000000000..80c07e88b --- /dev/null +++ b/InterfaceStubGenerator.Shared/RestMethodSymbolInternal.cs @@ -0,0 +1,1610 @@ +using System.Diagnostics; +using System.Reflection; +using System.Text.RegularExpressions; +using DefaultNamespace; +using Microsoft.CodeAnalysis; +using Refit.Generator.Configuration; + +namespace Refit.Generator; + +/// +/// RestMethodInfo +/// +public record RestMethodSymbol( + string Name, + Type HostingType, + IMethodSymbol MethodSymbol, + string RelativePath, + ITypeSymbol ReturnType +); + +// TODO: most files in Configuration are not needed +[DebuggerDisplay("{MethodInfo}")] +internal class RestMethodSymbolInternal +{ + static readonly QueryConfiguration DefaultQueryAttribute = new(); + + private int HeaderCollectionParameterIndex { get; set; } + public string Name { get; set; } + + // public Type Type { get; set; } + public IMethodSymbol MethodInfo { get; set; } + public HttpMethod HttpMethod { get; set; } + public string RelativePath { get; set; } + public bool IsMultipart { get; private set; } + public string MultipartBoundary { get; private set; } + + // TODO: ensure that an off by one error does not occur because of cancellation token + public IParameterSymbol? CancellationToken { get; set; } + public UriFormat QueryUriFormat { get; set; } + public Dictionary Headers { get; set; } + public Dictionary HeaderParameterMap { get; set; } + public Dictionary PropertyParameterMap { get; set; } + public Tuple? BodyParameterInfo { get; set; } + public Tuple? AuthorizeParameterInfo { get; set; } + public Dictionary QueryParameterMap { get; set; } + public List QueryModels { get; set; } + public Dictionary> AttachmentNameMap { get; set; } + public IParameterSymbol[] ParameterSymbolArray { get; set; } + public Dictionary ParameterMap { get; set; } + public List PathFragments { get; set; } + public ITypeSymbol ReturnType { get; set; } + public ITypeSymbol ReturnResultType { get; set; } + public ITypeSymbol DeserializedResultType { get; set; } + + // TODO: logic associated with RefitSettings has to be moved into runtime + // public RefitSettings RefitSettings { get; set; } + public bool IsApiResponse { get; } + public bool ShouldDisposeResponse { get; private set; } + + static readonly Regex ParameterRegex = new(@"{(.*?)}"); + static readonly Regex ParameterRegex2 = new(@"{(([^/?\r\n])*?)}"); + static readonly HttpMethod PatchMethod = new("PATCH"); + +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + public RestMethodSymbolInternal(IMethodSymbol methodSymbol, WellKnownTypes knownTypes) +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + { + // RefitSettings = refitSettings ?? new RefitSettings(); + // Type = targetInterface ?? throw new ArgumentNullException(nameof(targetInterface)); + Name = methodSymbol.Name; + MethodInfo = methodSymbol ?? throw new ArgumentNullException(nameof(methodSymbol)); + + var hma = GetHttpMethod(methodSymbol, knownTypes); + + HttpMethod = hma.Method; + RelativePath = hma.Path; + + var multiPartSymbol = knownTypes.TryGet("Refit.MultipartAttribute"); + var multiPartsAttribute = methodSymbol.AccessFirstOrDefault( + multiPartSymbol, + knownTypes + ); + IsMultipart = multiPartsAttribute is not null; + + // TODO: default boundary + // MultipartBoundary = IsMultipart + // ? multiPartsAttribute?.BoundaryText + // ?? new MultipartAttribute().BoundaryText + // : null; + MultipartBoundary = IsMultipart + ? multiPartsAttribute?.BoundaryText ?? "----MyGreatBoundary" + : null; + + VerifyUrlPathIsSane(RelativePath); + DetermineReturnTypeInfo(methodSymbol, knownTypes); + DetermineIfResponseMustBeDisposed(knownTypes); + + // Exclude cancellation token parameters from this list + var cancellationToken = knownTypes.Get(); + ParameterSymbolArray = methodSymbol + .Parameters.Where(p => !SymbolEqualityComparer.Default.Equals(cancellationToken, p)) + .ToArray(); + (ParameterMap, PathFragments) = BuildParameterMap2( + RelativePath, + ParameterSymbolArray, + knownTypes + ); + BodyParameterInfo = FindBodyParameter( + ParameterSymbolArray, + IsMultipart, + hma.Method, + knownTypes + ); + AuthorizeParameterInfo = FindAuthorizationParameter(ParameterSymbolArray, knownTypes); + + // TODO: make pseudo enum header to represent the 3 types of headers + // initialise the same way refit does + Headers = ParseHeaders(methodSymbol, knownTypes); + HeaderParameterMap = BuildHeaderParameterMap(ParameterSymbolArray, knownTypes); + HeaderCollectionParameterIndex = RestMethodSymbolInternal.GetHeaderCollectionParameterIndex( + ParameterSymbolArray, + knownTypes + ); + PropertyParameterMap = BuildRequestPropertyMap(ParameterSymbolArray, knownTypes); + + // get names for multipart attachments + Dictionary>? attachmentDict = null; + if (IsMultipart) + { + for (var i = 0; i < ParameterSymbolArray.Length; i++) + { + if ( + ParameterMap.ContainsKey(i) + || HeaderParameterMap.ContainsKey(i) + || PropertyParameterMap.ContainsKey(i) + || HeaderCollectionAt(i) + ) + { + continue; + } + + var attachmentName = GetAttachmentNameForParameter( + ParameterSymbolArray[i], + knownTypes + ); + if (attachmentName == null) + continue; + + attachmentDict ??= []; + attachmentDict[i] = Tuple.Create( + attachmentName, + GetUrlNameForParameter(ParameterSymbolArray[i], knownTypes) + ); + } + } + + AttachmentNameMap = attachmentDict ?? new Dictionary>(); + + Dictionary? queryDict = null; + for (var i = 0; i < ParameterSymbolArray.Length; i++) + { + if ( + ParameterMap.ContainsKey(i) + || HeaderParameterMap.ContainsKey(i) + || PropertyParameterMap.ContainsKey(i) + || HeaderCollectionAt(i) + || (BodyParameterInfo != null && BodyParameterInfo.Item3 == i) + || (AuthorizeParameterInfo != null && AuthorizeParameterInfo.Item2 == i) + ) + { + continue; + } + + queryDict ??= []; + queryDict.Add(i, GetUrlNameForParameter(ParameterSymbolArray[i], knownTypes)); + } + + QueryParameterMap = queryDict ?? new Dictionary(); + + var ctParamEnumerable = methodSymbol + .Parameters.Where(p => SymbolEqualityComparer.Default.Equals(p.Type, cancellationToken)) + .ToArray(); + if (ctParamEnumerable.Length > 1) + { + throw new ArgumentException( + $"Argument list to method \"{methodSymbol.Name}\" can only contain a single CancellationToken" + ); + } + + QueryModels = BuildQueryParameterList(knownTypes); + + CancellationToken = ctParamEnumerable.FirstOrDefault(); + + var queryUriAttribute = knownTypes.TryGet("Refit.QueryUriFormatAttribute")!; + QueryUriFormat = + methodSymbol + .AccessFirstOrDefault(queryUriAttribute, knownTypes) + ?.UriFormat ?? UriFormat.UriEscaped; + + var apiResponse = knownTypes.TryGet("Refit.ApiResponse`1"); + var unboundIApiResponse = knownTypes.TryGet("Refit.IApiResponse`1"); + var iApiResponse = knownTypes.TryGet("Refit.IApiResponse"); + + IsApiResponse = + ReturnResultType is INamedTypeSymbol { IsGenericType: true } namedTypeSymbol + && ( + SymbolEqualityComparer.Default.Equals( + namedTypeSymbol.OriginalDefinition, + apiResponse + ) + || namedTypeSymbol.OriginalDefinition.InheritsFromOrEquals(unboundIApiResponse) + ) + || SymbolEqualityComparer.Default.Equals(ReturnResultType, iApiResponse); + } + + public bool HasHeaderCollection => HeaderCollectionParameterIndex >= 0; + + public bool HeaderCollectionAt(int index) => + HeaderCollectionParameterIndex >= 0 && HeaderCollectionParameterIndex == index; + + // TODO: this should be moved to a new class, along with most model logic + public RefitBodyModel ToRefitBodyModel() + { + // TODO: Add Authorise + // // // TODO: should ParseHeaders already add this? + // if (AuthorizeParameterInfo is not null) + // { + // Headers[AuthorizeParameterInfo.] + // } + + // TODO: headercollectionParam is broken + // TODO: Is query model logic correct? + + return new RefitBodyModel( + HttpMethod, + ReturnResultType?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + DeserializedResultType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + IsApiResponse, + CancellationToken?.Name, + MultipartBoundary, + PathFragments.ToImmutableEquatableArray(), + BuildHeaderPsModel().ToImmutableEquatableArray(), + Headers.Select(kp => new HeaderModel(kp.Key, kp.Value)).ToImmutableEquatableArray(), + HeaderParameterMap + .Select(kp => new HeaderParameterModel(ParameterSymbolArray[kp.Key].Name, kp.Value)) + .ToImmutableEquatableArray(), + HeaderCollectionParameterIndex < 0 + ? null + : ParameterSymbolArray[HeaderCollectionParameterIndex].Name, + ImmutableEquatableArray.Empty, + PropertyParameterMap + .Select(kp => new PropertyModel(ParameterSymbolArray[kp.Key].Name, kp.Value)) + .ToImmutableEquatableArray(), + QueryModels.ToImmutableEquatableArray(), + BodyParameterInfo is null + ? null + : new BodyModel( + ParameterSymbolArray[BodyParameterInfo.Item3].Name, + BodyParameterInfo.Item2, + BodyParameterInfo.Item1 + ), + QueryUriFormat + ); + } + + static HttpMethodConfiguration GetHttpMethod( + IMethodSymbol methodSymbol, + WellKnownTypes knownTypes + ) + { + var attributeSymbol = knownTypes.TryGet("Refit.HttpMethodAttribute")!; + var attribute = methodSymbol.GetAttributesFor(attributeSymbol).FirstOrDefault()!; + var hma = attribute?.MapToType(knownTypes)!; + + var getAttribute = knownTypes.TryGet("Refit.GetAttribute"); + var postAttribute = knownTypes.TryGet("Refit.PostAttribute"); + var putAttribute = knownTypes.TryGet("Refit.PutAttribute"); + var deleteAttribute = knownTypes.TryGet("Refit.DeleteAttribute"); + var patchAttribute = knownTypes.TryGet("Refit.PatchAttribute"); + var optionsAttribute = knownTypes.TryGet("Refit.OptionsAttribute"); + var headAttribute = knownTypes.TryGet("Refit.HeadAttribute"); + + if (attribute.AttributeClass.InheritsFromOrEquals(getAttribute)) + { + hma.Method = HttpMethod.Get; + } + else if (attribute.AttributeClass.InheritsFromOrEquals(postAttribute)) + { + hma.Method = HttpMethod.Post; + } + else if (attribute.AttributeClass.InheritsFromOrEquals(putAttribute)) + { + hma.Method = HttpMethod.Put; + } + else if (attribute.AttributeClass.InheritsFromOrEquals(deleteAttribute)) + { + hma.Method = HttpMethod.Delete; + } + else if (attribute.AttributeClass.InheritsFromOrEquals(patchAttribute)) + { + hma.Method = PatchMethod; + } + else if (attribute.AttributeClass.InheritsFromOrEquals(optionsAttribute)) + { + hma.Method = HttpMethod.Options; + } + else if (attribute.AttributeClass.InheritsFromOrEquals(headAttribute)) + { + hma.Method = HttpMethod.Head; + } + else + { + // TODO: need to emit a diagnostic here + // I don't think I can support custom HttpMethodAttributes + } + + return hma; + } + + static int GetHeaderCollectionParameterIndex( + IParameterSymbol[] parameterArray, + WellKnownTypes knownTypes + ) + { + var headerIndex = -1; + + // TODO: convert this into a real string string dictionary + var dictionaryOpenType = knownTypes.Get(typeof(IDictionary<,>)); + var stringType = knownTypes.Get(); + var genericDictionary = dictionaryOpenType.Construct(stringType, stringType); + + for (var i = 0; i < parameterArray.Length; i++) + { + var param = parameterArray[i]; + var headerCollectionSymbol = knownTypes.TryGet("Refit.HeaderCollectionAttribute")!; + var headerCollection = param.AccessFirstOrDefault( + headerCollectionSymbol, + knownTypes + ); + + if (headerCollection == null) + continue; + + // TODO: this check may not work in nullable contexts. + //opted for IDictionary semantics here as opposed to the looser IEnumerable> because IDictionary will enforce uniqueness of keys + if (SymbolEqualityComparer.Default.Equals(param.Type, genericDictionary)) + { + // throw if there is already a HeaderCollection parameter + if (headerIndex >= 0) + throw new ArgumentException( + "Only one parameter can be a HeaderCollection parameter" + ); + + headerIndex = i; + } + else + { + throw new ArgumentException( + $"HeaderCollection parameter of type {param.Type.Name} is not assignable from IDictionary" + ); + } + } + + return headerIndex; + } + + // public RestMethodSymbol ToRestMethodSymbol() => + // new(Name, Type, MethodInfo, RelativePath, ReturnType); + + // TODO: need to escape strings line [Property("hey\"world")] + static Dictionary BuildRequestPropertyMap( + IParameterSymbol[] parameterArray, + WellKnownTypes knownTypes + ) + { + Dictionary? propertyMap = null; + + for (var i = 0; i < parameterArray.Length; i++) + { + var param = parameterArray[i]; + var propertySymbol = knownTypes.TryGet("Refit.PropertyAttribute")!; + var propertyAttribute = param.AccessFirstOrDefault( + propertySymbol, + knownTypes + ); + + if (propertyAttribute != null) + { + var propertyKey = !string.IsNullOrEmpty(propertyAttribute.Key) + ? propertyAttribute.Key + : param.Name!; + propertyMap ??= new Dictionary(); + propertyMap[i] = propertyKey!; + } + } + + return propertyMap ?? new Dictionary(); + } + + List BuildQueryParameterList(WellKnownTypes knownTypes) + { + List queryParamsToAdd = []; + RestMethodParameterInfo? parameterInfo = null; + + for (var i = 0; i < ParameterSymbolArray.Length; i++) + { + var isParameterMappedToRequest = false; + var param = ParameterSymbolArray[i]; + // if part of REST resource URL, substitute it in + if (this.ParameterMap.TryGetValue(i, out var parameterMapValue)) + { + parameterInfo = parameterMapValue; + if (!parameterInfo.IsObjectPropertyParameter) + { + // mark parameter mapped if not an object + // we want objects to fall through so any parameters on this object not bound here get passed as query parameters + isParameterMappedToRequest = true; + } + } + + // if marked as body, add to content + if (this.BodyParameterInfo != null && this.BodyParameterInfo.Item3 == i) + { + // AddBodyToRequest(restMethod, param, ret); + isParameterMappedToRequest = true; + } + + // if header, add to request headers + if (this.HeaderParameterMap.TryGetValue(i, out var headerParameterValue)) + { + isParameterMappedToRequest = true; + } + + //if header collection, add to request headers + if (this.HeaderCollectionAt(i)) + { + isParameterMappedToRequest = true; + } + + //if authorize, add to request headers with scheme + if (this.AuthorizeParameterInfo != null && this.AuthorizeParameterInfo.Item2 == i) + { + isParameterMappedToRequest = true; + } + + //if property, add to populate into HttpRequestMessage.Properties + if (this.PropertyParameterMap.ContainsKey(i)) + { + isParameterMappedToRequest = true; + } + + // ignore nulls and already processed parameters + if (isParameterMappedToRequest || param == null) + continue; + + // for anything that fell through to here, if this is not a multipart method add the parameter to the query string + // or if is an object bound to the path add any non-path bound properties to query string + // or if it's an object with a query attribute + QueryConfiguration queryAttribute = null; + + // var queryAttribute = this + // .ParameterSymbolArray[i] + // .GetCustomAttribute(); + if ( + !this.IsMultipart + || this.ParameterMap.ContainsKey(i) + && this.ParameterMap[i].IsObjectPropertyParameter + || queryAttribute != null + ) + { + queryParamsToAdd ??= []; + AddQueryParameters( + queryAttribute, + param, + queryParamsToAdd, + i, + parameterInfo, + knownTypes + ); + continue; + } + + // AddMultiPart(restMethod, i, param, multiPartContent); + } + + return queryParamsToAdd; + } + + static void AddQueryParameters( + QueryConfiguration? queryAttribute, + IParameterSymbol param, + List queryParamsToAdd, + int i, + RestMethodParameterInfo? parameterInfo, + WellKnownTypes knownTypes + ) + { + queryParamsToAdd.Add(new QueryModel(param.Name, 0, CollectionFormat.Csv, "-", null, null)); + } + + // void AddQueryParameters(QueryConfiguration? queryAttribute, IParameterSymbol param, + // List queryParamsToAdd, int i, RestMethodParameterInfo? parameterInfo, + // WellKnownTypes knownTypes) + // { + // var attr = queryAttribute ?? DefaultQueryAttribute; + // if (DoNotConvertToQueryMap(param, knownTypes)) + // { + // queryParamsToAdd.AddRange( + // ParseQueryParameter( + // param, + // this.ParameterInfoArray[i], + // this.QueryParameterMap[i], + // attr + // ) + // ); + // } + // else + // { + // foreach (var kvp in BuildQueryMap(param, attr.Delimiter, parameterInfo)) + // { + // var path = !string.IsNullOrWhiteSpace(attr.Prefix) + // ? $"{attr.Prefix}{attr.Delimiter}{kvp.Key}" + // : kvp.Key; + // queryParamsToAdd.AddRange( + // ParseQueryParameter( + // kvp.Value, + // this.ParameterInfoArray[i], + // path, + // attr + // ) + // ); + // } + // } + // } + + // TODO: add param nul check to runtime add + static bool DoNotConvertToQueryMap(IParameterSymbol value, WellKnownTypes knownTypes) + { + var type = value.Type; + + // Bail out early & match string + if (ShouldReturn(type, knownTypes)) + return true; + + var iEnumerableSymbol = knownTypes.Get(); + if (type.InheritsFromOrEquals(iEnumerableSymbol)) + return false; + + // Get the element type for enumerables + var iEnumerableTSymbol = knownTypes.Get(typeof(IEnumerable<>)); + // We don't want to enumerate to get the type, so we'll just look for IEnumerable + + foreach (var iface in type.AllInterfaces) + { + // TODO: could probably uncomment + // if (iface.OriginalDefinition.SpecialType == SpecialType.System_Collections_IEnumerable) + // return false; + + if ( + iface.TypeArguments.Length == 1 + && iface.OriginalDefinition.InheritsFromOrEquals(iEnumerableTSymbol) + ) + { + return ShouldReturn(iface.TypeArguments[0], knownTypes); + } + } + + return false; + + // TODO: I assume that NullableGetUnderlyingType is nullable struct types and not all types + // TODO: ensure that this works with char? and string? + // Check if type is a simple string or IFormattable type, check underlying type if Nullable + static bool ShouldReturn(ITypeSymbol typeSymbol, WellKnownTypes knownTypes) + { + if ( + typeSymbol is INamedTypeSymbol namedType + && namedType.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T + && namedType.TypeArguments.Length == 1 + && namedType.TypeArguments[0].IsValueType + ) + { + return ShouldReturn(namedType.TypeArguments[0], knownTypes); + } + + var iFormattableSymbol = knownTypes.Get(); + var uriSymbol = knownTypes.Get(); + + return typeSymbol.SpecialType == SpecialType.System_String + || typeSymbol.SpecialType == SpecialType.System_Boolean + || typeSymbol.SpecialType == SpecialType.System_Char + || typeSymbol.InheritsFromOrEquals(iFormattableSymbol) + || SymbolEqualityComparer.Default.Equals(typeSymbol, uriSymbol); + } + } + + // + // IEnumerable> ParseQueryParameter( + // ITypeSymbol param, + // ParameterInfo parameterInfo, + // string queryPath, + // QueryConfiguration queryAttribute, + // WellKnownTypes knownTypes + // ) + // { + // var iEnumerableSymbol = knownTypes.Get(); + // if (param.SpecialType != SpecialType.System_String && param.InheritsFromOrEquals(iEnumerableSymbol)) + // { + // foreach ( + // var value in ParseEnumerableQueryParameterValue( + // param, + // parameterInfo, + // parameterInfo.ParameterType, + // queryAttribute + // ) + // ) + // { + // yield return new KeyValuePair(queryPath, value); + // } + // } + // else + // { + // throw new NotImplementedException(nameof(ParseQueryParameter)); + // yield return new KeyValuePair( + // queryPath, + // settings.UrlParameterFormatter.Format( + // param, + // parameterInfo, + // parameterInfo.ParameterType + // ) + // ); + // } + // } + // + // IEnumerable ParseEnumerableQueryParameterValue( + // IEnumerable paramValues, + // ICustomAttributeProvider customAttributeProvider, + // Type type, + // QueryConfiguration? queryAttribute + // ) + // { + // // TODO: collection + // var collectionFormat = + // queryAttribute != null && queryAttribute.IsCollectionFormatSpecified + // ? queryAttribute.CollectionFormat + // : settings.CollectionFormat; + // + // switch (collectionFormat) + // { + // case CollectionFormat.Multi: + // foreach (var paramValue in paramValues) + // { + // yield return settings.UrlParameterFormatter.Format( + // paramValue, + // customAttributeProvider, + // type + // ); + // } + // + // break; + // + // default: + // var delimiter = + // collectionFormat switch + // { + // CollectionFormat.Ssv => " ", + // CollectionFormat.Tsv => "\t", + // CollectionFormat.Pipes => "|", + // _ => "," + // }; + // + // // Missing a "default" clause was preventing the collection from serializing at all, as it was hitting "continue" thus causing an off-by-one error + // var formattedValues = paramValues + // .Cast() + // .Select( + // v => + // settings.UrlParameterFormatter.Format( + // v, + // customAttributeProvider, + // type + // ) + // ); + // + // yield return string.Join(delimiter, formattedValues); + // + // break; + // } + // } + + + static IEnumerable GetParameterProperties(IParameterSymbol parameter) + { + return parameter + .Type.GetMembers() + .OfType() + .Where(static p => p.DeclaredAccessibility == Accessibility.Public && !p.IsStatic) + .Where(static p => p.GetMethod is { DeclaredAccessibility: Accessibility.Public }); + } + + static void VerifyUrlPathIsSane(string relativePath) + { + if (string.IsNullOrEmpty(relativePath)) + return; + + if (!relativePath.StartsWith("/")) + throw new ArgumentException( + $"URL path {relativePath} must start with '/' and be of the form '/foo/bar/baz'" + ); + + // CRLF injection protection + if (relativePath.Contains('\r') || relativePath.Contains('\n')) + throw new ArgumentException( + $"URL path {relativePath} must not contain CR or LF characters" + ); + } + + // static Dictionary BuildParameterMap( + // string relativePath, + // IParameterSymbol[] parameterSymbol, + // WellKnownTypes knownTypes + // ) + // { + // var ret = new Dictionary(); + // + // // This section handles pattern matching in the URL. We also need it to add parameter key/values for any attribute with a [Query] + // var parameterizedParts = relativePath + // .Split('/', '?') + // .SelectMany(x => ParameterRegex.Matches(x).Cast()) + // .ToList(); + // + // if (parameterizedParts.Count > 0) + // { + // var paramValidationDict = parameterSymbol.ToDictionary( + // k => GetUrlNameForParameter(k, knownTypes).ToLowerInvariant(), + // v => v + // ); + // //if the param is an lets make a dictionary for all it's potential parameters + // var objectParamValidationDict = parameterSymbol + // .Where(x => x.Type.IsReferenceType) + // .SelectMany(x => GetParameterProperties(x).Select(p => Tuple.Create(x, p))) + // .GroupBy( + // i => $"{i.Item1.Name}.{GetUrlNameForProperty(i.Item2, knownTypes)}".ToLowerInvariant() + // ) + // .ToDictionary(k => k.Key, v => v.First()); + // foreach (var match in parameterizedParts) + // { + // var rawName = match.Groups[1].Value.ToLowerInvariant(); + // var isRoundTripping = rawName.StartsWith("**"); + // string name; + // if (isRoundTripping) + // { + // name = rawName.Substring(2); + // } + // else + // { + // name = rawName; + // } + // + // if (paramValidationDict.TryGetValue(name, out var value)) //if it's a standard parameter + // { + // var paramType = value.Type; + // if (isRoundTripping && paramType.SpecialType != SpecialType.System_String) + // { + // throw new ArgumentException( + // $"URL {relativePath} has round-tripping parameter {rawName}, but the type of matched method parameter is {paramType.Name}. It must be a string." + // ); + // } + // var parameterType = isRoundTripping + // ? ParameterType.RoundTripping + // : ParameterType.Normal; + // var restMethodParameterInfo = new RestMethodParameterInfo(name, value) + // { + // Type = parameterType + // }; + // #if NET6_0_OR_GREATER + // ret.TryAdd( + // Array.IndexOf(parameterInfo, restMethodParameterInfo.ParameterInfo), + // restMethodParameterInfo + // ); + // #else + // var idx = Array.IndexOf(parameterSymbol, restMethodParameterInfo.ParameterInfo); + // if (!ret.ContainsKey(idx)) + // { + // ret.Add(idx, restMethodParameterInfo); + // } + // #endif + // } + // //else if it's a property on a object parameter + // else if ( + // objectParamValidationDict.TryGetValue(name, out var value1) + // && !isRoundTripping + // ) + // { + // var property = value1; + // var parameterIndex = Array.IndexOf(parameterSymbol, property.Item1); + // //If we already have this parameter, add additional ParameterProperty + // if (ret.TryGetValue(parameterIndex, out var value2)) + // { + // if (!value2.IsObjectPropertyParameter) + // { + // throw new ArgumentException( + // $"Parameter {property.Item1.Name} matches both a parameter and nested parameter on a parameter object" + // ); + // } + // + // value2.ParameterProperties.Add( + // new RestMethodParameterProperty(name, property.Item2) + // ); + // } + // else + // { + // var restMethodParameterInfo = new RestMethodParameterInfo( + // true, + // property.Item1 + // ); + // restMethodParameterInfo.ParameterProperties.Add( + // new RestMethodParameterProperty(name, property.Item2) + // ); + // #if NET6_0_OR_GREATER + // ret.TryAdd( + // Array.IndexOf(parameterInfo, restMethodParameterInfo.ParameterInfo), + // restMethodParameterInfo + // ); + // #else + // // Do the contains check + // var idx = Array.IndexOf(parameterSymbol, restMethodParameterInfo.ParameterInfo); + // if (!ret.ContainsKey(idx)) + // { + // ret.Add(idx, restMethodParameterInfo); + // } + // #endif + // } + // } + // else + // { + // throw new ArgumentException( + // $"URL {relativePath} has parameter {rawName}, but no method parameter matches" + // ); + // } + // } + // } + // return ret; + // } + + static ( + Dictionary ret, + List fragmentList + ) BuildParameterMap2( + string relativePath, + IParameterSymbol[] parameterSymbols, + WellKnownTypes knownTypes + ) + { + var ret = new Dictionary(); + + // This section handles pattern matching in the URL. We also need it to add parameter key/values for any attribute with a [Query] + var parameterizedParts = ParameterRegex2.Matches(relativePath).Cast().ToArray(); + + if (parameterizedParts.Length == 0) + { + // TODO: does this handle cases where we start with round tripping? + if (string.IsNullOrEmpty(relativePath)) + return (ret, new List()); + + return (ret, new List() { new ConstantFragmentModel(relativePath) }); + } + + var paramValidationDict = parameterSymbols.ToDictionary( + k => GetUrlNameForParameter(k, knownTypes).ToLowerInvariant(), + v => v + ); + //if the param is an lets make a dictionary for all it's potential parameters + var objectParamValidationDict = parameterSymbols + .Where(x => x.Type.IsReferenceType) + .SelectMany(x => GetParameterProperties(x).Select(p => Tuple.Create(x, p))) + .GroupBy(i => + $"{i.Item1.Name}.{GetUrlNameForProperty(i.Item2, knownTypes)}".ToLowerInvariant() + ) + .ToDictionary(k => k.Key, v => v.First()); + + var fragmentList = new List(); + var index = 0; + foreach (var match in parameterizedParts) + { + // Add constant value from given http path and continue + if (match.Index != index) + { + fragmentList.Add( + new ConstantFragmentModel(relativePath.Substring(index, match.Index - index)) + ); + } + index = match.Index + match.Length; + + var rawName = match.Groups[1].Value.ToLowerInvariant(); + var isRoundTripping = rawName.StartsWith("**"); + var name = isRoundTripping ? rawName.Substring(2) : rawName; + + if (paramValidationDict.TryGetValue(name, out var value)) //if it's a standard parameter + { + var paramType = value.Type; + if (isRoundTripping && paramType.SpecialType == SpecialType.System_String) + { + throw new ArgumentException( + $"URL {relativePath} has round-tripping parameter {rawName}, but the type of matched method parameter is {paramType.ToDisplayString( + SymbolDisplayFormat.FullyQualifiedFormat + )}. It must be a string." + ); + } + var parameterType = isRoundTripping + ? ParameterType.RoundTripping + : ParameterType.Normal; + var restMethodParameterInfo = new RestMethodParameterInfo(name, value) + { + Type = parameterType, + }; + + var paramSymbol = restMethodParameterInfo.ParameterInfo; + var parameterIndex = Array.IndexOf( + parameterSymbols, + restMethodParameterInfo.ParameterInfo + ); + var parameterTypeDeclaration = paramSymbol.Type.ToDisplayString( + SymbolDisplayFormat.FullyQualifiedFormat + ); + fragmentList.Add( + new DynamicFragmentModel( + paramSymbol.Name, + paramSymbol.Ordinal, + parameterTypeDeclaration + ) + ); +#if NET6_0_OR_GREATER + ret.TryAdd(parameterIndex, restMethodParameterInfo); +#else + if (!ret.ContainsKey(parameterIndex)) + { + ret.Add(parameterIndex, restMethodParameterInfo); + } +#endif + } + //else if it's a property on a object parameter + else if ( + objectParamValidationDict.TryGetValue(name, out var value1) && !isRoundTripping + ) + { + var property = value1; + var parameterIndex = Array.IndexOf(parameterSymbols, property.Item1); + //If we already have this parameter, add additional ParameterProperty + if (ret.TryGetValue(parameterIndex, out var value2)) + { + if (!value2.IsObjectPropertyParameter) + { + throw new ArgumentException( + $"Parameter {property.Item1.Name} matches both a parameter and nested parameter on a parameter object" + ); + } + + value2.ParameterProperties.Add( + new RestMethodParameterProperty(name, property.Item2) + ); + + var propertyAccessExpression = $"{property.Item1.Name}.{property.Item2.Name}"; + var containingType = property.Item2.ContainingType.ToDisplayString( + SymbolDisplayFormat.FullyQualifiedFormat + ); + var typeDeclaration = property.Item2.Type.ToDisplayString( + SymbolDisplayFormat.FullyQualifiedFormat + ); + + fragmentList.Add( + new DynamicPropertyFragmentModel( + propertyAccessExpression, + property.Item2.Name, + containingType, + typeDeclaration + ) + ); + } + else + { + var restMethodParameterInfo = new RestMethodParameterInfo(true, property.Item1); + restMethodParameterInfo.ParameterProperties.Add( + new RestMethodParameterProperty(name, property.Item2) + ); + + var idx = Array.IndexOf( + parameterSymbols, + restMethodParameterInfo.ParameterInfo + ); + var propertyAccessExpression = $"{property.Item1.Name}.{property.Item2.Name}"; + var containingType = property.Item2.ContainingType.ToDisplayString( + SymbolDisplayFormat.FullyQualifiedFormat + ); + var typeDeclaration = property.Item2.Type.ToDisplayString( + SymbolDisplayFormat.FullyQualifiedFormat + ); + fragmentList.Add( + new DynamicPropertyFragmentModel( + propertyAccessExpression, + property.Item2.Name, + containingType, + typeDeclaration + ) + ); +#if NET6_0_OR_GREATER + ret.TryAdd(idx, restMethodParameterInfo); +#else + // Do the contains check + if (!ret.ContainsKey(idx)) + { + ret.Add(idx, restMethodParameterInfo); + } +#endif + } + } + else + { + throw new ArgumentException( + $"URL {relativePath} has parameter {rawName}, but no method parameter matches" + ); + } + } + + // add trailing string + if (index < relativePath.Length - 1) + { + var s = relativePath.Substring(index, relativePath.Length - index); + fragmentList.Add(new ConstantFragmentModel(s)); + } + return (ret, fragmentList); + } + + // TODO: could these two methods be merged? + static string GetUrlNameForParameter(IParameterSymbol paramSymbol, WellKnownTypes knownTypes) + { + var aliasAsSymbol = knownTypes.TryGet("Refit.AliasAsAttribute"); + var aliasAttr = paramSymbol.AccessFirstOrDefault( + aliasAsSymbol, + knownTypes + ); + return aliasAttr != null ? aliasAttr.Name : paramSymbol.Name!; + } + + static string GetUrlNameForProperty(IPropertySymbol propSymbol, WellKnownTypes knownTypes) + { + var aliasAsSymbol = knownTypes.TryGet("Refit.AliasAsAttribute"); + var aliasAttr = propSymbol.AccessFirstOrDefault( + aliasAsSymbol, + knownTypes + ); + return aliasAttr != null ? aliasAttr.Name : propSymbol.Name; + } + + static string GetAttachmentNameForParameter( + IParameterSymbol paramSymbol, + WellKnownTypes knownTypes + ) + { + var attachmentNameSymbol = knownTypes.TryGet("Refit.AttachmentNameAttribute"); + var aliasAsSymbol = knownTypes.TryGet("Refit.AliasAsAttribute"); + +#pragma warning disable CS0618 // Type or member is obsolete + var nameAttr = paramSymbol.AccessFirstOrDefault( + attachmentNameSymbol, + knownTypes + ); +#pragma warning restore CS0618 // Type or member is obsolete + + // also check for AliasAs + return nameAttr?.Name + ?? paramSymbol + .AccessFirstOrDefault(aliasAsSymbol, knownTypes) + ?.Name!; + } + + Tuple? FindBodyParameter( + IParameterSymbol[] parameterArray, + bool isMultipart, + HttpMethod method, + WellKnownTypes knownTypes + ) + { + // The body parameter is found using the following logic / order of precedence: + // 1) [Body] attribute + // 2) POST/PUT/PATCH: Reference type other than string + // 3) If there are two reference types other than string, without the body attribute, throw + + var bodySymbol = knownTypes.TryGet("Refit.BodyAttribute"); + var bodyParamEnumerable = parameterArray + .Select(x => + ( + Parameter: x, + BodyAttribute: x.AccessFirstOrDefault(bodySymbol, knownTypes) + ) + ) + .Where(x => x.BodyAttribute != null) + .ToArray(); + + // multipart requests may not contain a body, implicit or explicit + if (isMultipart) + { + if (bodyParamEnumerable.Length > 0) + { + throw new ArgumentException("Multipart requests may not contain a Body parameter"); + } + return null; + } + + if (bodyParamEnumerable.Length > 1) + { + throw new ArgumentException("Only one parameter can be a Body parameter"); + } + + // #1, body attribute wins + if (bodyParamEnumerable.Length == 1) + { + var bodyParam = bodyParamEnumerable.First(); + + // TODO: move logic to runtime + // return Tuple.Create( + // bodyParam!.BodyAttribute!.SerializationMethod, + // bodyParam.BodyAttribute.Buffered ?? RefitSettings.Buffered, + // Array.IndexOf(parameterArray, bodyParam.Parameter) + // ); + // TODO: this default to false + + return Tuple.Create( + bodyParam!.BodyAttribute!.SerializationMethod, + bodyParam.BodyAttribute.Buffered ?? false, + Array.IndexOf(parameterArray, bodyParam.Parameter) + ); + } + + // TODO: no idea if this works with derived attributes + // Not in post/put/patch? bail + if ( + !method.Equals(HttpMethod.Post) + && !method.Equals(HttpMethod.Put) + && !method.Equals(PatchMethod) + ) + { + return null; + } + + var querySymbol = knownTypes.TryGet("Refit.QueryAttribute"); + var headerCollectionSymbol = knownTypes.TryGet("Refit.HeaderCollectionAttribute"); + var propertySymbol = knownTypes.TryGet("Refit.PropertyAttribute"); + + // see if we're a post/put/patch + // explicitly skip [Query], [HeaderCollection], and [Property]-denoted params + var refParamEnumerable = parameterArray + .Where(pi => + !pi.Type.IsValueType + && pi.Type.SpecialType != SpecialType.System_String + && pi.AccessFirstOrDefault(querySymbol, knownTypes) == null + && pi.AccessFirstOrDefault( + headerCollectionSymbol, + knownTypes + ) == null + && pi.AccessFirstOrDefault(propertySymbol, knownTypes) + == null + ) + .ToArray(); + + // Check for rule #3 + if (refParamEnumerable.Length > 1) + { + throw new ArgumentException( + "Multiple complex types found. Specify one parameter as the body using BodyAttribute" + ); + } + + if (refParamEnumerable.Length == 1) + { + var refParam = refParamEnumerable.First(); + // TODO: move RefitSettings logic to runtime. + // return Tuple.Create( + // BodySerializationMethod.Serialized, + // RefitSettings.Buffered, + // Array.IndexOf(parameterArray, refParam!) + // ); + + return Tuple.Create( + BodySerializationMethod.Serialized, + false, + Array.IndexOf(parameterArray, refParam!) + ); + } + + return null; + } + + static Tuple? FindAuthorizationParameter( + IParameterSymbol[] parameterArray, + WellKnownTypes knownTypes + ) + { + var authorizeSymbol = knownTypes.TryGet("Refit.AuthorizeAttribute"); + var authorizeParams = parameterArray + .Select(x => + ( + Parameter: x, + AuthorizeAttribute: x.AccessFirstOrDefault( + authorizeSymbol, + knownTypes + ) + ) + ) + .Where(x => x.AuthorizeAttribute != null) + .ToArray(); + + if (authorizeParams.Length > 1) + { + throw new ArgumentException("Only one parameter can be an Authorize parameter"); + } + + if (authorizeParams.Length == 1) + { + var authorizeParam = authorizeParams.First(); + return Tuple.Create( + authorizeParam!.AuthorizeAttribute!.Scheme, + Array.IndexOf(parameterArray, authorizeParam.Parameter) + ); + } + + return null; + } + + static Dictionary ParseHeaders( + IMethodSymbol methodSymbol, + WellKnownTypes knownTypes + ) + { + var headersSymbol = knownTypes.TryGet("Refit.HeadersAttribute"); + var inheritedAttributes = + methodSymbol.ContainingType != null + ? methodSymbol + .ContainingType.AllInterfaces.SelectMany(i => + i.Access(headersSymbol, knownTypes) + ) + .Reverse() + : []; + + var declaringTypeAttributes = methodSymbol.ContainingType!.Access( + headersSymbol, + knownTypes + ); + + // Headers set on the declaring type have to come first, + // so headers set on the method can replace them. Switching + // the order here will break stuff. + var headers = inheritedAttributes + .Concat(declaringTypeAttributes) + .Concat(methodSymbol.Access(headersSymbol, knownTypes)) + .SelectMany(ha => ha.Headers); + + var ret = new Dictionary(); + + foreach (var header in headers) + { + if (string.IsNullOrWhiteSpace(header)) + continue; + + // NB: Silverlight doesn't have an overload for String.Split() + // with a count parameter, but header values can contain + // ':' so we have to re-join all but the first part to get the + // value. + + var parsedHeader = EnsureSafe(header); + var parts = parsedHeader.Split(':'); + ret[parts[0].Trim()] = parts.Length > 1 ? string.Join(":", parts.Skip(1)).Trim() : null; + } + + return ret; + } + + static Dictionary BuildHeaderParameterMap( + IParameterSymbol[] parameterArray, + WellKnownTypes knownTypes + ) + { + var ret = new System.Collections.Generic.Dictionary(); + var headerSymbol = knownTypes.TryGet("Refit.HeaderAttribute"); + + for (var i = 0; i < parameterArray.Length; i++) + { + var headerAttribute = parameterArray[i] + .AccessFirstOrDefault(headerSymbol, knownTypes); + + var header = headerAttribute?.Header; + + if (!string.IsNullOrWhiteSpace(header)) + { + ret[i] = header!.Trim(); + } + } + + return ret; + } + + // TODO: maybe merge with similar code for query + // TODO: is it safe to use parameters by name without using @? + // TODO: does parsing the Headers interfaces work the same as runtime reflection? + // TODO: properly escape strings + // TODO: adding " here is a bad idea + // TODO: check overlapping strings + + List BuildHeaderPsModel() + { + var headersToAdd = new System.Collections.Generic.List(); + + foreach (var pair in Headers) + { + headersToAdd.Add( + new HeaderPsModel( + HeaderType.Static, + new HeaderModel($""" "{pair.Key}" """, $""" "{pair.Value}" """), + null, + null, + null + ) + ); + } + + for (var i = 0; i < ParameterSymbolArray.Length; i++) + { + var isParameterMappedToRequest = false; + var param = ParameterSymbolArray[i]; + RestMethodParameterInfo? parameterInfo = null; + + // if part of REST resource URL, substitute it in + if (this.ParameterMap.TryGetValue(i, out var parameterMapValue)) + { + parameterInfo = parameterMapValue; + if (!parameterInfo.IsObjectPropertyParameter) + { + // mark parameter mapped if not an object + // we want objects to fall through so any parameters on this object not bound here get passed as query parameters + isParameterMappedToRequest = true; + } + } + + // if marked as body, add to content + if (this.BodyParameterInfo != null && this.BodyParameterInfo.Item3 == i) + { + // AddBodyToRequest(restMethod, param, ret); + isParameterMappedToRequest = true; + } + + // if header, add to request headers + if (this.HeaderParameterMap.TryGetValue(i, out var headerParameterValue)) + { + headersToAdd.Add( + new HeaderPsModel( + HeaderType.Static, + new HeaderModel($""" "{headerParameterValue}" """, param.Name), + null, + null, + null + ) + ); + isParameterMappedToRequest = true; + } + + //if header collection, add to request headers + if (this.HeaderCollectionAt(i)) + { + headersToAdd.Add( + new HeaderPsModel(HeaderType.Collection, null, null, param.Name, null) + ); + + isParameterMappedToRequest = true; + } + + //if authorize, add to request headers with scheme + if (this.AuthorizeParameterInfo != null && this.AuthorizeParameterInfo.Item2 == i) + { + // headersToAdd["Authorization"] = + // $"{this.AuthorizeParameterInfo.Item1} {param}"; + // + headersToAdd.Add( + new HeaderPsModel( + HeaderType.Authorise, + null, + null, + null, + new AuthoriseModel( + ParameterSymbolArray[this.AuthorizeParameterInfo.Item2].Name, + this.AuthorizeParameterInfo.Item1 + ) + ) + ); + + isParameterMappedToRequest = true; + } + } + + return headersToAdd; + } + + void DetermineReturnTypeInfo(IMethodSymbol methodInfo, WellKnownTypes knownTypes) + { + var unboundTaskSymbol = knownTypes.Get(typeof(Task<>)); + var valueTaskSymbol = knownTypes.Get(typeof(ValueTask<>)); + var observableSymbol = knownTypes.Get(typeof(IObservable<>)); + + var taskSymbol = knownTypes.Get(); + + var returnType = methodInfo.ReturnType; + if ( + returnType is INamedTypeSymbol { IsGenericType: true } namedType + && ( + SymbolEqualityComparer.Default.Equals( + namedType.OriginalDefinition, + unboundTaskSymbol + ) + || SymbolEqualityComparer.Default.Equals( + namedType.OriginalDefinition, + valueTaskSymbol + ) + || SymbolEqualityComparer.Default.Equals( + namedType.OriginalDefinition, + observableSymbol + ) + ) + ) + { + ReturnType = returnType; + ReturnResultType = namedType.TypeArguments[0]; + + var unboundApiResponseSymbol = knownTypes.TryGet("Refit.ApiResponse`1"); + var unboundIApiResponseSymbol = knownTypes.TryGet("Refit.IApiResponse`1"); + + var iApiResponseSymbol = knownTypes.TryGet("Refit.IApiResponse"); + + // TODO: maybe use inherits from here? + // Does refit support types inheriting from IApiResponse + if ( + ReturnResultType is INamedTypeSymbol { IsGenericType: true } returnResultNamedType + && ( + SymbolEqualityComparer.Default.Equals( + returnResultNamedType.OriginalDefinition, + unboundApiResponseSymbol + ) + || SymbolEqualityComparer.Default.Equals( + returnResultNamedType.OriginalDefinition, + unboundIApiResponseSymbol + ) + ) + ) + { + DeserializedResultType = returnResultNamedType.TypeArguments[0]; + } + else if (SymbolEqualityComparer.Default.Equals(ReturnResultType, iApiResponseSymbol)) + { + DeserializedResultType = knownTypes.Get(); + } + else + DeserializedResultType = ReturnResultType; + } + else if (SymbolEqualityComparer.Default.Equals(returnType, taskSymbol)) + { + var voidSymbol = knownTypes.Get(typeof(void)); + ReturnType = methodInfo.ReturnType; + ReturnResultType = voidSymbol; + DeserializedResultType = voidSymbol; + } + else + throw new ArgumentException( + $"Method \"{methodInfo.Name}\" is invalid. All REST Methods must return either Task or ValueTask or IObservable" + ); + } + + void DetermineIfResponseMustBeDisposed(WellKnownTypes knownTypes) + { + // Rest method caller will have to dispose if it's one of those 3 + var httpResponseSymbol = knownTypes.Get(); + var httpContentSymbol = knownTypes.Get(); + var streamSymbol = knownTypes.Get(); + + ShouldDisposeResponse = + (!SymbolEqualityComparer.Default.Equals(DeserializedResultType, httpResponseSymbol)) + && (!SymbolEqualityComparer.Default.Equals(DeserializedResultType, httpContentSymbol)) + && (!SymbolEqualityComparer.Default.Equals(DeserializedResultType, streamSymbol)); + } + + static string EnsureSafe(string value) + { + // Remove CR and LF characters +#pragma warning disable CA1307 // Specify StringComparison for clarity + return value.Replace("\r", string.Empty).Replace("\n", string.Empty); +#pragma warning restore CA1307 // Specify StringComparison for clarity + } +} + +/// +/// RestMethodParameterInfo. +/// +public class RestMethodParameterInfo +{ + /// + /// Initializes a new instance of the class. + /// + /// The name. + /// The parameter information. + public RestMethodParameterInfo(string name, IParameterSymbol parameterInfo) + { + Name = name; + ParameterInfo = parameterInfo; + } + + /// + /// Initializes a new instance of the class. + /// + /// if set to true [is object property parameter]. + /// The parameter information. + public RestMethodParameterInfo(bool isObjectPropertyParameter, IParameterSymbol parameterInfo) + { + IsObjectPropertyParameter = isObjectPropertyParameter; + ParameterInfo = parameterInfo; + } + + /// + /// Gets or sets the name. + /// + /// + /// The name. + /// + public string? Name { get; set; } + + /// + /// Gets or sets the parameter information. + /// + /// + /// The parameter information. + /// + public IParameterSymbol ParameterInfo { get; set; } + + /// + /// Gets or sets a value indicating whether this instance is object property parameter. + /// + /// + /// true if this instance is object property parameter; otherwise, false. + /// + public bool IsObjectPropertyParameter { get; set; } + + /// + /// Gets or sets the parameter properties. + /// + /// + /// The parameter properties. + /// + public List ParameterProperties { get; set; } = []; + + /// + /// Gets or sets the type. + /// + /// + /// The type. + /// + public ParameterType Type { get; set; } = ParameterType.Normal; +} + +/// +/// RestMethodParameterProperty. +/// +public class RestMethodParameterProperty +{ + /// + /// Initializes a new instance of the class. + /// + /// The name. + /// The property information. + public RestMethodParameterProperty(string name, IPropertySymbol propertyInfo) + { + Name = name; + PropertyInfo = propertyInfo; + } + + /// + /// Gets or sets the name. + /// + /// + /// The name. + /// + public string Name { get; set; } + + /// + /// Gets or sets the property information. + /// + /// + /// The property information. + /// + public IPropertySymbol PropertyInfo { get; set; } +} + +/// +/// ParameterType. +/// +public enum ParameterType +{ + /// + /// The normal + /// + Normal, + + /// + /// The round tripping + /// + RoundTripping, +} diff --git a/InterfaceStubGenerator.Shared/UniqueNameBuilder.cs b/InterfaceStubGenerator.Shared/UniqueNameBuilder.cs index feed8ecde..9c44ceffb 100644 --- a/InterfaceStubGenerator.Shared/UniqueNameBuilder.cs +++ b/InterfaceStubGenerator.Shared/UniqueNameBuilder.cs @@ -2,16 +2,16 @@ public class UniqueNameBuilder() { - private readonly HashSet _usedNames = new(StringComparer.Ordinal); - private readonly UniqueNameBuilder? _parentScope; + readonly HashSet usedNames = new(StringComparer.Ordinal); + readonly UniqueNameBuilder? parentScope; private UniqueNameBuilder(UniqueNameBuilder parentScope) : this() { - _parentScope = parentScope; + this.parentScope = parentScope; } - public void Reserve(string name) => _usedNames.Add(name); + public void Reserve(string name) => usedNames.Add(name); public UniqueNameBuilder NewScope() => new(this); @@ -25,7 +25,7 @@ public string New(string name) i++; } - _usedNames.Add(uniqueName); + usedNames.Add(uniqueName); return uniqueName; } @@ -34,17 +34,17 @@ public void Reserve(IEnumerable names) { foreach (var name in names) { - _usedNames.Add(name); + usedNames.Add(name); } } - private bool Contains(string name) + bool Contains(string name) { - if (_usedNames.Contains(name)) + if (usedNames.Contains(name)) return true; - if (_parentScope != null) - return _parentScope.Contains(name); + if (parentScope != null) + return parentScope.Contains(name); return false; } diff --git a/InterfaceStubGenerator.Shared/WellKnownTypes.cs b/InterfaceStubGenerator.Shared/WellKnownTypes.cs new file mode 100644 index 000000000..b36fd9f37 --- /dev/null +++ b/InterfaceStubGenerator.Shared/WellKnownTypes.cs @@ -0,0 +1,35 @@ +using Microsoft.CodeAnalysis; + +namespace Refit.Generator; + +public class WellKnownTypes(Compilation compilation) +{ + readonly Dictionary cachedTypes = new(); + + public INamedTypeSymbol Get() => Get(typeof(T)); + + public INamedTypeSymbol Get(Type type) + { + return Get( + type.FullName + ?? throw new InvalidOperationException("Could not get name of type " + type) + ); + } + + public INamedTypeSymbol? TryGet(string typeFullName) + { + if (cachedTypes.TryGetValue(typeFullName, out var typeSymbol)) + { + return typeSymbol; + } + + typeSymbol = compilation.GetTypeByMetadataName(typeFullName); + cachedTypes.Add(typeFullName, typeSymbol); + + return typeSymbol; + } + + private INamedTypeSymbol Get(string typeFullName) => + TryGet(typeFullName) + ?? throw new InvalidOperationException("Could not get type " + typeFullName); +} diff --git a/Refit.Tests/RequestBuilder.cs b/Refit.Tests/RequestBuilder.cs index f0477c5aa..2c167a7f5 100644 --- a/Refit.Tests/RequestBuilder.cs +++ b/Refit.Tests/RequestBuilder.cs @@ -3716,6 +3716,8 @@ public Func BuildRestResultFuncForMethod( CallCount++; return null; } + + public RefitSettings Settings { get; } } [Fact] diff --git a/Refit/CachedRequestBuilderImplementation.cs b/Refit/CachedRequestBuilderImplementation.cs index 1a0908330..2167d4b9b 100644 --- a/Refit/CachedRequestBuilderImplementation.cs +++ b/Refit/CachedRequestBuilderImplementation.cs @@ -17,6 +17,7 @@ public CachedRequestBuilderImplementation(IRequestBuilder innerBuilder) { this.innerBuilder = innerBuilder ?? throw new ArgumentNullException(nameof(innerBuilder)); + this.Settings = innerBuilder.Settings; } readonly IRequestBuilder innerBuilder; @@ -25,6 +26,8 @@ internal readonly ConcurrentDictionary< Func > MethodDictionary = new(); + public RefitSettings Settings { get; } + public Func BuildRestResultFuncForMethod( string methodName, Type[]? parameterTypes = null, diff --git a/Refit/RefitHelper.cs b/Refit/RefitHelper.cs new file mode 100644 index 000000000..87e2db23d --- /dev/null +++ b/Refit/RefitHelper.cs @@ -0,0 +1,518 @@ +using System.Net.Http; +using System.Text; + +namespace Refit; + +// TODO: evaluate use of methodinlining, prolly a good idea for small single line methods +// TODO: when I move this to emit will I need to ensure all parameter names are unique? +public static class RefitHelper +{ + public static Uri BaseUri = new ("http://api"); + + public static global::System.Net.Http.HttpMethod Patch = new ("PATCH"); + + public static void AddUrlFragment(ref ValueStringBuilder vsb, T value, global::Refit.RefitSettings settings, + global::System.Type type) + { + // TODO: implement this properly + vsb.Append(value.ToString()); + } + + public static void AddRoundTripUrlFragment(ref ValueStringBuilder vsb, global::Refit.RefitSettings settings, + global::System.Type type) => throw new NotImplementedException(nameof(AddRoundTripUrlFragment)); + + public static void AddPropertyFragment(ref ValueStringBuilder vsb, global::Refit.RefitSettings settings, + global::System.Type type) => throw new NotImplementedException(nameof(AddPropertyFragment)); + + public static void AddQueryObject(ref ValueStringBuilder vsb, global::Refit.RefitSettings settings, + string key, object value) + { + vsb.Append(key); + vsb.Append('='); + vsb.Append(value.ToString()); + } + + public static void InitialiseHeaders(global::System.Net.Http.HttpRequestMessage request) + { + // TODO: ensure not emitted when body. + request.Content ??= new global::System.Net.Http.ByteArrayContent([]); + } + + public static void AddHeader(global::System.Net.Http.HttpRequestMessage request, string key, string value) + { + SetHeader(request, key, value); + } + + public static void AddHeaderCollection(global::System.Net.Http.HttpRequestMessage request, IDictionary keys) + { + foreach (var pairs in keys) + { + SetHeader(request, pairs.Key, pairs.Value); + } + } + + static void SetHeader(global::System.Net.Http.HttpRequestMessage request, string name, string? value) + { + // Clear any existing version of this header that might be set, because + // we want to allow removal/redefinition of headers. + // We also don't want to double up content headers which may have been + // set for us automatically. + + // NB: We have to enumerate the header names to check existence because + // Contains throws if it's the wrong header type for the collection. + if (request.Headers.Any(x => x.Key == name)) + { + request.Headers.Remove(name); + } + + if (request.Content != null && request.Content.Headers.Any(x => x.Key == name)) + { + request.Content.Headers.Remove(name); + } + + if (value == null) + return; + + // CRLF injection protection + name = EnsureSafe(name); + value = EnsureSafe(value); + + var added = request.Headers.TryAddWithoutValidation(name, value); + + // Don't even bother trying to add the header as a content header + // if we just added it to the other collection. + if (!added && request.Content != null) + { + request.Content.Headers.TryAddWithoutValidation(name, value); + } + } + + static string EnsureSafe(string value) + { + // Remove CR and LF characters +#pragma warning disable CA1307 // Specify StringComparison for clarity + return value.Replace("\r", string.Empty).Replace("\n", string.Empty); +#pragma warning restore CA1307 // Specify StringComparison for clarity + } + + public static void AddBody(global::System.Net.Http.HttpRequestMessage request, + global::Refit.RefitSettings settings, object param, bool isBuffered, BodySerializationMethod serializationMethod) + { + if (param is HttpContent httpContentParam) + { + request.Content = httpContentParam; + } + else if (param is Stream streamParam) + { + request.Content = new StreamContent(streamParam); + } + // Default sends raw strings + else if ( + serializationMethod == BodySerializationMethod.Default + && param is string stringParam + ) + { + request.Content = new StringContent(stringParam); + } + else + { + switch (serializationMethod) + { + case BodySerializationMethod.UrlEncoded: + request.Content = param is string str + ? (HttpContent) + new StringContent( + Uri.EscapeDataString(str), + Encoding.UTF8, + "application/x-www-form-urlencoded" + ) + : new FormUrlEncodedContent( + new FormValueMultimap(param, settings) + ); + break; + case BodySerializationMethod.Default: +#pragma warning disable CS0618 // Type or member is obsolete + case BodySerializationMethod.Json: +#pragma warning restore CS0618 // Type or member is obsolete + case BodySerializationMethod.Serialized: + var content = settings.ContentSerializer.ToHttpContent(param); + switch (isBuffered) + { + case false: + request.Content = new PushStreamContent( +#pragma warning disable IDE1006 // Naming Styles + async (stream, _, __) => +#pragma warning restore IDE1006 // Naming Styles + { + using (stream) + { + await content + .CopyToAsync(stream) + .ConfigureAwait(false); + } + }, + content.Headers.ContentType + ); + break; + case true: + request.Content = content; + break; + } + + break; + } + } + } + + public static void WriteProperty(global::System.Net.Http.HttpRequestMessage request, object key, object value) => + throw new NotImplementedException(nameof(WriteProperty)); + + public static void WriteRefitSettingsProperties(global::System.Net.Http.HttpRequestMessage request, global::Refit.RefitSettings settings) + { + // Add RefitSetting.HttpRequestMessageOptions to the HttpRequestMessage + if (settings.HttpRequestMessageOptions != null) + { + foreach (var p in settings.HttpRequestMessageOptions) + { +#if NET6_0_OR_GREATER + request.Options.Set(new HttpRequestOptionsKey(p.Key), p.Value); +#else + request.Properties.Add(p); +#endif + } + } + } + + // TODO: qualify types, check nullability and use of generics here. Might stop a cheeky box + public static void WriteRefitSettingsProperties(global::System.Net.Http.HttpRequestMessage request, + string? key, object value) + { +#if NET6_0_OR_GREATER + request.Options.Set( + new HttpRequestOptionsKey(key), + value + ); +#else + request.Properties[key] = value; +#endif + } + + // TODO: is RestMethodInfo neeeded here? I feel like it breaks AOT + public static void WriteRefitSettingsProperties(global::System.Net.Http.HttpRequestMessage request, + Type interfaceType, RestMethodInfo restMethodInfo) + { + // Always add the top-level type of the interface to the properties +#if NET6_0_OR_GREATER + request.Options.Set( + new HttpRequestOptionsKey(HttpRequestMessageOptions.InterfaceType), + interfaceType + ); + request.Options.Set( + new HttpRequestOptionsKey( + HttpRequestMessageOptions.RestMethodInfo + ), + restMethodInfo + ); +#else + request.Properties[HttpRequestMessageOptions.InterfaceType] = interfaceType; + request.Properties[HttpRequestMessageOptions.RestMethodInfo] = + restMethodInfo; +#endif + } + + public static void AddVersionToRequest(global::System.Net.Http.HttpRequestMessage request, + global::Refit.RefitSettings settings) + { +#if NET6_0_OR_GREATER + request.Version = settings.Version; + request.VersionPolicy = settings.VersionPolicy; +#endif + } + + // TODO: double check if its an interface type of concrete type + public static void AddTopLevelTypes(global::System.Net.Http.HttpRequestMessage request, + global::System.Type interfaceType, + global::Refit.RestMethodInfo restMethodInfo) + { +// Always add the top-level type of the interface to the properties +#if NET6_0_OR_GREATER + request.Options.Set( + new HttpRequestOptionsKey(HttpRequestMessageOptions.InterfaceType), + interfaceType + ); + request.Options.Set( + new HttpRequestOptionsKey( + HttpRequestMessageOptions.RestMethodInfo + ), + restMethodInfo + ); +#else + request.Properties[HttpRequestMessageOptions.InterfaceType] = interfaceType; + request.Properties[HttpRequestMessageOptions.RestMethodInfo] = restMethodInfo; +#endif + } + + + // TODO: should this be split into methods for T, IApiResponse and ApiResponse?? + public static async Task SendTaskResultAsync(global::System.Net.Http.HttpRequestMessage request, + global::System.Net.Http.HttpClient client, + global::Refit.RefitSettings settings, + bool isBodyBuffered, + global::System.Threading.CancellationToken cancellationToken) + { + global::System.Net.Http.HttpResponseMessage? resp = null; + global::System.Net.Http.HttpContent? content = null; + var disposeResponse = true; + try + { + // Load the data into buffer when body should be buffered. + if (IsBodyBuffered(isBodyBuffered, request)) + { + await request.Content!.LoadIntoBufferAsync().ConfigureAwait(false); + } + resp = await client + .SendAsync(request, global::System.Net.Http.HttpCompletionOption.ResponseHeadersRead, cancellationToken) + .ConfigureAwait(false); + content = resp.Content ?? new global::System.Net.Http.StringContent(string.Empty); + Exception? e = null; + + // TODO: dispose + disposeResponse = true; + // disposeResponse = restMethod.ShouldDisposeResponse; + + if (typeof(T) != typeof(global::System.Net.Http.HttpResponseMessage)) + { + e = await settings.ExceptionFactory(resp).ConfigureAwait(false); + } + + // if (restMethod.IsApiResponse) + // { + // var body = default(TBody); + // + // try + // { + // // Only attempt to deserialize content if no error present for backward-compatibility + // body = + // e == null + // ? await DeserializeContentAsync(resp, content, cancellationToken) + // .ConfigureAwait(false) + // : default; + // } + // catch (Exception ex) + // { + // //if an error occured while attempting to deserialize return the wrapped ApiException + // if (settings.DeserializationExceptionFactory != null) + // e = await settings.DeserializationExceptionFactory(resp, ex).ConfigureAwait(false); + // else + // { + // e = await ApiException.Create( + // "An error occured deserializing the response.", + // resp.RequestMessage!, + // resp.RequestMessage!.Method, + // resp, + // settings, + // ex + // ); + // } + // } + // + // return ApiResponse.Create( + // resp, + // body, + // settings, + // e as ApiException + // ); + // } + if (e != null) + { + disposeResponse = false; // caller has to dispose + throw e; + } + else + { + try + { + return await DeserializeContentAsync(resp, content, cancellationToken) + .ConfigureAwait(false); + } + catch (Exception ex) + { + if (settings.DeserializationExceptionFactory != null) + { + var customEx = await settings.DeserializationExceptionFactory(resp, ex).ConfigureAwait(false); + if (customEx != null) + throw customEx; + return default; + } + else + { + throw await ApiException.Create( + "An error occured deserializing the response.", + resp.RequestMessage!, + resp.RequestMessage!.Method, + resp, + settings, + ex + ); + } + } + } + } + finally + { + // Ensure we clean up the request + // Especially important if it has open files/streams + request.Dispose(); + if (disposeResponse) + { + resp?.Dispose(); + content?.Dispose(); + } + } + } + + private static bool IsBodyBuffered( + bool isBuffered, + HttpRequestMessage? request + ) + { + return isBuffered && (request?.Content != null); + } + + + // TODO: lots of overlap in cod etry and share? + public static async Task SendTaskIApiResultAsync(global::System.Net.Http.HttpRequestMessage request, + global::System.Net.Http.HttpClient client, + global::Refit.RefitSettings settings, + bool isBuffered, + global::System.Threading.CancellationToken cancellationToken) + { + global::System.Net.Http.HttpResponseMessage? resp = null; + global::System.Net.Http.HttpContent? content = null; + var disposeResponse = true; + try + { + // TODO: add isBody buffered + // Load the data into buffer when body should be buffered. + if (IsBodyBuffered(isBuffered, request)) + { + await request.Content!.LoadIntoBufferAsync().ConfigureAwait(false); + } + resp = await client + .SendAsync(request, global::System.Net.Http.HttpCompletionOption.ResponseHeadersRead, cancellationToken) + .ConfigureAwait(false); + content = resp.Content ?? new global::System.Net.Http.StringContent(string.Empty); + Exception? e = null; + + // TODO: dispose + disposeResponse = true; + // disposeResponse = restMethod.ShouldDisposeResponse; + + if (typeof(T) != typeof(global::System.Net.Http.HttpResponseMessage)) + { + e = await settings.ExceptionFactory(resp).ConfigureAwait(false); + } + + var body = default(T); + + try + { + // Only attempt to deserialize content if no error present for backward-compatibility + body = + e == null + ? await DeserializeContentAsync(resp, content, cancellationToken) + .ConfigureAwait(false) + : default; + } + catch (Exception ex) + { + //if an error occured while attempting to deserialize return the wrapped ApiException + if (settings.DeserializationExceptionFactory != null) + e = await settings.DeserializationExceptionFactory(resp, ex).ConfigureAwait(false); + else + { + e = await ApiException.Create( + "An error occured deserializing the response.", + resp.RequestMessage!, + resp.RequestMessage!.Method, + resp, + settings, + ex + ); + } + } + + return ApiResponse.Create( + resp, + body, + settings, + e as ApiException + ); + } + finally + { + // Ensure we clean up the request + // Especially important if it has open files/streams + request.Dispose(); + if (disposeResponse) + { + resp?.Dispose(); + content?.Dispose(); + } + } + } + + static async Task DeserializeContentAsync( + global::System.Net.Http.HttpResponseMessage resp, + global::System.Net.Http.HttpContent content, + CancellationToken cancellationToken + ) + { + T? result; + if (typeof(T) == typeof(global::System.Net.Http.HttpResponseMessage)) + { + // NB: This double-casting manual-boxing hate crime is the only way to make + // this work without a 'class' generic constraint. It could blow up at runtime + // and would be A Bad Idea if we hadn't already vetted the return type. + result = (T)(object)resp; + } + else if (typeof(T) == typeof(global::System.Net.Http.HttpContent)) + { + result = (T)(object)content; + } + else if (typeof(T) == typeof(Stream)) + { + var stream = (object) + await content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + result = (T)stream; + } + else if (typeof(T) == typeof(string)) + { + using var stream = await content + .ReadAsStreamAsync(cancellationToken) + .ConfigureAwait(false); + using var reader = new StreamReader(stream); + var str = (object)await reader.ReadToEndAsync().ConfigureAwait(false); + result = (T)str; + } + else + { + // result = await serializer + // .FromHttpContentAsync(content, cancellationToken) + // .ConfigureAwait(false); + throw new NotImplementedException("serializer"); + } + return result; + } + + public static async Task SendVoidTaskAsync(global::System.Net.Http.HttpRequestMessage request, + global::System.Net.Http.HttpClient httpClient, + global::Refit.RefitSettings settings, + global::System.Threading.CancellationToken cancellationToken) + { + using var response = await httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false); + var exception = await settings.ExceptionFactory(response).ConfigureAwait(false); + if(exception != null) + throw exception; + } +} diff --git a/Refit/RequestBuilder.cs b/Refit/RequestBuilder.cs index aa06da38b..9c7eed9b8 100644 --- a/Refit/RequestBuilder.cs +++ b/Refit/RequestBuilder.cs @@ -19,6 +19,8 @@ public interface IRequestBuilder Type[]? parameterTypes = null, Type[]? genericArgumentTypes = null ); + + RefitSettings Settings { get; } } /// diff --git a/Refit/RequestBuilderImplementation.TaskToObservable.cs b/Refit/RequestBuilderImplementation.TaskToObservable.cs index cd0951588..a43b1e7e8 100644 --- a/Refit/RequestBuilderImplementation.TaskToObservable.cs +++ b/Refit/RequestBuilderImplementation.TaskToObservable.cs @@ -1,8 +1,8 @@ namespace Refit { - partial class RequestBuilderImplementation + public partial class RequestBuilderImplementation { - sealed class TaskToObservable : IObservable + public sealed class TaskToObservable : IObservable { readonly Func> taskFactory; diff --git a/Refit/RequestBuilderImplementation.cs b/Refit/RequestBuilderImplementation.cs index eec31a9f9..1b9dbda58 100644 --- a/Refit/RequestBuilderImplementation.cs +++ b/Refit/RequestBuilderImplementation.cs @@ -26,6 +26,8 @@ readonly ConcurrentDictionary< readonly RefitSettings settings; public Type TargetType { get; } + public RefitSettings Settings { get; } + public RequestBuilderImplementation( Type refitInterfaceType, RefitSettings? refitSettings = null @@ -34,6 +36,7 @@ public RequestBuilderImplementation( var targetInterfaceInheritedInterfaces = refitInterfaceType.GetInterfaces(); settings = refitSettings ?? new RefitSettings(); + this.Settings = settings; serializer = settings.ContentSerializer; interfaceGenericHttpMethods = new ConcurrentDictionary(); @@ -174,6 +177,7 @@ RestMethodInfoInternal CloseGenericMethodIfNeeded( Type[]? genericArgumentTypes = null ) { + throw new Exception($"Used fallback {BuildRestResultFuncForMethod} :("); if (!interfaceHttpMethods.ContainsKey(methodName)) { throw new ArgumentException( diff --git a/Refit/ValueStringBuilder.cs b/Refit/ValueStringBuilder.cs index bc180bb67..d410d18aa 100644 --- a/Refit/ValueStringBuilder.cs +++ b/Refit/ValueStringBuilder.cs @@ -5,8 +5,9 @@ namespace Refit; +// TODO: make internal and get generator to create this. // From https://github/dotnet/runtime/blob/main/src/libraries/Common/src/System/Text/ValueStringBuilder.cs -internal ref struct ValueStringBuilder +public ref struct ValueStringBuilder { private char[]? _arrayToReturnToPool; private Span _chars;