Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ffi): screenshot example using C# bindings #437

Merged
merged 18 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ffi/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ doctest = false
[dependencies]
diplomat = "0.7.0"
diplomat-runtime = "0.7.0"
ironrdp = { workspace = true, features = ["connector", "dvc", "svc","rdpdr","rdpsnd"] }
ironrdp = { workspace = true, features = ["connector", "dvc", "svc","rdpdr","rdpsnd","graphics","input"] }
sspi = { workspace = true, features = ["network_client"] }
thiserror.workspace = true

Expand Down
3 changes: 2 additions & 1 deletion ffi/dotnet/Devolutions.IronRdp.ConnectExample/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
obj
bin
.vs
.vs
output.bmp
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,10 @@
<ItemGroup>
<ProjectReference Include="../Devolutions.IronRdp/Devolutions.IronRdp.csproj" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Windows.Compatibility" Version="8.0.3" />
irvingoujAtDevolution marked this conversation as resolved.
Show resolved Hide resolved
<PackageReference Include="SixLabors.ImageSharp" Version="3.1.3" />
irvingoujAtDevolution marked this conversation as resolved.
Show resolved Hide resolved
</ItemGroup>

</Project>
305 changes: 92 additions & 213 deletions ffi/dotnet/Devolutions.IronRdp.ConnectExample/Program.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Drawing;
using System.Drawing.Imaging;
using Devolutions.IronRdp;

namespace Devolutions.IronRdp.ConnectExample
{
Expand All @@ -10,9 +10,9 @@ static async Task Main(string[] args)
{
var arguments = ParseArguments(args);

if (arguments == null)
if (arguments == null)
{
return;
return;
}

var serverName = arguments["--serverName"];
Expand All @@ -21,15 +21,95 @@ static async Task Main(string[] args)
var domain = arguments["--domain"];
try
{
await Connect(serverName, username, password, domain);
var (res, framed) = await Connection.Connect(buildConfig(serverName, username, password, domain, 1980, 1080), serverName);
var decodedImage = DecodedImage.New(PixelFormat.RgbA32, res.GetDesktopSize().GetWidth(), res.GetDesktopSize().GetHeight());
var activeState = ActiveStage.New(res);
var keepLooping = true;
while (keepLooping)
{
var readPduTask = framed.ReadPdu();
Action? action = null;
byte[]? payload = null;
if (readPduTask == await Task.WhenAny(readPduTask, Task.Delay(1000)))
{
var pduReadTask = await readPduTask;
action = pduReadTask.Item1;
payload = pduReadTask.Item2;
Console.WriteLine($"Action: {action}");
}
else
{
Console.WriteLine("Timeout");
keepLooping = false;
continue;
irvingoujAtDevolution marked this conversation as resolved.
Show resolved Hide resolved
}
var outputIterator = activeState.Process(decodedImage, action, payload);

while (!outputIterator.IsEmpty())
{
var output = outputIterator.Next()!; // outputIterator.Next() is not null since outputIterator.IsEmpty() is false
Console.WriteLine($"Output type: {output.GetType()}");
if (output.GetType() == ActiveStageOutputType.Terminate)
{
Console.WriteLine("Connection terminated.");
keepLooping = false;
}

if (output.GetType() == ActiveStageOutputType.ResponseFrame)
{
var responseFrame = output.GetResponseFrame()!;
byte[] responseFrameBytes = new byte[responseFrame.GetSize()];
responseFrame.Fill(responseFrameBytes);
await framed.Write(responseFrameBytes);
}
}
}

saveImage(decodedImage, "output.png");

}
catch (Exception e)
catch (Exception e)
{
Console.WriteLine($"An error occurred: {e.Message}");
}
}

static Dictionary<string, string> ParseArguments(string[] args)
private static void saveImage(DecodedImage decodedImage, string v)
{
int width = decodedImage.GetWidth();
int height = decodedImage.GetHeight();
var data = decodedImage.GetData();

var bytes = new byte[data.GetSize()];
data.Fill(bytes);
for (int i = 0; i < bytes.Length; i += 4)
{
byte temp = bytes[i]; // Store the original Blue value
bytes[i] = bytes[i + 2]; // Move Red to Blue's position
bytes[i + 2] = temp; // Move original Blue to Red's position
// Green (bytes[i+1]) and Alpha (bytes[i+3]) remain unchanged
}
using (var bmp = new Bitmap(width, height))
{
// Lock the bits of the bitmap.
var bmpData = bmp.LockBits(new Rectangle(0, 0, bmp.Width, bmp.Height),
ImageLockMode.WriteOnly, System.Drawing.Imaging.PixelFormat.Format32bppArgb);

// Get the address of the first line.
IntPtr ptr = bmpData.Scan0;
// Copy the RGBA values back to the bitmap
System.Runtime.InteropServices.Marshal.Copy(bytes, 0, ptr, bytes.Length);
// Unlock the bits.
bmp.UnlockBits(bmpData);

// Save the bitmap to the specified output path
bmp.Save("./output.bmp", ImageFormat.Bmp);
}


}

static Dictionary<string, string>? ParseArguments(string[] args)
{
if (args.Length == 0 || Array.Exists(args, arg => arg == "--help"))
{
Expand Down Expand Up @@ -97,222 +177,21 @@ static void PrintHelp()
Console.WriteLine(" --help Show this message and exit.");
}

static async Task Connect(String servername, String username, String password, String domain)
{
Config config = buildConfig(servername, username, password, domain);

var stream = await CreateTcpConnection(servername, 3389);
var framed = new Framed<NetworkStream>(stream);

ClientConnector connector = ClientConnector.New(config);

var ip = await Dns.GetHostAddressesAsync(servername);
if (ip.Length == 0)
{
throw new Exception("Could not resolve server address");
}

var socketAddrString = ip[0].ToString()+":3389";
connector.WithServerAddr(socketAddrString);

await connectBegin(framed, connector);
var (serverPublicKey, framedSsl) = await securityUpgrade(servername, framed, connector);
await ConnectFinalize(servername, connector, serverPublicKey, framedSsl);
}

private static async Task<(byte[], Framed<SslStream>)> securityUpgrade(string servername, Framed<NetworkStream> framed, ClientConnector connector)
{
byte[] serverPublicKey;
Framed<SslStream> framedSsl;
var (streamRequireUpgrade, _) = framed.GetInner();
var promise = new TaskCompletionSource<byte[]>();
var sslStream = new SslStream(streamRequireUpgrade, false, (sender, certificate, chain, sslPolicyErrors) =>
{
promise.SetResult(certificate!.GetPublicKey());
return true;
});
await sslStream.AuthenticateAsClientAsync(servername);
serverPublicKey = await promise.Task;
framedSsl = new Framed<SslStream>(sslStream);
connector.MarkSecurityUpgradeAsDone();

return (serverPublicKey, framedSsl);
}

private static async Task connectBegin(Framed<NetworkStream> framed, ClientConnector connector)
{
var writeBuf = WriteBuf.New();
while (!connector.ShouldPerformSecurityUpgrade())
{
await SingleConnectStep(connector, writeBuf, framed);
}
}

private static Config buildConfig(string servername, string username, string password, string domain)
private static Config buildConfig(string servername, string username, string password, string domain, int width, int height)
{
ConfigBuilder configBuilder = ConfigBuilder.New();

configBuilder.WithUsernameAndPasswrord(username, password);
configBuilder.WithUsernameAndPassword(username, password);
configBuilder.SetDomain(domain);
configBuilder.SetDesktopSize(800, 600);
configBuilder.SetDesktopSize((ushort)height, (ushort)width);
configBuilder.SetClientName("IronRdp");
configBuilder.SetClientDir("C:\\");
configBuilder.SetPerformanceFlags(PerformanceFlags.NewDefault());

return configBuilder.Build();
}

private static async Task ConnectFinalize(string servername, ClientConnector connector, byte[] serverpubkey, Framed<SslStream> framedSsl)
{
var writeBuf2 = WriteBuf.New();
if (connector.ShouldPerformCredssp())
{
await PerformCredsspSteps(connector, servername, writeBuf2, framedSsl, serverpubkey);
}
while (!connector.State().IsTerminal())
{
await SingleConnectStep(connector, writeBuf2, framedSsl);
}
}

private static async Task PerformCredsspSteps(ClientConnector connector, string serverName, WriteBuf writeBuf, Framed<SslStream> framedSsl, byte[] serverpubkey)
{
var credsspSequenceInitResult = CredsspSequence.Init(connector, serverName, serverpubkey, null);
var credsspSequence = credsspSequenceInitResult.GetCredsspSequence();
var tsRequest = credsspSequenceInitResult.GetTsRequest();
TcpClient tcpClient = new TcpClient();
while (true)
{
var generator = credsspSequence.ProcessTsRequest(tsRequest);
var clientState = await ResolveGenerator(generator, tcpClient);
writeBuf.Clear();
var written = credsspSequence.HandleProcessResult(clientState, writeBuf);

if (written.GetSize().IsSome())
{
var actualSize = (int)written.GetSize().Get();
var response = new byte[actualSize];
writeBuf.ReadIntoBuf(response);
await framedSsl.Write(response);
}

var pduHint = credsspSequence.NextPduHint()!;
if (pduHint == null)
{
break;
}

var pdu = await framedSsl.ReadByHint(pduHint);
var decoded = credsspSequence.DecodeServerMessage(pdu);

if (null == decoded)
{
break;
}

tsRequest = decoded;
}
}

private static async Task<ClientState> ResolveGenerator(CredsspProcessGenerator generator, TcpClient tcpClient)
{
var state = generator.Start();
NetworkStream stream = null;
while (true)
{
if (state.IsSuspended())
{
var request = state.GetNetworkRequestIfSuspended()!;
var protocol = request.GetProtocol();
var url = request.GetUrl();
var data = request.GetData();
if (null == stream)
{
url = url.Replace("tcp://", "");
var split = url.Split(":");
await tcpClient.ConnectAsync(split[0], int.Parse(split[1]));
stream = tcpClient.GetStream();

}
if (protocol == NetworkRequestProtocol.Tcp)
{
stream.Write(Utils.Vecu8ToByte(data));
var readBuf = new byte[8096];
var readlen = await stream.ReadAsync(readBuf, 0, readBuf.Length);
var actuallyRead = new byte[readlen];
Array.Copy(readBuf, actuallyRead, readlen);
state = generator.Resume(actuallyRead);
}
else
{
throw new Exception("Unimplemented protocol");
}
}
else
{
var client_state = state.GetClientStateIfCompleted();
return client_state;
}
}
}

static async Task SingleConnectStep<T>(ClientConnector connector, WriteBuf buf, Framed<T> framed)
where T : Stream
{
buf.Clear();

var pduHint = connector.NextPduHint();
Written written;
if (pduHint != null)
{
byte[] pdu = await framed.ReadByHint(pduHint);
written = connector.Step(pdu, buf);
}
else
{
written = connector.StepNoInput(buf);
}

if (written.GetWrittenType() == WrittenType.Nothing)
{
return;
}

// will throw if size is not set
var size = written.GetSize().Get();

var response = new byte[size];
buf.ReadIntoBuf(response);

await framed.Write(response);
}

static async Task<NetworkStream> CreateTcpConnection(String servername, int port)
{
IPHostEntry ipHostInfo = await Dns.GetHostEntryAsync(servername);
IPAddress ipAddress = ipHostInfo.AddressList[0];
IPEndPoint ipEndPoint = new(ipAddress, port);

TcpClient client = new TcpClient();

await client.ConnectAsync(ipEndPoint);
NetworkStream stream = client.GetStream();

return stream;
}

}

public static class Utils
{
public static byte[] Vecu8ToByte(VecU8 vecU8)
{
var len = vecU8.GetSize();
byte[] buffer = new byte[len];
vecU8.Fill(buffer);
return buffer;
}
}
}


}
Loading