Skip to content
Merged
8 changes: 8 additions & 0 deletions api/types/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,11 @@ type ListWindowsDesktopServicesRequest struct {
Labels map[string]string
SearchKeywords []string
}

// RDPLicenseKey is struct for retrieving licenses from backend cache, used only internally
type RDPLicenseKey struct {
Version uint32 // e.g. 0x000a0002
Issuer string // e.g. example.com
Company string // e.g. Example Corporation
ProductID string // e.g. A02
}
38 changes: 38 additions & 0 deletions lib/auth/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ package storage
import (
"context"
"encoding/json"
"strconv"
"strings"
"time"

"github.com/coreos/go-semver/semver"
"github.com/gravitational/trace"
Expand Down Expand Up @@ -233,6 +235,42 @@ func (p *ProcessStorage) WriteTeleportVersion(ctx context.Context, version *semv
return nil
}

func rdpLicenseKey(key *types.RDPLicenseKey) backend.Key {
return backend.NewKey("rdplicense", key.Issuer, strconv.Itoa(int(key.Version)), key.Company, key.ProductID)
}

type rdpLicense struct {
Data []byte `json:"data"`
}

// WriteRDPLicense writes an RDP license to local storage.
func (p *ProcessStorage) WriteRDPLicense(ctx context.Context, key *types.RDPLicenseKey, license []byte) error {
value, err := json.Marshal(rdpLicense{Data: license})
if err != nil {
return trace.Wrap(err)
}
item := backend.Item{
Key: rdpLicenseKey(key),
Value: value,
Expires: p.BackendStorage.Clock().Now().Add(28 * 24 * time.Hour),
}
_, err = p.stateStorage.Put(ctx, item)
return trace.Wrap(err)
}

// ReadRDPLicense reads a previously obtained license from storage.
func (p *ProcessStorage) ReadRDPLicense(ctx context.Context, key *types.RDPLicenseKey) ([]byte, error) {
item, err := p.stateStorage.Get(ctx, rdpLicenseKey(key))
if err != nil {
return nil, trace.Wrap(err)
}
license := rdpLicense{}
if err := json.Unmarshal(item.Value, &license); err != nil {
return nil, trace.Wrap(err)
}
return license.Data, nil
}

// ReadLocalIdentity reads, parses and returns the given pub/pri key + cert from the
// key storage (dataDir).
func ReadLocalIdentity(dataDir string, id state.IdentityID) (*state.Identity, error) {
Expand Down
72 changes: 72 additions & 0 deletions lib/auth/storage/storage_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Teleport
// Copyright (C) 2025 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package storage

import (
"context"
"testing"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/backend/memory"
)

func TestRDPLicense(t *testing.T) {
ctx := context.Background()
mem, err := memory.New(memory.Config{})
require.NoError(t, err)
storage := ProcessStorage{
BackendStorage: mem,
stateStorage: mem,
}

_, err = storage.ReadRDPLicense(ctx, &types.RDPLicenseKey{
Version: 1,
Issuer: "issuer",
Company: "company",
ProductID: "productID",
})
require.True(t, trace.IsNotFound(err))

licenseData := []byte{1, 2, 3}
err = storage.WriteRDPLicense(ctx, &types.RDPLicenseKey{
Version: 1,
Issuer: "issuer",
Company: "company",
ProductID: "productID",
}, licenseData)
require.NoError(t, err)

_, err = storage.ReadRDPLicense(ctx, &types.RDPLicenseKey{
Version: 2,
Issuer: "issuer",
Company: "company",
ProductID: "productID",
})
require.True(t, trace.IsNotFound(err))

license, err := storage.ReadRDPLicense(ctx, &types.RDPLicenseKey{
Version: 1,
Issuer: "issuer",
Company: "company",
ProductID: "productID",
})
require.NoError(t, err)
require.Equal(t, licenseData, license)
}
1 change: 1 addition & 0 deletions lib/service/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ func (process *TeleportProcess) initWindowsDesktopServiceRegistered(logger *slog
srv, err := desktop.NewWindowsService(desktop.WindowsServiceConfig{
DataDir: process.Config.DataDir,
Log: process.log.WithField(teleport.ComponentKey, teleport.Component(teleport.ComponentWindowsDesktop, process.id)),
LicenseStore: process.storage,
Clock: process.Clock,
Authorizer: authorizer,
Emitter: conn.Client,
Expand Down
116 changes: 116 additions & 0 deletions lib/srv/desktop/rdp/rdpclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ import "C"

import (
"context"
"encoding/binary"
"fmt"
"os"
"runtime/cgo"
Expand All @@ -80,6 +81,7 @@ import (
"time"
"unsafe"

"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"

Expand Down Expand Up @@ -304,6 +306,19 @@ func (c *Client) startRustRDP(ctx context.Context) error {
return trace.BadParameter("user key was nil")
}

hostID, err := uuid.Parse(c.cfg.HostID)
if err != nil {
return trace.Wrap(err)
}

nextHostID := hostID[:]
cHostID := [4]C.uint32_t{}
for i := 0; i < len(cHostID); i++ {
const uint32Len = 4
cHostID[i] = (C.uint32_t)(binary.LittleEndian.Uint32(nextHostID[:uint32Len]))
nextHostID = nextHostID[uint32Len:]
}

res := C.client_run(
C.uintptr_t(c.handle),
C.CGOConnectParams{
Expand All @@ -319,6 +334,7 @@ func (c *Client) startRustRDP(ctx context.Context) error {
allow_clipboard: C.bool(c.cfg.AllowClipboard),
allow_directory_sharing: C.bool(c.cfg.AllowDirectorySharing),
show_desktop_wallpaper: C.bool(c.cfg.ShowDesktopWallpaper),
client_id: cHostID,
},
)

Expand Down Expand Up @@ -716,6 +732,106 @@ func toClient(handle C.uintptr_t) (value *Client, err error) {
return cgo.Handle(handle).Value().(*Client), nil
}

//export cgo_read_rdp_license
func cgo_read_rdp_license(handle C.uintptr_t, req *C.CGOLicenseRequest, data_out **C.uint8_t, len_out *C.size_t) C.CGOErrCode {
*data_out = nil
*len_out = 0

client, err := toClient(handle)
if err != nil {
return C.ErrCodeFailure
}

issuer := C.GoString(req.issuer)
company := C.GoString(req.company)
productID := C.GoString(req.product_id)

license, err := client.readRDPLicense(context.Background(), types.RDPLicenseKey{
Version: uint32(req.version),
Issuer: issuer,
Company: company,
ProductID: productID,
})
if trace.IsNotFound(err) {
return C.ErrCodeNotFound
} else if err != nil {
return C.ErrCodeFailure
}

// in this case, we expect the caller to use cgo_free_rdp_license
// when the data is no longer needed
*data_out = (*C.uint8_t)(C.CBytes(license))
*len_out = C.size_t(len(license))
return C.ErrCodeSuccess
}

//export cgo_free_rdp_license
func cgo_free_rdp_license(p *C.uint8_t) {
C.free(unsafe.Pointer(p))
}

//export cgo_write_rdp_license
func cgo_write_rdp_license(handle C.uintptr_t, req *C.CGOLicenseRequest, data *C.uint8_t, length C.size_t) C.CGOErrCode {
client, err := toClient(handle)
if err != nil {
return C.ErrCodeFailure
}

issuer := C.GoString(req.issuer)
company := C.GoString(req.company)
productID := C.GoString(req.product_id)

licenseData := C.GoBytes(unsafe.Pointer(data), C.int(length))

err = client.writeRDPLicense(context.Background(), types.RDPLicenseKey{
Version: uint32(req.version),
Issuer: issuer,
Company: company,
ProductID: productID,
}, licenseData)
if err != nil {
return C.ErrCodeFailure
}

return C.ErrCodeSuccess
}

func (c *Client) readRDPLicense(ctx context.Context, key types.RDPLicenseKey) ([]byte, error) {
log := c.cfg.Log.WithFields(logrus.Fields{
"issuer": key.Issuer,
"company": key.Company,
"version": key.Version,
"product": key.ProductID,
})

license, err := c.cfg.LicenseStore.ReadRDPLicense(ctx, &key)
switch {
case trace.IsNotFound(err):
log.Info("existing RDP license not found")
case err != nil:
log.Error("could not look up existing RDP license", "error", err)
case len(license) > 0:
log.Info("found existing RDP license")
}

return license, trace.Wrap(err)
}

func (c *Client) writeRDPLicense(ctx context.Context, key types.RDPLicenseKey, license []byte) error {
log := c.cfg.Log.WithFields(logrus.Fields{
"issuer": key.Issuer,
"company": key.Company,
"version": key.Version,
"product": key.ProductID,
})
log.Info("writing RDP license to storage")
err := c.cfg.LicenseStore.WriteRDPLicense(ctx, &key, license)
if err != nil {
log.Error("could not write RDP license", "error", err)
}
return trace.Wrap(err)
}

//export cgo_handle_fastpath_pdu
func cgo_handle_fastpath_pdu(handle C.uintptr_t, data *C.uint8_t, length C.uint32_t) C.CGOErrCode {
goData := asRustBackedSlice(data, int(length))
Expand Down
12 changes: 12 additions & 0 deletions lib/srv/desktop/rdp/rdpclient/client_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,25 @@ import (
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/srv/desktop/tdp"
)

// LicenseStore implements client-side license storage for Microsoft
// Remote Desktop Services (RDS) licenses.
type LicenseStore interface {
WriteRDPLicense(ctx context.Context, key *types.RDPLicenseKey, license []byte) error
ReadRDPLicense(ctx context.Context, key *types.RDPLicenseKey) ([]byte, error)
}

// Config for creating a new Client.
type Config struct {
// Addr is the network address of the RDP server, in the form host:port.
Addr string

LicenseStore LicenseStore
HostID string

// UserCertGenerator generates user certificates for RDP authentication.
GenerateUserCert GenerateUserCertFn
CertTTL time.Duration
Expand Down
10 changes: 6 additions & 4 deletions lib/srv/desktop/rdp/rdpclient/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ use tokio::sync::mpsc::{channel, error::SendError, Receiver, Sender};
use tokio::task::JoinError;
// Export this for crate level use.
use crate::cliprdr::{ClipboardFn, TeleportCliprdrBackend};
use crate::license::GoLicenseCache;
use crate::rdpdr::scard::SCARD_DEVICE_ID;
use crate::rdpdr::TeleportRdpdrBackend;
use crate::ssl::TlsStream;
Expand Down Expand Up @@ -142,7 +143,7 @@ impl Client {
let mut rng = rand_chacha::ChaCha20Rng::from_entropy();
let pin = format!("{:08}", rng.gen_range(0i32..=99999999i32));

let connector_config = create_config(&params, pin.clone());
let connector_config = create_config(&params, pin.clone(), cgo_handle);

// Create a channel for sending/receiving function calls to/from the Client.
let (client_handle, function_receiver) = ClientHandle::new(100);
Expand Down Expand Up @@ -1385,7 +1386,7 @@ impl FunctionReceiver {
type RdpReadStream = Framed<TokioStream<ReadHalf<TlsStream<TokioTcpStream>>>>;
type RdpWriteStream = Framed<TokioStream<WriteHalf<TlsStream<TokioTcpStream>>>>;

fn create_config(params: &ConnectParams, pin: String) -> Config {
fn create_config(params: &ConnectParams, pin: String, cgo_handle: CgoHandle) -> Config {
Config {
desktop_size: ironrdp_connector::DesktopSize {
width: params.screen_width,
Expand Down Expand Up @@ -1426,8 +1427,8 @@ fn create_config(params: &ConnectParams, pin: String) -> Config {
PerformanceFlags::empty()
},
desktop_scale_factor: 0,
license_cache: None,
hardware_id: None,
license_cache: Some(Arc::new(GoLicenseCache { cgo_handle })),
hardware_id: Some(params.client_id),
}
}

Expand All @@ -1441,6 +1442,7 @@ pub struct ConnectParams {
pub allow_clipboard: bool,
pub allow_directory_sharing: bool,
pub show_desktop_wallpaper: bool,
pub client_id: [u32; 4],
}

#[derive(Debug)]
Expand Down
Loading