Skip to content

Commit

Permalink
Add DoH support, prevent private address leaking
Browse files Browse the repository at this point in the history
  • Loading branch information
jdomnitz committed Jul 18, 2024
1 parent e4f3a4a commit 3b7901d
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 24 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Tiny DNS

[![Build](https://github.com/SmartHomeOS/TinyDNS/actions/workflows/dotnet.yml/badge.svg)](https://github.com/SmartHomeOS/TinyDNS/actions/workflows/dotnet.yml)
[![Version](https://img.shields.io/nuget/v/TinyDNS.svg)](https://www.nuget.org/packages/TinyDNS)

A small, fast, modern DNS / MDNS client

### Features:
* Recursive resolution from root hints with no DNS servers configured
* Resolution from OS or DHCP configured DNS servers
* Resolution using common public recursive resolvers (Google, CloudFlare, etc.)
* Support for DoH (DNS over HTTPS) with options for secure or insecure lookup
* Leak protection to ensure sensitive queries are not shared with public DNS servers
* Support for async, zerocopy, spans and all the modern .Net performance features
108 changes: 96 additions & 12 deletions TinyDNS/DNSResolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
// along with this program. If not, see <http://www.gnu.org/licenses/>.

using System.Net;
using System.Net.Http.Headers;
using System.Net.NetworkInformation;
using System.Net.Sockets;
using TinyDNS.Cache;
Expand All @@ -24,14 +25,17 @@ public sealed class DNSResolver
public const int PORT = 53;
private readonly HashSet<IPAddress> globalNameservers = [];
private ResolverCache cache = new ResolverCache();
public DNSResolver()
private ResolutionMode resolutionMode;
public DNSResolver(ResolutionMode mode = ResolutionMode.InsecureOnly)
{
this.resolutionMode = mode;
ReloadNameservers();
NetworkChange.NetworkAddressChanged += (s, e) => ReloadNameservers();
}

public DNSResolver(List<IPAddress> nameservers)
public DNSResolver(List<IPAddress> nameservers, ResolutionMode mode = ResolutionMode.InsecureOnly)
{
this.resolutionMode = mode;
foreach (IPAddress nameserver in nameservers)
this.globalNameservers.Add(nameserver);
}
Expand Down Expand Up @@ -106,6 +110,7 @@ public async Task<List<IPAddress>> ResolveHostV6(string hostname)
{
byte[] addressBytes = address.GetAddressBytes();
List<string> host;
bool privateQuery = IsPrivate(address, addressBytes);
if (address.AddressFamily == AddressFamily.InterNetwork)
{
host = new List<string>(6);
Expand All @@ -121,13 +126,13 @@ public async Task<List<IPAddress>> ResolveHostV6(string hostname)
for (int i = addressBytes.Length - 1; i >= 0; i--)
{
string hex = addressBytes[i].ToString("x2");
host.Add(hex.Substring(1,1));
host.Add(hex.Substring(1, 1));
host.Add(hex.Substring(0, 1));
}
host.Add("IP6");
host.Add("ARPA");
}
Message? response = await ResolveQuery(new QuestionRecord(host, DNSRecordType.PTR, false));
Message? response = await ResolveQuery(new QuestionRecord(host, DNSRecordType.PTR, false), privateQuery);
if (response == null || response.ResponseCode != DNSStatus.OK)
return null;

Expand All @@ -141,15 +146,23 @@ public async Task<List<IPAddress>> ResolveHostV6(string hostname)

public async Task<Message?> ResolveQuery(QuestionRecord question)
{
return await ResolveQueryInternal(question, globalNameservers);
bool privateQuery = (question.Name.Last() == "local");
return await ResolveQueryInternal(question, globalNameservers, privateQuery);
}

private async Task<Message?> ResolveQueryInternal(QuestionRecord question, HashSet<IPAddress> nameservers, int recursionCount = 0)
private async Task<Message?> ResolveQuery(QuestionRecord question, bool privateQuery)
{
return await ResolveQueryInternal(question, globalNameservers, privateQuery);
}

private async Task<Message?> ResolveQueryInternal(QuestionRecord question, HashSet<IPAddress> nameservers, bool privateQuery, int recursionCount = 0)
{
//Check for excessive recursion
recursionCount++;
if (recursionCount > 10)
return null;

//Check for cache hits
ResourceRecord[]? cacheHits = cache.Search(question);
if (cacheHits != null && cacheHits.Length > 0)
{
Expand All @@ -160,6 +173,7 @@ public async Task<List<IPAddress>> ResolveHostV6(string hostname)
return msg;
}

//Otherwise query the nameserver(s)
Socket? socket = null;
try
{
Expand All @@ -170,12 +184,34 @@ public async Task<List<IPAddress>> ResolveHostV6(string hostname)

foreach (IPAddress nsIP in nameservers)
{
//Prevent leaking local domains into the global DNS space
if (privateQuery && IsPrivate(nsIP))
return null;

int bytes;
try
{
int len = query.ToBytes(buffer.Span);
await socket.SendToAsync(buffer.Slice(0, len), SocketFlags.None, new IPEndPoint(nsIP, PORT));
bytes = await socket.ReceiveAsync(buffer, SocketFlags.None, new CancellationTokenSource(3000).Token);
if (resolutionMode == ResolutionMode.InsecureOnly)
bytes = await ResolveUDP(query, buffer, socket, nsIP);
else
{
try
{
bytes = await ResolveHTTPS(query, buffer, nsIP);
}
catch (HttpRequestException)
{
if (resolutionMode == ResolutionMode.SecureOnly)
continue;
bytes = await ResolveUDP(query, buffer, socket, nsIP);
}
catch (OperationCanceledException)
{
if (resolutionMode == ResolutionMode.SecureOnly)
continue;
bytes = await ResolveUDP(query, buffer, socket, nsIP);
}
}
}
catch (SocketException) { continue; }
catch (OperationCanceledException) { continue; }
Expand All @@ -192,10 +228,12 @@ public async Task<List<IPAddress>> ResolveHostV6(string hostname)
if (response.ResponseCode != DNSStatus.OK)
continue;

//Check if we have a valid answer
//Add new info to the cache
cache.Store(response.Answers);
cache.Store(response.Authorities);
cache.Store(response.Additionals);

//Check if we have a valid answer
foreach (ResourceRecord answer in response.Answers)
{
if (answer.Type == question.Type)
Expand All @@ -213,7 +251,7 @@ public async Task<List<IPAddress>> ResolveHostV6(string hostname)
if (answer is CNameRecord cname)
{
question.Name = cname.CNameLabels;
return await ResolveQueryInternal(question, nameservers, recursionCount);
return await ResolveQueryInternal(question, nameservers, privateQuery, recursionCount);
}
}

Expand Down Expand Up @@ -264,7 +302,7 @@ public async Task<List<IPAddress>> ResolveHostV6(string hostname)
}

if (nextNSIPs.Any())
return await ResolveQueryInternal(question, nextNSIPs, recursionCount);
return await ResolveQueryInternal(question, nextNSIPs, privateQuery, recursionCount);
}
}
}
Expand All @@ -277,5 +315,51 @@ public async Task<List<IPAddress>> ResolveHostV6(string hostname)
}
return null;
}

private static bool IsPrivate(IPAddress ip, byte[]? addr = null)
{
if (ip.IsIPv6UniqueLocal || ip.IsIPv6SiteLocal || ip.IsIPv6LinkLocal || IPAddress.IsLoopback(ip))
return true;
if (ip.AddressFamily == AddressFamily.InterNetwork)
{
if (addr == null)
addr = ip.GetAddressBytes();
if ((addr[0] == 169 && addr[1] == 254) || (addr[0] == 192 && addr[1] == 168) ||
(addr[0] == 10) || (addr[0] == 172 && (addr[1] & 0xF0) == 0x10))
return true;
}
return false;
}

private async Task<int> ResolveUDP(Message query, Memory<byte> buffer, Socket socket, IPAddress nameserverIP)
{
int len = query.ToBytes(buffer.Span);
await socket.SendToAsync(buffer.Slice(0, len), SocketFlags.None, new IPEndPoint(nameserverIP, PORT));
return await socket.ReceiveAsync(buffer, SocketFlags.None, new CancellationTokenSource(3000).Token);
}

private async Task<int> ResolveHTTPS(Message query, Memory<byte> buffer, IPAddress nameserverIP)
{
query.TransactionID = 0;
int len = query.ToBytes(buffer.Span);
using (HttpClient httpClient = new HttpClient())
{
ByteArrayContent content = new ByteArrayContent(buffer.Slice(0, len).ToArray());
content.Headers.ContentType = new MediaTypeHeaderValue("application/dns-message");
string hostname = nameserverIP.ToString();
if (nameserverIP.AddressFamily == AddressFamily.InterNetworkV6)
hostname = String.Concat("[", hostname, "]");
HttpRequestMessage request = new HttpRequestMessage(HttpMethod.Post, $"https://{hostname}/dns-query");
request.Content = content;
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/dns-message"));
request.Version = new Version(2, 0);
request.VersionPolicy = HttpVersionPolicy.RequestVersionOrHigher;
var response = await httpClient.SendAsync(request, new CancellationTokenSource(3000).Token);
response.EnsureSuccessStatusCode();
var tempBuff = await response.Content.ReadAsByteArrayAsync();
tempBuff.CopyTo(buffer);
return tempBuff.Length;
}
}
}
}
2 changes: 1 addition & 1 deletion TinyDNS/DomainParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public static List<string> Parse(string domain)
{
//Escaped char follows
if (char.IsAsciiHexDigit(domain[++i]))
label.Append((char)int.Parse(domain.AsSpan().Slice(i++, 2), NumberStyles.HexNumber)); //2 digit char code
label.Append((char)int.Parse(domain.AsSpan().Slice(i++, 2), NumberStyles.HexNumber)); //2 digit hex code
else
label.Append(domain[i]); //single character
}
Expand Down
22 changes: 22 additions & 0 deletions TinyDNS/Enums/ResolutionMode.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// TinyDNS Copyright (C) 2024
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or any later version.
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY, without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU Affero General Public License for more details.
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

namespace TinyDNS.Enums
{
public enum ResolutionMode : byte
{
InsecureOnly,
SecureOnly,
SecureWithFallback,

}
}
8 changes: 4 additions & 4 deletions TinyDNS/Message.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace TinyDNS
{
public sealed class Message : IEquatable<Message>

Check warning on line 19 in TinyDNS/Message.cs

View workflow job for this annotation

GitHub Actions / build

'Message' overrides Object.Equals(object o) but does not override Object.GetHashCode()

Check warning on line 19 in TinyDNS/Message.cs

View workflow job for this annotation

GitHub Actions / build

'Message' overrides Object.Equals(object o) but does not override Object.GetHashCode()
{
public ushort transaction;
public ushort TransactionID { get; set; }
public bool Response { get; set; }
public bool RecursionDesired { get; set; }
public bool RecursionAvailable { get; set; }
Expand All @@ -31,7 +31,7 @@ public sealed class Message : IEquatable<Message>
public Message()
{
RecursionDesired = true;
transaction = (ushort)new Random().Next(ushort.MaxValue);
TransactionID = (ushort)new Random().Next(ushort.MaxValue);
}
/// <summary>
/// Create a DNS Message from a byte buffer
Expand All @@ -44,7 +44,7 @@ public Message(Span<byte> buffer)
byte op = buffer[2];
if ((op & 0x2) == 0x2)
throw new InvalidDataException("Message Truncated");
transaction = BinaryPrimitives.ReadUInt16BigEndian(buffer);
TransactionID = BinaryPrimitives.ReadUInt16BigEndian(buffer);
Response = (op & 0x80) == 0x80;
Authoritative = (op & 0x4) == 0x4;
RecursionDesired = (op & 0x1) == 0x1;
Expand Down Expand Up @@ -79,7 +79,7 @@ public Message(Span<byte> buffer)
public ResourceRecord[] Additionals { get; set; } = [];
public int ToBytes(Span<byte> buffer)
{
BinaryPrimitives.WriteUInt16BigEndian(buffer, transaction);
BinaryPrimitives.WriteUInt16BigEndian(buffer, TransactionID);
byte op = (byte)(((byte)Operation & 0xF) << 3);
if (Response)
op |= 0x80;
Expand Down
3 changes: 2 additions & 1 deletion TinyDNS/TinyDNS.csproj
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
<TargetFrameworks>net80</TargetFrameworks>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<Version>0.5</Version>
</PropertyGroup>

<ItemGroup>
Expand Down
9 changes: 4 additions & 5 deletions TinyDNSDemo/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,23 @@

using System.Net;
using TinyDNS;
using TinyDNS.Enums;

internal class Program
{
static async Task Main()
{
List<IPAddress> addresses = DNSSources.RootNameservers;
DNSResolver resolver = new DNSResolver(addresses); //From root hints
DNSResolver resolver = new DNSResolver(DNSSources.CloudflareDNSAddresses, ResolutionMode.SecureWithFallback); //From root hints
string host = "google.com";
List<IPAddress> ip = await resolver.ResolveHost(host);
if (ip.Count > 0)
Console.WriteLine($"Resolved {host} as {ip[0]}");
List<IPAddress> ip2 = await resolver.ResolveHost(host);
List<IPAddress> ip2 = await resolver.ResolveHostV6("mail." + host);
if (ip2.Count == 0)
Console.WriteLine("Unable to resolve IPs");
else
{
Console.WriteLine($"Resolved {host} as {ip2[0]}");
//Console.WriteLine($"Resolved {ip[0]} as " + await resolver.ResolveIP(ip[0]));
Console.WriteLine($"Resolved mail.{host} as {ip2[0]}");
}
Console.ReadLine();
}
Expand Down
2 changes: 1 addition & 1 deletion TinyDNSDemo/TinyDNSDemo.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net8.0</TargetFramework>
<TargetFrameworks>net80</TargetFrameworks>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>
Expand Down

0 comments on commit 3b7901d

Please sign in to comment.