Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions src/AutoRest.CSharp/Common/Generation/Types/TypeFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public TypeFactory(OutputLibrary library)
ConstantSchema constantSchema => constantSchema.ValueType is not ChoiceSchema && ToXMsFormatType(format) is Type type ? new CSharpType(type, isNullable) : CreateType(constantSchema.ValueType, isNullable),
BinarySchema _ => new CSharpType(typeof(Stream), isNullable),
ByteArraySchema _ => new CSharpType(typeof(byte[]), isNullable),
ArraySchema array => new CSharpType(typeof(IList<>), isNullable, CreateType(array.ElementType, array.NullableItems ?? false)),
ArraySchema array => new CSharpType(GetListType(schema), isNullable, CreateType(array.ElementType, array.NullableItems ?? false)),
DictionarySchema dictionary => new CSharpType(typeof(IDictionary<,>), isNullable, new CSharpType(typeof(string)), CreateType(dictionary.ElementType, dictionary.NullableItems ?? false)),
CredentialSchema credentialSchema => new CSharpType(typeof(string), isNullable),
NumberSchema number => new CSharpType(ToFrameworkNumericType(number), isNullable),
Expand All @@ -103,10 +103,20 @@ public TypeFactory(OutputLibrary library)
_ => _library.FindTypeForSchema(schema).WithNullable(isNullable)
};

private Type GetListType(Schema schema)
{
return schema.Extensions is not null && schema.Extensions.IsEmbeddingsVector ? typeof(ReadOnlyMemory<>) : typeof(IList<>);
}

public static CSharpType GetImplementationType(CSharpType type)
{
if (type.IsFrameworkType)
{
if (IsReadOnlyMemory(type))
{
return new CSharpType(type.Arguments[0].FrameworkType.MakeArrayType());
}

if (IsList(type))
{
return new CSharpType(typeof(List<>), type.Arguments);
Expand All @@ -125,6 +135,11 @@ public static CSharpType GetPropertyImplementationType(CSharpType type)
{
if (type.IsFrameworkType)
{
if (IsReadOnlyMemory(type))
{
return new CSharpType(typeof(ReadOnlyMemory<>), type.Arguments);
}

if (IsList(type))
{
return new CSharpType(Configuration.ApiTypes.ChangeTrackingListType, type.Arguments);
Expand Down Expand Up @@ -172,6 +187,16 @@ public static CSharpType GetElementType(CSharpType type)
{
if (type.IsFrameworkType)
{
if (type.FrameworkType.IsArray)
{
return new CSharpType(type.FrameworkType.GetElementType()!);
}

if (IsReadOnlyMemory(type))
{
return type.Arguments[0];
}

if (IsList(type))
{
return type.Arguments[0];
Expand Down Expand Up @@ -207,7 +232,10 @@ internal static bool IsReadWriteDictionary(CSharpType type)
=> type.IsFrameworkType && (type.FrameworkType == typeof(IDictionary<,>) || type.FrameworkType == typeof(Dictionary<,>));

internal static bool IsList(CSharpType type)
=> IsReadOnlyList(type) || IsReadWriteList(type);
=> IsReadOnlyList(type) || IsReadWriteList(type) || IsReadOnlyMemory(type);

internal static bool IsReadOnlyMemory(CSharpType type)
=> type.IsFrameworkType && type.FrameworkType == typeof(ReadOnlyMemory<>);

internal static bool IsReadOnlyList(CSharpType type)
=> type.IsFrameworkType &&
Expand Down Expand Up @@ -313,6 +341,11 @@ public static CSharpType GetInputType(CSharpType type)
{
if (type.IsFrameworkType)
{
if (IsReadOnlyMemory(type))
{
return new CSharpType(typeof(ReadOnlyMemory<>), isNullable: type.IsNullable, type.Arguments);
}

if (IsList(type))
{
return new CSharpType(
Expand All @@ -329,6 +362,11 @@ public static CSharpType GetOutputType(CSharpType type)
{
if (type.IsFrameworkType)
{
if (IsReadOnlyMemory(type))
{
return new CSharpType(typeof(ReadOnlyMemory<>), isNullable: type.IsNullable, type.Arguments);
}

if (IsList(type))
{
return new CSharpType(
Expand Down Expand Up @@ -448,5 +486,10 @@ public static bool RequiresToList(CSharpType from, CSharpType to)

return to.FrameworkType == typeof(IReadOnlyList<>) || to.FrameworkType == typeof(IList<>);
}

internal static bool IsArray(CSharpType type)
{
return type.IsFrameworkType && type.FrameworkType.IsArray;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,9 @@ private static void WriteDeclaration(this CodeWriter writer, DeclarationStatemen
writer.AppendRaw(") => ");
writer.WriteValueExpression(localFunction.Body);
break;
case UnaryOperatorStatement unaryOperatorStatement:
writer.WriteValueExpression(unaryOperatorStatement.Expression);
break;
}

writer.LineRaw(";");
Expand Down Expand Up @@ -509,8 +512,16 @@ public static void WriteValueExpression(this CodeWriter writer, ValueExpression
}
break;

case NewArrayExpression(var type, var items):
if (items is { Elements.Count: > 0 })
case NewArrayExpression(var type, var items, var size):
if (size is not null)
{
writer.Append($"new {type?.FrameworkType.GetElementType()}");
writer.AppendRaw("[");
writer.WriteValueExpression(size);
writer.AppendRaw("]");
break;
}
else if (items is { Elements.Count: > 0 })
{
if (type is null)
{
Expand Down Expand Up @@ -591,6 +602,12 @@ public static void WriteValueExpression(this CodeWriter writer, ValueExpression
case StringLiteralExpression(var literal, false):
writer.Literal(literal);
break;
case ArrayElementExpression(var array, var index):
writer.WriteValueExpression(array);
writer.AppendRaw("[");
writer.WriteValueExpression(index);
writer.AppendRaw("]");
break;
}

static void WriteArguments(CodeWriter writer, IEnumerable<ValueExpression> arguments)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ public static FormattableString GetIdentifiersFormattable(this IEnumerable<strin

public static FormattableString? GetParameterInitializer(this CSharpType parameterType, Constant? defaultValue)
{
if (parameterType.IsValueType)
{
return null;
}

if (TypeFactory.IsCollectionType(parameterType) && (defaultValue == null || TypeFactory.IsCollectionType(defaultValue.Value.Type)))
{
defaultValue = Constant.NewInstanceOf(TypeFactory.GetImplementationType(parameterType).WithNullable(false));
Expand Down
2 changes: 2 additions & 0 deletions src/AutoRest.CSharp/Common/Input/CodeModelPartials.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ internal partial class RecordOfStringAndAny
public string? Namespace => TryGetValue("x-namespace", out object? value) ? value?.ToString() : null;
public string? Usage => TryGetValue("x-csharp-usage", out object? value) ? value?.ToString() : null;

public bool IsEmbeddingsVector => TryGetValue("x-ms-embedding-vector", out var value) && Convert.ToBoolean(value);

public string[] Formats
{
get
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public static IEnumerable<MethodBodyStatement> WriteProperties(Utf8JsonWriterExp
}
else if (property.SerializedType is { IsNullable: true })
{
var checkPropertyIsInitialized = TypeFactory.IsCollectionType(property.SerializedType) && property.IsRequired
var checkPropertyIsInitialized = TypeFactory.IsCollectionType(property.SerializedType) && !TypeFactory.IsReadOnlyMemory(property.SerializedType) && property.IsRequired
? And(NotEqual(property.Value, Null), InvokeOptional.IsCollectionDefined(property.Value))
: NotEqual(property.Value, Null);

Expand All @@ -113,7 +113,7 @@ private static MethodBodyStatement WriteProperty(Utf8JsonWriterExpression utf8Js
utf8JsonWriter.WritePropertyName(serialization.SerializedName),
serialization.CustomSerializationMethodName is {} serializationMethodName
? InvokeCustomSerializationMethod(serializationMethodName, utf8JsonWriter)
: SerializeExpression(utf8JsonWriter, serialization.ValueSerialization, serialization.Value)
: SerializeExpression(utf8JsonWriter, serialization.ValueSerialization, serialization.EnumerableValue ?? serialization.Value)
};
}

Expand Down Expand Up @@ -480,13 +480,18 @@ private static MethodBodyStatement DeserializeIntoObjectProperty(JsonPropertySer
// Reading a property value
if (jsonPropertySerialization.ValueSerialization is not null)
{
return new[]
List<MethodBodyStatement> statements = new List<MethodBodyStatement>
{
CreatePropertyNullCheckStatement(jsonPropertySerialization, jsonProperty, propertyVariables, shouldTreatEmptyStringAsNull),
DeserializeValue(jsonPropertySerialization.ValueSerialization, jsonProperty.Value, out var value),
Assign(propertyVariables[jsonPropertySerialization], value),
Continue
DeserializeValue(jsonPropertySerialization.ValueSerialization, jsonProperty.Value, out var value)
};

AssignValueStatement assignStatement = TypeFactory.IsReadOnlyMemory(jsonPropertySerialization.SerializedType!)
? Assign(propertyVariables[jsonPropertySerialization], New.Instance(jsonPropertySerialization.SerializedType!, value))
: Assign(propertyVariables[jsonPropertySerialization], value);
statements.Add(assignStatement);
statements.Add(Continue);
return statements;
}

// Reading a nested object
Expand Down Expand Up @@ -529,7 +534,7 @@ private static MethodBodyStatement CreatePropertyNullCheckStatement(JsonProperty
};
}

if (jsonPropertySerialization.IsRequired)
if (jsonPropertySerialization.IsRequired && !TypeFactory.IsReadOnlyMemory(serializedType))
{
return new IfStatement(checkEmptyProperty)
{
Expand All @@ -544,7 +549,8 @@ private static MethodBodyStatement CreatePropertyNullCheckStatement(JsonProperty
};
}

if (!jsonPropertySerialization.IsRequired &&
// even if ReadOnlyMemory is required we leave the list empty if the payload doesn't have it
if ((!jsonPropertySerialization.IsRequired || (serializedType is not null && TypeFactory.IsReadOnlyMemory(serializedType))) &&
serializedType?.Equals(typeof(JsonElement)) != true && // JsonElement handles nulls internally
serializedType?.Equals(typeof(string)) != true) //https://github.com/Azure/autorest.csharp/issues/922
{
Expand Down Expand Up @@ -647,15 +653,32 @@ public static MethodBodyStatement DeserializeValue(JsonSerialization serializati
{
switch (serialization)
{
case JsonArraySerialization jsonReadOnlyMemory when TypeFactory.IsArray(jsonReadOnlyMemory.ImplementationType):
var readOnlyMemory = new VariableReference(jsonReadOnlyMemory.ImplementationType, "array");
value = readOnlyMemory;
VariableReference index = new VariableReference(typeof(int), "index");

return new MethodBodyStatement[]
{
Declare(index, Int(0)),
Declare(readOnlyMemory, New.Array(jsonReadOnlyMemory.ImplementationType, element.GetArrayLength())),
new ForeachStatement("item", element.EnumerateArray(), out ValueExpression readOnlyMemoryItem)
{
DeserializeArrayItem(jsonReadOnlyMemory, value, new JsonElementExpression(readOnlyMemoryItem), index),
Increment(index)
}
};

case JsonArraySerialization jsonArray:
var array = new VariableReference(jsonArray.ImplementationType, "array");
value = array;

return new MethodBodyStatement[]
{
Declare(array, New.Instance(jsonArray.ImplementationType)),
new ForeachStatement("item", element.EnumerateArray(), out ValueExpression item)
new ForeachStatement("item", element.EnumerateArray(), out ValueExpression arrayItem)
{
DeserializeArrayItem(jsonArray.ValueSerialization, value, new JsonElementExpression(item))
DeserializeArrayItem(jsonArray, value, new JsonElementExpression(arrayItem)),
}
};

Expand Down Expand Up @@ -685,25 +708,31 @@ public static MethodBodyStatement DeserializeValue(JsonSerialization serializati
}
}

private static MethodBodyStatement DeserializeArrayItem(JsonSerialization serialization, ValueExpression arrayVariable, JsonElementExpression arrayItemVariable)
private static MethodBodyStatement DeserializeArrayItem(JsonArraySerialization serialization, ValueExpression arrayVariable, JsonElementExpression arrayItemVariable, ValueExpression? index = null)
{
if (!CollectionItemRequiresNullCheckInSerialization(serialization))
bool isArray = index is not null;

List<MethodBodyStatement> statements = new List<MethodBodyStatement>();

MethodBodyStatement deserializeAndAssign = new[]
{
return Deserialize(serialization, arrayVariable, arrayItemVariable);
}
DeserializeValue(serialization.ValueSerialization, arrayItemVariable, out var value),
isArray ? InvokeArrayElementAssignment(arrayVariable, index!, value) : InvokeListAdd(arrayVariable, value)
};

return new IfElseStatement(
arrayItemVariable.ValueKindEqualsNull(),
InvokeListAdd(arrayVariable, Null),
Deserialize(serialization, arrayVariable, arrayItemVariable)
);
if (CollectionItemRequiresNullCheckInSerialization(serialization.ValueSerialization))
{
statements.Add(new IfElseStatement(
arrayItemVariable.ValueKindEqualsNull(),
isArray ? InvokeArrayElementAssignment(arrayVariable, index!, Null) : InvokeListAdd(arrayVariable, Null),
deserializeAndAssign));
}
else
{
statements.Add(deserializeAndAssign);
}

static MethodBodyStatement Deserialize(JsonSerialization jsonSerialization, ValueExpression arrayVariable, JsonElementExpression jsonElementExpression)
=> new[]
{
DeserializeValue(jsonSerialization, jsonElementExpression, out var value),
InvokeListAdd(arrayVariable, value)
};
return statements;
}

private static MethodBodyStatement DeserializeDictionaryValue(JsonSerialization serialization, DictionaryExpression dictionary, JsonPropertyExpression property)
Expand Down Expand Up @@ -782,7 +811,7 @@ private static ValueExpression GetOptional(PropertySerialization jsonPropertySer
}

var targetType = jsonPropertySerialization.Value.Type;
if (TypeFactory.IsList(targetType))
if (TypeFactory.IsList(targetType) && !TypeFactory.IsReadOnlyMemory(targetType))
{
return InvokeOptional.ToList(variable);
}
Expand Down Expand Up @@ -893,6 +922,9 @@ public static ValueExpression GetFrameworkTypeValueExpression(Type frameworkType
private static MethodBodyStatement InvokeListAdd(ValueExpression list, ValueExpression value)
=> new InvokeInstanceMethodStatement(list, nameof(List<object>.Add), value);

private static MethodBodyStatement InvokeArrayElementAssignment(ValueExpression array, ValueExpression index, ValueExpression value)
=> new AssignValueStatement(new ArrayElementExpression(array, index), value);

private static ValueExpression InvokeJsonSerializerDeserializeMethod(JsonElementExpression element, CSharpType serializationType, ValueExpression? options = null)
{
var arguments = options is null
Expand Down
13 changes: 11 additions & 2 deletions src/AutoRest.CSharp/Common/Output/Builders/SerializationBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -273,17 +273,26 @@ private IEnumerable<JsonPropertySerialization> GetPropertySerializationsFromBag(
var isReadOnly = schemaProperty.IsReadOnly;
var serialization = BuildSerialization(schemaProperty.Schema, property.Declaration.Type, false);

var memberValueExpression = new TypedMemberExpression(null, property.Declaration.Name, property.Declaration.Type);
TypedMemberExpression? enumerableExpression = null;
if (property.SchemaProperty is not null && property.SchemaProperty.Extensions is not null && property.SchemaProperty.Extensions.IsEmbeddingsVector)
{
enumerableExpression = property.Declaration.Type.IsNullable
? new TypedMemberExpression(null, $"{property.Declaration.Name}.{nameof(Nullable<ReadOnlyMemory<object>>.Value)}.{nameof(ReadOnlyMemory<object>.Span)}", typeof(ReadOnlySpan<>).MakeGenericType(property.Declaration.Type.Arguments[0].FrameworkType))
: new TypedMemberExpression(null, $"{property.Declaration.Name}.{nameof(ReadOnlyMemory<object>.Span)}", typeof(ReadOnlySpan<>).MakeGenericType(property.Declaration.Type.Arguments[0].FrameworkType));
}
yield return new JsonPropertySerialization(
parameter.Name,
new TypedMemberExpression(null, property.Declaration.Name, property.Declaration.Type),
memberValueExpression,
serializedName,
property.ValueType,
serialization,
isRequired,
isReadOnly,
false,
customSerializationMethodName: property.SerializationMapping?.SerializationValueHook,
customDeserializationMethodName: property.SerializationMapping?.DeserializationValueHook);
customDeserializationMethodName: property.SerializationMapping?.DeserializationValueHook,
enumerableExpression: enumerableExpression);
}

foreach ((string name, SerializationPropertyBag innerBag) in propertyBag.Bag)
Expand Down
Loading