From 36c50561a419c26f5be4b14622ac68b8a8e48c5a Mon Sep 17 00:00:00 2001 From: MihaZupan Date: Mon, 13 Jul 2020 03:07:32 +0200 Subject: [PATCH 1/5] Add HeaderEncodingSelector to MultipartContent --- .../System.Net.Http/ref/System.Net.Http.cs | 1 + .../src/System/Net/Http/MultipartContent.cs | 99 +++++++++----- .../FunctionalTests/MultipartContentTest.cs | 123 ++++++++++++++++++ 3 files changed, 194 insertions(+), 29 deletions(-) diff --git a/src/libraries/System.Net.Http/ref/System.Net.Http.cs b/src/libraries/System.Net.Http/ref/System.Net.Http.cs index 8f72cd9665e790..19041a5b78f59e 100644 --- a/src/libraries/System.Net.Http/ref/System.Net.Http.cs +++ b/src/libraries/System.Net.Http/ref/System.Net.Http.cs @@ -285,6 +285,7 @@ public partial class MultipartContent : System.Net.Http.HttpContent, System.Coll public MultipartContent() { } public MultipartContent(string subtype) { } public MultipartContent(string subtype, string boundary) { } + public System.Func? HeaderEncodingSelector { get { throw null; } set { } } public virtual void Add(System.Net.Http.HttpContent content) { } protected override System.IO.Stream CreateContentReadStream(System.Threading.CancellationToken cancellationToken) { throw null; } protected override System.Threading.Tasks.Task CreateContentReadStreamAsync() { throw null; } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs b/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs index 4f1a1dd23acf64..0767e39bc4980a 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs @@ -157,6 +157,8 @@ Collections.IEnumerator Collections.IEnumerable.GetEnumerator() #region Serialization + public Func? HeaderEncodingSelector { get; set; } + // for-each content // write "--" + boundary // for-each content header @@ -171,20 +173,19 @@ protected override void SerializeToStream(Stream stream, TransportContext? conte try { // Write start boundary. - EncodeStringToStream(stream, "--" + _boundary + CrLf); + WriteLatin1ToStream(stream, "--" + _boundary + CrLf); // Write each nested content. - var output = new StringBuilder(); for (int contentIndex = 0; contentIndex < _nestedContent.Count; contentIndex++) { // Write divider, headers, and content. HttpContent content = _nestedContent[contentIndex]; - EncodeStringToStream(stream, SerializeHeadersToString(output, contentIndex, content)); + SerializeHeadersToStream(stream, content, writeDivider: contentIndex != 0); content.CopyTo(stream, context, cancellationToken); } // Write footer boundary. - EncodeStringToStream(stream, CrLf + "--" + _boundary + "--" + CrLf); + WriteLatin1ToStream(stream, CrLf + "--" + _boundary + "--" + CrLf); } catch (Exception ex) { @@ -219,12 +220,17 @@ private protected async Task SerializeToStreamAsyncCore(Stream stream, Transport await EncodeStringToStreamAsync(stream, "--" + _boundary + CrLf, cancellationToken).ConfigureAwait(false); // Write each nested content. - var output = new StringBuilder(); + var output = new MemoryStream(); for (int contentIndex = 0; contentIndex < _nestedContent.Count; contentIndex++) { // Write divider, headers, and content. HttpContent content = _nestedContent[contentIndex]; - await EncodeStringToStreamAsync(stream, SerializeHeadersToString(output, contentIndex, content), cancellationToken).ConfigureAwait(false); + + output.SetLength(0); + SerializeHeadersToStream(output, content, writeDivider: contentIndex != 0); + output.Position = 0; + await output.CopyToAsync(stream, cancellationToken).ConfigureAwait(false); + await content.CopyToAsync(stream, context, cancellationToken).ConfigureAwait(false); } @@ -259,7 +265,6 @@ private async ValueTask CreateContentReadStreamAsyncCore(bool async, Can try { var streams = new Stream[2 + (_nestedContent.Count * 2)]; - var scratch = new StringBuilder(); int streamIndex = 0; // Start boundary. @@ -271,7 +276,7 @@ private async ValueTask CreateContentReadStreamAsyncCore(bool async, Can cancellationToken.ThrowIfCancellationRequested(); HttpContent nestedContent = _nestedContent[contentIndex]; - streams[streamIndex++] = EncodeStringToNewStream(SerializeHeadersToString(scratch, contentIndex, nestedContent)); + streams[streamIndex++] = EncodeHeadersToNewStream(nestedContent, writeDivider: contentIndex != 0); Stream readStream; if (async) @@ -312,43 +317,42 @@ private async ValueTask CreateContentReadStreamAsyncCore(bool async, Can } } - private string SerializeHeadersToString(StringBuilder scratch, int contentIndex, HttpContent content) + private void SerializeHeadersToStream(Stream stream, HttpContent content, bool writeDivider) { - scratch.Clear(); - // Add divider. - if (contentIndex != 0) // Write divider for all but the first content. + if (writeDivider) // Write divider for all but the first content. { - scratch.Append(CrLf + "--"); // const strings - scratch.Append(_boundary); - scratch.Append(CrLf); + WriteLatin1ToStream(stream, CrLf + "--"); // const strings + WriteLatin1ToStream(stream, _boundary); + WriteLatin1ToStream(stream, CrLf); } // Add headers. foreach (KeyValuePair> headerPair in content.Headers) { - scratch.Append(headerPair.Key); - scratch.Append(": "); + Encoding? headerValueEncoding = HeaderEncodingSelector?.Invoke(headerPair.Key, content); + + WriteLatin1ToStream(stream, headerPair.Key); + WriteLatin1ToStream(stream, ": "); string delim = string.Empty; foreach (string value in headerPair.Value) { - scratch.Append(delim); - scratch.Append(value); + WriteLatin1ToStream(stream, delim); + if (headerValueEncoding is null) + { + WriteLatin1ToStream(stream, value); + } + else + { + WriteToStream(stream, value, headerValueEncoding); + } delim = ", "; } - scratch.Append(CrLf); + WriteLatin1ToStream(stream, CrLf); } // Extra CRLF to end headers (even if there are no headers). - scratch.Append(CrLf); - - return scratch.ToString(); - } - - private static void EncodeStringToStream(Stream stream, string input) - { - byte[] buffer = HttpRuleParser.DefaultHttpEncoding.GetBytes(input); - stream.Write(buffer); + WriteLatin1ToStream(stream, CrLf); } private static ValueTask EncodeStringToStreamAsync(Stream stream, string input, CancellationToken cancellationToken) @@ -362,6 +366,14 @@ private static Stream EncodeStringToNewStream(string input) return new MemoryStream(HttpRuleParser.DefaultHttpEncoding.GetBytes(input), writable: false); } + private Stream EncodeHeadersToNewStream(HttpContent content, bool writeDivider) + { + var stream = new MemoryStream(); + SerializeHeadersToStream(stream, content, writeDivider); + stream.Position = 0; + return stream; + } + internal override bool AllowDuplex => false; protected internal override bool TryComputeLength(out long length) @@ -671,6 +683,35 @@ public override void Flush() { } public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { throw new NotSupportedException(); } public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) { throw new NotSupportedException(); } } + + + private static void WriteToStream(Stream stream, string content, Encoding encoding) + { + if (ReferenceEquals(encoding, Encoding.UTF8) && content.Length <= 128) + { + // Fast-path for the expected common case + Span buffer = stackalloc byte[512]; + int written = encoding.GetBytes(content, buffer); + stream.Write(buffer.Slice(0, written)); + } + else + { + stream.Write(encoding.GetBytes(content)); + } + } + + private static void WriteLatin1ToStream(Stream stream, string content) + { + Span buffer = content.Length <= 512 + ? stackalloc byte[512] + : new byte[content.Length]; + + int written = Encoding.Latin1.GetBytes(content, buffer); + Debug.Assert(written == content.Length); + + stream.Write(buffer.Slice(0, written)); + } + #endregion Serialization } } diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/MultipartContentTest.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/MultipartContentTest.cs index 7565e5b698314e..8bf622373ddc91 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/MultipartContentTest.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/MultipartContentTest.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.IO; +using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -387,6 +388,128 @@ public void ReadAsStream_CreateContentReadStreamThrows() Assert.Throws(() => mc.ReadAsStream()); } + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadAsStream_CustomEncodingSelector_SelectorIsCalledWithCustomState(bool async) + { + var mc = new MultipartContent(); + + var stringContent = new StringContent("foo"); + stringContent.Headers.Add("StringContent", "foo"); + mc.Add(stringContent); + + var byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("foo")); + byteArrayContent.Headers.Add("ByteArrayContent", "foo"); + mc.Add(byteArrayContent); + + bool seenStringContent = false, seenByteArrayContent = false; + + mc.HeaderEncodingSelector = (name, content) => + { + if (ReferenceEquals(content, stringContent) && name == "StringContent") + { + seenStringContent = true; + } + + if (ReferenceEquals(content, byteArrayContent) && name == "ByteArrayContent") + { + seenByteArrayContent = true; + } + + return null; + }; + + var dummy = new MemoryStream(); + if (async) + { + await (await mc.ReadAsStreamAsync()).CopyToAsync(dummy); + } + else + { + mc.ReadAsStream().CopyTo(dummy); + } + + Assert.True(seenStringContent); + Assert.True(seenByteArrayContent); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadAsStream_CustomEncodingSelector_CustomEncodingIsUsed(bool async) + { + var mc = new MultipartContent("subtype", "fooBoundary"); + + var stringContent = new StringContent("bar1"); + stringContent.Headers.Add("latin1", "\uD83D\uDE00"); + mc.Add(stringContent); + + var byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar2")); + byteArrayContent.Headers.Add("utf8", "\uD83D\uDE00"); + mc.Add(byteArrayContent); + + byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar3")); + byteArrayContent.Headers.Add("ascii", "\uD83D\uDE00"); + mc.Add(byteArrayContent); + + byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar4")); + byteArrayContent.Headers.Add("default", "\uD83D\uDE00"); + mc.Add(byteArrayContent); + + mc.HeaderEncodingSelector = (name, _) => name switch + { + "latin1" => Encoding.Latin1, + "utf8" => Encoding.UTF8, + "ascii" => Encoding.ASCII, + _ => null + }; + + var ms = new MemoryStream(); + if (async) + { + await (await mc.ReadAsStreamAsync()).CopyToAsync(ms); + } + else + { + mc.ReadAsStream().CopyTo(ms); + } + + byte[] expected = Concat( + Encoding.Latin1.GetBytes("--fooBoundary\r\n"), + Encoding.Latin1.GetBytes("Content-Type: text/plain; charset=utf-8\r\n"), + Encoding.Latin1.GetBytes("latin1: "), + Encoding.Latin1.GetBytes("\uD83D\uDE00"), + Encoding.Latin1.GetBytes("\r\n\r\n"), + Encoding.Latin1.GetBytes("bar1"), + Encoding.Latin1.GetBytes("\r\n--fooBoundary\r\n"), + Encoding.Latin1.GetBytes("utf8: "), + Encoding.UTF8.GetBytes("\uD83D\uDE00"), + Encoding.Latin1.GetBytes("\r\n\r\n"), + Encoding.Latin1.GetBytes("bar2"), + Encoding.Latin1.GetBytes("\r\n--fooBoundary\r\n"), + Encoding.Latin1.GetBytes("ascii: "), + Encoding.ASCII.GetBytes("\uD83D\uDE00"), + Encoding.Latin1.GetBytes("\r\n\r\n"), + Encoding.Latin1.GetBytes("bar3"), + Encoding.Latin1.GetBytes("\r\n--fooBoundary\r\n"), + Encoding.Latin1.GetBytes("default: "), + Encoding.Latin1.GetBytes("\uD83D\uDE00"), + Encoding.Latin1.GetBytes("\r\n\r\n"), + Encoding.Latin1.GetBytes("bar4"), + Encoding.Latin1.GetBytes("\r\n--fooBoundary--\r\n")); + + byte[] actual = ms.ToArray(); + + Assert.Equal(expected.Length, actual.Length); + for (int i = 0; i < expected.Length; i++) + { + Assert.Equal(expected[i], actual[i]); + } + + static byte[] Concat(params byte[][] arrays) => arrays.SelectMany(b => b).ToArray(); + } + #region Helpers private static async Task MultipartContentToStringAsync(MultipartContent content, MultipartContentToStringMode mode, bool async) From c98ba087b9af9b7797146e75a0853a661cc4f056 Mon Sep 17 00:00:00 2001 From: MihaZupan Date: Mon, 13 Jul 2020 15:50:18 +0200 Subject: [PATCH 2/5] Test cleanup --- .../tests/FunctionalTests/MultipartContentTest.cs | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/MultipartContentTest.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/MultipartContentTest.cs index 8bf622373ddc91..a945cc3db3ea97 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/MultipartContentTest.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/MultipartContentTest.cs @@ -391,7 +391,7 @@ public void ReadAsStream_CreateContentReadStreamThrows() [Theory] [InlineData(true)] [InlineData(false)] - public async Task ReadAsStream_CustomEncodingSelector_SelectorIsCalledWithCustomState(bool async) + public async Task ReadAsStreamAsync_CustomEncodingSelector_SelectorIsCalledWithCustomState(bool async) { var mc = new MultipartContent(); @@ -437,7 +437,7 @@ public async Task ReadAsStream_CustomEncodingSelector_SelectorIsCalledWithCustom [Theory] [InlineData(true)] [InlineData(false)] - public async Task ReadAsStream_CustomEncodingSelector_CustomEncodingIsUsed(bool async) + public async Task ReadAsStreamAsync_CustomEncodingSelector_CustomEncodingIsUsed(bool async) { var mc = new MultipartContent("subtype", "fooBoundary"); @@ -499,13 +499,7 @@ public async Task ReadAsStream_CustomEncodingSelector_CustomEncodingIsUsed(bool Encoding.Latin1.GetBytes("bar4"), Encoding.Latin1.GetBytes("\r\n--fooBoundary--\r\n")); - byte[] actual = ms.ToArray(); - - Assert.Equal(expected.Length, actual.Length); - for (int i = 0; i < expected.Length; i++) - { - Assert.Equal(expected[i], actual[i]); - } + Assert.Equal(expected, ms.ToArray()); static byte[] Concat(params byte[][] arrays) => arrays.SelectMany(b => b).ToArray(); } From 59f8ca08dac135dba877a1130554fecdce6a0142 Mon Sep 17 00:00:00 2001 From: MihaZupan Date: Mon, 13 Jul 2020 20:52:40 +0200 Subject: [PATCH 3/5] Avoid WriteLatin1 logic duplication --- .../src/System/Net/Http/MultipartContent.cs | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs b/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs index 0767e39bc4980a..114e250d3f6c30 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs @@ -687,9 +687,10 @@ public override void Flush() { } private static void WriteToStream(Stream stream, string content, Encoding encoding) { - if (ReferenceEquals(encoding, Encoding.UTF8) && content.Length <= 128) + if (ReferenceEquals(encoding, Encoding.Latin1) ? content.Length <= 512 + : (ReferenceEquals(encoding, Encoding.UTF8) && content.Length <= 128)) { - // Fast-path for the expected common case + // Fast-path for the expected common cases Span buffer = stackalloc byte[512]; int written = encoding.GetBytes(content, buffer); stream.Write(buffer.Slice(0, written)); @@ -702,14 +703,9 @@ private static void WriteToStream(Stream stream, string content, Encoding encodi private static void WriteLatin1ToStream(Stream stream, string content) { - Span buffer = content.Length <= 512 - ? stackalloc byte[512] - : new byte[content.Length]; + Debug.Assert(ReferenceEquals(Encoding.Latin1, HttpRuleParser.DefaultHttpEncoding)); - int written = Encoding.Latin1.GetBytes(content, buffer); - Debug.Assert(written == content.Length); - - stream.Write(buffer.Slice(0, written)); + WriteToStream(stream, content, HttpRuleParser.DefaultHttpEncoding); } #endregion Serialization From edc80ba6008626bbf0d25a8be78552b0294290e8 Mon Sep 17 00:00:00 2001 From: MihaZupan Date: Thu, 30 Jul 2020 23:08:16 +0200 Subject: [PATCH 4/5] Move to common HeaderEncodingSelector --- .../System.Net.Http/ref/System.Net.Http.cs | 2 +- .../src/System/Net/Http/MultipartContent.cs | 124 ++++++++---------- .../UnitTests/Headers/MultipartContentTest.cs | 86 ++++++++++++ .../System.Net.Http.Unit.Tests.csproj | 5 +- 4 files changed, 148 insertions(+), 69 deletions(-) create mode 100644 src/libraries/System.Net.Http/tests/UnitTests/Headers/MultipartContentTest.cs diff --git a/src/libraries/System.Net.Http/ref/System.Net.Http.cs b/src/libraries/System.Net.Http/ref/System.Net.Http.cs index 19041a5b78f59e..69d608956e6a5e 100644 --- a/src/libraries/System.Net.Http/ref/System.Net.Http.cs +++ b/src/libraries/System.Net.Http/ref/System.Net.Http.cs @@ -285,7 +285,7 @@ public partial class MultipartContent : System.Net.Http.HttpContent, System.Coll public MultipartContent() { } public MultipartContent(string subtype) { } public MultipartContent(string subtype, string boundary) { } - public System.Func? HeaderEncodingSelector { get { throw null; } set { } } + public System.Net.Http.HeaderEncodingSelector? HeaderEncodingSelector { get { throw null; } set { } } public virtual void Add(System.Net.Http.HttpContent content) { } protected override System.IO.Stream CreateContentReadStream(System.Threading.CancellationToken cancellationToken) { throw null; } protected override System.Threading.Tasks.Task CreateContentReadStreamAsync() { throw null; } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs b/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs index 114e250d3f6c30..43f24e41c92225 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; @@ -18,10 +19,10 @@ public class MultipartContent : HttpContent, IEnumerable private const string CrLf = "\r\n"; - private static readonly int s_crlfLength = GetEncodedLength(CrLf); - private static readonly int s_dashDashLength = GetEncodedLength("--"); - private static readonly int s_colonSpaceLength = GetEncodedLength(": "); - private static readonly int s_commaSpaceLength = GetEncodedLength(", "); + private const int CrlfLength = 2; + private const int DashDashLength = 2; + private const int ColonSpaceLength = 2; + private const int CommaSpaceLength = 2; private readonly List _nestedContent; private readonly string _boundary; @@ -157,7 +158,11 @@ Collections.IEnumerator Collections.IEnumerable.GetEnumerator() #region Serialization - public Func? HeaderEncodingSelector { get; set; } + /// + /// Gets or sets a callback that returns the to decode the value for the specified response header name, + /// or to use the default behavior. + /// + public HeaderEncodingSelector? HeaderEncodingSelector { get; set; } // for-each content // write "--" + boundary @@ -173,7 +178,7 @@ protected override void SerializeToStream(Stream stream, TransportContext? conte try { // Write start boundary. - WriteLatin1ToStream(stream, "--" + _boundary + CrLf); + WriteToStream(stream, "--" + _boundary + CrLf); // Write each nested content. for (int contentIndex = 0; contentIndex < _nestedContent.Count; contentIndex++) @@ -185,7 +190,7 @@ protected override void SerializeToStream(Stream stream, TransportContext? conte } // Write footer boundary. - WriteLatin1ToStream(stream, CrLf + "--" + _boundary + "--" + CrLf); + WriteToStream(stream, CrLf + "--" + _boundary + "--" + CrLf); } catch (Exception ex) { @@ -322,37 +327,30 @@ private void SerializeHeadersToStream(Stream stream, HttpContent content, bool w // Add divider. if (writeDivider) // Write divider for all but the first content. { - WriteLatin1ToStream(stream, CrLf + "--"); // const strings - WriteLatin1ToStream(stream, _boundary); - WriteLatin1ToStream(stream, CrLf); + WriteToStream(stream, CrLf + "--"); // const strings + WriteToStream(stream, _boundary); + WriteToStream(stream, CrLf); } // Add headers. foreach (KeyValuePair> headerPair in content.Headers) { - Encoding? headerValueEncoding = HeaderEncodingSelector?.Invoke(headerPair.Key, content); + Encoding headerValueEncoding = HeaderEncodingSelector?.Invoke(headerPair.Key, content) ?? HttpRuleParser.DefaultHttpEncoding; - WriteLatin1ToStream(stream, headerPair.Key); - WriteLatin1ToStream(stream, ": "); + WriteToStream(stream, headerPair.Key); + WriteToStream(stream, ": "); string delim = string.Empty; foreach (string value in headerPair.Value) { - WriteLatin1ToStream(stream, delim); - if (headerValueEncoding is null) - { - WriteLatin1ToStream(stream, value); - } - else - { - WriteToStream(stream, value, headerValueEncoding); - } + WriteToStream(stream, delim); + WriteToStream(stream, value, headerValueEncoding); delim = ", "; } - WriteLatin1ToStream(stream, CrLf); + WriteToStream(stream, CrLf); } // Extra CRLF to end headers (even if there are no headers). - WriteLatin1ToStream(stream, CrLf); + WriteToStream(stream, CrLf); } private static ValueTask EncodeStringToStreamAsync(Stream stream, string input, CancellationToken cancellationToken) @@ -378,51 +376,43 @@ private Stream EncodeHeadersToNewStream(HttpContent content, bool writeDivider) protected internal override bool TryComputeLength(out long length) { - int boundaryLength = GetEncodedLength(_boundary); - - long currentLength = 0; - long internalBoundaryLength = s_crlfLength + s_dashDashLength + boundaryLength + s_crlfLength; - // Start Boundary. - currentLength += s_dashDashLength + boundaryLength + s_crlfLength; + long currentLength = DashDashLength + _boundary.Length + CrlfLength; - bool first = true; - foreach (HttpContent content in _nestedContent) + if (_nestedContent.Count > 1) { - if (first) - { - first = false; // First boundary already written. - } - else - { - // Internal Boundary. - currentLength += internalBoundaryLength; - } + // Internal boundaries + currentLength += (_nestedContent.Count - 1) * (CrlfLength + DashDashLength + _boundary.Length + CrlfLength); + } + foreach (HttpContent content in _nestedContent) + { // Headers. foreach (KeyValuePair> headerPair in content.Headers) { - currentLength += GetEncodedLength(headerPair.Key) + s_colonSpaceLength; + currentLength += headerPair.Key.Length + ColonSpaceLength; + + Encoding headerValueEncoding = HeaderEncodingSelector?.Invoke(headerPair.Key, content) ?? HttpRuleParser.DefaultHttpEncoding; int valueCount = 0; foreach (string value in headerPair.Value) { - currentLength += GetEncodedLength(value); + currentLength += headerValueEncoding.GetByteCount(value); valueCount++; } + if (valueCount > 1) { - currentLength += (valueCount - 1) * s_commaSpaceLength; + currentLength += (valueCount - 1) * CommaSpaceLength; } - currentLength += s_crlfLength; + currentLength += CrlfLength; } - currentLength += s_crlfLength; + currentLength += CrlfLength; // Content. - long tempContentLength = 0; - if (!content.TryComputeLength(out tempContentLength)) + if (!content.TryComputeLength(out long tempContentLength)) { length = 0; return false; @@ -431,17 +421,12 @@ protected internal override bool TryComputeLength(out long length) } // Terminating boundary. - currentLength += s_crlfLength + s_dashDashLength + boundaryLength + s_dashDashLength + s_crlfLength; + currentLength += CrlfLength + DashDashLength + _boundary.Length + DashDashLength + CrlfLength; length = currentLength; return true; } - private static int GetEncodedLength(string input) - { - return HttpRuleParser.DefaultHttpEncoding.GetByteCount(input); - } - private sealed class ContentReadStream : Stream { private readonly Stream[] _streams; @@ -685,29 +670,34 @@ public override void Flush() { } } + private static void WriteToStream(Stream stream, string content) => + WriteToStream(stream, content, HttpRuleParser.DefaultHttpEncoding); + private static void WriteToStream(Stream stream, string content, Encoding encoding) { - if (ReferenceEquals(encoding, Encoding.Latin1) ? content.Length <= 512 - : (ReferenceEquals(encoding, Encoding.UTF8) && content.Length <= 128)) + const int StackallocThreshold = 1024; + + int maxLength = encoding.GetMaxByteCount(content.Length); + + byte[]? rentedBuffer = null; + Span buffer = maxLength <= StackallocThreshold + ? stackalloc byte[StackallocThreshold] + : (rentedBuffer = ArrayPool.Shared.Rent(maxLength)); + + try { - // Fast-path for the expected common cases - Span buffer = stackalloc byte[512]; int written = encoding.GetBytes(content, buffer); stream.Write(buffer.Slice(0, written)); } - else + finally { - stream.Write(encoding.GetBytes(content)); + if (rentedBuffer != null) + { + ArrayPool.Shared.Return(rentedBuffer); + } } } - private static void WriteLatin1ToStream(Stream stream, string content) - { - Debug.Assert(ReferenceEquals(Encoding.Latin1, HttpRuleParser.DefaultHttpEncoding)); - - WriteToStream(stream, content, HttpRuleParser.DefaultHttpEncoding); - } - #endregion Serialization } } diff --git a/src/libraries/System.Net.Http/tests/UnitTests/Headers/MultipartContentTest.cs b/src/libraries/System.Net.Http/tests/UnitTests/Headers/MultipartContentTest.cs new file mode 100644 index 00000000000000..f77f2b5cd9d2e8 --- /dev/null +++ b/src/libraries/System.Net.Http/tests/UnitTests/Headers/MultipartContentTest.cs @@ -0,0 +1,86 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.IO; +using System.Text; +using System.Threading.Tasks; +using Xunit; + +namespace System.Net.Http.Tests +{ + public class MultipartContentTest + { + public static IEnumerable MultipartContent_TestData() + { + var multipartContents = new List(); + + var complexContent = new MultipartContent(); + + var stringContent = new StringContent("bar1"); + stringContent.Headers.Add("latin1", "\uD83D\uDE00"); + complexContent.Add(stringContent); + + var byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar2")); + byteArrayContent.Headers.Add("utf8", "\uD83D\uDE00"); + complexContent.Add(byteArrayContent); + + byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar3")); + byteArrayContent.Headers.Add("ascii", "\uD83D\uDE00"); + complexContent.Add(byteArrayContent); + + byteArrayContent = new ByteArrayContent(Encoding.ASCII.GetBytes("bar4")); + byteArrayContent.Headers.Add("default", "\uD83D\uDE00"); + complexContent.Add(byteArrayContent); + + stringContent = new StringContent("bar5"); + stringContent.Headers.Add("foo", "bar"); + complexContent.Add(stringContent); + + multipartContents.Add(complexContent); + multipartContents.Add(new MultipartContent()); + multipartContents.Add(new MultipartFormDataContent()); + + var encodingSelectors = new HeaderEncodingSelector[] + { + (_, _) => null, + (_, _) => Encoding.ASCII, + (_, _) => Encoding.Latin1, + (_, _) => Encoding.UTF8, + (name, _) => name switch + { + "latin1" => Encoding.Latin1, + "utf8" => Encoding.UTF8, + "ascii" => Encoding.ASCII, + _ => null + } + }; + + foreach (MultipartContent multipartContent in multipartContents) + { + foreach (HeaderEncodingSelector encodingSelector in encodingSelectors) + { + multipartContent.HeaderEncodingSelector = encodingSelector; + yield return new object[] { multipartContent }; + } + } + } + + [Theory] + [MemberData(nameof(MultipartContent_TestData))] + public async Task MultipartContent_TryComputeLength_ReturnsSameLengthAsCopyToAsync(MultipartContent multipartContent) + { + Assert.True(multipartContent.TryComputeLength(out long length)); + + var copyToStream = new MemoryStream(); + multipartContent.CopyTo(copyToStream, context: null, cancellationToken: default); + Assert.Equal(length, copyToStream.Length); + + var copyToAsyncStream = new MemoryStream(); + await multipartContent.CopyToAsync(copyToAsyncStream, context: null, cancellationToken: default); + Assert.Equal(length, copyToAsyncStream.Length); + + Assert.Equal(copyToStream.ToArray(), copyToAsyncStream.ToArray()); + } + } +} diff --git a/src/libraries/System.Net.Http/tests/UnitTests/System.Net.Http.Unit.Tests.csproj b/src/libraries/System.Net.Http/tests/UnitTests/System.Net.Http.Unit.Tests.csproj index f5419d0faf487f..19491e8484af9f 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/System.Net.Http.Unit.Tests.csproj +++ b/src/libraries/System.Net.Http/tests/UnitTests/System.Net.Http.Unit.Tests.csproj @@ -1,4 +1,4 @@ - + ../../src/Resources/Strings.resx true @@ -90,6 +90,8 @@ Link="ProductionCode\System\Net\Http\EmptyReadStream.cs" /> + + From ac6dda63ff9187249a3ab60c23cbc9686f214f13 Mon Sep 17 00:00:00 2001 From: MihaZupan Date: Wed, 5 Aug 2020 22:07:38 +0200 Subject: [PATCH 5/5] Fix indentation --- .../src/System/Net/Http/MultipartContent.cs | 12 +++++----- .../UnitTests/Headers/MultipartContentTest.cs | 22 +++++++++---------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs b/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs index 43f24e41c92225..7ea8eb232f0397 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs @@ -19,7 +19,7 @@ public class MultipartContent : HttpContent, IEnumerable private const string CrLf = "\r\n"; - private const int CrlfLength = 2; + private const int CrLfLength = 2; private const int DashDashLength = 2; private const int ColonSpaceLength = 2; private const int CommaSpaceLength = 2; @@ -377,12 +377,12 @@ private Stream EncodeHeadersToNewStream(HttpContent content, bool writeDivider) protected internal override bool TryComputeLength(out long length) { // Start Boundary. - long currentLength = DashDashLength + _boundary.Length + CrlfLength; + long currentLength = DashDashLength + _boundary.Length + CrLfLength; if (_nestedContent.Count > 1) { // Internal boundaries - currentLength += (_nestedContent.Count - 1) * (CrlfLength + DashDashLength + _boundary.Length + CrlfLength); + currentLength += (_nestedContent.Count - 1) * (CrLfLength + DashDashLength + _boundary.Length + CrLfLength); } foreach (HttpContent content in _nestedContent) @@ -406,10 +406,10 @@ protected internal override bool TryComputeLength(out long length) currentLength += (valueCount - 1) * CommaSpaceLength; } - currentLength += CrlfLength; + currentLength += CrLfLength; } - currentLength += CrlfLength; + currentLength += CrLfLength; // Content. if (!content.TryComputeLength(out long tempContentLength)) @@ -421,7 +421,7 @@ protected internal override bool TryComputeLength(out long length) } // Terminating boundary. - currentLength += CrlfLength + DashDashLength + _boundary.Length + DashDashLength + CrlfLength; + currentLength += CrLfLength + DashDashLength + _boundary.Length + DashDashLength + CrLfLength; length = currentLength; return true; diff --git a/src/libraries/System.Net.Http/tests/UnitTests/Headers/MultipartContentTest.cs b/src/libraries/System.Net.Http/tests/UnitTests/Headers/MultipartContentTest.cs index f77f2b5cd9d2e8..e89d38c204bfdb 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/Headers/MultipartContentTest.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/Headers/MultipartContentTest.cs @@ -43,17 +43,17 @@ public static IEnumerable MultipartContent_TestData() var encodingSelectors = new HeaderEncodingSelector[] { - (_, _) => null, - (_, _) => Encoding.ASCII, - (_, _) => Encoding.Latin1, - (_, _) => Encoding.UTF8, - (name, _) => name switch - { - "latin1" => Encoding.Latin1, - "utf8" => Encoding.UTF8, - "ascii" => Encoding.ASCII, - _ => null - } + (_, _) => null, + (_, _) => Encoding.ASCII, + (_, _) => Encoding.Latin1, + (_, _) => Encoding.UTF8, + (name, _) => name switch + { + "latin1" => Encoding.Latin1, + "utf8" => Encoding.UTF8, + "ascii" => Encoding.ASCII, + _ => null + } }; foreach (MultipartContent multipartContent in multipartContents)