diff --git a/api/types/desktop.go b/api/types/desktop.go index 74cba39d078f4..8b20ae2331c07 100644 --- a/api/types/desktop.go +++ b/api/types/desktop.go @@ -366,3 +366,11 @@ type ListWindowsDesktopServicesRequest struct { Labels map[string]string SearchKeywords []string } + +// RDPLicenseKey is struct for retrieving licenses from backend cache, used only internally +type RDPLicenseKey struct { + Version uint32 // e.g. 0x000a0002 + Issuer string // e.g. example.com + Company string // e.g. Example Corporation + ProductID string // e.g. A02 +} diff --git a/lib/auth/storage/storage.go b/lib/auth/storage/storage.go index 76db71182e982..625cc393f8698 100644 --- a/lib/auth/storage/storage.go +++ b/lib/auth/storage/storage.go @@ -27,7 +27,9 @@ package storage import ( "context" "encoding/json" + "strconv" "strings" + "time" "github.com/coreos/go-semver/semver" "github.com/gravitational/trace" @@ -233,6 +235,42 @@ func (p *ProcessStorage) WriteTeleportVersion(ctx context.Context, version *semv return nil } +func rdpLicenseKey(key *types.RDPLicenseKey) backend.Key { + return backend.NewKey("rdplicense", key.Issuer, strconv.Itoa(int(key.Version)), key.Company, key.ProductID) +} + +type rdpLicense struct { + Data []byte `json:"data"` +} + +// WriteRDPLicense writes an RDP license to local storage. +func (p *ProcessStorage) WriteRDPLicense(ctx context.Context, key *types.RDPLicenseKey, license []byte) error { + value, err := json.Marshal(rdpLicense{Data: license}) + if err != nil { + return trace.Wrap(err) + } + item := backend.Item{ + Key: rdpLicenseKey(key), + Value: value, + Expires: p.BackendStorage.Clock().Now().Add(28 * 24 * time.Hour), + } + _, err = p.stateStorage.Put(ctx, item) + return trace.Wrap(err) +} + +// ReadRDPLicense reads a previously obtained license from storage. +func (p *ProcessStorage) ReadRDPLicense(ctx context.Context, key *types.RDPLicenseKey) ([]byte, error) { + item, err := p.stateStorage.Get(ctx, rdpLicenseKey(key)) + if err != nil { + return nil, trace.Wrap(err) + } + license := rdpLicense{} + if err := json.Unmarshal(item.Value, &license); err != nil { + return nil, trace.Wrap(err) + } + return license.Data, nil +} + // ReadLocalIdentity reads, parses and returns the given pub/pri key + cert from the // key storage (dataDir). func ReadLocalIdentity(dataDir string, id state.IdentityID) (*state.Identity, error) { diff --git a/lib/auth/storage/storage_test.go b/lib/auth/storage/storage_test.go new file mode 100644 index 0000000000000..42302101c7036 --- /dev/null +++ b/lib/auth/storage/storage_test.go @@ -0,0 +1,72 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package storage + +import ( + "context" + "testing" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/backend/memory" +) + +func TestRDPLicense(t *testing.T) { + ctx := context.Background() + mem, err := memory.New(memory.Config{}) + require.NoError(t, err) + storage := ProcessStorage{ + BackendStorage: mem, + stateStorage: mem, + } + + _, err = storage.ReadRDPLicense(ctx, &types.RDPLicenseKey{ + Version: 1, + Issuer: "issuer", + Company: "company", + ProductID: "productID", + }) + require.True(t, trace.IsNotFound(err)) + + licenseData := []byte{1, 2, 3} + err = storage.WriteRDPLicense(ctx, &types.RDPLicenseKey{ + Version: 1, + Issuer: "issuer", + Company: "company", + ProductID: "productID", + }, licenseData) + require.NoError(t, err) + + _, err = storage.ReadRDPLicense(ctx, &types.RDPLicenseKey{ + Version: 2, + Issuer: "issuer", + Company: "company", + ProductID: "productID", + }) + require.True(t, trace.IsNotFound(err)) + + license, err := storage.ReadRDPLicense(ctx, &types.RDPLicenseKey{ + Version: 1, + Issuer: "issuer", + Company: "company", + ProductID: "productID", + }) + require.NoError(t, err) + require.Equal(t, licenseData, license) +} diff --git a/lib/service/desktop.go b/lib/service/desktop.go index 9a1d411052ed6..c156858a34e16 100644 --- a/lib/service/desktop.go +++ b/lib/service/desktop.go @@ -212,6 +212,7 @@ func (process *TeleportProcess) initWindowsDesktopServiceRegistered(logger *slog srv, err := desktop.NewWindowsService(desktop.WindowsServiceConfig{ DataDir: process.Config.DataDir, Log: process.log.WithField(teleport.ComponentKey, teleport.Component(teleport.ComponentWindowsDesktop, process.id)), + LicenseStore: process.storage, Clock: process.Clock, Authorizer: authorizer, Emitter: conn.Client, diff --git a/lib/srv/desktop/rdp/rdpclient/client.go b/lib/srv/desktop/rdp/rdpclient/client.go index e194fa96891be..e0498f4fd5395 100644 --- a/lib/srv/desktop/rdp/rdpclient/client.go +++ b/lib/srv/desktop/rdp/rdpclient/client.go @@ -72,6 +72,7 @@ import "C" import ( "context" + "encoding/binary" "fmt" "os" "runtime/cgo" @@ -80,6 +81,7 @@ import ( "time" "unsafe" + "github.com/google/uuid" "github.com/gravitational/trace" "github.com/sirupsen/logrus" @@ -304,6 +306,19 @@ func (c *Client) startRustRDP(ctx context.Context) error { return trace.BadParameter("user key was nil") } + hostID, err := uuid.Parse(c.cfg.HostID) + if err != nil { + return trace.Wrap(err) + } + + nextHostID := hostID[:] + cHostID := [4]C.uint32_t{} + for i := 0; i < len(cHostID); i++ { + const uint32Len = 4 + cHostID[i] = (C.uint32_t)(binary.LittleEndian.Uint32(nextHostID[:uint32Len])) + nextHostID = nextHostID[uint32Len:] + } + res := C.client_run( C.uintptr_t(c.handle), C.CGOConnectParams{ @@ -319,6 +334,7 @@ func (c *Client) startRustRDP(ctx context.Context) error { allow_clipboard: C.bool(c.cfg.AllowClipboard), allow_directory_sharing: C.bool(c.cfg.AllowDirectorySharing), show_desktop_wallpaper: C.bool(c.cfg.ShowDesktopWallpaper), + client_id: cHostID, }, ) @@ -716,6 +732,106 @@ func toClient(handle C.uintptr_t) (value *Client, err error) { return cgo.Handle(handle).Value().(*Client), nil } +//export cgo_read_rdp_license +func cgo_read_rdp_license(handle C.uintptr_t, req *C.CGOLicenseRequest, data_out **C.uint8_t, len_out *C.size_t) C.CGOErrCode { + *data_out = nil + *len_out = 0 + + client, err := toClient(handle) + if err != nil { + return C.ErrCodeFailure + } + + issuer := C.GoString(req.issuer) + company := C.GoString(req.company) + productID := C.GoString(req.product_id) + + license, err := client.readRDPLicense(context.Background(), types.RDPLicenseKey{ + Version: uint32(req.version), + Issuer: issuer, + Company: company, + ProductID: productID, + }) + if trace.IsNotFound(err) { + return C.ErrCodeNotFound + } else if err != nil { + return C.ErrCodeFailure + } + + // in this case, we expect the caller to use cgo_free_rdp_license + // when the data is no longer needed + *data_out = (*C.uint8_t)(C.CBytes(license)) + *len_out = C.size_t(len(license)) + return C.ErrCodeSuccess +} + +//export cgo_free_rdp_license +func cgo_free_rdp_license(p *C.uint8_t) { + C.free(unsafe.Pointer(p)) +} + +//export cgo_write_rdp_license +func cgo_write_rdp_license(handle C.uintptr_t, req *C.CGOLicenseRequest, data *C.uint8_t, length C.size_t) C.CGOErrCode { + client, err := toClient(handle) + if err != nil { + return C.ErrCodeFailure + } + + issuer := C.GoString(req.issuer) + company := C.GoString(req.company) + productID := C.GoString(req.product_id) + + licenseData := C.GoBytes(unsafe.Pointer(data), C.int(length)) + + err = client.writeRDPLicense(context.Background(), types.RDPLicenseKey{ + Version: uint32(req.version), + Issuer: issuer, + Company: company, + ProductID: productID, + }, licenseData) + if err != nil { + return C.ErrCodeFailure + } + + return C.ErrCodeSuccess +} + +func (c *Client) readRDPLicense(ctx context.Context, key types.RDPLicenseKey) ([]byte, error) { + log := c.cfg.Log.WithFields(logrus.Fields{ + "issuer": key.Issuer, + "company": key.Company, + "version": key.Version, + "product": key.ProductID, + }) + + license, err := c.cfg.LicenseStore.ReadRDPLicense(ctx, &key) + switch { + case trace.IsNotFound(err): + log.Info("existing RDP license not found") + case err != nil: + log.Error("could not look up existing RDP license", "error", err) + case len(license) > 0: + log.Info("found existing RDP license") + } + + return license, trace.Wrap(err) +} + +func (c *Client) writeRDPLicense(ctx context.Context, key types.RDPLicenseKey, license []byte) error { + log := c.cfg.Log.WithFields(logrus.Fields{ + "issuer": key.Issuer, + "company": key.Company, + "version": key.Version, + "product": key.ProductID, + }) + log.Info("writing RDP license to storage") + err := c.cfg.LicenseStore.WriteRDPLicense(ctx, &key, license) + if err != nil { + log.Error("could not write RDP license", "error", err) + } + return trace.Wrap(err) +} + //export cgo_handle_fastpath_pdu func cgo_handle_fastpath_pdu(handle C.uintptr_t, data *C.uint8_t, length C.uint32_t) C.CGOErrCode { goData := asRustBackedSlice(data, int(length)) diff --git a/lib/srv/desktop/rdp/rdpclient/client_common.go b/lib/srv/desktop/rdp/rdpclient/client_common.go index 9efe72dad2e18..7217813226a73 100644 --- a/lib/srv/desktop/rdp/rdpclient/client_common.go +++ b/lib/srv/desktop/rdp/rdpclient/client_common.go @@ -27,13 +27,25 @@ import ( "github.com/gravitational/trace" "github.com/sirupsen/logrus" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/srv/desktop/tdp" ) +// LicenseStore implements client-side license storage for Microsoft +// Remote Desktop Services (RDS) licenses. +type LicenseStore interface { + WriteRDPLicense(ctx context.Context, key *types.RDPLicenseKey, license []byte) error + ReadRDPLicense(ctx context.Context, key *types.RDPLicenseKey) ([]byte, error) +} + // Config for creating a new Client. type Config struct { // Addr is the network address of the RDP server, in the form host:port. Addr string + + LicenseStore LicenseStore + HostID string + // UserCertGenerator generates user certificates for RDP authentication. GenerateUserCert GenerateUserCertFn CertTTL time.Duration diff --git a/lib/srv/desktop/rdp/rdpclient/src/client.rs b/lib/srv/desktop/rdp/rdpclient/src/client.rs index f5a0c482f7fc0..d6db7d1d052cd 100644 --- a/lib/srv/desktop/rdp/rdpclient/src/client.rs +++ b/lib/srv/desktop/rdp/rdpclient/src/client.rs @@ -67,6 +67,7 @@ use tokio::sync::mpsc::{channel, error::SendError, Receiver, Sender}; use tokio::task::JoinError; // Export this for crate level use. use crate::cliprdr::{ClipboardFn, TeleportCliprdrBackend}; +use crate::license::GoLicenseCache; use crate::rdpdr::scard::SCARD_DEVICE_ID; use crate::rdpdr::TeleportRdpdrBackend; use crate::ssl::TlsStream; @@ -142,7 +143,7 @@ impl Client { let mut rng = rand_chacha::ChaCha20Rng::from_entropy(); let pin = format!("{:08}", rng.gen_range(0i32..=99999999i32)); - let connector_config = create_config(¶ms, pin.clone()); + let connector_config = create_config(¶ms, pin.clone(), cgo_handle); // Create a channel for sending/receiving function calls to/from the Client. let (client_handle, function_receiver) = ClientHandle::new(100); @@ -1385,7 +1386,7 @@ impl FunctionReceiver { type RdpReadStream = Framed>>>; type RdpWriteStream = Framed>>>; -fn create_config(params: &ConnectParams, pin: String) -> Config { +fn create_config(params: &ConnectParams, pin: String, cgo_handle: CgoHandle) -> Config { Config { desktop_size: ironrdp_connector::DesktopSize { width: params.screen_width, @@ -1426,8 +1427,8 @@ fn create_config(params: &ConnectParams, pin: String) -> Config { PerformanceFlags::empty() }, desktop_scale_factor: 0, - license_cache: None, - hardware_id: None, + license_cache: Some(Arc::new(GoLicenseCache { cgo_handle })), + hardware_id: Some(params.client_id), } } @@ -1441,6 +1442,7 @@ pub struct ConnectParams { pub allow_clipboard: bool, pub allow_directory_sharing: bool, pub show_desktop_wallpaper: bool, + pub client_id: [u32; 4], } #[derive(Debug)] diff --git a/lib/srv/desktop/rdp/rdpclient/src/lib.rs b/lib/srv/desktop/rdp/rdpclient/src/lib.rs index c90e9daecfe42..4bdf1096853b8 100644 --- a/lib/srv/desktop/rdp/rdpclient/src/lib.rs +++ b/lib/srv/desktop/rdp/rdpclient/src/lib.rs @@ -45,6 +45,7 @@ use std::ptr; use util::{from_c_string, from_go_array}; pub mod client; mod cliprdr; +mod license; mod piv; mod rdpdr; mod ssl; @@ -103,6 +104,7 @@ pub unsafe extern "C" fn client_run(cgo_handle: CgoHandle, params: CGOConnectPar allow_clipboard: params.allow_clipboard, allow_directory_sharing: params.allow_directory_sharing, show_desktop_wallpaper: params.show_desktop_wallpaper, + client_id: params.client_id, }, ) { Ok(res) => CGOResult { @@ -482,6 +484,7 @@ pub struct CGOConnectParams { allow_clipboard: bool, allow_directory_sharing: bool, show_desktop_wallpaper: bool, + client_id: [u32; 4], } /// CGOKeyboardEvent is a CGO-compatible version of KeyboardEvent that we pass back to Go. @@ -552,6 +555,7 @@ pub enum CGOErrCode { ErrCodeSuccess = 0, ErrCodeFailure = 1, ErrCodeClientPtr = 2, + ErrCodeNotFound = 3, } #[repr(C)] @@ -706,6 +710,19 @@ pub type CGOSharedDirectoryTruncateResponse = SharedDirectoryTruncateResponse; // These functions are defined on the Go side. // Look for functions with '//export funcname' comments. extern "C" { + fn cgo_free_rdp_license(data: *mut u8); + fn cgo_read_rdp_license( + cgo_handle: CgoHandle, + req: *mut CGOLicenseRequest, + data_out: *mut *mut u8, + len_out: *mut usize, + ) -> CGOErrCode; + fn cgo_write_rdp_license( + cgo_handle: CgoHandle, + req: *mut CGOLicenseRequest, + data: *mut u8, + length: usize, + ) -> CGOErrCode; fn cgo_handle_remote_copy(cgo_handle: CgoHandle, data: *mut u8, len: u32) -> CGOErrCode; fn cgo_handle_fastpath_pdu(cgo_handle: CgoHandle, data: *mut u8, len: u32) -> CGOErrCode; fn cgo_handle_rdp_connection_activated( @@ -757,3 +774,11 @@ extern "C" { /// /// [cgo.Handle]: https://pkg.go.dev/runtime/cgo#Handle type CgoHandle = usize; + +#[repr(C)] +pub struct CGOLicenseRequest { + version: u32, + issuer: *const c_char, + company: *const c_char, + product_id: *const c_char, +} diff --git a/lib/srv/desktop/rdp/rdpclient/src/license.rs b/lib/srv/desktop/rdp/rdpclient/src/license.rs new file mode 100644 index 0000000000000..3636d2d4c6eb3 --- /dev/null +++ b/lib/srv/desktop/rdp/rdpclient/src/license.rs @@ -0,0 +1,85 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +use crate::{ + cgo_free_rdp_license, cgo_read_rdp_license, cgo_write_rdp_license, CGOErrCode, + CGOLicenseRequest, CgoHandle, +}; +use ironrdp_connector::{custom_err, general_err, ConnectorError, ConnectorResult, LicenseCache}; +use ironrdp_pdu::rdp::server_license::LicenseInformation; +use picky_krb::negoex::NegoexDataType; +use std::ffi::{CString, NulError}; +use std::{ptr, slice}; + +#[derive(Debug)] +pub(crate) struct GoLicenseCache { + pub(crate) cgo_handle: CgoHandle, +} + +fn conversion_error(e: NulError) -> ConnectorError { + custom_err!("conversion error", e) +} + +impl LicenseCache for GoLicenseCache { + fn get_license(&self, license_info: LicenseInformation) -> ConnectorResult>> { + let issuer = CString::new(license_info.scope).map_err(conversion_error)?; + let company = CString::new(license_info.company_name).map_err(conversion_error)?; + let product_id = CString::new(license_info.product_id).map_err(conversion_error)?; + let mut req = CGOLicenseRequest { + version: license_info.version, + issuer: issuer.as_ptr(), + company: company.as_ptr(), + product_id: product_id.as_ptr(), + }; + let mut data: *mut u8 = ptr::null_mut(); + let mut size = 0usize; + unsafe { + match cgo_read_rdp_license(self.cgo_handle, &mut req, &mut data, &mut size) { + CGOErrCode::ErrCodeSuccess => { + let license = slice::from_raw_parts_mut(data, size).to_vec(); + cgo_free_rdp_license(data); + Ok(Some(license)) + } + CGOErrCode::ErrCodeFailure => Err(general_err!("error retrieving license")), + CGOErrCode::ErrCodeClientPtr => Err(general_err!("invalid client pointer")), + CGOErrCode::ErrCodeNotFound => Ok(None), + } + } + } + + fn store_license(&self, mut license_info: LicenseInformation) -> ConnectorResult<()> { + let issuer = CString::new(license_info.scope).map_err(conversion_error)?; + let company = CString::new(license_info.company_name).map_err(conversion_error)?; + let product_id = CString::new(license_info.product_id).map_err(conversion_error)?; + let mut req = CGOLicenseRequest { + version: license_info.version, + issuer: issuer.as_ptr(), + company: company.as_ptr(), + product_id: product_id.as_ptr(), + }; + unsafe { + match cgo_write_rdp_license( + self.cgo_handle, + &mut req, + license_info.license_info.as_mut_ptr(), + license_info.license_info.size(), + ) { + CGOErrCode::ErrCodeSuccess => Ok(()), + _ => Err(general_err!("error storing license")), + } + } + } +} diff --git a/lib/srv/desktop/windows_server.go b/lib/srv/desktop/windows_server.go index 44e2ebd148d9a..233d65f484ef1 100644 --- a/lib/srv/desktop/windows_server.go +++ b/lib/srv/desktop/windows_server.go @@ -151,8 +151,9 @@ type WindowsServiceConfig struct { // Log is the logging sink for the service. Log logrus.FieldLogger // Clock provides current time. - Clock clockwork.Clock - DataDir string + Clock clockwork.Clock + DataDir string + LicenseStore rdpclient.LicenseStore // Authorizer is used to authorize requests. Authorizer authz.Authorizer // LockWatcher is used to monitor for new locks. @@ -860,7 +861,9 @@ func (s *WindowsService) connectRDP(ctx context.Context, log logrus.FieldLogger, //nolint:staticcheck // SA4023. False positive, depends on build tags. rdpc, err := rdpclient.New(rdpclient.Config{ - Log: log, + Log: log, + LicenseStore: s.cfg.LicenseStore, + HostID: s.cfg.Heartbeat.HostUUID, GenerateUserCert: func(ctx context.Context, username string, ttl time.Duration) (certDER, keyDER []byte, err error) { return s.generateUserCert(ctx, username, ttl, desktop, createUsers, groups) },