Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,10 @@ public virtual void Write(string value)

if (_useFastUtf8)
{
if (value.Length <= 127 / 3)
// If this is a non-derived BinaryWriter, then we can bypass the Write7BitEncodedInt call.
// But when this is a derived instance, call must not bypass it for compatibility reasons
// as it calls the virtual Write(int) overload.
if (GetType() == typeof(BinaryWriter) && value.Length <= 127 / 3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@EgorBo How is runtime optimizing the logic that uses GetType() == typeof(BinaryWriter) pattern as of today? Will this be always replaced with a constant for each derived type? If not, should we cache the result of the check? Or at least change the order (perform the cheap length check first)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a single pointer check (*object == TypeTableOf(BinaryWriter)) and equivalent to accessing a field. This pattern is used a lot currently.

{
// Max expansion: each char -> 3 bytes, so 127 bytes max of data, +1 for length prefix
Span<byte> buffer = stackalloc byte[128];
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Text;
Expand Down Expand Up @@ -188,26 +188,358 @@ protected virtual Stream CreateStream()

private void WriteTest<T>(T[] testElements, Action<BinaryWriter, T> write, Func<BinaryReader, T> read)
{
// Non-derived BinaryWriter/BinaryReader, UTF-8 encoding
using (Stream memStream = CreateStream())
using (BinaryWriter writer = new BinaryWriter(memStream))
using (BinaryReader reader = new BinaryReader(memStream))
using (var writer = new BinaryWriter(memStream))
using (var reader = new BinaryReader(memStream))
{
for (int i = 0; i < testElements.Length; i++)
{
write(writer, testElements[i]);
}
WriteTest(memStream, writer, reader, testElements, write, read);
}

// Derived BinaryWriter/BinaryReader, UTF-8 encoding
using (Stream memStream = CreateStream())
using (var writer = new TestWriter(memStream))
using (var reader = new TestReader(memStream))
{
WriteTest(memStream, writer, reader, testElements, write, read);
}

// Non-derived BinaryWriter/BinaryReader, UTF-16 encoding
using (Stream memStream = CreateStream())
using (var writer = new BinaryWriter(memStream, Encoding.Unicode))
using (var reader = new BinaryReader(memStream, Encoding.Unicode))
{
WriteTest(memStream, writer, reader, testElements, write, read);
}

// Derived BinaryWriter/BinaryReader, UTF-16 encoding
using (Stream memStream = CreateStream())
using (var writer = new TestWriter(memStream, Encoding.Unicode))
using (var reader = new TestReader(memStream, Encoding.Unicode))
{
WriteTest(memStream, writer, reader, testElements, write, read);
}
}

private void WriteTest<T>(Stream stream, BinaryWriter writer, BinaryReader reader, T[] testElements, Action<BinaryWriter, T> write, Func<BinaryReader, T> read)
{
for (int i = 0; i < testElements.Length; i++)
{
write(writer, testElements[i]);
}

writer.Flush();
stream.Position = 0;

for (int i = 0; i < testElements.Length; i++)
{
Assert.Equal(testElements[i], read(reader));
}

if (writer is TestWriter derivedWriter && reader is TestReader derivedReader)
{
// Checking if the internally tracked positions of a derived reader/writer are in sync (#107265)
Assert.Equal(derivedReader.Position, derivedWriter.Position);
}

// We've reached the end of the stream. Check for expected EndOfStreamException
Assert.Throws<EndOfStreamException>(() => read(reader));
}


private class TestWriter : BinaryWriter
{
private readonly Encoding _encoding;
public long Position { get; private set; }

public TestWriter(Stream stream, Encoding? encoding = null)
: base(stream, encoding ?? Encoding.UTF8)
{
_encoding = encoding ?? Encoding.UTF8;
}

public override void Write(bool value)
{
Advance(sizeof(byte));
base.Write(value);
}

public override void Write(byte value)
{
Advance(sizeof(byte));
base.Write(value);
}

public override void Write(byte[] buffer)
{
Advance(buffer.Length);
base.Write(buffer);
}

public override void Write(byte[] buffer, int index, int count)
{
Advance(count);
base.Write(buffer, index, count);
}

public override void Write(char ch)
{
Advance(_encoding.GetBytes([ch]).Length);
base.Write(ch);
}

public override void Write(char[] chars)
{
Advance(_encoding.GetBytes(chars).Length);
base.Write(chars);
}

public override void Write(char[] chars, int index, int count)
{
Advance(_encoding.GetBytes(chars, index, count).Length);
base.Write(chars, index, count);
}

public override void Write(decimal value)
{
Advance(sizeof(decimal));
base.Write(value);
}

public override void Write(double value)
{
Advance(sizeof(double));
base.Write(value);
}

public override void Write(float value)
{
Advance(sizeof(float));
base.Write(value);
}

public override void Write(int value)
{
Advance(sizeof(int));
base.Write(value);
}

public override void Write(long value)
{
Advance(sizeof(long));
base.Write(value);
}

public override void Write(sbyte value)
{
Advance(sizeof(sbyte));
base.Write(value);
}

public override void Write(short value)
{
Advance(sizeof(short));
base.Write(value);
}

public override void Write(string value)
{
Advance(_encoding.GetBytes(value).Length);
base.Write(value);
}

public override void Write(uint value)
{
Advance(sizeof(uint));
base.Write(value);
}

public override void Write(ulong value)
{
Advance(sizeof(ulong));
base.Write(value);
}

public override void Write(ushort value)
{
Advance(sizeof(ushort));
base.Write(value);
}

public override unsafe void Write(Half value)
{
Advance(sizeof(Half));
base.Write(value);
}

public override void Write(ReadOnlySpan<byte> buffer)
{
Advance(buffer.Length);
base.Write(buffer);
}

public override void Write(ReadOnlySpan<char> chars)
{
Advance(_encoding.GetBytes(chars.ToArray()).Length);
base.Write(chars);
}

private void Advance(int offset) => Position += offset;
}

private class TestReader : BinaryReader
{
private readonly Encoding _encoding;
public long Position { get; private set; }

public TestReader(Stream s, Encoding? encoding = null)
: base(s, encoding ?? Encoding.UTF8)
{
_encoding = encoding ?? Encoding.UTF8;
}

public override int Read()
{
var current = BaseStream.Position;
var result = base.Read();
Advance(BaseStream.Position - current);
return result;
}

writer.Flush();
memStream.Position = 0;
public override int Read(byte[] buffer, int index, int count)
{
var result = base.Read(buffer, index, count);
Advance(result);
return result;
}

for (int i = 0; i < testElements.Length; i++)
{
Assert.Equal(testElements[i], read(reader));
}
public override int Read(char[] buffer, int index, int count)
{
var result = base.Read(buffer, index, count);
Advance(_encoding.GetBytes(buffer, 0, result).Length);
return result;
}

// We've reached the end of the stream. Check for expected EndOfStreamException
Assert.Throws<EndOfStreamException>(() => read(reader));
public override bool ReadBoolean()
{
Advance(sizeof(bool));
return base.ReadBoolean();
}

public override byte ReadByte()
{
Advance(sizeof(byte));
return base.ReadByte();
}

public override byte[] ReadBytes(int count)
{
Advance(count);
return base.ReadBytes(count);
}

public override char ReadChar()
{
var result = base.ReadChar();
Advance(_encoding.GetBytes([result]).Length);
return result;
}

public override char[] ReadChars(int count)
{
var result = base.ReadChars(count);
Advance(_encoding.GetBytes(result).Length);
return result;
}

public override decimal ReadDecimal()
{
Advance(sizeof(decimal));
return base.ReadDecimal();
}

public override double ReadDouble()
{
Advance(sizeof(double));
return base.ReadDouble();
}

public override short ReadInt16()
{
Advance(sizeof(short));
return base.ReadInt16();
}

public override int ReadInt32()
{
Advance(sizeof(int));
return base.ReadInt32();
}

public override long ReadInt64()
{
Advance(sizeof(long));
return base.ReadInt64();
}

public override sbyte ReadSByte()
{
Advance(sizeof(sbyte));
return base.ReadSByte();
}

public override float ReadSingle()
{
Advance(sizeof(float));
return base.ReadSingle();
}

public override string ReadString()
{
var result = base.ReadString();
Advance(_encoding.GetBytes(result).Length);
return result;
}

public override ushort ReadUInt16()
{
Advance(sizeof(ushort));
return base.ReadUInt16();
}

public override uint ReadUInt32()
{
Advance(sizeof(uint));
return base.ReadUInt32();
}

public override ulong ReadUInt64()
{
Advance(sizeof(ulong));
return base.ReadUInt64();
}

public override unsafe Half ReadHalf()
{
Advance(sizeof(Half));
return base.ReadHalf();
}

public override int Read(Span<byte> buffer)
{
var result = base.Read(buffer);
Advance(result);
return result;
}

public override int Read(Span<char> buffer)
{
var result = base.Read(buffer);
Advance(_encoding.GetBytes(buffer[..result].ToArray()).Length);
return result;
}

private void Advance(long offset) => Position += offset;
}
}
}
Loading