From b3b59bdf080effeb3f6b49a93a75b2d7fab4bb82 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Fri, 13 Feb 2026 10:51:21 +0000 Subject: [PATCH 01/37] [py] BiDi Python code generation from CDDL --- common/bidi/spec/BUILD.bazel | 13 + common/bidi/spec/all.cddl | 2489 +++++++++++++++++ common/bidi/spec/local.cddl | 1331 +++++++++ common/bidi/spec/remote.cddl | 1716 ++++++++++++ py/AGENTS.md | 17 +- py/BUILD.bazel | 19 + py/conftest.py | 8 + py/generate_bidi.py | 1824 ++++++++++++ py/private/BUILD.bazel | 5 + py/private/bidi_enhancements_manifest.py | 1557 +++++++++++ py/private/cdp.py | 515 ++++ py/private/generate_bidi.bzl | 112 + py/requirements_lock.txt | 5 +- py/selenium/common/exceptions.py | 62 +- py/selenium/webdriver/common/bidi/__init__.py | 21 +- py/selenium/webdriver/common/bidi/browser.py | 508 ++-- .../webdriver/common/bidi/browsing_context.py | 1505 +++++----- py/selenium/webdriver/common/bidi/common.py | 27 +- py/selenium/webdriver/common/bidi/console.py | 0 .../webdriver/common/bidi/emulation.py | 835 +++--- py/selenium/webdriver/common/bidi/input.py | 684 +++-- py/selenium/webdriver/common/bidi/log.py | 140 +- py/selenium/webdriver/common/bidi/network.py | 1151 ++++++-- .../webdriver/common/bidi/permissions.py | 85 +- py/selenium/webdriver/common/bidi/py.typed | 0 py/selenium/webdriver/common/bidi/script.py | 1539 +++++++--- py/selenium/webdriver/common/bidi/session.py | 314 ++- py/selenium/webdriver/common/bidi/storage.py | 593 ++-- .../webdriver/common/bidi/webextension.py | 142 +- py/selenium/webdriver/common/by.py | 13 +- py/selenium/webdriver/common/proxy.py | 36 +- py/selenium/webdriver/remote/webdriver.py | 202 +- .../webdriver/remote/websocket_connection.py | 37 +- .../webdriver/common/bidi_browser_tests.py | 11 +- 34 files changed, 14292 insertions(+), 3224 deletions(-) create mode 100644 common/bidi/spec/BUILD.bazel create mode 100644 common/bidi/spec/all.cddl create mode 100644 common/bidi/spec/local.cddl create mode 100644 common/bidi/spec/remote.cddl create mode 100755 py/generate_bidi.py create mode 100644 py/private/bidi_enhancements_manifest.py create mode 100644 py/private/cdp.py create mode 100644 py/private/generate_bidi.bzl mode change 100644 => 100755 py/selenium/webdriver/common/bidi/console.py create mode 100755 py/selenium/webdriver/common/bidi/py.typed diff --git a/common/bidi/spec/BUILD.bazel b/common/bidi/spec/BUILD.bazel new file mode 100644 index 0000000000000..74c3cffd35ed0 --- /dev/null +++ b/common/bidi/spec/BUILD.bazel @@ -0,0 +1,13 @@ +package( + default_visibility = [ + "//py:__pkg__", + ], +) + +exports_files( + srcs = [ + "all.cddl", + "local.cddl", + "remote.cddl", + ], +) diff --git a/common/bidi/spec/all.cddl b/common/bidi/spec/all.cddl new file mode 100644 index 0000000000000..85c4536a2cd10 --- /dev/null +++ b/common/bidi/spec/all.cddl @@ -0,0 +1,2489 @@ +Command = { + id: js-uint, + CommandData, + Extensible, +} + +CommandData = ( + BrowserCommand // + BrowsingContextCommand // + EmulationCommand // + InputCommand // + NetworkCommand // + ScriptCommand // + SessionCommand // + StorageCommand // + WebExtensionCommand +) + +EmptyParams = { + Extensible +} + +Message = ( + CommandResponse / + ErrorResponse / + Event +) + +CommandResponse = { + type: "success", + id: js-uint, + result: ResultData, + Extensible +} + +ErrorResponse = { + type: "error", + id: js-uint / null, + error: ErrorCode, + message: text, + ? stacktrace: text, + Extensible +} + +ResultData = ( + BrowserResult / + BrowsingContextResult / + EmulationResult / + InputResult / + NetworkResult / + ScriptResult / + SessionResult / + StorageResult / + WebExtensionResult +) + +EmptyResult = { + Extensible +} + +Event = { + type: "event", + EventData, + Extensible +} + +EventData = ( + BrowsingContextEvent // + InputEvent // + LogEvent // + NetworkEvent // + ScriptEvent +) + +Extensible = (*text => any) + +js-int = -9007199254740991..9007199254740991 +js-uint = 0..9007199254740991 + +ErrorCode = "invalid argument" / + "invalid selector" / + "invalid session id" / + "invalid web extension" / + "move target out of bounds" / + "no such alert" / + "no such network collector" / + "no such element" / + "no such frame" / + "no such handle" / + "no such history entry" / + "no such intercept" / + "no such network data" / + "no such node" / + "no such request" / + "no such script" / + "no such storage partition" / + "no such user context" / + "no such web extension" / + "session not created" / + "unable to capture screen" / + "unable to close browser" / + "unable to set cookie" / + "unable to set file input" / + "unavailable network data" / + "underspecified storage partition" / + "unknown command" / + "unknown error" / + "unsupported operation" + +SessionCommand = ( + session.End // + session.New // + session.Status // + session.Subscribe // + session.Unsubscribe +) + +SessionResult = ( + session.EndResult / + session.NewResult / + session.StatusResult / + session.SubscribeResult / + session.UnsubscribeResult +) + +session.CapabilitiesRequest = { + ? alwaysMatch: session.CapabilityRequest, + ? firstMatch: [*session.CapabilityRequest] +} + +session.CapabilityRequest = { + ? acceptInsecureCerts: bool, + ? browserName: text, + ? browserVersion: text, + ? platformName: text, + ? proxy: session.ProxyConfiguration, + ? unhandledPromptBehavior: session.UserPromptHandler, + Extensible +} + +session.ProxyConfiguration = { + session.AutodetectProxyConfiguration // + session.DirectProxyConfiguration // + session.ManualProxyConfiguration // + session.PacProxyConfiguration // + session.SystemProxyConfiguration +} + +session.AutodetectProxyConfiguration = ( + proxyType: "autodetect", + Extensible +) + +session.DirectProxyConfiguration = ( + proxyType: "direct", + Extensible +) + +session.ManualProxyConfiguration = ( + proxyType: "manual", + ? httpProxy: text, + ? sslProxy: text, + ? session.SocksProxyConfiguration, + ? noProxy: [*text], + Extensible +) + +session.SocksProxyConfiguration = ( + socksProxy: text, + socksVersion: 0..255, +) + +session.PacProxyConfiguration = ( + proxyType: "pac", + proxyAutoconfigUrl: text, + Extensible +) + +session.SystemProxyConfiguration = ( + proxyType: "system", + Extensible +) + + +session.UserPromptHandler = { + ? alert: session.UserPromptHandlerType, + ? beforeUnload: session.UserPromptHandlerType, + ? confirm: session.UserPromptHandlerType, + ? default: session.UserPromptHandlerType, + ? file: session.UserPromptHandlerType, + ? prompt: session.UserPromptHandlerType, +} + +session.UserPromptHandlerType = "accept" / "dismiss" / "ignore"; + +session.Subscription = text + +session.SubscribeParameters = { + events: [+text], + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +session.UnsubscribeByIDRequest = { + subscriptions: [+session.Subscription], +} + +session.UnsubscribeByAttributesRequest = { + events: [+text], +} + +session.Status = ( + method: "session.status", + params: EmptyParams, +) + +session.StatusResult = { + ready: bool, + message: text, +} + +session.New = ( + method: "session.new", + params: session.NewParameters +) + +session.NewParameters = { + capabilities: session.CapabilitiesRequest +} + +session.NewResult = { + sessionId: text, + capabilities: { + acceptInsecureCerts: bool, + browserName: text, + browserVersion: text, + platformName: text, + setWindowRect: bool, + userAgent: text, + ? proxy: session.ProxyConfiguration, + ? unhandledPromptBehavior: session.UserPromptHandler, + ? webSocketUrl: text, + Extensible + } +} + +session.End = ( + method: "session.end", + params: EmptyParams +) + + +session.EndResult = EmptyResult + +session.Subscribe = ( + method: "session.subscribe", + params: session.SubscribeParameters +) + +session.SubscribeResult = { + subscription: session.Subscription, +} + +session.Unsubscribe = ( + method: "session.unsubscribe", + params: session.UnsubscribeParameters, +) + +session.UnsubscribeParameters = session.UnsubscribeByAttributesRequest / session.UnsubscribeByIDRequest + +session.UnsubscribeResult = EmptyResult + +BrowserCommand = ( + browser.Close // + browser.CreateUserContext // + browser.GetClientWindows // + browser.GetUserContexts // + browser.RemoveUserContext // + browser.SetClientWindowState // + browser.SetDownloadBehavior +) + +BrowserResult = ( + browser.CloseResult / + browser.CreateUserContextResult / + browser.GetClientWindowsResult / + browser.GetUserContextsResult / + browser.RemoveUserContextResult / + browser.SetClientWindowStateResult / + browser.SetDownloadBehaviorResult +) + +browser.ClientWindow = text; + +browser.ClientWindowInfo = { + active: bool, + clientWindow: browser.ClientWindow, + height: js-uint, + state: "fullscreen" / "maximized" / "minimized" / "normal", + width: js-uint, + x: js-int, + y: js-int, +} + +browser.UserContext = text; + +browser.UserContextInfo = { + userContext: browser.UserContext +} + +browser.Close = ( + method: "browser.close", + params: EmptyParams, +) + +browser.CloseResult = EmptyResult + +browser.CreateUserContext = ( + method: "browser.createUserContext", + params: browser.CreateUserContextParameters, +) + +browser.CreateUserContextParameters = { + ? acceptInsecureCerts: bool, + ? proxy: session.ProxyConfiguration, + ? unhandledPromptBehavior: session.UserPromptHandler +} + +browser.CreateUserContextResult = browser.UserContextInfo + +browser.GetClientWindows = ( + method: "browser.getClientWindows", + params: EmptyParams, +) + +browser.GetClientWindowsResult = { + clientWindows: [ * browser.ClientWindowInfo] +} + +browser.GetUserContexts = ( + method: "browser.getUserContexts", + params: EmptyParams, +) + +browser.GetUserContextsResult = { + userContexts: [ + browser.UserContextInfo] +} + +browser.RemoveUserContext = ( + method: "browser.removeUserContext", + params: browser.RemoveUserContextParameters +) + +browser.RemoveUserContextParameters = { + userContext: browser.UserContext +} + +browser.RemoveUserContextResult = EmptyResult + +browser.SetClientWindowState = ( + method: "browser.setClientWindowState", + params: browser.SetClientWindowStateParameters +) + +browser.SetClientWindowStateParameters = { + clientWindow: browser.ClientWindow, + (browser.ClientWindowNamedState // browser.ClientWindowRectState) +} + +browser.ClientWindowNamedState = ( + state: "fullscreen" / "maximized" / "minimized" +) + +browser.ClientWindowRectState = ( + state: "normal", + ? width: js-uint, + ? height: js-uint, + ? x: js-int, + ? y: js-int, +) + +browser.SetClientWindowStateResult = browser.ClientWindowInfo + +browser.SetDownloadBehavior = ( + method: "browser.setDownloadBehavior", + params: browser.SetDownloadBehaviorParameters +) + +browser.SetDownloadBehaviorParameters = { + downloadBehavior: browser.DownloadBehavior / null, + ? userContexts: [+browser.UserContext] +} + +browser.DownloadBehavior = { + ( + browser.DownloadBehaviorAllowed // + browser.DownloadBehaviorDenied + ) +} + +browser.DownloadBehaviorAllowed = ( + type: "allowed", + destinationFolder: text +) + +browser.DownloadBehaviorDenied = ( + type: "denied" +) + +browser.SetDownloadBehaviorResult = EmptyResult + +BrowsingContextCommand = ( + browsingContext.Activate // + browsingContext.CaptureScreenshot // + browsingContext.Close // + browsingContext.Create // + browsingContext.GetTree // + browsingContext.HandleUserPrompt // + browsingContext.LocateNodes // + browsingContext.Navigate // + browsingContext.Print // + browsingContext.Reload // + browsingContext.SetViewport // + browsingContext.TraverseHistory +) + +BrowsingContextResult = ( + browsingContext.ActivateResult / + browsingContext.CaptureScreenshotResult / + browsingContext.CloseResult / + browsingContext.CreateResult / + browsingContext.GetTreeResult / + browsingContext.HandleUserPromptResult / + browsingContext.LocateNodesResult / + browsingContext.NavigateResult / + browsingContext.PrintResult / + browsingContext.ReloadResult / + browsingContext.SetViewportResult / + browsingContext.TraverseHistoryResult +) + +BrowsingContextEvent = ( + browsingContext.ContextCreated // + browsingContext.ContextDestroyed // + browsingContext.DomContentLoaded // + browsingContext.DownloadEnd // + browsingContext.DownloadWillBegin // + browsingContext.FragmentNavigated // + browsingContext.HistoryUpdated // + browsingContext.Load // + browsingContext.NavigationAborted // + browsingContext.NavigationCommitted // + browsingContext.NavigationFailed // + browsingContext.NavigationStarted // + browsingContext.UserPromptClosed // + browsingContext.UserPromptOpened +) + +browsingContext.BrowsingContext = text; + +browsingContext.InfoList = [*browsingContext.Info] + +browsingContext.Info = { + children: browsingContext.InfoList / null, + clientWindow: browser.ClientWindow, + context: browsingContext.BrowsingContext, + originalOpener: browsingContext.BrowsingContext / null, + url: text, + userContext: browser.UserContext, + ? parent: browsingContext.BrowsingContext / null, +} + +browsingContext.Locator = ( + browsingContext.AccessibilityLocator / + browsingContext.CssLocator / + browsingContext.ContextLocator / + browsingContext.InnerTextLocator / + browsingContext.XPathLocator +) + +browsingContext.AccessibilityLocator = { + type: "accessibility", + value: { + ? name: text, + ? role: text, + } +} + +browsingContext.CssLocator = { + type: "css", + value: text +} + +browsingContext.ContextLocator = { + type: "context", + value: { + context: browsingContext.BrowsingContext, + } +} + +browsingContext.InnerTextLocator = { + type: "innerText", + value: text, + ? ignoreCase: bool + ? matchType: "full" / "partial", + ? maxDepth: js-uint, +} + +browsingContext.XPathLocator = { + type: "xpath", + value: text +} + +browsingContext.Navigation = text; + +browsingContext.BaseNavigationInfo = ( + context: browsingContext.BrowsingContext, + navigation: browsingContext.Navigation / null, + timestamp: js-uint, + url: text, +) + +browsingContext.NavigationInfo = { + browsingContext.BaseNavigationInfo +} + +browsingContext.ReadinessState = "none" / "interactive" / "complete" + +browsingContext.UserPromptType = "alert" / "beforeunload" / "confirm" / "prompt"; + +browsingContext.Activate = ( + method: "browsingContext.activate", + params: browsingContext.ActivateParameters +) + +browsingContext.ActivateParameters = { + context: browsingContext.BrowsingContext +} + +browsingContext.ActivateResult = EmptyResult + +browsingContext.CaptureScreenshot = ( + method: "browsingContext.captureScreenshot", + params: browsingContext.CaptureScreenshotParameters +) + +browsingContext.CaptureScreenshotParameters = { + context: browsingContext.BrowsingContext, + ? origin: ("viewport" / "document") .default "viewport", + ? format: browsingContext.ImageFormat, + ? clip: browsingContext.ClipRectangle, +} + +browsingContext.ImageFormat = { + type: text, + ? quality: 0.0..1.0, +} + +browsingContext.ClipRectangle = ( + browsingContext.BoxClipRectangle / + browsingContext.ElementClipRectangle +) + +browsingContext.ElementClipRectangle = { + type: "element", + element: script.SharedReference +} + +browsingContext.BoxClipRectangle = { + type: "box", + x: float, + y: float, + width: float, + height: float +} + +browsingContext.CaptureScreenshotResult = { + data: text +} + +browsingContext.Close = ( + method: "browsingContext.close", + params: browsingContext.CloseParameters +) + +browsingContext.CloseParameters = { + context: browsingContext.BrowsingContext, + ? promptUnload: bool .default false +} + +browsingContext.CloseResult = EmptyResult + +browsingContext.Create = ( + method: "browsingContext.create", + params: browsingContext.CreateParameters +) + +browsingContext.CreateType = "tab" / "window" + +browsingContext.CreateParameters = { + type: browsingContext.CreateType, + ? referenceContext: browsingContext.BrowsingContext, + ? background: bool .default false, + ? userContext: browser.UserContext +} + +browsingContext.CreateResult = { + context: browsingContext.BrowsingContext +} + +browsingContext.GetTree = ( + method: "browsingContext.getTree", + params: browsingContext.GetTreeParameters +) + +browsingContext.GetTreeParameters = { + ? maxDepth: js-uint, + ? root: browsingContext.BrowsingContext, +} + +browsingContext.GetTreeResult = { + contexts: browsingContext.InfoList +} + +browsingContext.HandleUserPrompt = ( + method: "browsingContext.handleUserPrompt", + params: browsingContext.HandleUserPromptParameters +) + +browsingContext.HandleUserPromptParameters = { + context: browsingContext.BrowsingContext, + ? accept: bool, + ? userText: text, +} + +browsingContext.HandleUserPromptResult = EmptyResult + +browsingContext.LocateNodes = ( + method: "browsingContext.locateNodes", + params: browsingContext.LocateNodesParameters +) + +browsingContext.LocateNodesParameters = { + context: browsingContext.BrowsingContext, + locator: browsingContext.Locator, + ? maxNodeCount: (js-uint .ge 1), + ? serializationOptions: script.SerializationOptions, + ? startNodes: [ + script.SharedReference ] +} + +browsingContext.LocateNodesResult = { + nodes: [ * script.NodeRemoteValue ] +} + +browsingContext.Navigate = ( + method: "browsingContext.navigate", + params: browsingContext.NavigateParameters +) + +browsingContext.NavigateParameters = { + context: browsingContext.BrowsingContext, + url: text, + ? wait: browsingContext.ReadinessState, +} + +browsingContext.NavigateResult = { + navigation: browsingContext.Navigation / null, + url: text, +} + +browsingContext.Print = ( + method: "browsingContext.print", + params: browsingContext.PrintParameters +) + +browsingContext.PrintParameters = { + context: browsingContext.BrowsingContext, + ? background: bool .default false, + ? margin: browsingContext.PrintMarginParameters, + ? orientation: ("portrait" / "landscape") .default "portrait", + ? page: browsingContext.PrintPageParameters, + ? pageRanges: [*(js-uint / text)], + ? scale: (0.1..2.0) .default 1.0, + ? shrinkToFit: bool .default true, +} + +browsingContext.PrintMarginParameters = { + ? bottom: (float .ge 0.0) .default 1.0, + ? left: (float .ge 0.0) .default 1.0, + ? right: (float .ge 0.0) .default 1.0, + ? top: (float .ge 0.0) .default 1.0, +} + +; Minimum size is 1pt x 1pt. Conversion follows from +; https://www.w3.org/TR/css3-values/#absolute-lengths +browsingContext.PrintPageParameters = { + ? height: (float .ge 0.0352) .default 27.94, + ? width: (float .ge 0.0352) .default 21.59, +} + +browsingContext.PrintResult = { + data: text +} + +browsingContext.Reload = ( + method: "browsingContext.reload", + params: browsingContext.ReloadParameters +) + +browsingContext.ReloadParameters = { + context: browsingContext.BrowsingContext, + ? ignoreCache: bool, + ? wait: browsingContext.ReadinessState, +} + +browsingContext.ReloadResult = browsingContext.NavigateResult + +browsingContext.SetViewport = ( + method: "browsingContext.setViewport", + params: browsingContext.SetViewportParameters +) + +browsingContext.SetViewportParameters = { + ? context: browsingContext.BrowsingContext, + ? viewport: browsingContext.Viewport / null, + ? devicePixelRatio: (float .gt 0.0) / null, + ? userContexts: [+browser.UserContext], +} + +browsingContext.Viewport = { + width: js-uint, + height: js-uint, +} + +browsingContext.SetViewportResult = EmptyResult + +browsingContext.TraverseHistory = ( + method: "browsingContext.traverseHistory", + params: browsingContext.TraverseHistoryParameters +) + +browsingContext.TraverseHistoryParameters = { + context: browsingContext.BrowsingContext, + delta: js-int, +} + +browsingContext.TraverseHistoryResult = EmptyResult + +browsingContext.ContextCreated = ( + method: "browsingContext.contextCreated", + params: browsingContext.Info +) + +browsingContext.ContextDestroyed = ( + method: "browsingContext.contextDestroyed", + params: browsingContext.Info +) + +browsingContext.NavigationStarted = ( + method: "browsingContext.navigationStarted", + params: browsingContext.NavigationInfo +) + +browsingContext.FragmentNavigated = ( + method: "browsingContext.fragmentNavigated", + params: browsingContext.NavigationInfo +) + +browsingContext.HistoryUpdated = ( + method: "browsingContext.historyUpdated", + params: browsingContext.HistoryUpdatedParameters +) + +browsingContext.HistoryUpdatedParameters = { + context: browsingContext.BrowsingContext, + timestamp: js-uint, + url: text +} + +browsingContext.DomContentLoaded = ( + method: "browsingContext.domContentLoaded", + params: browsingContext.NavigationInfo +) + +browsingContext.Load = ( + method: "browsingContext.load", + params: browsingContext.NavigationInfo +) + +browsingContext.DownloadWillBegin = ( + method: "browsingContext.downloadWillBegin", + params: browsingContext.DownloadWillBeginParams +) + +browsingContext.DownloadWillBeginParams = { + suggestedFilename: text, + browsingContext.BaseNavigationInfo +} + +browsingContext.DownloadEnd = ( + method: "browsingContext.downloadEnd", + params: browsingContext.DownloadEndParams +) + +browsingContext.DownloadEndParams = { + ( + browsingContext.DownloadCanceledParams // + browsingContext.DownloadCompleteParams + ) +} + +browsingContext.DownloadCanceledParams = ( + status: "canceled", + browsingContext.BaseNavigationInfo +) + +browsingContext.DownloadCompleteParams = ( + status: "complete", + filepath: text / null, + browsingContext.BaseNavigationInfo +) + +browsingContext.NavigationAborted = ( + method: "browsingContext.navigationAborted", + params: browsingContext.NavigationInfo +) + +browsingContext.NavigationCommitted = ( + method: "browsingContext.navigationCommitted", + params: browsingContext.NavigationInfo +) + +browsingContext.NavigationFailed = ( + method: "browsingContext.navigationFailed", + params: browsingContext.NavigationInfo +) + +browsingContext.UserPromptClosed = ( + method: "browsingContext.userPromptClosed", + params: browsingContext.UserPromptClosedParameters +) + +browsingContext.UserPromptClosedParameters = { + context: browsingContext.BrowsingContext, + accepted: bool, + type: browsingContext.UserPromptType, + ? userText: text +} + +browsingContext.UserPromptOpened = ( + method: "browsingContext.userPromptOpened", + params: browsingContext.UserPromptOpenedParameters +) + +browsingContext.UserPromptOpenedParameters = { + context: browsingContext.BrowsingContext, + handler: session.UserPromptHandlerType, + message: text, + type: browsingContext.UserPromptType, + ? defaultValue: text +} + +EmulationCommand = ( + emulation.SetForcedColorsModeThemeOverride // + emulation.SetGeolocationOverride // + emulation.SetLocaleOverride // + emulation.SetNetworkConditions // + emulation.SetScreenOrientationOverride // + emulation.SetScreenSettingsOverride // + emulation.SetScriptingEnabled // + emulation.SetScrollbarTypeOverride // + emulation.SetTimezoneOverride // + emulation.SetTouchOverride // + emulation.SetUserAgentOverride // + emulation.SetViewportMetaOverride +) + + +EmulationResult = ( + emulation.SetForcedColorsModeThemeOverrideResult / + emulation.SetGeolocationOverrideResult / + emulation.SetLocaleOverrideResult / + emulation.SetScreenOrientationOverrideResult / + emulation.SetScriptingEnabledResult / + emulation.SetScrollbarTypeOverrideResult / + emulation.SetTimezoneOverrideResult / + emulation.SetTouchOverrideResult / + emulation.SetUserAgentOverrideResult / + emulation.SetViewportMetaOverrideResult +) + +emulation.SetForcedColorsModeThemeOverride = ( + method: "emulation.setForcedColorsModeThemeOverride", + params: emulation.SetForcedColorsModeThemeOverrideParameters +) + +emulation.SetForcedColorsModeThemeOverrideParameters = { + theme: emulation.ForcedColorsModeTheme / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.ForcedColorsModeTheme = "light" / "dark" + +emulation.SetForcedColorsModeThemeOverrideResult = EmptyResult + +emulation.SetGeolocationOverride = ( + method: "emulation.setGeolocationOverride", + params: emulation.SetGeolocationOverrideParameters +) + +emulation.SetGeolocationOverrideParameters = { + ( + (coordinates: emulation.GeolocationCoordinates / null) // + (error: emulation.GeolocationPositionError) + ), + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.GeolocationCoordinates = { + latitude: -90.0..90.0, + longitude: -180.0..180.0, + ? accuracy: (float .ge 0.0) .default 1.0, + ? altitude: float / null .default null, + ? altitudeAccuracy: (float .ge 0.0) / null .default null, + ? heading: (0.0...360.0) / null .default null, + ? speed: (float .ge 0.0) / null .default null, +} + +emulation.GeolocationPositionError = { + type: "positionUnavailable" +} + +emulation.SetGeolocationOverrideResult = EmptyResult + +emulation.SetLocaleOverride = ( + method: "emulation.setLocaleOverride", + params: emulation.SetLocaleOverrideParameters +) + +emulation.SetLocaleOverrideParameters = { + locale: text / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetLocaleOverrideResult = EmptyResult + +emulation.SetNetworkConditions = ( + method: "emulation.setNetworkConditions", + params: emulation.setNetworkConditionsParameters +) + +emulation.setNetworkConditionsParameters = { + networkConditions: emulation.NetworkConditions / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.NetworkConditions = emulation.NetworkConditionsOffline + +emulation.NetworkConditionsOffline = { + type: "offline" +} + +emulation.SetNetworkConditionsResult = EmptyResult + +emulation.SetScreenSettingsOverride = ( + method: "emulation.setScreenSettingsOverride", + params: emulation.SetScreenSettingsOverrideParameters +) + +emulation.ScreenArea = { + width: js-uint, + height: js-uint +} + +emulation.SetScreenSettingsOverrideParameters = { + screenArea: emulation.ScreenArea / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetScreenSettingsOverrideResult = EmptyResult + +emulation.SetScreenOrientationOverride = ( + method: "emulation.setScreenOrientationOverride", + params: emulation.SetScreenOrientationOverrideParameters +) + +emulation.ScreenOrientationNatural = "portrait" / "landscape" +emulation.ScreenOrientationType = "portrait-primary" / "portrait-secondary" / "landscape-primary" / "landscape-secondary" + +emulation.ScreenOrientation = { + natural: emulation.ScreenOrientationNatural, + type: emulation.ScreenOrientationType +} + +emulation.SetScreenOrientationOverrideParameters = { + screenOrientation: emulation.ScreenOrientation / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetScreenOrientationOverrideResult = EmptyResult + +emulation.SetUserAgentOverride = ( + method: "emulation.setUserAgentOverride", + params: emulation.SetUserAgentOverrideParameters +) + +emulation.SetUserAgentOverrideParameters = { + userAgent: text / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetUserAgentOverrideResult = EmptyResult + +emulation.SetViewportMetaOverride = ( + method: "emulation.setViewportMetaOverride", + params: emulation.SetViewportMetaOverrideParameters +) + +emulation.SetViewportMetaOverrideParameters = { + viewportMeta: true / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetViewportMetaOverrideResult = EmptyResult + +emulation.SetScriptingEnabled = ( + method: "emulation.setScriptingEnabled", + params: emulation.SetScriptingEnabledParameters +) + +emulation.SetScriptingEnabledParameters = { + enabled: false / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetScriptingEnabledResult = EmptyResult + +emulation.SetScrollbarTypeOverride = ( + method: "emulation.setScrollbarTypeOverride", + params: emulation.SetScrollbarTypeOverrideParameters +) + +emulation.SetScrollbarTypeOverrideParameters = { + scrollbarType: "classic" / "overlay" / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetScrollbarTypeOverrideResult = EmptyResult + +emulation.SetTimezoneOverride = ( + method: "emulation.setTimezoneOverride", + params: emulation.SetTimezoneOverrideParameters +) + +emulation.SetTimezoneOverrideParameters = { + timezone: text / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetTimezoneOverrideResult = EmptyResult + +emulation.SetTouchOverride = ( + method: "emulation.setTouchOverride", + params: emulation.SetTouchOverrideParameters +) + +emulation.SetTouchOverrideParameters = { + maxTouchPoints: (js-uint .ge 1) / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetTouchOverrideResult = EmptyResult + + +NetworkCommand = ( + network.AddDataCollector // + network.AddIntercept // + network.ContinueRequest // + network.ContinueResponse // + network.ContinueWithAuth // + network.DisownData // + network.FailRequest // + network.GetData // + network.ProvideResponse // + network.RemoveDataCollector // + network.RemoveIntercept // + network.SetCacheBehavior // + network.SetExtraHeaders +) + + + +NetworkResult = ( + network.AddDataCollectorResult / + network.AddInterceptResult / + network.ContinueRequestResult / + network.ContinueResponseResult / + network.ContinueWithAuthResult / + network.DisownDataResult / + network.FailRequestResult / + network.GetDataResult / + network.ProvideResponseResult / + network.RemoveDataCollectorResult / + network.RemoveInterceptResult / + network.SetCacheBehaviorResult / + network.SetExtraHeadersResult +) + +NetworkEvent = ( + network.AuthRequired // + network.BeforeRequestSent // + network.FetchError // + network.ResponseCompleted // + network.ResponseStarted +) + + +network.AuthChallenge = { + scheme: text, + realm: text, +} + +network.AuthCredentials = { + type: "password", + username: text, + password: text, +} + +network.BaseParameters = ( + context: browsingContext.BrowsingContext / null, + isBlocked: bool, + navigation: browsingContext.Navigation / null, + redirectCount: js-uint, + request: network.RequestData, + timestamp: js-uint, + ? intercepts: [+network.Intercept] +) + +network.BytesValue = network.StringValue / network.Base64Value; + +network.StringValue = { + type: "string", + value: text, +} + +network.Base64Value = { + type: "base64", + value: text, +} + +network.Collector = text + +network.CollectorType = "blob" + + +network.SameSite = "strict" / "lax" / "none" / "default" + + +network.Cookie = { + name: text, + value: network.BytesValue, + domain: text, + path: text, + size: js-uint, + httpOnly: bool, + secure: bool, + sameSite: network.SameSite, + ? expiry: js-uint, + Extensible, +} + +network.CookieHeader = { + name: text, + value: network.BytesValue, +} + +network.DataType = "request" / "response" + +network.FetchTimingInfo = { + timeOrigin: float, + requestTime: float, + redirectStart: float, + redirectEnd: float, + fetchStart: float, + dnsStart: float, + dnsEnd: float, + connectStart: float, + connectEnd: float, + tlsStart: float, + + requestStart: float, + responseStart: float, + + responseEnd: float, +} + +network.Header = { + name: text, + value: network.BytesValue, +} + +network.Initiator = { + ? columnNumber: js-uint, + ? lineNumber: js-uint, + ? request: network.Request, + ? stackTrace: script.StackTrace, + ? type: "parser" / "script" / "preflight" / "other" +} + +network.Intercept = text + +network.Request = text; + +network.RequestData = { + request: network.Request, + url: text, + method: text, + headers: [*network.Header], + cookies: [*network.Cookie], + headersSize: js-uint, + bodySize: js-uint / null, + destination: text, + initiatorType: text / null, + timings: network.FetchTimingInfo, +} + +network.ResponseContent = { + size: js-uint +} + +network.ResponseData = { + url: text, + protocol: text, + status: js-uint, + statusText: text, + fromCache: bool, + headers: [*network.Header], + mimeType: text, + bytesReceived: js-uint, + headersSize: js-uint / null, + bodySize: js-uint / null, + content: network.ResponseContent, + ?authChallenges: [*network.AuthChallenge], +} + + +network.SetCookieHeader = { + name: text, + value: network.BytesValue, + ? domain: text, + ? httpOnly: bool, + ? expiry: text, + ? maxAge: js-int, + ? path: text, + ? sameSite: network.SameSite, + ? secure: bool, +} + +network.UrlPattern = ( + network.UrlPatternPattern / + network.UrlPatternString +) + +network.UrlPatternPattern = { + type: "pattern", + ?protocol: text, + ?hostname: text, + ?port: text, + ?pathname: text, + ?search: text, +} + + +network.UrlPatternString = { + type: "string", + pattern: text, +} + + +network.AddDataCollector = ( + method: "network.addDataCollector", + params: network.AddDataCollectorParameters +) + +network.AddDataCollectorParameters = { + dataTypes: [+network.DataType], + maxEncodedDataSize: js-uint, + ? collectorType: network.CollectorType .default "blob", + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +network.AddDataCollectorResult = { + collector: network.Collector +} + +network.AddIntercept = ( + method: "network.addIntercept", + params: network.AddInterceptParameters +) + +network.AddInterceptParameters = { + phases: [+network.InterceptPhase], + ? contexts: [+browsingContext.BrowsingContext], + ? urlPatterns: [*network.UrlPattern], +} + +network.InterceptPhase = "beforeRequestSent" / "responseStarted" / + "authRequired" + +network.AddInterceptResult = { + intercept: network.Intercept +} + +network.ContinueRequest = ( + method: "network.continueRequest", + params: network.ContinueRequestParameters +) + +network.ContinueRequestParameters = { + request: network.Request, + ?body: network.BytesValue, + ?cookies: [*network.CookieHeader], + ?headers: [*network.Header], + ?method: text, + ?url: text, +} + +network.ContinueRequestResult = EmptyResult + +network.ContinueResponse = ( + method: "network.continueResponse", + params: network.ContinueResponseParameters +) + +network.ContinueResponseParameters = { + request: network.Request, + ?cookies: [*network.SetCookieHeader] + ?credentials: network.AuthCredentials, + ?headers: [*network.Header], + ?reasonPhrase: text, + ?statusCode: js-uint, +} + +network.ContinueResponseResult = EmptyResult + +network.ContinueWithAuth = ( + method: "network.continueWithAuth", + params: network.ContinueWithAuthParameters +) + +network.ContinueWithAuthParameters = { + request: network.Request, + (network.ContinueWithAuthCredentials // network.ContinueWithAuthNoCredentials) +} + +network.ContinueWithAuthCredentials = ( + action: "provideCredentials", + credentials: network.AuthCredentials +) + +network.ContinueWithAuthNoCredentials = ( + action: "default" / "cancel" +) + +network.ContinueWithAuthResult = EmptyResult + +network.DisownData = ( + method: "network.disownData", + params: network.disownDataParameters +) + +network.disownDataParameters = { + dataType: network.DataType, + collector: network.Collector, + request: network.Request, +} + +network.DisownDataResult = EmptyResult + +network.FailRequest = ( + method: "network.failRequest", + params: network.FailRequestParameters +) + +network.FailRequestParameters = { + request: network.Request, +} + +network.FailRequestResult = EmptyResult + +network.GetData = ( + method: "network.getData", + params: network.GetDataParameters +) + +network.GetDataParameters = { + dataType: network.DataType, + ? collector: network.Collector, + ? disown: bool .default false, + request: network.Request, +} + +network.GetDataResult = { + bytes: network.BytesValue, +} + +network.ProvideResponse = ( + method: "network.provideResponse", + params: network.ProvideResponseParameters +) + +network.ProvideResponseParameters = { + request: network.Request, + ?body: network.BytesValue, + ?cookies: [*network.SetCookieHeader], + ?headers: [*network.Header], + ?reasonPhrase: text, + ?statusCode: js-uint, +} + +network.ProvideResponseResult = EmptyResult + +network.RemoveDataCollector = ( + method: "network.removeDataCollector", + params: network.RemoveDataCollectorParameters +) + +network.RemoveDataCollectorParameters = { + collector: network.Collector +} + +network.RemoveDataCollectorResult = EmptyResult + +network.RemoveIntercept = ( + method: "network.removeIntercept", + params: network.RemoveInterceptParameters +) + +network.RemoveInterceptParameters = { + intercept: network.Intercept +} + +network.RemoveInterceptResult = EmptyResult + +network.SetCacheBehavior = ( + method: "network.setCacheBehavior", + params: network.SetCacheBehaviorParameters +) + +network.SetCacheBehaviorParameters = { + cacheBehavior: "default" / "bypass", + ? contexts: [+browsingContext.BrowsingContext] +} + +network.SetCacheBehaviorResult = EmptyResult + +network.SetExtraHeaders = ( + method: "network.setExtraHeaders", + params: network.SetExtraHeadersParameters +) + +network.SetExtraHeadersParameters = { + headers: [*network.Header] + ? contexts: [+browsingContext.BrowsingContext] + ? userContexts: [+browser.UserContext] +} + +network.SetExtraHeadersResult = EmptyResult + +network.AuthRequired = ( + method: "network.authRequired", + params: network.AuthRequiredParameters +) + +network.AuthRequiredParameters = { + network.BaseParameters, + response: network.ResponseData +} + + network.BeforeRequestSent = ( + method: "network.beforeRequestSent", + params: network.BeforeRequestSentParameters + ) + +network.BeforeRequestSentParameters = { + network.BaseParameters, + ? initiator: network.Initiator, +} + + network.FetchError = ( + method: "network.fetchError", + params: network.FetchErrorParameters + ) + +network.FetchErrorParameters = { + network.BaseParameters, + errorText: text, +} + + network.ResponseCompleted = ( + method: "network.responseCompleted", + params: network.ResponseCompletedParameters + ) + +network.ResponseCompletedParameters = { + network.BaseParameters, + response: network.ResponseData, +} + + network.ResponseStarted = ( + method: "network.responseStarted", + params: network.ResponseStartedParameters + ) + +network.ResponseStartedParameters = { + network.BaseParameters, + response: network.ResponseData, +} + +ScriptCommand = ( + script.AddPreloadScript // + script.CallFunction // + script.Disown // + script.Evaluate // + script.GetRealms // + script.RemovePreloadScript +) + +ScriptResult = ( + script.AddPreloadScriptResult / + script.CallFunctionResult / + script.DisownResult / + script.EvaluateResult / + script.GetRealmsResult / + script.RemovePreloadScriptResult +) + +ScriptEvent = ( + script.Message // + script.RealmCreated // + script.RealmDestroyed +) + +script.Channel = text; + +script.ChannelValue = { + type: "channel", + value: script.ChannelProperties, +} + +script.ChannelProperties = { + channel: script.Channel, + ? serializationOptions: script.SerializationOptions, + ? ownership: script.ResultOwnership, +} + +script.EvaluateResult = ( + script.EvaluateResultSuccess / + script.EvaluateResultException +) + +script.EvaluateResultSuccess = { + type: "success", + result: script.RemoteValue, + realm: script.Realm +} + +script.EvaluateResultException = { + type: "exception", + exceptionDetails: script.ExceptionDetails + realm: script.Realm +} + +script.ExceptionDetails = { + columnNumber: js-uint, + exception: script.RemoteValue, + lineNumber: js-uint, + stackTrace: script.StackTrace, + text: text, +} + +script.Handle = text; + +script.InternalId = text; + +script.LocalValue = ( + script.RemoteReference / + script.PrimitiveProtocolValue / + script.ChannelValue / + script.ArrayLocalValue / + { script.DateLocalValue } / + script.MapLocalValue / + script.ObjectLocalValue / + { script.RegExpLocalValue } / + script.SetLocalValue +) + +script.ListLocalValue = [*script.LocalValue]; + +script.ArrayLocalValue = { + type: "array", + value: script.ListLocalValue, +} + +script.DateLocalValue = ( + type: "date", + value: text +) + +script.MappingLocalValue = [*[(script.LocalValue / text), script.LocalValue]]; + +script.MapLocalValue = { + type: "map", + value: script.MappingLocalValue, +} + +script.ObjectLocalValue = { + type: "object", + value: script.MappingLocalValue, +} + +script.RegExpValue = { + pattern: text, + ? flags: text, +} + +script.RegExpLocalValue = ( + type: "regexp", + value: script.RegExpValue, +) + +script.SetLocalValue = { + type: "set", + value: script.ListLocalValue, +} + +script.PreloadScript = text; + +script.Realm = text; + +script.PrimitiveProtocolValue = ( + script.UndefinedValue / + script.NullValue / + script.StringValue / + script.NumberValue / + script.BooleanValue / + script.BigIntValue +) + +script.UndefinedValue = { + type: "undefined", +} + +script.NullValue = { + type: "null", +} + +script.StringValue = { + type: "string", + value: text, +} + +script.SpecialNumber = "NaN" / "-0" / "Infinity" / "-Infinity"; + +script.NumberValue = { + type: "number", + value: number / script.SpecialNumber, +} + +script.BooleanValue = { + type: "boolean", + value: bool, +} + +script.BigIntValue = { + type: "bigint", + value: text, +} + +script.RealmInfo = ( + script.WindowRealmInfo / + script.DedicatedWorkerRealmInfo / + script.SharedWorkerRealmInfo / + script.ServiceWorkerRealmInfo / + script.WorkerRealmInfo / + script.PaintWorkletRealmInfo / + script.AudioWorkletRealmInfo / + script.WorkletRealmInfo +) + +script.BaseRealmInfo = ( + realm: script.Realm, + origin: text +) + +script.WindowRealmInfo = { + script.BaseRealmInfo, + type: "window", + context: browsingContext.BrowsingContext, + ? sandbox: text +} + +script.DedicatedWorkerRealmInfo = { + script.BaseRealmInfo, + type: "dedicated-worker", + owners: [script.Realm] +} + +script.SharedWorkerRealmInfo = { + script.BaseRealmInfo, + type: "shared-worker" +} + +script.ServiceWorkerRealmInfo = { + script.BaseRealmInfo, + type: "service-worker" +} + +script.WorkerRealmInfo = { + script.BaseRealmInfo, + type: "worker" +} + +script.PaintWorkletRealmInfo = { + script.BaseRealmInfo, + type: "paint-worklet" +} + +script.AudioWorkletRealmInfo = { + script.BaseRealmInfo, + type: "audio-worklet" +} + +script.WorkletRealmInfo = { + script.BaseRealmInfo, + type: "worklet" +} + +script.RealmType = "window" / "dedicated-worker" / "shared-worker" / "service-worker" / + "worker" / "paint-worklet" / "audio-worklet" / "worklet" + + + +script.RemoteReference = ( + script.SharedReference / + script.RemoteObjectReference +) + +script.SharedReference = { + sharedId: script.SharedId + + ? handle: script.Handle, + Extensible +} + +script.RemoteObjectReference = { + handle: script.Handle, + + ? sharedId: script.SharedId + Extensible +} + +script.RemoteValue = ( + script.PrimitiveProtocolValue / + script.SymbolRemoteValue / + script.ArrayRemoteValue / + script.ObjectRemoteValue / + script.FunctionRemoteValue / + script.RegExpRemoteValue / + script.DateRemoteValue / + script.MapRemoteValue / + script.SetRemoteValue / + script.WeakMapRemoteValue / + script.WeakSetRemoteValue / + script.GeneratorRemoteValue / + script.ErrorRemoteValue / + script.ProxyRemoteValue / + script.PromiseRemoteValue / + script.TypedArrayRemoteValue / + script.ArrayBufferRemoteValue / + script.NodeListRemoteValue / + script.HTMLCollectionRemoteValue / + script.NodeRemoteValue / + script.WindowProxyRemoteValue +) + +script.ListRemoteValue = [*script.RemoteValue]; + +script.MappingRemoteValue = [*[(script.RemoteValue / text), script.RemoteValue]]; + +script.SymbolRemoteValue = { + type: "symbol", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ArrayRemoteValue = { + type: "array", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.ObjectRemoteValue = { + type: "object", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.MappingRemoteValue, +} + +script.FunctionRemoteValue = { + type: "function", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.RegExpRemoteValue = { + script.RegExpLocalValue, + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.DateRemoteValue = { + script.DateLocalValue, + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.MapRemoteValue = { + type: "map", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.MappingRemoteValue, +} + +script.SetRemoteValue = { + type: "set", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue +} + +script.WeakMapRemoteValue = { + type: "weakmap", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.WeakSetRemoteValue = { + type: "weakset", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.GeneratorRemoteValue = { + type: "generator", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ErrorRemoteValue = { + type: "error", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ProxyRemoteValue = { + type: "proxy", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.PromiseRemoteValue = { + type: "promise", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.TypedArrayRemoteValue = { + type: "typedarray", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ArrayBufferRemoteValue = { + type: "arraybuffer", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.NodeListRemoteValue = { + type: "nodelist", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.HTMLCollectionRemoteValue = { + type: "htmlcollection", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.NodeRemoteValue = { + type: "node", + ? sharedId: script.SharedId, + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.NodeProperties, +} + +script.NodeProperties = { + nodeType: js-uint, + childNodeCount: js-uint, + ? attributes: {*text => text}, + ? children: [*script.NodeRemoteValue], + ? localName: text, + ? mode: "open" / "closed", + ? namespaceURI: text, + ? nodeValue: text, + ? shadowRoot: script.NodeRemoteValue / null, +} + +script.WindowProxyRemoteValue = { + type: "window", + value: script.WindowProxyProperties, + ? handle: script.Handle, + ? internalId: script.InternalId +} + +script.WindowProxyProperties = { + context: browsingContext.BrowsingContext +} + +script.ResultOwnership = "root" / "none" + +script.SerializationOptions = { + ? maxDomDepth: (js-uint / null) .default 0, + ? maxObjectDepth: (js-uint / null) .default null, + ? includeShadowTree: ("none" / "open" / "all") .default "none", +} + +script.SharedId = text; + +script.StackFrame = { + columnNumber: js-uint, + functionName: text, + lineNumber: js-uint, + url: text, +} + +script.StackTrace = { + callFrames: [*script.StackFrame], +} + +script.Source = { + realm: script.Realm, + ? context: browsingContext.BrowsingContext +} + +script.RealmTarget = { + realm: script.Realm +} + +script.ContextTarget = { + context: browsingContext.BrowsingContext, + ? sandbox: text +} + +script.Target = ( + script.ContextTarget / + script.RealmTarget +) + +script.AddPreloadScript = ( + method: "script.addPreloadScript", + params: script.AddPreloadScriptParameters +) + +script.AddPreloadScriptParameters = { + functionDeclaration: text, + ? arguments: [*script.ChannelValue], + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], + ? sandbox: text +} + +script.AddPreloadScriptResult = { + script: script.PreloadScript +} + +script.Disown = ( + method: "script.disown", + params: script.DisownParameters +) + +script.DisownParameters = { + handles: [*script.Handle] + target: script.Target; +} + +script.DisownResult = EmptyResult + +script.CallFunction = ( + method: "script.callFunction", + params: script.CallFunctionParameters +) + +script.CallFunctionParameters = { + functionDeclaration: text, + awaitPromise: bool, + target: script.Target, + ? arguments: [*script.LocalValue], + ? resultOwnership: script.ResultOwnership, + ? serializationOptions: script.SerializationOptions, + ? this: script.LocalValue, + ? userActivation: bool .default false, +} + +script.CallFunctionResult = script.EvaluateResult + +script.Evaluate = ( + method: "script.evaluate", + params: script.EvaluateParameters +) + +script.EvaluateParameters = { + expression: text, + target: script.Target, + awaitPromise: bool, + ? resultOwnership: script.ResultOwnership, + ? serializationOptions: script.SerializationOptions, + ? userActivation: bool .default false, +} + +script.GetRealms = ( + method: "script.getRealms", + params: script.GetRealmsParameters +) + +script.GetRealmsParameters = { + ? context: browsingContext.BrowsingContext, + ? type: script.RealmType, +} + +script.GetRealmsResult = { + realms: [*script.RealmInfo] +} + +script.RemovePreloadScript = ( + method: "script.removePreloadScript", + params: script.RemovePreloadScriptParameters +) + +script.RemovePreloadScriptParameters = { + script: script.PreloadScript +} + +script.RemovePreloadScriptResult = EmptyResult + + script.Message = ( + method: "script.message", + params: script.MessageParameters + ) + +script.MessageParameters = { + channel: script.Channel, + data: script.RemoteValue, + source: script.Source, +} + +script.RealmCreated = ( + method: "script.realmCreated", + params: script.RealmInfo +) + +script.RealmDestroyed = ( + method: "script.realmDestroyed", + params: script.RealmDestroyedParameters +) + +script.RealmDestroyedParameters = { + realm: script.Realm +} + + +StorageCommand = ( + storage.DeleteCookies // + storage.GetCookies // + storage.SetCookie +) + +StorageResult = ( + storage.DeleteCookiesResult / + storage.GetCookiesResult / + storage.SetCookieResult +) + +storage.PartitionKey = { + ? userContext: text, + ? sourceOrigin: text, + Extensible, +} + +storage.GetCookies = ( + method: "storage.getCookies", + params: storage.GetCookiesParameters +) + + +storage.CookieFilter = { + ? name: text, + ? value: network.BytesValue, + ? domain: text, + ? path: text, + ? size: js-uint, + ? httpOnly: bool, + ? secure: bool, + ? sameSite: network.SameSite, + ? expiry: js-uint, + Extensible, +} + +storage.BrowsingContextPartitionDescriptor = { + type: "context", + context: browsingContext.BrowsingContext +} + +storage.StorageKeyPartitionDescriptor = { + type: "storageKey", + ? userContext: text, + ? sourceOrigin: text, + Extensible, +} + +storage.PartitionDescriptor = ( + storage.BrowsingContextPartitionDescriptor / + storage.StorageKeyPartitionDescriptor +) + +storage.GetCookiesParameters = { + ? filter: storage.CookieFilter, + ? partition: storage.PartitionDescriptor, +} + +storage.GetCookiesResult = { + cookies: [*network.Cookie], + partitionKey: storage.PartitionKey, +} + +storage.SetCookie = ( + method: "storage.setCookie", + params: storage.SetCookieParameters, +) + + +storage.PartialCookie = { + name: text, + value: network.BytesValue, + domain: text, + ? path: text, + ? httpOnly: bool, + ? secure: bool, + ? sameSite: network.SameSite, + ? expiry: js-uint, + Extensible, +} + +storage.SetCookieParameters = { + cookie: storage.PartialCookie, + ? partition: storage.PartitionDescriptor, +} + +storage.SetCookieResult = { + partitionKey: storage.PartitionKey +} + +storage.DeleteCookies = ( + method: "storage.deleteCookies", + params: storage.DeleteCookiesParameters, +) + +storage.DeleteCookiesParameters = { + ? filter: storage.CookieFilter, + ? partition: storage.PartitionDescriptor, +} + +storage.DeleteCookiesResult = { + partitionKey: storage.PartitionKey +} + +LogEvent = ( + log.EntryAdded +) + +log.Level = "debug" / "info" / "warn" / "error" + +log.Entry = ( + log.GenericLogEntry / + log.ConsoleLogEntry / + log.JavascriptLogEntry +) + +log.BaseLogEntry = ( + level: log.Level, + source: script.Source, + text: text / null, + timestamp: js-uint, + ? stackTrace: script.StackTrace, +) + +log.GenericLogEntry = { + log.BaseLogEntry, + type: text, +} + +log.ConsoleLogEntry = { + log.BaseLogEntry, + type: "console", + method: text, + args: [*script.RemoteValue], +} + +log.JavascriptLogEntry = { + log.BaseLogEntry, + type: "javascript", +} + +log.EntryAdded = ( + method: "log.entryAdded", + params: log.Entry, +) + +InputCommand = ( + input.PerformActions // + input.ReleaseActions // + input.SetFiles +) + +InputResult = ( + input.PerformActionsResult / + input.ReleaseActionsResult / + input.SetFilesResult +) + + +InputEvent = ( + input.FileDialogOpened +) + +input.ElementOrigin = { + type: "element", + element: script.SharedReference +} + +input.PerformActions = ( + method: "input.performActions", + params: input.PerformActionsParameters +) + +input.PerformActionsParameters = { + context: browsingContext.BrowsingContext, + actions: [*input.SourceActions] +} + +input.SourceActions = ( + input.NoneSourceActions / + input.KeySourceActions / + input.PointerSourceActions / + input.WheelSourceActions +) + +input.NoneSourceActions = { + type: "none", + id: text, + actions: [*input.NoneSourceAction] +} + +input.NoneSourceAction = input.PauseAction + +input.KeySourceActions = { + type: "key", + id: text, + actions: [*input.KeySourceAction] +} + +input.KeySourceAction = ( + input.PauseAction / + input.KeyDownAction / + input.KeyUpAction +) + +input.PointerSourceActions = { + type: "pointer", + id: text, + ? parameters: input.PointerParameters, + actions: [*input.PointerSourceAction] +} + +input.PointerType = "mouse" / "pen" / "touch" + +input.PointerParameters = { + ? pointerType: input.PointerType .default "mouse" +} + +input.PointerSourceAction = ( + input.PauseAction / + input.PointerDownAction / + input.PointerUpAction / + input.PointerMoveAction +) + +input.WheelSourceActions = { + type: "wheel", + id: text, + actions: [*input.WheelSourceAction] +} + +input.WheelSourceAction = ( + input.PauseAction / + input.WheelScrollAction +) + +input.PauseAction = { + type: "pause", + ? duration: js-uint +} + +input.KeyDownAction = { + type: "keyDown", + value: text +} + +input.KeyUpAction = { + type: "keyUp", + value: text +} + +input.PointerUpAction = { + type: "pointerUp", + button: js-uint, +} + +input.PointerDownAction = { + type: "pointerDown", + button: js-uint, + input.PointerCommonProperties +} + +input.PointerMoveAction = { + type: "pointerMove", + x: float, + y: float, + ? duration: js-uint, + ? origin: input.Origin, + input.PointerCommonProperties +} + +input.WheelScrollAction = { + type: "scroll", + x: js-int, + y: js-int, + deltaX: js-int, + deltaY: js-int, + ? duration: js-uint, + ? origin: input.Origin .default "viewport", +} + +input.PointerCommonProperties = ( + ? width: js-uint .default 1, + ? height: js-uint .default 1, + ? pressure: float .default 0.0, + ? tangentialPressure: float .default 0.0, + ? twist: (0..359) .default 0, + ; 0 .. Math.PI / 2 + ? altitudeAngle: (0.0..1.5707963267948966) .default 0.0, + ; 0 .. 2 * Math.PI + ? azimuthAngle: (0.0..6.283185307179586) .default 0.0, +) + +input.Origin = "viewport" / "pointer" / input.ElementOrigin + +input.PerformActionsResult = EmptyResult + +input.ReleaseActions = ( + method: "input.releaseActions", + params: input.ReleaseActionsParameters +) + +input.ReleaseActionsParameters = { + context: browsingContext.BrowsingContext, +} + +input.ReleaseActionsResult = EmptyResult + +input.SetFiles = ( + method: "input.setFiles", + params: input.SetFilesParameters +) + +input.SetFilesParameters = { + context: browsingContext.BrowsingContext, + element: script.SharedReference, + files: [*text] +} + +input.SetFilesResult = EmptyResult + +input.FileDialogOpened = ( + method: "input.fileDialogOpened", + params: input.FileDialogInfo +) + +input.FileDialogInfo = { + context: browsingContext.BrowsingContext, + ? element: script.SharedReference, + multiple: bool, +} + +WebExtensionCommand = ( + webExtension.Install // + webExtension.Uninstall +) + +WebExtensionResult = ( + webExtension.InstallResult / + webExtension.UninstallResult +) + +webExtension.Extension = text + +webExtension.Install = ( + method: "webExtension.install", + params: webExtension.InstallParameters +) + +webExtension.InstallParameters = { + extensionData: webExtension.ExtensionData, +} + +webExtension.ExtensionData = ( + webExtension.ExtensionArchivePath / + webExtension.ExtensionBase64Encoded / + webExtension.ExtensionPath +) + +webExtension.ExtensionPath = { + type: "path", + path: text, +} + +webExtension.ExtensionArchivePath = { + type: "archivePath", + path: text, +} + +webExtension.ExtensionBase64Encoded = { + type: "base64", + value: text, +} + +webExtension.InstallResult = { + extension: webExtension.Extension +} + +webExtension.Uninstall = ( + method: "webExtension.uninstall", + params: webExtension.UninstallParameters +) + +webExtension.UninstallParameters = { + extension: webExtension.Extension, +} + +webExtension.UninstallResult = EmptyResult diff --git a/common/bidi/spec/local.cddl b/common/bidi/spec/local.cddl new file mode 100644 index 0000000000000..d43af0ae11b03 --- /dev/null +++ b/common/bidi/spec/local.cddl @@ -0,0 +1,1331 @@ +Message = ( + CommandResponse / + ErrorResponse / + Event +) + +CommandResponse = { + type: "success", + id: js-uint, + result: ResultData, + Extensible +} + +ErrorResponse = { + type: "error", + id: js-uint / null, + error: ErrorCode, + message: text, + ? stacktrace: text, + Extensible +} + +ResultData = ( + BrowserResult / + BrowsingContextResult / + EmulationResult / + InputResult / + NetworkResult / + ScriptResult / + SessionResult / + StorageResult / + WebExtensionResult +) + +EmptyResult = { + Extensible +} + +Event = { + type: "event", + EventData, + Extensible +} + +EventData = ( + BrowsingContextEvent // + InputEvent // + LogEvent // + NetworkEvent // + ScriptEvent +) + +Extensible = (*text => any) + +js-int = -9007199254740991..9007199254740991 +js-uint = 0..9007199254740991 + +ErrorCode = "invalid argument" / + "invalid selector" / + "invalid session id" / + "invalid web extension" / + "move target out of bounds" / + "no such alert" / + "no such network collector" / + "no such element" / + "no such frame" / + "no such handle" / + "no such history entry" / + "no such intercept" / + "no such network data" / + "no such node" / + "no such request" / + "no such script" / + "no such storage partition" / + "no such user context" / + "no such web extension" / + "session not created" / + "unable to capture screen" / + "unable to close browser" / + "unable to set cookie" / + "unable to set file input" / + "unavailable network data" / + "underspecified storage partition" / + "unknown command" / + "unknown error" / + "unsupported operation" + +SessionResult = ( + session.EndResult / + session.NewResult / + session.StatusResult / + session.SubscribeResult / + session.UnsubscribeResult +) + +session.CapabilitiesRequest = { + ? alwaysMatch: session.CapabilityRequest, + ? firstMatch: [*session.CapabilityRequest] +} + +session.CapabilityRequest = { + ? acceptInsecureCerts: bool, + ? browserName: text, + ? browserVersion: text, + ? platformName: text, + ? proxy: session.ProxyConfiguration, + ? unhandledPromptBehavior: session.UserPromptHandler, + Extensible +} + +session.ProxyConfiguration = { + session.AutodetectProxyConfiguration // + session.DirectProxyConfiguration // + session.ManualProxyConfiguration // + session.PacProxyConfiguration // + session.SystemProxyConfiguration +} + +session.AutodetectProxyConfiguration = ( + proxyType: "autodetect", + Extensible +) + +session.DirectProxyConfiguration = ( + proxyType: "direct", + Extensible +) + +session.ManualProxyConfiguration = ( + proxyType: "manual", + ? httpProxy: text, + ? sslProxy: text, + ? session.SocksProxyConfiguration, + ? noProxy: [*text], + Extensible +) + +session.SocksProxyConfiguration = ( + socksProxy: text, + socksVersion: 0..255, +) + +session.PacProxyConfiguration = ( + proxyType: "pac", + proxyAutoconfigUrl: text, + Extensible +) + +session.SystemProxyConfiguration = ( + proxyType: "system", + Extensible +) + + +session.UserPromptHandler = { + ? alert: session.UserPromptHandlerType, + ? beforeUnload: session.UserPromptHandlerType, + ? confirm: session.UserPromptHandlerType, + ? default: session.UserPromptHandlerType, + ? file: session.UserPromptHandlerType, + ? prompt: session.UserPromptHandlerType, +} + +session.UserPromptHandlerType = "accept" / "dismiss" / "ignore"; + +session.Subscription = text + +session.StatusResult = { + ready: bool, + message: text, +} + +session.NewResult = { + sessionId: text, + capabilities: { + acceptInsecureCerts: bool, + browserName: text, + browserVersion: text, + platformName: text, + setWindowRect: bool, + userAgent: text, + ? proxy: session.ProxyConfiguration, + ? unhandledPromptBehavior: session.UserPromptHandler, + ? webSocketUrl: text, + Extensible + } +} + +session.EndResult = EmptyResult + +session.SubscribeResult = { + subscription: session.Subscription, +} + +session.UnsubscribeResult = EmptyResult + +BrowserResult = ( + browser.CloseResult / + browser.CreateUserContextResult / + browser.GetClientWindowsResult / + browser.GetUserContextsResult / + browser.RemoveUserContextResult / + browser.SetClientWindowStateResult / + browser.SetDownloadBehaviorResult +) + +browser.ClientWindow = text; + +browser.ClientWindowInfo = { + active: bool, + clientWindow: browser.ClientWindow, + height: js-uint, + state: "fullscreen" / "maximized" / "minimized" / "normal", + width: js-uint, + x: js-int, + y: js-int, +} + +browser.UserContext = text; + +browser.UserContextInfo = { + userContext: browser.UserContext +} + +browser.CloseResult = EmptyResult + +browser.CreateUserContextResult = browser.UserContextInfo + +browser.GetClientWindowsResult = { + clientWindows: [ * browser.ClientWindowInfo] +} + +browser.GetUserContextsResult = { + userContexts: [ + browser.UserContextInfo] +} + +browser.RemoveUserContextResult = EmptyResult + +browser.SetClientWindowStateResult = browser.ClientWindowInfo + +browser.SetDownloadBehaviorResult = EmptyResult + +BrowsingContextResult = ( + browsingContext.ActivateResult / + browsingContext.CaptureScreenshotResult / + browsingContext.CloseResult / + browsingContext.CreateResult / + browsingContext.GetTreeResult / + browsingContext.HandleUserPromptResult / + browsingContext.LocateNodesResult / + browsingContext.NavigateResult / + browsingContext.PrintResult / + browsingContext.ReloadResult / + browsingContext.SetViewportResult / + browsingContext.TraverseHistoryResult +) + +BrowsingContextEvent = ( + browsingContext.ContextCreated // + browsingContext.ContextDestroyed // + browsingContext.DomContentLoaded // + browsingContext.DownloadEnd // + browsingContext.DownloadWillBegin // + browsingContext.FragmentNavigated // + browsingContext.HistoryUpdated // + browsingContext.Load // + browsingContext.NavigationAborted // + browsingContext.NavigationCommitted // + browsingContext.NavigationFailed // + browsingContext.NavigationStarted // + browsingContext.UserPromptClosed // + browsingContext.UserPromptOpened +) + +browsingContext.BrowsingContext = text; + +browsingContext.InfoList = [*browsingContext.Info] + +browsingContext.Info = { + children: browsingContext.InfoList / null, + clientWindow: browser.ClientWindow, + context: browsingContext.BrowsingContext, + originalOpener: browsingContext.BrowsingContext / null, + url: text, + userContext: browser.UserContext, + ? parent: browsingContext.BrowsingContext / null, +} + +browsingContext.Locator = ( + browsingContext.AccessibilityLocator / + browsingContext.CssLocator / + browsingContext.ContextLocator / + browsingContext.InnerTextLocator / + browsingContext.XPathLocator +) + +browsingContext.AccessibilityLocator = { + type: "accessibility", + value: { + ? name: text, + ? role: text, + } +} + +browsingContext.CssLocator = { + type: "css", + value: text +} + +browsingContext.ContextLocator = { + type: "context", + value: { + context: browsingContext.BrowsingContext, + } +} + +browsingContext.InnerTextLocator = { + type: "innerText", + value: text, + ? ignoreCase: bool + ? matchType: "full" / "partial", + ? maxDepth: js-uint, +} + +browsingContext.XPathLocator = { + type: "xpath", + value: text +} + +browsingContext.Navigation = text; + +browsingContext.BaseNavigationInfo = ( + context: browsingContext.BrowsingContext, + navigation: browsingContext.Navigation / null, + timestamp: js-uint, + url: text, +) + +browsingContext.NavigationInfo = { + browsingContext.BaseNavigationInfo +} + +browsingContext.UserPromptType = "alert" / "beforeunload" / "confirm" / "prompt"; + +browsingContext.ActivateResult = EmptyResult + +browsingContext.CaptureScreenshotResult = { + data: text +} + +browsingContext.CloseResult = EmptyResult + +browsingContext.CreateResult = { + context: browsingContext.BrowsingContext +} + +browsingContext.GetTreeResult = { + contexts: browsingContext.InfoList +} + +browsingContext.HandleUserPromptResult = EmptyResult + +browsingContext.LocateNodesResult = { + nodes: [ * script.NodeRemoteValue ] +} + +browsingContext.NavigateResult = { + navigation: browsingContext.Navigation / null, + url: text, +} + +browsingContext.PrintResult = { + data: text +} + +browsingContext.ReloadResult = browsingContext.NavigateResult + +browsingContext.SetViewportResult = EmptyResult + +browsingContext.TraverseHistoryResult = EmptyResult + +browsingContext.ContextCreated = ( + method: "browsingContext.contextCreated", + params: browsingContext.Info +) + +browsingContext.ContextDestroyed = ( + method: "browsingContext.contextDestroyed", + params: browsingContext.Info +) + +browsingContext.NavigationStarted = ( + method: "browsingContext.navigationStarted", + params: browsingContext.NavigationInfo +) + +browsingContext.FragmentNavigated = ( + method: "browsingContext.fragmentNavigated", + params: browsingContext.NavigationInfo +) + +browsingContext.HistoryUpdated = ( + method: "browsingContext.historyUpdated", + params: browsingContext.HistoryUpdatedParameters +) + +browsingContext.HistoryUpdatedParameters = { + context: browsingContext.BrowsingContext, + timestamp: js-uint, + url: text +} + +browsingContext.DomContentLoaded = ( + method: "browsingContext.domContentLoaded", + params: browsingContext.NavigationInfo +) + +browsingContext.Load = ( + method: "browsingContext.load", + params: browsingContext.NavigationInfo +) + +browsingContext.DownloadWillBegin = ( + method: "browsingContext.downloadWillBegin", + params: browsingContext.DownloadWillBeginParams +) + +browsingContext.DownloadWillBeginParams = { + suggestedFilename: text, + browsingContext.BaseNavigationInfo +} + +browsingContext.DownloadEnd = ( + method: "browsingContext.downloadEnd", + params: browsingContext.DownloadEndParams +) + +browsingContext.DownloadEndParams = { + ( + browsingContext.DownloadCanceledParams // + browsingContext.DownloadCompleteParams + ) +} + +browsingContext.DownloadCanceledParams = ( + status: "canceled", + browsingContext.BaseNavigationInfo +) + +browsingContext.DownloadCompleteParams = ( + status: "complete", + filepath: text / null, + browsingContext.BaseNavigationInfo +) + +browsingContext.NavigationAborted = ( + method: "browsingContext.navigationAborted", + params: browsingContext.NavigationInfo +) + +browsingContext.NavigationCommitted = ( + method: "browsingContext.navigationCommitted", + params: browsingContext.NavigationInfo +) + +browsingContext.NavigationFailed = ( + method: "browsingContext.navigationFailed", + params: browsingContext.NavigationInfo +) + +browsingContext.UserPromptClosed = ( + method: "browsingContext.userPromptClosed", + params: browsingContext.UserPromptClosedParameters +) + +browsingContext.UserPromptClosedParameters = { + context: browsingContext.BrowsingContext, + accepted: bool, + type: browsingContext.UserPromptType, + ? userText: text +} + +browsingContext.UserPromptOpened = ( + method: "browsingContext.userPromptOpened", + params: browsingContext.UserPromptOpenedParameters +) + +browsingContext.UserPromptOpenedParameters = { + context: browsingContext.BrowsingContext, + handler: session.UserPromptHandlerType, + message: text, + type: browsingContext.UserPromptType, + ? defaultValue: text +} + +EmulationResult = ( + emulation.SetForcedColorsModeThemeOverrideResult / + emulation.SetGeolocationOverrideResult / + emulation.SetLocaleOverrideResult / + emulation.SetScreenOrientationOverrideResult / + emulation.SetScriptingEnabledResult / + emulation.SetScrollbarTypeOverrideResult / + emulation.SetTimezoneOverrideResult / + emulation.SetTouchOverrideResult / + emulation.SetUserAgentOverrideResult / + emulation.SetViewportMetaOverrideResult +) + +emulation.SetForcedColorsModeThemeOverrideResult = EmptyResult + +emulation.SetGeolocationOverrideResult = EmptyResult + +emulation.SetLocaleOverrideResult = EmptyResult + +emulation.SetNetworkConditionsResult = EmptyResult + +emulation.SetScreenSettingsOverrideResult = EmptyResult + +emulation.SetScreenOrientationOverrideResult = EmptyResult + +emulation.SetUserAgentOverrideResult = EmptyResult + +emulation.SetViewportMetaOverrideResult = EmptyResult + +emulation.SetScriptingEnabledResult = EmptyResult + +emulation.SetScrollbarTypeOverrideResult = EmptyResult + +emulation.SetTimezoneOverrideResult = EmptyResult + +emulation.SetTouchOverrideResult = EmptyResult + + +NetworkResult = ( + network.AddDataCollectorResult / + network.AddInterceptResult / + network.ContinueRequestResult / + network.ContinueResponseResult / + network.ContinueWithAuthResult / + network.DisownDataResult / + network.FailRequestResult / + network.GetDataResult / + network.ProvideResponseResult / + network.RemoveDataCollectorResult / + network.RemoveInterceptResult / + network.SetCacheBehaviorResult / + network.SetExtraHeadersResult +) + +NetworkEvent = ( + network.AuthRequired // + network.BeforeRequestSent // + network.FetchError // + network.ResponseCompleted // + network.ResponseStarted +) + + +network.AuthChallenge = { + scheme: text, + realm: text, +} + +network.BaseParameters = ( + context: browsingContext.BrowsingContext / null, + isBlocked: bool, + navigation: browsingContext.Navigation / null, + redirectCount: js-uint, + request: network.RequestData, + timestamp: js-uint, + ? intercepts: [+network.Intercept] +) + +network.BytesValue = network.StringValue / network.Base64Value; + +network.StringValue = { + type: "string", + value: text, +} + +network.Base64Value = { + type: "base64", + value: text, +} + +network.Collector = text + +network.CollectorType = "blob" + + +network.SameSite = "strict" / "lax" / "none" / "default" + + +network.Cookie = { + name: text, + value: network.BytesValue, + domain: text, + path: text, + size: js-uint, + httpOnly: bool, + secure: bool, + sameSite: network.SameSite, + ? expiry: js-uint, + Extensible, +} + +network.DataType = "request" / "response" + +network.FetchTimingInfo = { + timeOrigin: float, + requestTime: float, + redirectStart: float, + redirectEnd: float, + fetchStart: float, + dnsStart: float, + dnsEnd: float, + connectStart: float, + connectEnd: float, + tlsStart: float, + + requestStart: float, + responseStart: float, + + responseEnd: float, +} + +network.Header = { + name: text, + value: network.BytesValue, +} + +network.Initiator = { + ? columnNumber: js-uint, + ? lineNumber: js-uint, + ? request: network.Request, + ? stackTrace: script.StackTrace, + ? type: "parser" / "script" / "preflight" / "other" +} + +network.Intercept = text + +network.Request = text; + +network.RequestData = { + request: network.Request, + url: text, + method: text, + headers: [*network.Header], + cookies: [*network.Cookie], + headersSize: js-uint, + bodySize: js-uint / null, + destination: text, + initiatorType: text / null, + timings: network.FetchTimingInfo, +} + +network.ResponseContent = { + size: js-uint +} + +network.ResponseData = { + url: text, + protocol: text, + status: js-uint, + statusText: text, + fromCache: bool, + headers: [*network.Header], + mimeType: text, + bytesReceived: js-uint, + headersSize: js-uint / null, + bodySize: js-uint / null, + content: network.ResponseContent, + ?authChallenges: [*network.AuthChallenge], +} + +network.AddDataCollectorResult = { + collector: network.Collector +} + +network.AddInterceptResult = { + intercept: network.Intercept +} + +network.ContinueRequestResult = EmptyResult + +network.ContinueResponseResult = EmptyResult + +network.ContinueWithAuthResult = EmptyResult + +network.DisownDataResult = EmptyResult + +network.FailRequestResult = EmptyResult + +network.GetDataResult = { + bytes: network.BytesValue, +} + +network.ProvideResponseResult = EmptyResult + +network.RemoveDataCollectorResult = EmptyResult + +network.RemoveInterceptResult = EmptyResult + +network.SetCacheBehaviorResult = EmptyResult + +network.SetExtraHeadersResult = EmptyResult + +network.AuthRequired = ( + method: "network.authRequired", + params: network.AuthRequiredParameters +) + +network.AuthRequiredParameters = { + network.BaseParameters, + response: network.ResponseData +} + + network.BeforeRequestSent = ( + method: "network.beforeRequestSent", + params: network.BeforeRequestSentParameters + ) + +network.BeforeRequestSentParameters = { + network.BaseParameters, + ? initiator: network.Initiator, +} + + network.FetchError = ( + method: "network.fetchError", + params: network.FetchErrorParameters + ) + +network.FetchErrorParameters = { + network.BaseParameters, + errorText: text, +} + + network.ResponseCompleted = ( + method: "network.responseCompleted", + params: network.ResponseCompletedParameters + ) + +network.ResponseCompletedParameters = { + network.BaseParameters, + response: network.ResponseData, +} + + network.ResponseStarted = ( + method: "network.responseStarted", + params: network.ResponseStartedParameters + ) + +network.ResponseStartedParameters = { + network.BaseParameters, + response: network.ResponseData, +} + +ScriptResult = ( + script.AddPreloadScriptResult / + script.CallFunctionResult / + script.DisownResult / + script.EvaluateResult / + script.GetRealmsResult / + script.RemovePreloadScriptResult +) + +ScriptEvent = ( + script.Message // + script.RealmCreated // + script.RealmDestroyed +) + +script.Channel = text; + +script.ChannelValue = { + type: "channel", + value: script.ChannelProperties, +} + +script.ChannelProperties = { + channel: script.Channel, + ? serializationOptions: script.SerializationOptions, + ? ownership: script.ResultOwnership, +} + +script.EvaluateResult = ( + script.EvaluateResultSuccess / + script.EvaluateResultException +) + +script.EvaluateResultSuccess = { + type: "success", + result: script.RemoteValue, + realm: script.Realm +} + +script.EvaluateResultException = { + type: "exception", + exceptionDetails: script.ExceptionDetails + realm: script.Realm +} + +script.ExceptionDetails = { + columnNumber: js-uint, + exception: script.RemoteValue, + lineNumber: js-uint, + stackTrace: script.StackTrace, + text: text, +} + +script.Handle = text; + +script.InternalId = text; + +script.LocalValue = ( + script.RemoteReference / + script.PrimitiveProtocolValue / + script.ChannelValue / + script.ArrayLocalValue / + { script.DateLocalValue } / + script.MapLocalValue / + script.ObjectLocalValue / + { script.RegExpLocalValue } / + script.SetLocalValue +) + +script.ListLocalValue = [*script.LocalValue]; + +script.ArrayLocalValue = { + type: "array", + value: script.ListLocalValue, +} + +script.DateLocalValue = ( + type: "date", + value: text +) + +script.MappingLocalValue = [*[(script.LocalValue / text), script.LocalValue]]; + +script.MapLocalValue = { + type: "map", + value: script.MappingLocalValue, +} + +script.ObjectLocalValue = { + type: "object", + value: script.MappingLocalValue, +} + +script.RegExpValue = { + pattern: text, + ? flags: text, +} + +script.RegExpLocalValue = ( + type: "regexp", + value: script.RegExpValue, +) + +script.SetLocalValue = { + type: "set", + value: script.ListLocalValue, +} + +script.PreloadScript = text; + +script.Realm = text; + +script.PrimitiveProtocolValue = ( + script.UndefinedValue / + script.NullValue / + script.StringValue / + script.NumberValue / + script.BooleanValue / + script.BigIntValue +) + +script.UndefinedValue = { + type: "undefined", +} + +script.NullValue = { + type: "null", +} + +script.StringValue = { + type: "string", + value: text, +} + +script.SpecialNumber = "NaN" / "-0" / "Infinity" / "-Infinity"; + +script.NumberValue = { + type: "number", + value: number / script.SpecialNumber, +} + +script.BooleanValue = { + type: "boolean", + value: bool, +} + +script.BigIntValue = { + type: "bigint", + value: text, +} + +script.RealmInfo = ( + script.WindowRealmInfo / + script.DedicatedWorkerRealmInfo / + script.SharedWorkerRealmInfo / + script.ServiceWorkerRealmInfo / + script.WorkerRealmInfo / + script.PaintWorkletRealmInfo / + script.AudioWorkletRealmInfo / + script.WorkletRealmInfo +) + +script.BaseRealmInfo = ( + realm: script.Realm, + origin: text +) + +script.WindowRealmInfo = { + script.BaseRealmInfo, + type: "window", + context: browsingContext.BrowsingContext, + ? sandbox: text +} + +script.DedicatedWorkerRealmInfo = { + script.BaseRealmInfo, + type: "dedicated-worker", + owners: [script.Realm] +} + +script.SharedWorkerRealmInfo = { + script.BaseRealmInfo, + type: "shared-worker" +} + +script.ServiceWorkerRealmInfo = { + script.BaseRealmInfo, + type: "service-worker" +} + +script.WorkerRealmInfo = { + script.BaseRealmInfo, + type: "worker" +} + +script.PaintWorkletRealmInfo = { + script.BaseRealmInfo, + type: "paint-worklet" +} + +script.AudioWorkletRealmInfo = { + script.BaseRealmInfo, + type: "audio-worklet" +} + +script.WorkletRealmInfo = { + script.BaseRealmInfo, + type: "worklet" +} + +script.RealmType = "window" / "dedicated-worker" / "shared-worker" / "service-worker" / + "worker" / "paint-worklet" / "audio-worklet" / "worklet" + + + +script.RemoteReference = ( + script.SharedReference / + script.RemoteObjectReference +) + +script.SharedReference = { + sharedId: script.SharedId + + ? handle: script.Handle, + Extensible +} + +script.RemoteObjectReference = { + handle: script.Handle, + + ? sharedId: script.SharedId + Extensible +} + +script.RemoteValue = ( + script.PrimitiveProtocolValue / + script.SymbolRemoteValue / + script.ArrayRemoteValue / + script.ObjectRemoteValue / + script.FunctionRemoteValue / + script.RegExpRemoteValue / + script.DateRemoteValue / + script.MapRemoteValue / + script.SetRemoteValue / + script.WeakMapRemoteValue / + script.WeakSetRemoteValue / + script.GeneratorRemoteValue / + script.ErrorRemoteValue / + script.ProxyRemoteValue / + script.PromiseRemoteValue / + script.TypedArrayRemoteValue / + script.ArrayBufferRemoteValue / + script.NodeListRemoteValue / + script.HTMLCollectionRemoteValue / + script.NodeRemoteValue / + script.WindowProxyRemoteValue +) + +script.ListRemoteValue = [*script.RemoteValue]; + +script.MappingRemoteValue = [*[(script.RemoteValue / text), script.RemoteValue]]; + +script.SymbolRemoteValue = { + type: "symbol", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ArrayRemoteValue = { + type: "array", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.ObjectRemoteValue = { + type: "object", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.MappingRemoteValue, +} + +script.FunctionRemoteValue = { + type: "function", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.RegExpRemoteValue = { + script.RegExpLocalValue, + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.DateRemoteValue = { + script.DateLocalValue, + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.MapRemoteValue = { + type: "map", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.MappingRemoteValue, +} + +script.SetRemoteValue = { + type: "set", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue +} + +script.WeakMapRemoteValue = { + type: "weakmap", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.WeakSetRemoteValue = { + type: "weakset", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.GeneratorRemoteValue = { + type: "generator", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ErrorRemoteValue = { + type: "error", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ProxyRemoteValue = { + type: "proxy", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.PromiseRemoteValue = { + type: "promise", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.TypedArrayRemoteValue = { + type: "typedarray", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ArrayBufferRemoteValue = { + type: "arraybuffer", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.NodeListRemoteValue = { + type: "nodelist", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.HTMLCollectionRemoteValue = { + type: "htmlcollection", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.NodeRemoteValue = { + type: "node", + ? sharedId: script.SharedId, + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.NodeProperties, +} + +script.NodeProperties = { + nodeType: js-uint, + childNodeCount: js-uint, + ? attributes: {*text => text}, + ? children: [*script.NodeRemoteValue], + ? localName: text, + ? mode: "open" / "closed", + ? namespaceURI: text, + ? nodeValue: text, + ? shadowRoot: script.NodeRemoteValue / null, +} + +script.WindowProxyRemoteValue = { + type: "window", + value: script.WindowProxyProperties, + ? handle: script.Handle, + ? internalId: script.InternalId +} + +script.WindowProxyProperties = { + context: browsingContext.BrowsingContext +} + +script.ResultOwnership = "root" / "none" + +script.SerializationOptions = { + ? maxDomDepth: (js-uint / null) .default 0, + ? maxObjectDepth: (js-uint / null) .default null, + ? includeShadowTree: ("none" / "open" / "all") .default "none", +} + +script.SharedId = text; + +script.StackFrame = { + columnNumber: js-uint, + functionName: text, + lineNumber: js-uint, + url: text, +} + +script.StackTrace = { + callFrames: [*script.StackFrame], +} + +script.Source = { + realm: script.Realm, + ? context: browsingContext.BrowsingContext +} + +script.AddPreloadScriptResult = { + script: script.PreloadScript +} + +script.DisownResult = EmptyResult + +script.CallFunctionResult = script.EvaluateResult + +script.GetRealmsResult = { + realms: [*script.RealmInfo] +} + +script.RemovePreloadScriptResult = EmptyResult + + script.Message = ( + method: "script.message", + params: script.MessageParameters + ) + +script.MessageParameters = { + channel: script.Channel, + data: script.RemoteValue, + source: script.Source, +} + +script.RealmCreated = ( + method: "script.realmCreated", + params: script.RealmInfo +) + +script.RealmDestroyed = ( + method: "script.realmDestroyed", + params: script.RealmDestroyedParameters +) + +script.RealmDestroyedParameters = { + realm: script.Realm +} + + +StorageResult = ( + storage.DeleteCookiesResult / + storage.GetCookiesResult / + storage.SetCookieResult +) + +storage.PartitionKey = { + ? userContext: text, + ? sourceOrigin: text, + Extensible, +} + +storage.GetCookiesResult = { + cookies: [*network.Cookie], + partitionKey: storage.PartitionKey, +} + +storage.SetCookieResult = { + partitionKey: storage.PartitionKey +} + +storage.DeleteCookiesResult = { + partitionKey: storage.PartitionKey +} + +LogEvent = ( + log.EntryAdded +) + +log.Level = "debug" / "info" / "warn" / "error" + +log.Entry = ( + log.GenericLogEntry / + log.ConsoleLogEntry / + log.JavascriptLogEntry +) + +log.BaseLogEntry = ( + level: log.Level, + source: script.Source, + text: text / null, + timestamp: js-uint, + ? stackTrace: script.StackTrace, +) + +log.GenericLogEntry = { + log.BaseLogEntry, + type: text, +} + +log.ConsoleLogEntry = { + log.BaseLogEntry, + type: "console", + method: text, + args: [*script.RemoteValue], +} + +log.JavascriptLogEntry = { + log.BaseLogEntry, + type: "javascript", +} + +log.EntryAdded = ( + method: "log.entryAdded", + params: log.Entry, +) + + +InputEvent = ( + input.FileDialogOpened +) + +input.PerformActionsResult = EmptyResult + +input.ReleaseActionsResult = EmptyResult + +input.SetFilesResult = EmptyResult + +input.FileDialogOpened = ( + method: "input.fileDialogOpened", + params: input.FileDialogInfo +) + +input.FileDialogInfo = { + context: browsingContext.BrowsingContext, + ? element: script.SharedReference, + multiple: bool, +} + +WebExtensionResult = ( + webExtension.InstallResult / + webExtension.UninstallResult +) + +webExtension.Extension = text + +webExtension.InstallResult = { + extension: webExtension.Extension +} + +webExtension.UninstallResult = EmptyResult diff --git a/common/bidi/spec/remote.cddl b/common/bidi/spec/remote.cddl new file mode 100644 index 0000000000000..a98859a021e12 --- /dev/null +++ b/common/bidi/spec/remote.cddl @@ -0,0 +1,1716 @@ +Command = { + id: js-uint, + CommandData, + Extensible, +} + +CommandData = ( + BrowserCommand // + BrowsingContextCommand // + EmulationCommand // + InputCommand // + NetworkCommand // + ScriptCommand // + SessionCommand // + StorageCommand // + WebExtensionCommand +) + +EmptyParams = { + Extensible +} + +Extensible = (*text => any) + +js-int = -9007199254740991..9007199254740991 +js-uint = 0..9007199254740991 + +SessionCommand = ( + session.End // + session.New // + session.Status // + session.Subscribe // + session.Unsubscribe +) + +session.CapabilitiesRequest = { + ? alwaysMatch: session.CapabilityRequest, + ? firstMatch: [*session.CapabilityRequest] +} + +session.CapabilityRequest = { + ? acceptInsecureCerts: bool, + ? browserName: text, + ? browserVersion: text, + ? platformName: text, + ? proxy: session.ProxyConfiguration, + ? unhandledPromptBehavior: session.UserPromptHandler, + Extensible +} + +session.ProxyConfiguration = { + session.AutodetectProxyConfiguration // + session.DirectProxyConfiguration // + session.ManualProxyConfiguration // + session.PacProxyConfiguration // + session.SystemProxyConfiguration +} + +session.AutodetectProxyConfiguration = ( + proxyType: "autodetect", + Extensible +) + +session.DirectProxyConfiguration = ( + proxyType: "direct", + Extensible +) + +session.ManualProxyConfiguration = ( + proxyType: "manual", + ? httpProxy: text, + ? sslProxy: text, + ? session.SocksProxyConfiguration, + ? noProxy: [*text], + Extensible +) + +session.SocksProxyConfiguration = ( + socksProxy: text, + socksVersion: 0..255, +) + +session.PacProxyConfiguration = ( + proxyType: "pac", + proxyAutoconfigUrl: text, + Extensible +) + +session.SystemProxyConfiguration = ( + proxyType: "system", + Extensible +) + + +session.UserPromptHandler = { + ? alert: session.UserPromptHandlerType, + ? beforeUnload: session.UserPromptHandlerType, + ? confirm: session.UserPromptHandlerType, + ? default: session.UserPromptHandlerType, + ? file: session.UserPromptHandlerType, + ? prompt: session.UserPromptHandlerType, +} + +session.UserPromptHandlerType = "accept" / "dismiss" / "ignore"; + +session.Subscription = text + +session.SubscribeParameters = { + events: [+text], + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +session.UnsubscribeByIDRequest = { + subscriptions: [+session.Subscription], +} + +session.UnsubscribeByAttributesRequest = { + events: [+text], +} + +session.Status = ( + method: "session.status", + params: EmptyParams, +) + +session.New = ( + method: "session.new", + params: session.NewParameters +) + +session.NewParameters = { + capabilities: session.CapabilitiesRequest +} + +session.End = ( + method: "session.end", + params: EmptyParams +) + + +session.Subscribe = ( + method: "session.subscribe", + params: session.SubscribeParameters +) + +session.Unsubscribe = ( + method: "session.unsubscribe", + params: session.UnsubscribeParameters, +) + +session.UnsubscribeParameters = session.UnsubscribeByAttributesRequest / session.UnsubscribeByIDRequest + +BrowserCommand = ( + browser.Close // + browser.CreateUserContext // + browser.GetClientWindows // + browser.GetUserContexts // + browser.RemoveUserContext // + browser.SetClientWindowState // + browser.SetDownloadBehavior +) + +browser.ClientWindow = text; + +browser.ClientWindowInfo = { + active: bool, + clientWindow: browser.ClientWindow, + height: js-uint, + state: "fullscreen" / "maximized" / "minimized" / "normal", + width: js-uint, + x: js-int, + y: js-int, +} + +browser.UserContext = text; + +browser.UserContextInfo = { + userContext: browser.UserContext +} + +browser.Close = ( + method: "browser.close", + params: EmptyParams, +) + +browser.CreateUserContext = ( + method: "browser.createUserContext", + params: browser.CreateUserContextParameters, +) + +browser.CreateUserContextParameters = { + ? acceptInsecureCerts: bool, + ? proxy: session.ProxyConfiguration, + ? unhandledPromptBehavior: session.UserPromptHandler +} + +browser.GetClientWindows = ( + method: "browser.getClientWindows", + params: EmptyParams, +) + +browser.GetUserContexts = ( + method: "browser.getUserContexts", + params: EmptyParams, +) + +browser.RemoveUserContext = ( + method: "browser.removeUserContext", + params: browser.RemoveUserContextParameters +) + +browser.RemoveUserContextParameters = { + userContext: browser.UserContext +} + +browser.SetClientWindowState = ( + method: "browser.setClientWindowState", + params: browser.SetClientWindowStateParameters +) + +browser.SetClientWindowStateParameters = { + clientWindow: browser.ClientWindow, + (browser.ClientWindowNamedState // browser.ClientWindowRectState) +} + +browser.ClientWindowNamedState = ( + state: "fullscreen" / "maximized" / "minimized" +) + +browser.ClientWindowRectState = ( + state: "normal", + ? width: js-uint, + ? height: js-uint, + ? x: js-int, + ? y: js-int, +) + +browser.SetDownloadBehavior = ( + method: "browser.setDownloadBehavior", + params: browser.SetDownloadBehaviorParameters +) + +browser.SetDownloadBehaviorParameters = { + downloadBehavior: browser.DownloadBehavior / null, + ? userContexts: [+browser.UserContext] +} + +browser.DownloadBehavior = { + ( + browser.DownloadBehaviorAllowed // + browser.DownloadBehaviorDenied + ) +} + +browser.DownloadBehaviorAllowed = ( + type: "allowed", + destinationFolder: text +) + +browser.DownloadBehaviorDenied = ( + type: "denied" +) + +BrowsingContextCommand = ( + browsingContext.Activate // + browsingContext.CaptureScreenshot // + browsingContext.Close // + browsingContext.Create // + browsingContext.GetTree // + browsingContext.HandleUserPrompt // + browsingContext.LocateNodes // + browsingContext.Navigate // + browsingContext.Print // + browsingContext.Reload // + browsingContext.SetViewport // + browsingContext.TraverseHistory +) + +browsingContext.BrowsingContext = text; + +browsingContext.Locator = ( + browsingContext.AccessibilityLocator / + browsingContext.CssLocator / + browsingContext.ContextLocator / + browsingContext.InnerTextLocator / + browsingContext.XPathLocator +) + +browsingContext.AccessibilityLocator = { + type: "accessibility", + value: { + ? name: text, + ? role: text, + } +} + +browsingContext.CssLocator = { + type: "css", + value: text +} + +browsingContext.ContextLocator = { + type: "context", + value: { + context: browsingContext.BrowsingContext, + } +} + +browsingContext.InnerTextLocator = { + type: "innerText", + value: text, + ? ignoreCase: bool + ? matchType: "full" / "partial", + ? maxDepth: js-uint, +} + +browsingContext.XPathLocator = { + type: "xpath", + value: text +} + +browsingContext.Navigation = text; + +browsingContext.ReadinessState = "none" / "interactive" / "complete" + +browsingContext.UserPromptType = "alert" / "beforeunload" / "confirm" / "prompt"; + +browsingContext.Activate = ( + method: "browsingContext.activate", + params: browsingContext.ActivateParameters +) + +browsingContext.ActivateParameters = { + context: browsingContext.BrowsingContext +} + +browsingContext.CaptureScreenshot = ( + method: "browsingContext.captureScreenshot", + params: browsingContext.CaptureScreenshotParameters +) + +browsingContext.CaptureScreenshotParameters = { + context: browsingContext.BrowsingContext, + ? origin: ("viewport" / "document") .default "viewport", + ? format: browsingContext.ImageFormat, + ? clip: browsingContext.ClipRectangle, +} + +browsingContext.ImageFormat = { + type: text, + ? quality: 0.0..1.0, +} + +browsingContext.ClipRectangle = ( + browsingContext.BoxClipRectangle / + browsingContext.ElementClipRectangle +) + +browsingContext.ElementClipRectangle = { + type: "element", + element: script.SharedReference +} + +browsingContext.BoxClipRectangle = { + type: "box", + x: float, + y: float, + width: float, + height: float +} + +browsingContext.Close = ( + method: "browsingContext.close", + params: browsingContext.CloseParameters +) + +browsingContext.CloseParameters = { + context: browsingContext.BrowsingContext, + ? promptUnload: bool .default false +} + +browsingContext.Create = ( + method: "browsingContext.create", + params: browsingContext.CreateParameters +) + +browsingContext.CreateType = "tab" / "window" + +browsingContext.CreateParameters = { + type: browsingContext.CreateType, + ? referenceContext: browsingContext.BrowsingContext, + ? background: bool .default false, + ? userContext: browser.UserContext +} + +browsingContext.GetTree = ( + method: "browsingContext.getTree", + params: browsingContext.GetTreeParameters +) + +browsingContext.GetTreeParameters = { + ? maxDepth: js-uint, + ? root: browsingContext.BrowsingContext, +} + +browsingContext.HandleUserPrompt = ( + method: "browsingContext.handleUserPrompt", + params: browsingContext.HandleUserPromptParameters +) + +browsingContext.HandleUserPromptParameters = { + context: browsingContext.BrowsingContext, + ? accept: bool, + ? userText: text, +} + +browsingContext.LocateNodes = ( + method: "browsingContext.locateNodes", + params: browsingContext.LocateNodesParameters +) + +browsingContext.LocateNodesParameters = { + context: browsingContext.BrowsingContext, + locator: browsingContext.Locator, + ? maxNodeCount: (js-uint .ge 1), + ? serializationOptions: script.SerializationOptions, + ? startNodes: [ + script.SharedReference ] +} + +browsingContext.Navigate = ( + method: "browsingContext.navigate", + params: browsingContext.NavigateParameters +) + +browsingContext.NavigateParameters = { + context: browsingContext.BrowsingContext, + url: text, + ? wait: browsingContext.ReadinessState, +} + +browsingContext.Print = ( + method: "browsingContext.print", + params: browsingContext.PrintParameters +) + +browsingContext.PrintParameters = { + context: browsingContext.BrowsingContext, + ? background: bool .default false, + ? margin: browsingContext.PrintMarginParameters, + ? orientation: ("portrait" / "landscape") .default "portrait", + ? page: browsingContext.PrintPageParameters, + ? pageRanges: [*(js-uint / text)], + ? scale: (0.1..2.0) .default 1.0, + ? shrinkToFit: bool .default true, +} + +browsingContext.PrintMarginParameters = { + ? bottom: (float .ge 0.0) .default 1.0, + ? left: (float .ge 0.0) .default 1.0, + ? right: (float .ge 0.0) .default 1.0, + ? top: (float .ge 0.0) .default 1.0, +} + +; Minimum size is 1pt x 1pt. Conversion follows from +; https://www.w3.org/TR/css3-values/#absolute-lengths +browsingContext.PrintPageParameters = { + ? height: (float .ge 0.0352) .default 27.94, + ? width: (float .ge 0.0352) .default 21.59, +} + +browsingContext.Reload = ( + method: "browsingContext.reload", + params: browsingContext.ReloadParameters +) + +browsingContext.ReloadParameters = { + context: browsingContext.BrowsingContext, + ? ignoreCache: bool, + ? wait: browsingContext.ReadinessState, +} + +browsingContext.SetViewport = ( + method: "browsingContext.setViewport", + params: browsingContext.SetViewportParameters +) + +browsingContext.SetViewportParameters = { + ? context: browsingContext.BrowsingContext, + ? viewport: browsingContext.Viewport / null, + ? devicePixelRatio: (float .gt 0.0) / null, + ? userContexts: [+browser.UserContext], +} + +browsingContext.Viewport = { + width: js-uint, + height: js-uint, +} + +browsingContext.TraverseHistory = ( + method: "browsingContext.traverseHistory", + params: browsingContext.TraverseHistoryParameters +) + +browsingContext.TraverseHistoryParameters = { + context: browsingContext.BrowsingContext, + delta: js-int, +} + +EmulationCommand = ( + emulation.SetForcedColorsModeThemeOverride // + emulation.SetGeolocationOverride // + emulation.SetLocaleOverride // + emulation.SetNetworkConditions // + emulation.SetScreenOrientationOverride // + emulation.SetScreenSettingsOverride // + emulation.SetScriptingEnabled // + emulation.SetScrollbarTypeOverride // + emulation.SetTimezoneOverride // + emulation.SetTouchOverride // + emulation.SetUserAgentOverride // + emulation.SetViewportMetaOverride +) + + +emulation.SetForcedColorsModeThemeOverride = ( + method: "emulation.setForcedColorsModeThemeOverride", + params: emulation.SetForcedColorsModeThemeOverrideParameters +) + +emulation.SetForcedColorsModeThemeOverrideParameters = { + theme: emulation.ForcedColorsModeTheme / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.ForcedColorsModeTheme = "light" / "dark" + +emulation.SetGeolocationOverride = ( + method: "emulation.setGeolocationOverride", + params: emulation.SetGeolocationOverrideParameters +) + +emulation.SetGeolocationOverrideParameters = { + ( + (coordinates: emulation.GeolocationCoordinates / null) // + (error: emulation.GeolocationPositionError) + ), + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.GeolocationCoordinates = { + latitude: -90.0..90.0, + longitude: -180.0..180.0, + ? accuracy: (float .ge 0.0) .default 1.0, + ? altitude: float / null .default null, + ? altitudeAccuracy: (float .ge 0.0) / null .default null, + ? heading: (0.0...360.0) / null .default null, + ? speed: (float .ge 0.0) / null .default null, +} + +emulation.GeolocationPositionError = { + type: "positionUnavailable" +} + +emulation.SetLocaleOverride = ( + method: "emulation.setLocaleOverride", + params: emulation.SetLocaleOverrideParameters +) + +emulation.SetLocaleOverrideParameters = { + locale: text / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetNetworkConditions = ( + method: "emulation.setNetworkConditions", + params: emulation.setNetworkConditionsParameters +) + +emulation.setNetworkConditionsParameters = { + networkConditions: emulation.NetworkConditions / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.NetworkConditions = emulation.NetworkConditionsOffline + +emulation.NetworkConditionsOffline = { + type: "offline" +} + +emulation.SetScreenSettingsOverride = ( + method: "emulation.setScreenSettingsOverride", + params: emulation.SetScreenSettingsOverrideParameters +) + +emulation.ScreenArea = { + width: js-uint, + height: js-uint +} + +emulation.SetScreenSettingsOverrideParameters = { + screenArea: emulation.ScreenArea / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetScreenOrientationOverride = ( + method: "emulation.setScreenOrientationOverride", + params: emulation.SetScreenOrientationOverrideParameters +) + +emulation.ScreenOrientationNatural = "portrait" / "landscape" +emulation.ScreenOrientationType = "portrait-primary" / "portrait-secondary" / "landscape-primary" / "landscape-secondary" + +emulation.ScreenOrientation = { + natural: emulation.ScreenOrientationNatural, + type: emulation.ScreenOrientationType +} + +emulation.SetScreenOrientationOverrideParameters = { + screenOrientation: emulation.ScreenOrientation / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetUserAgentOverride = ( + method: "emulation.setUserAgentOverride", + params: emulation.SetUserAgentOverrideParameters +) + +emulation.SetUserAgentOverrideParameters = { + userAgent: text / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetViewportMetaOverride = ( + method: "emulation.setViewportMetaOverride", + params: emulation.SetViewportMetaOverrideParameters +) + +emulation.SetViewportMetaOverrideParameters = { + viewportMeta: true / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetScriptingEnabled = ( + method: "emulation.setScriptingEnabled", + params: emulation.SetScriptingEnabledParameters +) + +emulation.SetScriptingEnabledParameters = { + enabled: false / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetScrollbarTypeOverride = ( + method: "emulation.setScrollbarTypeOverride", + params: emulation.SetScrollbarTypeOverrideParameters +) + +emulation.SetScrollbarTypeOverrideParameters = { + scrollbarType: "classic" / "overlay" / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetTimezoneOverride = ( + method: "emulation.setTimezoneOverride", + params: emulation.SetTimezoneOverrideParameters +) + +emulation.SetTimezoneOverrideParameters = { + timezone: text / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetTouchOverride = ( + method: "emulation.setTouchOverride", + params: emulation.SetTouchOverrideParameters +) + +emulation.SetTouchOverrideParameters = { + maxTouchPoints: (js-uint .ge 1) / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + + +NetworkCommand = ( + network.AddDataCollector // + network.AddIntercept // + network.ContinueRequest // + network.ContinueResponse // + network.ContinueWithAuth // + network.DisownData // + network.FailRequest // + network.GetData // + network.ProvideResponse // + network.RemoveDataCollector // + network.RemoveIntercept // + network.SetCacheBehavior // + network.SetExtraHeaders +) + + +network.AuthCredentials = { + type: "password", + username: text, + password: text, +} + +network.BytesValue = network.StringValue / network.Base64Value; + +network.StringValue = { + type: "string", + value: text, +} + +network.Base64Value = { + type: "base64", + value: text, +} + +network.Collector = text + +network.CollectorType = "blob" + + +network.SameSite = "strict" / "lax" / "none" / "default" + + +network.Cookie = { + name: text, + value: network.BytesValue, + domain: text, + path: text, + size: js-uint, + httpOnly: bool, + secure: bool, + sameSite: network.SameSite, + ? expiry: js-uint, + Extensible, +} + +network.CookieHeader = { + name: text, + value: network.BytesValue, +} + +network.DataType = "request" / "response" + +network.Header = { + name: text, + value: network.BytesValue, +} + +network.Intercept = text + +network.Request = text; + + +network.SetCookieHeader = { + name: text, + value: network.BytesValue, + ? domain: text, + ? httpOnly: bool, + ? expiry: text, + ? maxAge: js-int, + ? path: text, + ? sameSite: network.SameSite, + ? secure: bool, +} + +network.UrlPattern = ( + network.UrlPatternPattern / + network.UrlPatternString +) + +network.UrlPatternPattern = { + type: "pattern", + ?protocol: text, + ?hostname: text, + ?port: text, + ?pathname: text, + ?search: text, +} + + +network.UrlPatternString = { + type: "string", + pattern: text, +} + + +network.AddDataCollector = ( + method: "network.addDataCollector", + params: network.AddDataCollectorParameters +) + +network.AddDataCollectorParameters = { + dataTypes: [+network.DataType], + maxEncodedDataSize: js-uint, + ? collectorType: network.CollectorType .default "blob", + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +network.AddIntercept = ( + method: "network.addIntercept", + params: network.AddInterceptParameters +) + +network.AddInterceptParameters = { + phases: [+network.InterceptPhase], + ? contexts: [+browsingContext.BrowsingContext], + ? urlPatterns: [*network.UrlPattern], +} + +network.InterceptPhase = "beforeRequestSent" / "responseStarted" / + "authRequired" + +network.ContinueRequest = ( + method: "network.continueRequest", + params: network.ContinueRequestParameters +) + +network.ContinueRequestParameters = { + request: network.Request, + ?body: network.BytesValue, + ?cookies: [*network.CookieHeader], + ?headers: [*network.Header], + ?method: text, + ?url: text, +} + +network.ContinueResponse = ( + method: "network.continueResponse", + params: network.ContinueResponseParameters +) + +network.ContinueResponseParameters = { + request: network.Request, + ?cookies: [*network.SetCookieHeader] + ?credentials: network.AuthCredentials, + ?headers: [*network.Header], + ?reasonPhrase: text, + ?statusCode: js-uint, +} + +network.ContinueWithAuth = ( + method: "network.continueWithAuth", + params: network.ContinueWithAuthParameters +) + +network.ContinueWithAuthParameters = { + request: network.Request, + (network.ContinueWithAuthCredentials // network.ContinueWithAuthNoCredentials) +} + +network.ContinueWithAuthCredentials = ( + action: "provideCredentials", + credentials: network.AuthCredentials +) + +network.ContinueWithAuthNoCredentials = ( + action: "default" / "cancel" +) + +network.DisownData = ( + method: "network.disownData", + params: network.disownDataParameters +) + +network.disownDataParameters = { + dataType: network.DataType, + collector: network.Collector, + request: network.Request, +} + +network.FailRequest = ( + method: "network.failRequest", + params: network.FailRequestParameters +) + +network.FailRequestParameters = { + request: network.Request, +} + +network.GetData = ( + method: "network.getData", + params: network.GetDataParameters +) + +network.GetDataParameters = { + dataType: network.DataType, + ? collector: network.Collector, + ? disown: bool .default false, + request: network.Request, +} + +network.ProvideResponse = ( + method: "network.provideResponse", + params: network.ProvideResponseParameters +) + +network.ProvideResponseParameters = { + request: network.Request, + ?body: network.BytesValue, + ?cookies: [*network.SetCookieHeader], + ?headers: [*network.Header], + ?reasonPhrase: text, + ?statusCode: js-uint, +} + +network.RemoveDataCollector = ( + method: "network.removeDataCollector", + params: network.RemoveDataCollectorParameters +) + +network.RemoveDataCollectorParameters = { + collector: network.Collector +} + +network.RemoveIntercept = ( + method: "network.removeIntercept", + params: network.RemoveInterceptParameters +) + +network.RemoveInterceptParameters = { + intercept: network.Intercept +} + +network.SetCacheBehavior = ( + method: "network.setCacheBehavior", + params: network.SetCacheBehaviorParameters +) + +network.SetCacheBehaviorParameters = { + cacheBehavior: "default" / "bypass", + ? contexts: [+browsingContext.BrowsingContext] +} + +network.SetExtraHeaders = ( + method: "network.setExtraHeaders", + params: network.SetExtraHeadersParameters +) + +network.SetExtraHeadersParameters = { + headers: [*network.Header] + ? contexts: [+browsingContext.BrowsingContext] + ? userContexts: [+browser.UserContext] +} + +ScriptCommand = ( + script.AddPreloadScript // + script.CallFunction // + script.Disown // + script.Evaluate // + script.GetRealms // + script.RemovePreloadScript +) + +script.Channel = text; + +script.ChannelValue = { + type: "channel", + value: script.ChannelProperties, +} + +script.ChannelProperties = { + channel: script.Channel, + ? serializationOptions: script.SerializationOptions, + ? ownership: script.ResultOwnership, +} + +script.EvaluateResult = ( + script.EvaluateResultSuccess / + script.EvaluateResultException +) + +script.EvaluateResultSuccess = { + type: "success", + result: script.RemoteValue, + realm: script.Realm +} + +script.EvaluateResultException = { + type: "exception", + exceptionDetails: script.ExceptionDetails + realm: script.Realm +} + +script.ExceptionDetails = { + columnNumber: js-uint, + exception: script.RemoteValue, + lineNumber: js-uint, + stackTrace: script.StackTrace, + text: text, +} + +script.Handle = text; + +script.InternalId = text; + +script.LocalValue = ( + script.RemoteReference / + script.PrimitiveProtocolValue / + script.ChannelValue / + script.ArrayLocalValue / + { script.DateLocalValue } / + script.MapLocalValue / + script.ObjectLocalValue / + { script.RegExpLocalValue } / + script.SetLocalValue +) + +script.ListLocalValue = [*script.LocalValue]; + +script.ArrayLocalValue = { + type: "array", + value: script.ListLocalValue, +} + +script.DateLocalValue = ( + type: "date", + value: text +) + +script.MappingLocalValue = [*[(script.LocalValue / text), script.LocalValue]]; + +script.MapLocalValue = { + type: "map", + value: script.MappingLocalValue, +} + +script.ObjectLocalValue = { + type: "object", + value: script.MappingLocalValue, +} + +script.RegExpValue = { + pattern: text, + ? flags: text, +} + +script.RegExpLocalValue = ( + type: "regexp", + value: script.RegExpValue, +) + +script.SetLocalValue = { + type: "set", + value: script.ListLocalValue, +} + +script.PreloadScript = text; + +script.Realm = text; + +script.PrimitiveProtocolValue = ( + script.UndefinedValue / + script.NullValue / + script.StringValue / + script.NumberValue / + script.BooleanValue / + script.BigIntValue +) + +script.UndefinedValue = { + type: "undefined", +} + +script.NullValue = { + type: "null", +} + +script.StringValue = { + type: "string", + value: text, +} + +script.SpecialNumber = "NaN" / "-0" / "Infinity" / "-Infinity"; + +script.NumberValue = { + type: "number", + value: number / script.SpecialNumber, +} + +script.BooleanValue = { + type: "boolean", + value: bool, +} + +script.BigIntValue = { + type: "bigint", + value: text, +} + +script.RealmType = "window" / "dedicated-worker" / "shared-worker" / "service-worker" / + "worker" / "paint-worklet" / "audio-worklet" / "worklet" + + + +script.RemoteReference = ( + script.SharedReference / + script.RemoteObjectReference +) + +script.SharedReference = { + sharedId: script.SharedId + + ? handle: script.Handle, + Extensible +} + +script.RemoteObjectReference = { + handle: script.Handle, + + ? sharedId: script.SharedId + Extensible +} + +script.RemoteValue = ( + script.PrimitiveProtocolValue / + script.SymbolRemoteValue / + script.ArrayRemoteValue / + script.ObjectRemoteValue / + script.FunctionRemoteValue / + script.RegExpRemoteValue / + script.DateRemoteValue / + script.MapRemoteValue / + script.SetRemoteValue / + script.WeakMapRemoteValue / + script.WeakSetRemoteValue / + script.GeneratorRemoteValue / + script.ErrorRemoteValue / + script.ProxyRemoteValue / + script.PromiseRemoteValue / + script.TypedArrayRemoteValue / + script.ArrayBufferRemoteValue / + script.NodeListRemoteValue / + script.HTMLCollectionRemoteValue / + script.NodeRemoteValue / + script.WindowProxyRemoteValue +) + +script.ListRemoteValue = [*script.RemoteValue]; + +script.MappingRemoteValue = [*[(script.RemoteValue / text), script.RemoteValue]]; + +script.SymbolRemoteValue = { + type: "symbol", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ArrayRemoteValue = { + type: "array", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.ObjectRemoteValue = { + type: "object", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.MappingRemoteValue, +} + +script.FunctionRemoteValue = { + type: "function", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.RegExpRemoteValue = { + script.RegExpLocalValue, + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.DateRemoteValue = { + script.DateLocalValue, + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.MapRemoteValue = { + type: "map", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.MappingRemoteValue, +} + +script.SetRemoteValue = { + type: "set", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue +} + +script.WeakMapRemoteValue = { + type: "weakmap", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.WeakSetRemoteValue = { + type: "weakset", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.GeneratorRemoteValue = { + type: "generator", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ErrorRemoteValue = { + type: "error", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ProxyRemoteValue = { + type: "proxy", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.PromiseRemoteValue = { + type: "promise", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.TypedArrayRemoteValue = { + type: "typedarray", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ArrayBufferRemoteValue = { + type: "arraybuffer", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.NodeListRemoteValue = { + type: "nodelist", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.HTMLCollectionRemoteValue = { + type: "htmlcollection", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.NodeRemoteValue = { + type: "node", + ? sharedId: script.SharedId, + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.NodeProperties, +} + +script.NodeProperties = { + nodeType: js-uint, + childNodeCount: js-uint, + ? attributes: {*text => text}, + ? children: [*script.NodeRemoteValue], + ? localName: text, + ? mode: "open" / "closed", + ? namespaceURI: text, + ? nodeValue: text, + ? shadowRoot: script.NodeRemoteValue / null, +} + +script.WindowProxyRemoteValue = { + type: "window", + value: script.WindowProxyProperties, + ? handle: script.Handle, + ? internalId: script.InternalId +} + +script.WindowProxyProperties = { + context: browsingContext.BrowsingContext +} + +script.ResultOwnership = "root" / "none" + +script.SerializationOptions = { + ? maxDomDepth: (js-uint / null) .default 0, + ? maxObjectDepth: (js-uint / null) .default null, + ? includeShadowTree: ("none" / "open" / "all") .default "none", +} + +script.SharedId = text; + +script.StackFrame = { + columnNumber: js-uint, + functionName: text, + lineNumber: js-uint, + url: text, +} + +script.StackTrace = { + callFrames: [*script.StackFrame], +} + +script.RealmTarget = { + realm: script.Realm +} + +script.ContextTarget = { + context: browsingContext.BrowsingContext, + ? sandbox: text +} + +script.Target = ( + script.ContextTarget / + script.RealmTarget +) + +script.AddPreloadScript = ( + method: "script.addPreloadScript", + params: script.AddPreloadScriptParameters +) + +script.AddPreloadScriptParameters = { + functionDeclaration: text, + ? arguments: [*script.ChannelValue], + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], + ? sandbox: text +} + +script.Disown = ( + method: "script.disown", + params: script.DisownParameters +) + +script.DisownParameters = { + handles: [*script.Handle] + target: script.Target; +} + +script.CallFunction = ( + method: "script.callFunction", + params: script.CallFunctionParameters +) + +script.CallFunctionParameters = { + functionDeclaration: text, + awaitPromise: bool, + target: script.Target, + ? arguments: [*script.LocalValue], + ? resultOwnership: script.ResultOwnership, + ? serializationOptions: script.SerializationOptions, + ? this: script.LocalValue, + ? userActivation: bool .default false, +} + +script.Evaluate = ( + method: "script.evaluate", + params: script.EvaluateParameters +) + +script.EvaluateParameters = { + expression: text, + target: script.Target, + awaitPromise: bool, + ? resultOwnership: script.ResultOwnership, + ? serializationOptions: script.SerializationOptions, + ? userActivation: bool .default false, +} + +script.GetRealms = ( + method: "script.getRealms", + params: script.GetRealmsParameters +) + +script.GetRealmsParameters = { + ? context: browsingContext.BrowsingContext, + ? type: script.RealmType, +} + +script.RemovePreloadScript = ( + method: "script.removePreloadScript", + params: script.RemovePreloadScriptParameters +) + +script.RemovePreloadScriptParameters = { + script: script.PreloadScript +} + +StorageCommand = ( + storage.DeleteCookies // + storage.GetCookies // + storage.SetCookie +) + +storage.PartitionKey = { + ? userContext: text, + ? sourceOrigin: text, + Extensible, +} + +storage.GetCookies = ( + method: "storage.getCookies", + params: storage.GetCookiesParameters +) + + +storage.CookieFilter = { + ? name: text, + ? value: network.BytesValue, + ? domain: text, + ? path: text, + ? size: js-uint, + ? httpOnly: bool, + ? secure: bool, + ? sameSite: network.SameSite, + ? expiry: js-uint, + Extensible, +} + +storage.BrowsingContextPartitionDescriptor = { + type: "context", + context: browsingContext.BrowsingContext +} + +storage.StorageKeyPartitionDescriptor = { + type: "storageKey", + ? userContext: text, + ? sourceOrigin: text, + Extensible, +} + +storage.PartitionDescriptor = ( + storage.BrowsingContextPartitionDescriptor / + storage.StorageKeyPartitionDescriptor +) + +storage.GetCookiesParameters = { + ? filter: storage.CookieFilter, + ? partition: storage.PartitionDescriptor, +} + +storage.SetCookie = ( + method: "storage.setCookie", + params: storage.SetCookieParameters, +) + + +storage.PartialCookie = { + name: text, + value: network.BytesValue, + domain: text, + ? path: text, + ? httpOnly: bool, + ? secure: bool, + ? sameSite: network.SameSite, + ? expiry: js-uint, + Extensible, +} + +storage.SetCookieParameters = { + cookie: storage.PartialCookie, + ? partition: storage.PartitionDescriptor, +} + +storage.DeleteCookies = ( + method: "storage.deleteCookies", + params: storage.DeleteCookiesParameters, +) + +storage.DeleteCookiesParameters = { + ? filter: storage.CookieFilter, + ? partition: storage.PartitionDescriptor, +} + +InputCommand = ( + input.PerformActions // + input.ReleaseActions // + input.SetFiles +) + +InputResult = ( + input.PerformActionsResult / + input.ReleaseActionsResult / + input.SetFilesResult +) + +input.ElementOrigin = { + type: "element", + element: script.SharedReference +} + +input.PerformActions = ( + method: "input.performActions", + params: input.PerformActionsParameters +) + +input.PerformActionsParameters = { + context: browsingContext.BrowsingContext, + actions: [*input.SourceActions] +} + +input.SourceActions = ( + input.NoneSourceActions / + input.KeySourceActions / + input.PointerSourceActions / + input.WheelSourceActions +) + +input.NoneSourceActions = { + type: "none", + id: text, + actions: [*input.NoneSourceAction] +} + +input.NoneSourceAction = input.PauseAction + +input.KeySourceActions = { + type: "key", + id: text, + actions: [*input.KeySourceAction] +} + +input.KeySourceAction = ( + input.PauseAction / + input.KeyDownAction / + input.KeyUpAction +) + +input.PointerSourceActions = { + type: "pointer", + id: text, + ? parameters: input.PointerParameters, + actions: [*input.PointerSourceAction] +} + +input.PointerType = "mouse" / "pen" / "touch" + +input.PointerParameters = { + ? pointerType: input.PointerType .default "mouse" +} + +input.PointerSourceAction = ( + input.PauseAction / + input.PointerDownAction / + input.PointerUpAction / + input.PointerMoveAction +) + +input.WheelSourceActions = { + type: "wheel", + id: text, + actions: [*input.WheelSourceAction] +} + +input.WheelSourceAction = ( + input.PauseAction / + input.WheelScrollAction +) + +input.PauseAction = { + type: "pause", + ? duration: js-uint +} + +input.KeyDownAction = { + type: "keyDown", + value: text +} + +input.KeyUpAction = { + type: "keyUp", + value: text +} + +input.PointerUpAction = { + type: "pointerUp", + button: js-uint, +} + +input.PointerDownAction = { + type: "pointerDown", + button: js-uint, + input.PointerCommonProperties +} + +input.PointerMoveAction = { + type: "pointerMove", + x: float, + y: float, + ? duration: js-uint, + ? origin: input.Origin, + input.PointerCommonProperties +} + +input.WheelScrollAction = { + type: "scroll", + x: js-int, + y: js-int, + deltaX: js-int, + deltaY: js-int, + ? duration: js-uint, + ? origin: input.Origin .default "viewport", +} + +input.PointerCommonProperties = ( + ? width: js-uint .default 1, + ? height: js-uint .default 1, + ? pressure: float .default 0.0, + ? tangentialPressure: float .default 0.0, + ? twist: (0..359) .default 0, + ; 0 .. Math.PI / 2 + ? altitudeAngle: (0.0..1.5707963267948966) .default 0.0, + ; 0 .. 2 * Math.PI + ? azimuthAngle: (0.0..6.283185307179586) .default 0.0, +) + +input.Origin = "viewport" / "pointer" / input.ElementOrigin + +input.ReleaseActions = ( + method: "input.releaseActions", + params: input.ReleaseActionsParameters +) + +input.ReleaseActionsParameters = { + context: browsingContext.BrowsingContext, +} + +input.SetFiles = ( + method: "input.setFiles", + params: input.SetFilesParameters +) + +input.SetFilesParameters = { + context: browsingContext.BrowsingContext, + element: script.SharedReference, + files: [*text] +} + +input.FileDialogOpened = ( + method: "input.fileDialogOpened", + params: input.FileDialogInfo +) + +input.FileDialogInfo = { + context: browsingContext.BrowsingContext, + ? element: script.SharedReference, + multiple: bool, +} + +WebExtensionCommand = ( + webExtension.Install // + webExtension.Uninstall +) + +webExtension.Extension = text + +webExtension.Install = ( + method: "webExtension.install", + params: webExtension.InstallParameters +) + +webExtension.InstallParameters = { + extensionData: webExtension.ExtensionData, +} + +webExtension.ExtensionData = ( + webExtension.ExtensionArchivePath / + webExtension.ExtensionBase64Encoded / + webExtension.ExtensionPath +) + +webExtension.ExtensionPath = { + type: "path", + path: text, +} + +webExtension.ExtensionArchivePath = { + type: "archivePath", + path: text, +} + +webExtension.ExtensionBase64Encoded = { + type: "base64", + value: text, +} + +webExtension.Uninstall = ( + method: "webExtension.uninstall", + params: webExtension.UninstallParameters +) + +webExtension.UninstallParameters = { + extension: webExtension.Extension, +} diff --git a/py/AGENTS.md b/py/AGENTS.md index 57a5e819e1a8e..27c1aaac41e9a 100644 --- a/py/AGENTS.md +++ b/py/AGENTS.md @@ -51,24 +51,25 @@ def method(param: str | None) -> int | None: pass # Avoid -from typing import Optional def method(param: Optional[str]) -> Optional[int]: pass ``` ### Python version -Code must work with Python 3.10 or later. Use modern syntax features available in 3.10+. +Code must work with Python 3.10 or later. Use modern syntax features available in 3.10+: -See the **Type hints** section for guidance on preferred type annotation syntax (including unions). +- Use `|` for union types instead of `Union[]` +- Use `X | None` instead of `Optional[X]` -For testing: use `bazel test //py/...` which employs a hermetic Python 3.10+ toolchain (see `/AGENTS.md`). - -For ad-hoc scripts, check your Python version locally before running: +When running tests or code in the terminal, explicitly use `python3.10` or later: ```bash -python --version -# Ensure you have 3.10+; on macOS/Linux use python3.10+ or on Windows py -3.10 +# Use explicitly +python3.10 -c "..." +python3.11 -c "..." + +# Avoid relying on `python3` as it may be 3.9 or earlier ``` ### Documentation diff --git a/py/BUILD.bazel b/py/BUILD.bazel index 2f68f61b77021..d43994c15c531 100644 --- a/py/BUILD.bazel +++ b/py/BUILD.bazel @@ -10,6 +10,7 @@ load("@rules_python//sphinxdocs:sphinx.bzl", "sphinx_build_binary", "sphinx_docs load("//common:defs.bzl", "copy_file") load("//py:defs.bzl", "generate_devtools", "generate_devtools_latest", "py_test_suite") load("//py/private:browsers.bzl", "BROWSERS") +load("//py/private:generate_bidi.bzl", "generate_bidi") load("//py/private:import.bzl", "py_import") exports_files( @@ -596,6 +597,12 @@ py_binary( deps = [requirement("inflection")], ) +py_binary( + name = "generate_bidi", + srcs = ["generate_bidi.py"], + srcs_version = "PY3", +) + [generate_devtools( name = "create-cdp-srcs-{}".format(devtools_version), browser_protocol = "//common/devtools/chromium/{}:browser_protocol".format(devtools_version), @@ -610,6 +617,17 @@ generate_devtools_latest( browser_versions = BROWSER_VERSIONS, ) +# Pilot BiDi code generation from CDDL specification +generate_bidi( + name = "create-bidi-src", + cddl_file = "//common/bidi/spec:all.cddl", + enhancements_manifest = "//py/private:bidi_enhancements_manifest.py", + extra_srcs = ["//py/private:cdp.py"], + generator = ":generate_bidi", + module_name = "selenium/webdriver/common/bidi", + spec_version = "1.0", +) + py_test_suite( name = "unit", size = "small", @@ -789,6 +807,7 @@ BROWSER_TESTS = { ] ] + test_suite( name = "test-remote", tags = ["remote"], diff --git a/py/conftest.py b/py/conftest.py index 0b93b35721f62..f36a27faea84f 100644 --- a/py/conftest.py +++ b/py/conftest.py @@ -118,6 +118,14 @@ def pytest_addoption(parser): metavar="DRIVER", help="Driver to run tests against ({})".format(", ".join(drivers)), ) + parser.addoption( + "--browser", + action="append", + choices=drivers, + dest="drivers", + metavar="BROWSER", + help="Browser to run tests against (alias for --driver)", + ) parser.addoption( "--browser-binary", action="store", diff --git a/py/generate_bidi.py b/py/generate_bidi.py new file mode 100755 index 0000000000000..1770cf436bef1 --- /dev/null +++ b/py/generate_bidi.py @@ -0,0 +1,1824 @@ +#!/usr/bin/env python3 +""" +Generate Python WebDriver BiDi command modules from CDDL specification. + +This generator reads CDDL (Concise Data Definition Language) specification files +and produces Python type definitions and command classes that conform to the +WebDriver BiDi protocol. + +Usage: + python generate_bidi.py + +Example: + python generate_bidi.py local.cddl ./selenium/webdriver/common/bidi 1.0 +""" + +import argparse +import importlib.util +import logging +import re +import sys +from collections import defaultdict +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from textwrap import dedent, indent as tw_indent +from typing import Any, Dict, List, Optional, Set, Tuple + +__version__ = "1.0.0" + +# Logging setup +log_level = logging.INFO +logging.basicConfig(level=log_level) +logger = logging.getLogger("generate_bidi") + +# File headers +SHARED_HEADER = """# DO NOT EDIT THIS FILE! +# +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules.""" + +MODULE_HEADER = f"""{SHARED_HEADER} +# +# WebDriver BiDi module: {{}} +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +""" + + +def indent(s: str, n: int) -> str: + """Indent a string by n spaces.""" + return tw_indent(s, n * " ") + + +def load_enhancements_manifest(manifest_path: Optional[str]) -> Dict[str, Any]: + """Load enhancement manifest from a Python file. + + Args: + manifest_path: Path to Python file containing ENHANCEMENTS dict + + Returns: + Dictionary with enhancement rules, or empty dict if no manifest provided + """ + if not manifest_path: + return {} + + manifest_file = Path(manifest_path) + if not manifest_file.exists(): + logger.warning(f"Enhancement manifest not found: {manifest_path}") + return {} + + try: + spec = importlib.util.spec_from_file_location( + "bidi_enhancements", manifest_file + ) + if spec is None or spec.loader is None: + logger.warning(f"Could not load manifest: {manifest_path}") + return {} + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + enhancements = getattr(module, "ENHANCEMENTS", {}) + dataclass_methods = getattr(module, "DATACLASS_METHOD_TEMPLATES", {}) + method_docstrings = getattr(module, "DATACLASS_METHOD_DOCSTRINGS", {}) + + logger.info(f"Loaded enhancement manifest from: {manifest_path}") + logger.debug(f"Enhancements for modules: {list(enhancements.keys())}") + + return { + "enhancements": enhancements, + "dataclass_methods": dataclass_methods, + "method_docstrings": method_docstrings, + } + except Exception as e: + logger.error(f"Failed to load enhancement manifest: {e}", exc_info=True) + return {} + + +class CddlType(Enum): + """CDDL type mappings to Python types.""" + + TSTR = "str" # text string + TEXT = "str" # text (alias) + UINT = "int" # unsigned integer + INT = "int" # signed integer + NINT = "int" # negative integer + BOOL = "bool" # boolean + NULL = "None" # null + ANY = "Any" # any type + + @classmethod + def get_annotation(cls, cddl_type: str) -> str: + """Get Python type annotation for a CDDL type.""" + cddl_type = cddl_type.strip().lower() + + # Handle basic types + for member in cls: + if cddl_type == member.name.lower(): + return member.value + + # Handle composite types + if cddl_type.startswith("["): # Array + inner = cddl_type.strip("[]+ ") + inner_type = cls.get_annotation(inner) + return f"List[{inner_type}]" + + if cddl_type.startswith("{"): # Map/Dict + return "Dict[str, Any]" + + # Default to Any for unknown types + return "Any" + + +@dataclass +class CddlCommand: + """Represents a CDDL command definition.""" + + module: str + name: str + params: Dict[str, str] = field(default_factory=dict) + result: Optional[str] = None + description: str = "" + + def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + """Generate Python method code for this command. + + Args: + enhancements: Dictionary with enhancement rules for this method + """ + enhancements = enhancements or {} + method_name = self._camel_to_snake(self.name) + + # Build parameter list with type hints + # Check if there's a params_override for user-friendly named arguments + params_to_use = self.params + if "params_override" in enhancements: + params_to_use = enhancements["params_override"] + + param_strs = [] + param_names = [] # Keep track of parameter names for later use + for param_name, param_type in params_to_use.items(): + if param_type in ["bool", "str", "int"]: + python_type = param_type + else: + python_type = CddlType.get_annotation(param_type) + snake_param = self._camel_to_snake(param_name) + param_names.append((param_name, snake_param)) + param_strs.append(f"{snake_param}: {python_type} | None = None") + + if param_strs: + param_list = "self, " + ", ".join(param_strs) + else: + param_list = "self" + + # Build method body + body = f" def {method_name}({param_list}):\n" + body += f' """{self.description or "Execute " + self.module + "." + self.name}."""\n' + + # Add validation if specified + if "validate" in enhancements: + validate_func = enhancements["validate"] + # Build parameter list for validation function + param_args = ", ".join(f"{snake}={snake}" for _, snake in param_names) + body += f" {validate_func}({param_args})\n" + body += "\n" + + # Add transformation and preprocessing + # First, check if any transform is needed + if "transform" in enhancements: + transform_spec = enhancements["transform"] + if isinstance(transform_spec, dict): + # New format with explicit function and result parameter + transform_func = transform_spec.get("func") + result_param = transform_spec.get("result_param", "params") + input_params = [ + transform_spec.get(k) + for k in ["allowed", "destination_folder"] + if transform_spec.get(k) + ] + + if transform_func and result_param: + body += f" {result_param} = None\n" + param_args = ", ".join(input_params) + body += f" {result_param} = {transform_func}({param_args})\n" + body += "\n" + else: + # Legacy format for backward compatibility + transform_func = transform_spec + if self.name == "setDownloadBehavior": + body += " download_behavior = None\n" + body += f" download_behavior = {transform_func}(allowed, destination_folder)\n" + body += "\n" + + # Add preprocessing for serialization (check for to_bidi_dict method) + if "preprocess" in enhancements: + preprocess_rules = enhancements["preprocess"] + for param_name, preprocess_type in preprocess_rules.items(): + snake_param = self._camel_to_snake(param_name) + if preprocess_type == "check_serialize_method": + body += f" if {snake_param} and hasattr({snake_param}, 'to_bidi_dict'):\n" + body += ( + f" {snake_param} = {snake_param}.to_bidi_dict()\n" + ) + body += "\n" + + # Build params dict + body += " params = {\n" + + # If there's a transform with a result parameter, map it to the BiDi protocol name + if "transform" in enhancements and isinstance(enhancements["transform"], dict): + transform_spec = enhancements["transform"] + result_param = transform_spec.get("result_param") + + # Map the result parameter to the original CDDL parameter name + if result_param == "download_behavior": + body += ' "downloadBehavior": download_behavior,\n' + # Add remaining parameters that weren't part of the transform + override_params = enhancements.get("params_override", {}) + for cddl_param_name in self.params: + if cddl_param_name not in ["downloadBehavior"]: + snake_name = self._camel_to_snake(cddl_param_name) + body += f' "{cddl_param_name}": {snake_name},\n' + else: + # Standard parameter mapping from CDDL + for param_name, snake_param in param_names: + body += f' "{param_name}": {snake_param},\n' + + body += " }\n" + body += " params = {k: v for k, v in params.items() if v is not None}\n" + body += f' cmd = command_builder("{self.module}.{self.name}", params)\n' + body += " result = self._conn.execute(cmd)\n" + + # Add response handling for extraction/deserialization + if "extract_field" in enhancements: + extract_field = enhancements["extract_field"] + extract_property = enhancements.get("extract_property") + + # Check if we also need to deserialize the extracted field + deserialize_rules = enhancements.get("deserialize", {}) + + if extract_property: + # Extract property from list items + body += f' if result and "{extract_field}" in result:\n' + body += f' items = result.get("{extract_field}", [])\n' + body += f" return [\n" + body += f' item.get("{extract_property}")\n' + body += f" for item in items\n" + body += f" if isinstance(item, dict)\n" + body += f" ]\n" + body += f" return []\n" + elif extract_field in deserialize_rules: + # Extract field and deserialize to typed objects + type_name = deserialize_rules[extract_field] + body += f' if result and "{extract_field}" in result:\n' + body += f' items = result.get("{extract_field}", [])\n' + body += f" return [\n" + body += f" {type_name}(\n" + body += self._generate_field_args(extract_field, type_name) + body += f" )\n" + body += f" for item in items\n" + body += f" if isinstance(item, dict)\n" + body += f" ]\n" + body += f" return []\n" + else: + # Simple field extraction (return the value directly, not wrapped in result dict) + body += f' if result and "{extract_field}" in result:\n' + body += f' extracted = result.get("{extract_field}")\n' + body += f" return extracted\n" + body += f" return result\n" + elif "deserialize" in enhancements: + # Deserialize response to typed objects (legacy, without extract_field) + deserialize_rules = enhancements["deserialize"] + for response_field, type_name in deserialize_rules.items(): + body += f' if result and "{response_field}" in result:\n' + body += f' items = result.get("{response_field}", [])\n' + body += f" return [\n" + body += f" {type_name}(\n" + body += self._generate_field_args(response_field, type_name) + body += f" )\n" + body += f" for item in items\n" + body += f" if isinstance(item, dict)\n" + body += f" ]\n" + body += f" return []\n" + else: + # No special response handling, just return the result + body += " return result\n" + + return body + + def _generate_field_args(self, response_field: str, type_name: str) -> str: + """Generate constructor arguments for deserializing response objects. + + For now, this handles ClientWindowInfo and Info specifically. + Could be extended to be more generic. + """ + if type_name == "ClientWindowInfo": + return ( + ' active=item.get("active"),\n' + ' client_window=item.get("clientWindow"),\n' + ' height=item.get("height"),\n' + ' state=item.get("state"),\n' + ' width=item.get("width"),\n' + ' x=item.get("x"),\n' + ' y=item.get("y")\n' + ) + elif type_name == "Info": + return ( + ' children=_deserialize_info_list(item.get("children", [])),\n' + ' client_window=item.get("clientWindow"),\n' + ' context=item.get("context"),\n' + ' original_opener=item.get("originalOpener"),\n' + ' url=item.get("url"),\n' + ' user_context=item.get("userContext"),\n' + ' parent=item.get("parent")\n' + ) + # For other types, return empty + return "" + + @staticmethod + def _camel_to_snake(name: str) -> str: + """Convert camelCase to snake_case.""" + name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() + + +@dataclass +class CddlTypeDefinition: + """Represents a CDDL type definition.""" + + module: str + name: str + fields: Dict[str, str] = field(default_factory=dict) + description: str = "" + + def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + """Generate Python dataclass code for this type. + + Args: + enhancements: Dictionary containing dataclass_methods and method_docstrings + """ + enhancements = enhancements or {} + dataclass_methods = enhancements.get("dataclass_methods", {}) + method_docstrings = enhancements.get("method_docstrings", {}) + + # Generate class name from type name (keep it as-is, don't split on underscores) + class_name = self.name + code = f"@dataclass\n" + code += f"class {class_name}:\n" + code += f' """{self.description or self.name}."""\n\n' + + if not self.fields: + code += " pass\n" + else: + for field_name, field_type in self.fields.items(): + # Convert CDDL type to Python type + python_type = self._get_python_type(field_type) + snake_name = CddlCommand._camel_to_snake(field_name) + + # Check if the CDDL field type is a quoted string literal (e.g., type: "key") + # These are discriminant fields: auto-populate and exclude from __init__ + # so callers don't need to pass them as positional or keyword arguments. + literal_match = re.match(r'^"([^"]+)"$', field_type.strip()) + if literal_match: + literal_value = literal_match.group(1) + code += f' {snake_name}: str = field(default="{literal_value}", init=False)\n' + # Check if this field is a list type + elif "List[" in python_type: + code += f" {snake_name}: {python_type} = field(default_factory=list)\n" + else: + code += f" {snake_name}: {python_type} = None\n" + + # Add custom methods if defined for this class + if class_name in dataclass_methods: + code += "\n" + methods_dict = dataclass_methods[class_name] + docstrings_dict = method_docstrings.get(class_name, {}) + + for method_name in methods_dict: + method_impl = methods_dict[method_name] + docstring = docstrings_dict.get(method_name, "") + code += f" def {method_name}(self):\n" + if docstring: + code += f' """{docstring}"""\n' + code += f" {method_impl}\n" + code += "\n" + + return code + + @staticmethod + def _get_python_type(cddl_type: str) -> str: + """Convert CDDL type to Python type annotation using Python 3.10+ syntax.""" + cddl_type = cddl_type.strip().lower() + + # Handle basic types + type_mapping = { + "tstr": "str", + "text": "str", + "uint": "int", + "int": "int", + "nint": "int", + "bool": "bool", + "null": "None", + } + + for cddl, python in type_mapping.items(): + if cddl_type == cddl: + # Use Python 3.10+ union syntax: type | None + return f"{python} | None" + + # Handle arrays + if cddl_type.startswith("["): + inner = cddl_type.strip("[]+ ") + inner_type = CddlTypeDefinition._get_python_type(inner) + # Remove " | None" from inner type since it might be wrapped + if " | None" in inner_type: + inner_base = inner_type.replace(" | None", "") + return f"list[{inner_base} | None] | None" + return f"list[{inner_type}] | None" + + # Handle maps/dicts + if cddl_type.startswith("{"): + return "dict[str, Any] | None" + + # Default to Any for unknown/complex types + return "Any | None" + + +@dataclass +class CddlEnum: + """Represents a CDDL enum definition (string union).""" + + module: str + name: str + values: List[str] = field(default_factory=list) + description: str = "" + + def to_python_class(self) -> str: + """Generate Python enum class code. + + Generates a simple class with string constants to match the existing + pattern in the codebase (e.g., ClientWindowState). + """ + class_name = self.name + code = f"class {class_name}:\n" + code += f' """{self.description or self.name}."""\n\n' + + for value in self.values: + # Convert value to UPPER_SNAKE_CASE constant name + const_name = self._value_to_const_name(value) + code += f' {const_name} = "{value}"\n' + + return code + + @staticmethod + def _value_to_const_name(value: str) -> str: + """Convert enum string value to constant name. + + Examples: + "none" -> "NONE" + "portrait-primary" -> "PORTRAIT_PRIMARY" + "interactive" -> "INTERACTIVE" + """ + # Replace hyphens with underscores + const_name = value.replace("-", "_") + # Convert to uppercase + return const_name.upper() + + +@dataclass +class CddlEvent: + """Represents a CDDL event definition (incoming message from browser).""" + + module: str + name: str + method: str + params_type: str + description: str = "" + + def to_python_dataclass(self) -> str: + """Generate Python dataclass code for the event info type. + + Returns a dataclass code that attempts to use globals().get() for safety. + """ + class_name = self.name + + # Extract the type name from params_type (e.g., "browsingContext.Info" -> "Info") + # The params_type comes from the CDDL and includes module prefix + type_name = ( + self.params_type.split(".")[-1] + if "." in self.params_type + else self.params_type + ) + + # Special case: if the type is BaseNavigationInfo, use BaseNavigationInfo directly + # (NavigationInfo will be created as an alias to it) + if type_name == "NavigationInfo": + type_name = "BaseNavigationInfo" + + # Generate type alias using globals().get() for safety + code = f"# Event: {self.method}\n" + code += f"{class_name} = globals().get('{type_name}', dict) # Fallback to dict if type not defined\n" + + return code + + +@dataclass +class CddlModule: + """Represents a CDDL module (e.g., script, network, browsing_context).""" + + name: str + commands: List[CddlCommand] = field(default_factory=list) + types: List[CddlTypeDefinition] = field(default_factory=list) + enums: List[CddlEnum] = field(default_factory=list) + events: List[CddlEvent] = field(default_factory=list) + + @staticmethod + def _convert_method_to_event_name(method_suffix: str) -> str: + """Convert BiDi method suffix to friendly event name. + + Examples: + "contextCreated" -> "context_created" + "navigationStarted" -> "navigation_started" + "userPromptOpened" -> "user_prompt_opened" + """ + # Convert camelCase to snake_case + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", method_suffix) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + """Generate Python code for this module. + + Args: + enhancements: Dictionary with module-level enhancements + """ + enhancements = enhancements or {} + code = MODULE_HEADER.format(self.name) + + # Add imports if needed + if self.types: + code += "from dataclasses import field\n" + if self.commands or self.types: + code += "from typing import Generator\n" + code += "from dataclasses import dataclass\n" + + # Add imports for event handling if needed + if self.events: + code += "import threading\n" + code += "from collections.abc import Callable\n" + code += "from dataclasses import dataclass\n" + code += "from selenium.webdriver.common.bidi.session import Session\n" + + code += "\n\n" + + # Add helper function definitions from enhancements + # Collect all referenced helper functions (validate, transform) + helper_funcs_to_add = set() + for cmd in self.commands: + method_name_snake = cmd._camel_to_snake(cmd.name) + method_enhancements = enhancements.get(method_name_snake, {}) + if "validate" in method_enhancements: + helper_funcs_to_add.add(("validate", method_enhancements["validate"])) + if "transform" in method_enhancements and isinstance( + method_enhancements["transform"], dict + ): + transform_spec = method_enhancements["transform"] + if "func" in transform_spec: + helper_funcs_to_add.add(("transform", transform_spec["func"])) + + # Generate helper functions if needed + if helper_funcs_to_add: + for func_type, func_name in sorted(helper_funcs_to_add): + if ( + func_type == "validate" + and func_name == "validate_download_behavior" + ): + code += """def validate_download_behavior( + allowed: bool | None, + destination_folder: str | None, + user_contexts: Any | None = None, +) -> None: + \"\"\"Validate download behavior parameters. + + Args: + allowed: Whether downloads are allowed + destination_folder: Destination folder for downloads + user_contexts: Optional list of user contexts + + Raises: + ValueError: If parameters are invalid + \"\"\" + if allowed is True and not destination_folder: + raise ValueError("destination_folder is required when allowed=True") + if allowed is False and destination_folder: + raise ValueError("destination_folder should not be provided when allowed=False") + + +""" + elif ( + func_type == "transform" + and func_name == "transform_download_params" + ): + code += """def transform_download_params( + allowed: bool | None, + destination_folder: str | None, +) -> dict[str, Any] | None: + \"\"\"Transform download parameters into download_behavior object. + + Args: + allowed: Whether downloads are allowed + destination_folder: Destination folder for downloads (accepts str or + pathlib.Path; will be coerced to str) + + Returns: + Dictionary representing the download_behavior object, or None if allowed is None + \"\"\" + if allowed is True: + return { + "type": "allowed", + # Coerce pathlib.Path (or any path-like) to str so the BiDi + # protocol always receives a plain JSON string. + "destinationFolder": str(destination_folder) if destination_folder is not None else None, + } + elif allowed is False: + return {"type": "denied"} + else: # None — reset to browser default (sent as JSON null) + return None + + +""" + + # Generate enums first + for enum_def in self.enums: + code += enum_def.to_python_class() + code += "\n\n" + + # Emit module-level aliases from enhancement manifest (e.g. LogLevel = Level) + for alias, target in enhancements.get("aliases", {}).items(): + code += f"{alias} = {target}\n\n" + + # Generate type dataclasses, skipping any overridden by extra_dataclasses + exclude_types = set(enhancements.get("exclude_types", [])) + for type_def in self.types: + if type_def.name in exclude_types: + continue + code += type_def.to_python_dataclass(enhancements) + code += "\n\n" + + # Emit extra dataclasses from enhancement manifest (non-CDDL additions) + for extra_cls in enhancements.get("extra_dataclasses", []): + code += extra_cls + code += "\n\n" + + # NOTE: Don't generate event type aliases here - they reference types that may not be defined yet + # They will be generated after the class definition instead + + # Generate EVENT_NAME_MAPPING for the module (before the module class) + if self.events: + # Generate EVENT_NAME_MAPPING for the module + code += "# BiDi Event Name to Parameter Type Mapping\n" + code += "EVENT_NAME_MAPPING = {\n" + for event_def in self.events: + # Convert method name to user-friendly event name + # e.g., "browsingContext.contextCreated" -> "context_created" + method_parts = event_def.method.split(".") + if len(method_parts) == 2: + event_name = self._convert_method_to_event_name(method_parts[1]) + code += f' "{event_name}": "{event_def.method}",\n' + # Extra events not in the CDDL spec (e.g. Chromium-specific events) + for extra_evt in enhancements.get("extra_events", []): + code += ( + f' "{extra_evt["event_key"]}": "{extra_evt["bidi_event"]}",\n' + ) + code += "}\n\n" + + # Add custom method function definitions before the class (for browsingContext) + if self.name == "browsingContext": + # Add helper function for recursive Info deserialization + code += """def _deserialize_info_list(items: list) -> list | None: + \"\"\"Recursively deserialize a list of dicts to Info objects. + + Args: + items: List of dicts from the API response + + Returns: + List of Info objects with properly nested children, or None if empty + \"\"\" + if not items or not isinstance(items, list): + return None + + result = [] + for item in items: + if isinstance(item, dict): + # Recursively deserialize children only if the key exists in response + children_list = None + if "children" in item: + children_list = _deserialize_info_list(item.get("children", [])) + info = Info( + children=children_list, + client_window=item.get("clientWindow"), + context=item.get("context"), + original_opener=item.get("originalOpener"), + url=item.get("url"), + user_context=item.get("userContext"), + parent=item.get("parent"), + ) + result.append(info) + return result if result else None + + +""" + code += "\n\n" + + # Generate EventConfig and _EventManager for modules with events + if self.events: + # Generate EventConfig dataclass + code += """@dataclass +class EventConfig: + \"\"\"Configuration for a BiDi event.\"\"\" + event_key: str + bidi_event: str + event_class: type + + +""" + + # Generate _EventManager class + code += """class _EventWrapper: + \"\"\"Wrapper to provide event_class attribute for WebSocketConnection callbacks.\"\"\" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + \"\"\"Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + \"\"\" + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, \"from_json\") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend([\"_\", char.lower()]) + else: + result.append(char) + return \"\".join(result) + + +class _EventManager: + \"\"\"Manages event subscriptions and callbacks.\"\"\" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + \"\"\"Subscribe to a BiDi event if not already subscribed.\"\"\" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get(\"subscription\") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + \"callbacks\": [], + \"subscription_id\": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + \"\"\"Unsubscribe from a BiDi event if no more callbacks exist.\"\"\" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry[\"callbacks\"]: + session = Session(self.conn) + sub_id = entry.get(\"subscription_id\") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event][\"callbacks\"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry[\"callbacks\"]: + entry[\"callbacks\"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + \"\"\"Clear all event handlers.\"\"\" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry[\"callbacks\"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get(\"subscription_id\") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() + + +""" + code += "\n\n" + + # Generate class + # Convert module name (camelCase or snake_case) to proper class name (PascalCase) + class_name = module_name_to_class_name(self.name) + code += f"class {class_name}:\n" + code += f' """WebDriver BiDi {self.name} module."""\n\n' + + # Add EVENT_CONFIGS dict if there are events + if self.events: + code += ( + " EVENT_CONFIGS = {}\n" # Will be populated after types are defined + ) + + if self.name == "script": + code += " def __init__(self, conn, driver=None) -> None:\n" + code += " self._conn = conn\n" + code += " self._driver = driver\n" + else: + code += " def __init__(self, conn) -> None:\n" + code += " self._conn = conn\n" + + # Initialize _event_manager if there are events + if self.events: + code += " self._event_manager = _EventManager(conn, self.EVENT_CONFIGS)\n" + + # Append extra init code from enhancements (e.g. self.intercepts = []) + for init_line in enhancements.get("extra_init_code", []): + code += f" {init_line}\n" + + code += "\n" + + # Generate command methods + exclude_methods = enhancements.get("exclude_methods", []) + if self.commands: + for command in self.commands: + # Get method-specific enhancements + # Convert command name to snake_case to match enhancement manifest keys + method_name_snake = command._camel_to_snake(command.name) + if method_name_snake in exclude_methods: + continue + method_enhancements = enhancements.get(method_name_snake, {}) + code += command.to_python_method(method_enhancements) + code += "\n" + else: + code += " pass\n" + + # Emit extra methods from enhancement manifest + for extra_method in enhancements.get("extra_methods", []): + code += extra_method + code += "\n" + + # Add delegating event handler methods if events are present + if self.events: + code += """ + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + \"\"\"Add an event handler. + + Args: + event: The event to subscribe to. + callback: The callback function to execute on event. + contexts: The context IDs to subscribe to (optional). + + Returns: + The callback ID. + \"\"\" + return self._event_manager.add_event_handler(event, callback, contexts) + + def remove_event_handler(self, event: str, callback_id: int) -> None: + \"\"\"Remove an event handler. + + Args: + event: The event to unsubscribe from. + callback_id: The callback ID. + \"\"\" + return self._event_manager.remove_event_handler(event, callback_id) + + def clear_event_handlers(self) -> None: + \"\"\"Clear all event handlers.\"\"\" + return self._event_manager.clear_event_handlers() +""" + + # Generate event info type aliases AFTER the class definition + # This ensures all types are available when we create the aliases + if self.events: + code += "\n# Event Info Type Aliases\n" + for event_def in self.events: + code += event_def.to_python_dataclass() + code += "\n" + + # Now populate EVENT_CONFIGS after the aliases are defined + code += f"\n# Populate EVENT_CONFIGS with event configuration mappings\n" + # Use globals() to look up types dynamically to handle missing types gracefully + code += f"_globals = globals()\n" + code += f"{class_name}.EVENT_CONFIGS = {{\n" + for event_def in self.events: + # Convert method name to user-friendly event name + method_parts = event_def.method.split(".") + if len(method_parts) == 2: + event_name = self._convert_method_to_event_name(method_parts[1]) + # The event class is the event name (e.g., ContextCreated) + # Try to get it from globals, default to dict if not found + code += f' "{event_name}": (EventConfig("{event_name}", "{event_def.method}", _globals.get("{event_def.name}", dict)) if _globals.get("{event_def.name}") else EventConfig("{event_name}", "{event_def.method}", dict)),\n' + # Extra events not in the CDDL spec + for extra_evt in enhancements.get("extra_events", []): + ek = extra_evt["event_key"] + be = extra_evt["bidi_event"] + ec = extra_evt["event_class"] + code += f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),\n' + code += "}\n" + + return code + + +class CddlParser: + """Parse CDDL specification files.""" + + def __init__(self, cddl_path: str): + """Initialize parser with CDDL file path.""" + self.cddl_path = Path(cddl_path) + self.content = "" + self.modules: Dict[str, CddlModule] = {} + self.definitions: Dict[str, str] = {} + self.event_names: Set[str] = set() # Names of definitions that are events + self._read_file() + + def _read_file(self) -> None: + """Read and preprocess CDDL file.""" + if not self.cddl_path.exists(): + raise FileNotFoundError(f"CDDL file not found: {self.cddl_path}") + + with open(self.cddl_path, "r", encoding="utf-8") as f: + self.content = f.read() + + logger.info(f"Loaded CDDL file: {self.cddl_path}") + + def parse(self) -> Dict[str, CddlModule]: + """Parse CDDL content and return modules.""" + # Remove comments + content = self._remove_comments(self.content) + + # Extract all definitions + self._extract_definitions(content) + + # Extract event names from event union definitions + self._extract_event_names() + + # Extract type definitions by module + self._extract_types() + + # Extract event definitions by module + self._extract_events() + + # Extract command definitions by module + self._extract_commands() + + # If no modules found, create a default one from the filename + if not self.modules: + module_name = self.cddl_path.stem + default_module = CddlModule(name=module_name) + self.modules[module_name] = default_module + logger.warning(f"No modules found in CDDL, creating default: {module_name}") + + return self.modules + + def _remove_comments(self, content: str) -> str: + """Remove comments from CDDL content.""" + # CDDL uses ; for comments to end of line + lines = content.split("\n") + cleaned = [] + for line in lines: + if ";" in line and not line.strip().startswith(";"): + line = line[: line.index(";")] + elif line.strip().startswith(";"): + continue + cleaned.append(line) + return "\n".join(cleaned) + + def _extract_definitions(self, content: str) -> None: + """Extract CDDL definitions (type definitions, commands, etc.).""" + # Match pattern: Name = Definition + # Handles multiline definitions properly + pattern = r"(\w+(?:\.\w+)*)\s*=\s*(.+?)(?=\n\w+(?:\.\w+)?\s*=|\Z)" + + for match in re.finditer(pattern, content, re.DOTALL): + name = match.group(1).strip() + definition = match.group(2).strip() + self.definitions[name] = definition + logger.debug(f"Extracted definition: {name}") + + def _extract_event_names(self) -> None: + """Extract event names from event union definitions. + + Event union definitions follow pattern: + module.ModuleEvent = ( + module.EventName1 // + module.EventName2 // + ... + ) + """ + # Look for definitions like "BrowsingContextEvent", "SessionEvent", etc. + event_union_pattern = re.compile(r"(\w+\.)?(\w+)Event") + + for def_name, def_content in self.definitions.items(): + # Check if this looks like an event union (name ends with "Event") and + # contains a module-qualified reference like "module.EventName". + # Handles both single-item (no //) and multi-item (// separated) unions. + if "Event" in def_name and re.search(r"\w+\.\w+", def_content): + # Extract event names from the union (works for single and multi-item) + event_refs = re.findall(r"(\w+\.\w+)", def_content) + for event_ref in event_refs: + self.event_names.add(event_ref) + logger.debug(f"Identified event: {event_ref} (from {def_name})") + + def _extract_types(self) -> None: + """Extract type definitions from parsed definitions.""" + # Type definitions follow pattern: module.TypeName = { field: type, ... } + # They have dots in the name and curly braces in the content + # But they DON'T have method: "..." pattern (which means it's not a command) + # Enums follow pattern: module.EnumName = "value1" / "value2" / ... + + for def_name, def_content in self.definitions.items(): + # Skip if not a namespaced name (e.g., skip "EmptyParams", "Extensible") + if "." not in def_name: + continue + + # Skip if it's a command (contains method: pattern) + if "method:" in def_content: + continue + + # Extract module.TypeName + if "." in def_name: + module_name, type_name = def_name.rsplit(".", 1) + + # Create module if not exists + if module_name not in self.modules: + self.modules[module_name] = CddlModule(name=module_name) + + # Check if this is an enum (string union with /) + if self._is_enum_definition(def_content): + # Extract enum values + values = self._extract_enum_values(def_content) + if values: + enum_def = CddlEnum( + module=module_name, + name=type_name, + values=values, + description=f"{type_name}", + ) + self.modules[module_name].enums.append(enum_def) + logger.debug( + f"Found enum: {def_name} with {len(values)} values" + ) + else: + # Extract fields from type definition + fields = self._extract_type_fields(def_content) + + if fields: # Only create type if it has fields + type_def = CddlTypeDefinition( + module=module_name, + name=type_name, + fields=fields, + description=f"{type_name}", + ) + self.modules[module_name].types.append(type_def) + logger.debug( + f"Found type: {def_name} with {len(fields)} fields" + ) + + def _is_enum_definition(self, definition: str) -> bool: + """Check if a definition is an enum (string union with /). + + Enums are defined as: "value1" / "value2" / "value3" + """ + # Clean whitespace + clean_def = definition.strip() + + # Must not have curly braces (that would be a type definition) + if "{" in clean_def or "}" in clean_def: + return False + + # Must contain the union operator / surrounded by quotes + # Pattern: "something" / "something_else" + return " / " in clean_def and '"' in clean_def + + def _extract_enum_values(self, enum_definition: str) -> List[str]: + """Extract individual values from an enum definition. + + Enums are defined as: "value1" / "value2" / "value3" + Can span multiple lines. + """ + values = [] + + # Clean the definition and extract quoted strings + # Split by / and extract quoted values + parts = enum_definition.split("/") + + for part in parts: + part = part.strip() + + # Extract quoted string - use search instead of match to find quotes anywhere + match = re.search(r'"([^"]*)"', part) + if match: + value = match.group(1) + values.append(value) + logger.debug(f"Extracted enum value: {value}") + + return values + + @staticmethod + def _normalize_cddl_type(field_type: str) -> str: + """Normalize a CDDL type expression to a simple Python-compatible form. + + Strips CDDL control operators (.ge, .le, .gt, .lt, .default, etc.) and + replaces interval/constraint expressions with their base types so that + the caller can safely check for nested struct syntax. + + Examples: + '(float .ge 0.0) .default 1.0' -> 'float' + '(float .ge 0.0) / null' -> 'float / null' + '(0.0...360.0) / null' -> 'float / null' + '-90.0..90.0' -> 'float' + 'float / null .default null' -> 'float / null' + """ + result = field_type + # Remove trailing .default annotations + result = re.sub(r"\s*\.default\s+\S+", "", result) + # Replace parenthesised constraint expressions: (baseType .operator ...) -> baseType + result = re.sub(r"\((\w+)\s+\.\w+[^)]*\)", r"\1", result) + # Replace parenthesised numeric interval types: (0.0...360.0) -> float + result = re.sub(r"\(-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?\)", "float", result) + # Replace bare numeric interval types: -90.0..90.0 -> float + result = re.sub(r"-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?", "float", result) + return result.strip() + + def _extract_type_fields(self, type_definition: str) -> Dict[str, str]: + """Extract fields from a type definition block.""" + fields = {} + + # Remove outer braces + clean_def = type_definition.strip() + if clean_def.startswith("{"): + clean_def = clean_def[1:] + if clean_def.endswith("}"): + clean_def = clean_def[:-1] + + # Parse each line for field: type patterns + for line in clean_def.split("\n"): + line = line.strip() + if not line or "Extensible" in line or line.startswith("//"): + continue + + # Match pattern: [?] fieldName: type + match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + if not match: + # Try without optional marker + match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + + if match: + field_name = match.group(1).strip() + field_type = match.group(2).strip() + normalized_type = self._normalize_cddl_type(field_type) + + # Skip lines that are part of nested definitions + if "{" not in normalized_type and "(" not in normalized_type: + fields[field_name] = normalized_type + logger.debug(f"Extracted field {field_name}: {normalized_type}") + + return fields + + def _extract_events(self) -> None: + """Extract event definitions from parsed definitions. + + Events are definitions that: + 1. Are listed in an event union (e.g., BrowsingContextEvent) + 2. Have method: "..." and params: ... fields + + Event pattern: module.EventName = (method: "module.eventName", params: module.ParamType) + """ + # Find definitions that are in the event_names set + event_pattern = re.compile( + r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" + ) + + for def_name, def_content in self.definitions.items(): + # Skip if not identified as an event + if def_name not in self.event_names: + continue + + # Extract method and params + match = event_pattern.search(def_content) + if match: + method = match.group(1) # e.g., "browsingContext.contextCreated" + params_type = match.group(2) # e.g., "browsingContext.Info" + + # Extract module name from method + if "." in method: + module_name, _ = method.split(".", 1) + + # Create module if not exists + if module_name not in self.modules: + self.modules[module_name] = CddlModule(name=module_name) + + # Extract event name from definition name (e.g., browsingContext.ContextCreated) + _, event_name = def_name.rsplit(".", 1) + + # Create event + event = CddlEvent( + module=module_name, + name=event_name, + method=method, + params_type=params_type, + description=f"Event: {method}", + ) + + self.modules[module_name].events.append(event) + logger.debug( + f"Found event: {def_name} (method={method}, params={params_type})" + ) + + def _extract_commands(self) -> None: + """Extract command definitions from parsed definitions.""" + # Find command definitions that follow pattern: module.Command = (method: "...", params: ...) + command_pattern = re.compile( + r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" + ) + + for def_name, def_content in self.definitions.items(): + # Skip definitions that are events (they share the same pattern) + if def_name in self.event_names: + continue + matches = list(command_pattern.finditer(def_content)) + if matches: + for match in matches: + method = match.group(1) # e.g., "session.new" + params_type = match.group(2) # e.g., "session.NewParameters" + + # Extract module name from method + if "." in method: + module_name, command_name = method.split(".", 1) + + # Create module if not exists + if module_name not in self.modules: + self.modules[module_name] = CddlModule(name=module_name) + + # Extract parameters + params = self._extract_parameters(params_type) + + # Create command + cmd = CddlCommand( + module=module_name, + name=command_name, + params=params, + description=f"Execute {method}", + ) + + self.modules[module_name].commands.append(cmd) + logger.debug( + f"Found command: {method} with params {params_type}" + ) + + def _extract_parameters( + self, params_type: str, _seen: Optional[Set[str]] = None + ) -> Dict[str, str]: + """Extract parameters from a parameter type definition. + + Handles both struct types ({...}) and top-level union types (TypeA / TypeB), + merging all fields from each alternative as optional parameters. + """ + params = {} + + if _seen is None: + _seen = set() + if params_type in _seen: + return params + _seen.add(params_type) + + if params_type not in self.definitions: + logger.debug(f"Parameter type not found: {params_type}") + return params + + definition = self.definitions[params_type] + + # Handle top-level type alias that is a union of other named types: + # e.g. session.UnsubscribeByAttributesRequest / session.UnsubscribeByIDRequest + # These definitions contain a single line with "/" separating type names + # (not the double-slash "//" used for command unions). + stripped = definition.strip() + if not stripped.startswith("{") and "/" in stripped and "//" not in stripped: + # Each token separated by "/" should be a named type reference + alternatives = [a.strip() for a in stripped.split("/") if a.strip()] + all_named = all(re.match(r"^[\w.]+$", a) for a in alternatives) + if all_named: + for alt_type in alternatives: + alt_params = self._extract_parameters(alt_type, _seen) + params.update(alt_params) + return params + + # Remove the outer curly braces and split by comma + # Then parse each line for key: type patterns + clean_def = stripped + if clean_def.startswith("{"): + clean_def = clean_def[1:] + if clean_def.endswith("}"): + clean_def = clean_def[:-1] + + # Split by newlines and process each line + for line in clean_def.split("\n"): + line = line.strip() + if not line or "Extensible" in line: + continue + + # Match pattern: [?] name: type + # Using a simple pattern that handles optional prefix + match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + if not match: + # Try without optional marker + match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + + if match: + param_name = match.group(1).strip() + param_type = match.group(2).strip() + normalized_type = self._normalize_cddl_type(param_type) + + # Skip lines that are part of nested definitions + if "{" not in normalized_type and "(" not in normalized_type: + params[param_name] = normalized_type + logger.debug( + f"Extracted param {param_name}: {normalized_type} from {params_type}" + ) + + return params + + +def module_name_to_class_name(module_name: str) -> str: + """Convert module name to class name (PascalCase). + + Handles both camelCase (browsingContext) and snake_case (browsing_context). + """ + if "_" in module_name: + # Snake_case: browsing_context -> BrowsingContext + return "".join(word.capitalize() for word in module_name.split("_")) + else: + # CamelCase: browsingContext -> BrowsingContext + return module_name[0].upper() + module_name[1:] if module_name else "" + + +def module_name_to_filename(module_name: str) -> str: + """Convert module name to Python filename (snake_case). + + Handles both camelCase (browsingContext) and snake_case (browsing_context). + Special cases: + - browsingContext -> browsing_context + - webExtension -> webextension + """ + # Handle explicit mappings for known camelCase names + camel_to_snake_map = { + "browsingContext": "browsing_context", + "webExtension": "webextension", + } + + if module_name in camel_to_snake_map: + return camel_to_snake_map[module_name] + + if "_" in module_name: + # Already snake_case + return module_name + else: + # Convert camelCase to snake_case for other cases + # This handles cases like "myModuleName" -> "my_module_name" + import re + + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", module_name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + +def generate_init_file(output_path: Path, modules: Dict[str, CddlModule]) -> None: + """Generate __init__.py file for the module.""" + init_path = output_path / "__init__.py" + + code = f"""{SHARED_HEADER} + +from __future__ import annotations + +""" + + for module_name in sorted(modules.keys()): + class_name = module_name_to_class_name(module_name) + filename = module_name_to_filename(module_name) + code += f"from .{filename} import {class_name}\n" + + code += f"\n__all__ = [\n" + for module_name in sorted(modules.keys()): + class_name = module_name_to_class_name(module_name) + code += f' "{class_name}",\n' + code += "]\n" + + with open(init_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {init_path}") + + +def generate_common_file(output_path: Path) -> None: + """Generate common.py file with shared utilities.""" + common_path = output_path / "common.py" + + code = ( + "# Licensed to the Software Freedom Conservancy (SFC) under one\n" + "# or more contributor license agreements. See the NOTICE file\n" + "# distributed with this work for additional information\n" + "# regarding copyright ownership. The SFC licenses this file\n" + "# to you under the Apache License, Version 2.0 (the\n" + '# "License"); you may not use this file except in compliance\n' + "# with the License. You may obtain a copy of the License at\n" + "#\n" + "# http://www.apache.org/licenses/LICENSE-2.0\n" + "#\n" + "# Unless required by applicable law or agreed to in writing,\n" + "# software distributed under the License is distributed on an\n" + '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' + "# KIND, either express or implied. See the License for the\n" + "# specific language governing permissions and limitations\n" + "# under the License.\n" + "\n" + '"""Common utilities for BiDi command construction."""\n' + "\n" + "from typing import Any, Dict, Generator\n" + "\n" + "\n" + "def command_builder(\n" + " method: str, params: Dict[str, Any]\n" + ") -> Generator[Dict[str, Any], Any, Any]:\n" + ' """Build a BiDi command generator.\n' + "\n" + " Args:\n" + ' method: The BiDi method name (e.g., "session.status", "browser.close")\n' + " params: The parameters for the command\n" + "\n" + " Yields:\n" + " A dictionary representing the BiDi command\n" + "\n" + " Returns:\n" + " The result from the BiDi command execution\n" + ' """\n' + ' result = yield {"method": method, "params": params}\n' + " return result\n" + ) + + with open(common_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {common_path}") + + +def generate_console_file(output_path: Path) -> None: + """Generate console.py file with Console enum helper.""" + console_path = output_path / "console.py" + + code = ( + "# Licensed to the Software Freedom Conservancy (SFC) under one\n" + "# or more contributor license agreements. See the NOTICE file\n" + "# distributed with this work for additional information\n" + "# regarding copyright ownership. The SFC licenses this file\n" + "# to you under the Apache License, Version 2.0 (the\n" + '# "License"); you may not use this file except in compliance\n' + "# with the License. You may obtain a copy of the License at\n" + "#\n" + "# http://www.apache.org/licenses/LICENSE-2.0\n" + "#\n" + "# Unless required by applicable law or agreed to in writing,\n" + "# software distributed under the License is distributed on an\n" + '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' + "# KIND, either express or implied. See the License for the\n" + "# specific language governing permissions and limitations\n" + "# under the License.\n" + "\n" + "from enum import Enum\n" + "\n" + "\n" + "class Console(Enum):\n" + ' ALL = "all"\n' + ' LOG = "log"\n' + ' ERROR = "error"\n' + ) + + with open(console_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {console_path}") + + +def generate_permissions_file(output_path: Path) -> None: + """Generate permissions.py file with permission-related classes.""" + permissions_path = output_path / "permissions.py" + + code = ( + "# Licensed to the Software Freedom Conservancy (SFC) under one\n" + "# or more contributor license agreements. See the NOTICE file\n" + "# distributed with this work for additional information\n" + "# regarding copyright ownership. The SFC licenses this file\n" + "# to you under the Apache License, Version 2.0 (the\n" + '# "License"); you may not use this file except in compliance\n' + "# with the License. You may obtain a copy of the License at\n" + "#\n" + "# http://www.apache.org/licenses/LICENSE-2.0\n" + "#\n" + "# Unless required by applicable law or agreed to in writing,\n" + "# software distributed under the License is distributed on an\n" + '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' + "# KIND, either express or implied. See the License for the\n" + "# specific language governing permissions and limitations\n" + "# under the License.\n" + "\n" + '"""WebDriver BiDi Permissions module."""\n' + "\n" + "from __future__ import annotations\n" + "\n" + "from enum import Enum\n" + "from typing import Any, Optional, Union\n" + "\n" + "from .common import command_builder\n" + "\n" + '_VALID_PERMISSION_STATES = {"granted", "denied", "prompt"}\n' + "\n" + "\n" + "class PermissionState(str, Enum):\n" + ' """Permission state enumeration."""\n' + "\n" + ' GRANTED = "granted"\n' + ' DENIED = "denied"\n' + ' PROMPT = "prompt"\n' + "\n" + "\n" + "class PermissionDescriptor:\n" + ' """Descriptor for a permission."""\n' + "\n" + " def __init__(self, name: str) -> None:\n" + ' """Initialize a PermissionDescriptor.\n' + "\n" + " Args:\n" + " name: The name of the permission (e.g., 'geolocation', 'microphone', 'camera')\n" + ' """\n' + " self.name = name\n" + "\n" + " def __repr__(self) -> str:\n" + " return f\"PermissionDescriptor('{self.name}')\"\n" + "\n" + "\n" + "class Permissions:\n" + ' """WebDriver BiDi Permissions module."""\n' + "\n" + " def __init__(self, websocket_connection: Any) -> None:\n" + ' """Initialize the Permissions module.\n' + "\n" + " Args:\n" + " websocket_connection: The WebSocket connection for sending BiDi commands\n" + ' """\n' + " self._conn = websocket_connection\n" + "\n" + " def set_permission(\n" + " self,\n" + " descriptor: Union[PermissionDescriptor, str],\n" + " state: Union[PermissionState, str],\n" + " origin: Optional[str] = None,\n" + " user_context: Optional[str] = None,\n" + " ) -> None:\n" + ' """Set a permission for a given origin.\n' + "\n" + " Args:\n" + " descriptor: The permission descriptor or permission name as a string\n" + " state: The desired permission state\n" + " origin: The origin for which to set the permission\n" + " user_context: Optional user context ID to scope the permission\n" + "\n" + " Raises:\n" + " ValueError: If the state is not a valid permission state\n" + ' """\n' + " state_value = state.value if isinstance(state, PermissionState) else state\n" + " if state_value not in _VALID_PERMISSION_STATES:\n" + " raise ValueError(\n" + ' f"Invalid permission state: {state_value!r}. "\n' + ' f"Must be one of {sorted(_VALID_PERMISSION_STATES)}"\n' + " )\n" + "\n" + " if isinstance(descriptor, str):\n" + ' descriptor_dict = {"name": descriptor}\n' + " else:\n" + ' descriptor_dict = {"name": descriptor.name}\n' + "\n" + " params: dict[str, Any] = {\n" + ' "descriptor": descriptor_dict,\n' + ' "state": state_value,\n' + " }\n" + " if origin is not None:\n" + ' params["origin"] = origin\n' + " if user_context is not None:\n" + ' params["userContext"] = user_context\n' + "\n" + ' cmd = command_builder("permissions.setPermission", params)\n' + " self._conn.execute(cmd)\n" + ) + + with open(permissions_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {permissions_path}") + + +def main( + cddl_file: str, + output_dir: str, + spec_version: str = "1.0", + enhancements_manifest: Optional[str] = None, +) -> None: + """Main entry point. + + Args: + cddl_file: Path to CDDL specification file + output_dir: Output directory for generated modules + spec_version: BiDi spec version + enhancements_manifest: Path to enhancement manifest Python file + """ + output_path = Path(output_dir).resolve() + output_path.mkdir(parents=True, exist_ok=True) + + logger.info(f"WebDriver BiDi Code Generator v{__version__}") + logger.info(f"Input CDDL: {cddl_file}") + logger.info(f"Output directory: {output_path}") + logger.info(f"Spec version: {spec_version}") + + # Load enhancement manifest + manifest = load_enhancements_manifest(enhancements_manifest) + if manifest: + logger.info(f"Loaded enhancement manifest from: {enhancements_manifest}") + + # Parse CDDL + parser = CddlParser(cddl_file) + modules = parser.parse() + + logger.info(f"Parsed {len(modules)} modules") + + # Clean up existing generated files + for file_path in output_path.glob("*.py"): + if file_path.name != "py.typed" and not file_path.name.startswith("_"): + file_path.unlink() + logger.debug(f"Removed: {file_path}") + + # Generate module files using snake_case filenames + for module_name, module in sorted(modules.items()): + filename = module_name_to_filename(module_name) + module_path = output_path / f"{filename}.py" + + # Get module-specific enhancements (merge with dataclass templates) + module_enhancements = manifest.get("enhancements", {}).get(module_name, {}) + + # Add dataclass methods and docstrings to the enhancement data for this module + full_module_enhancements = { + **module_enhancements, + "dataclass_methods": manifest.get("dataclass_methods", {}), + "method_docstrings": manifest.get("method_docstrings", {}), + } + + with open(module_path, "w", encoding="utf-8") as f: + f.write(module.generate_code(full_module_enhancements)) + logger.info(f"Generated: {module_path}") + + # Generate __init__.py + generate_init_file(output_path, modules) + + # Generate common.py + generate_common_file(output_path) + + # Generate permissions.py + generate_permissions_file(output_path) + + # Generate console.py + generate_console_file(output_path) + + # Create py.typed marker + py_typed_path = output_path / "py.typed" + py_typed_path.touch() + logger.info(f"Generated type marker: {py_typed_path}") + + logger.info("Code generation complete!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate Python WebDriver BiDi modules from CDDL specification" + ) + parser.add_argument( + "cddl_file", + help="Path to CDDL specification file", + ) + parser.add_argument( + "output_dir", + help="Output directory for generated Python modules", + ) + parser.add_argument( + "--version", + default="1.0", + help="BiDi spec version (default: 1.0)", + ) + parser.add_argument( + "--enhancements-manifest", + default=None, + help="Path to enhancement manifest Python file (optional)", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Enable verbose logging", + ) + + args = parser.parse_args() + + if args.verbose: + logging.getLogger("generate_bidi").setLevel(logging.DEBUG) + + try: + main( + args.cddl_file, + args.output_dir, + args.version, + args.enhancements_manifest, + ) + sys.exit(0) + except Exception as e: + logger.error(f"Generation failed: {e}", exc_info=True) + sys.exit(1) diff --git a/py/private/BUILD.bazel b/py/private/BUILD.bazel index 8b02ac341a0dc..88acc9d2aba11 100644 --- a/py/private/BUILD.bazel +++ b/py/private/BUILD.bazel @@ -1,5 +1,10 @@ load("@rules_python//python:defs.bzl", "py_binary") +exports_files([ + "bidi_enhancements_manifest.py", + "cdp.py", +]) + py_binary( name = "untar", srcs = [ diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py new file mode 100644 index 0000000000000..ae7229f6ddebd --- /dev/null +++ b/py/private/bidi_enhancements_manifest.py @@ -0,0 +1,1557 @@ +""" +Enhancement manifest for BiDi code generation. + +This file defines custom enhancements applied to generated BiDi modules, +including custom dataclass methods, parameter validation/transformation, +response deserialization, and field extraction. + +All code must be compatible with Python 3.10+. +""" + +from __future__ import annotations + +from typing import Any + +# ============================================================================ +# Format Guide +# ============================================================================ +# Each module in ENHANCEMENTS specifies enhancement rules for methods: +# +# 'module_name': { +# 'method_name': { +# 'dataclass_methods': { # For dataclass enhancements +# 'ClassName': ['method1', 'method2', ...] +# }, +# 'preprocess': { # Pre-processing on parameters +# 'param_name': 'check_serialize_method' +# }, +# 'deserialize': { # Deserialize response to typed objects +# 'response_field': 'TypeName', +# }, +# 'extract_field': str, # Extract nested field from response +# 'extract_property': str, # Extract property from extracted items +# 'validate': str, # Validation function name +# 'transform': str, # Transformation function name +# } +# } +# ============================================================================ + +ENHANCEMENTS: dict[str, dict[str, Any]] = { + "browser": { + # Dataclass custom methods + "__dataclass_methods__": { + "ClientWindowInfo": [ + "get_client_window", + "get_state", + "get_width", + "get_height", + "is_active", + "get_x", + "get_y", + ], + }, + # Method enhancements + "create_user_context": { + "preprocess": { + "proxy": "check_serialize_method", + "unhandled_prompt_behavior": "check_serialize_method", + }, + "extract_field": "userContext", + }, + "get_client_windows": { + "deserialize": { + "clientWindows": "ClientWindowInfo", + }, + }, + "get_user_contexts": { + "extract_field": "userContexts", + "extract_property": "userContext", + }, + "set_download_behavior": { + "params_override": { + "allowed": "bool", + "destination_folder": "str", + "userContexts": "[*browser.UserContext]", + }, + "validate": "validate_download_behavior", + "transform": { + "allowed": "allowed", + "destination_folder": "destination_folder", + "func": "transform_download_params", + "result_param": "download_behavior", + }, + }, + # Override the generator-produced set_download_behavior so that + # downloadBehavior is never stripped by the generic None filter. + # The BiDi spec marks it as required (can be null, but must be present). + "extra_methods": [ + ''' def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): + """Set the download behavior for the browser. + + Args: + allowed: ``True`` to allow downloads, ``False`` to deny, or ``None`` + to reset to browser default (sends ``null`` to the protocol). + destination_folder: Destination folder for downloads. Required when + ``allowed=True``. Accepts a string or :class:`pathlib.Path`. + user_contexts: Optional list of user context IDs. + + Raises: + ValueError: If *allowed* is ``True`` and *destination_folder* is + omitted, or ``False`` and *destination_folder* is provided. + """ + validate_download_behavior( + allowed=allowed, + destination_folder=destination_folder, + user_contexts=user_contexts, + ) + download_behavior = transform_download_params(allowed, destination_folder) + # downloadBehavior is a REQUIRED field in the BiDi spec (can be null but + # must be present). Do NOT use a generic None-filter on it. + params: dict = {"downloadBehavior": download_behavior} + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("browser.setDownloadBehavior", params) + return self._conn.execute(cmd)''', + ], + }, + "browsingContext": { + # Method enhancements + "create": { + "extract_field": "context", + }, + "get_tree": { + "extract_field": "contexts", + "deserialize": { + "contexts": "Info", + }, + }, + "capture_screenshot": { + "extract_field": "data", + "params_override": { + "context": "str", + "format": "ImageFormat", + "clip": "BoxClipRectangle", + "origin": "str", + }, + }, + "print": { + "extract_field": "data", + }, + "locate_nodes": { + "extract_field": "nodes", + "params_override": { + "context": "str", + "locator": "dict", + "serializationOptions": "dict", + "startNodes": "list", + "maxNodeCount": "int", + }, + }, + "set_viewport": { + "params_override": { + "context": "str", + "viewport": "dict", + "userContexts": "list", + "devicePixelRatio": "float", + }, + }, + # Non-CDDL download event dataclasses (Chromium-specific) + "extra_dataclasses": [ + '''@dataclass +class DownloadWillBeginParams: + """DownloadWillBeginParams.""" + + suggested_filename: str | None = None''', + '''@dataclass +class DownloadCanceledParams: + """DownloadCanceledParams.""" + + status: Any | None = None''', + '''@dataclass +class DownloadParams: + """DownloadParams - fields shared by all download end event variants.""" + + status: str | None = None + context: Any | None = None + navigation: Any | None = None + timestamp: Any | None = None + url: str | None = None + filepath: str | None = None''', + '''@dataclass +class DownloadEndParams: + """DownloadEndParams - params for browsingContext.downloadEnd event.""" + + download_params: "DownloadParams | None" = None + + @classmethod + def from_json(cls, params: dict) -> "DownloadEndParams": + """Deserialize from BiDi wire-level params dict.""" + dp = DownloadParams( + status=params.get("status"), + context=params.get("context"), + navigation=params.get("navigation"), + timestamp=params.get("timestamp"), + url=params.get("url"), + filepath=params.get("filepath"), + ) + return cls(download_params=dp)''', + ], + # Non-CDDL download events (Chromium-specific, not in the BiDi spec) + "extra_events": [ + { + "event_key": "download_will_begin", + "bidi_event": "browsingContext.downloadWillBegin", + "event_class": "DownloadWillBeginParams", + }, + { + "event_key": "download_end", + "bidi_event": "browsingContext.downloadEnd", + "event_class": "DownloadEndParams", + }, + ], + }, + "log": { + # Make LogLevel an alias for Level so existing code using LogLevel works + "aliases": {"LogLevel": "Level"}, + # Replace the minimal CDDL-generated versions with richer ones that have from_json + "exclude_types": ["JavascriptLogEntry"], + "extra_dataclasses": [ + '''@dataclass +class ConsoleLogEntry: + """ConsoleLogEntry - a console log entry from the browser.""" + + type_: str | None = None + method: str | None = None + args: list | None = None + level: Any | None = None + text: Any | None = None + source: Any | None = None + timestamp: Any | None = None + stack_trace: Any | None = None + + @classmethod + def from_json(cls, params: dict) -> "ConsoleLogEntry": + """Deserialize from BiDi params dict.""" + return cls( + type_=params.get("type"), + method=params.get("method"), + args=params.get("args"), + level=params.get("level"), + text=params.get("text"), + source=params.get("source"), + timestamp=params.get("timestamp"), + stack_trace=params.get("stackTrace"), + )''', + '''@dataclass +class JavascriptLogEntry: + """JavascriptLogEntry - a JavaScript error log entry from the browser.""" + + type_: str | None = None + level: Any | None = None + text: Any | None = None + source: Any | None = None + timestamp: Any | None = None + stacktrace: Any | None = None + + @classmethod + def from_json(cls, params: dict) -> "JavascriptLogEntry": + """Deserialize from BiDi params dict.""" + return cls( + type_=params.get("type"), + level=params.get("level"), + text=params.get("text"), + source=params.get("source"), + timestamp=params.get("timestamp"), + stacktrace=params.get("stackTrace"), + )''', + ], + }, + "emulation": { + "extra_methods": [ + ''' def set_geolocation_override( + self, + coordinates=None, + error=None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setGeolocationOverride. + + Sets or clears the geolocation override for specified browsing or user contexts. + + Args: + coordinates: A GeolocationCoordinates instance (or dict) to override the + position, or ``None`` to clear a previously-set override. + error: A GeolocationPositionError instance (or dict) to simulate a + position-unavailable error. Mutually exclusive with *coordinates*. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + params = {} + if coordinates is not None: + if isinstance(coordinates, dict): + coords_dict = coordinates + else: + coords_dict = {} + if coordinates.latitude is not None: + coords_dict["latitude"] = coordinates.latitude + if coordinates.longitude is not None: + coords_dict["longitude"] = coordinates.longitude + if coordinates.accuracy is not None: + coords_dict["accuracy"] = coordinates.accuracy + if coordinates.altitude is not None: + coords_dict["altitude"] = coordinates.altitude + if coordinates.altitude_accuracy is not None: + coords_dict["altitudeAccuracy"] = coordinates.altitude_accuracy + if coordinates.heading is not None: + coords_dict["heading"] = coordinates.heading + if coordinates.speed is not None: + coords_dict["speed"] = coordinates.speed + params["coordinates"] = coords_dict + if error is not None: + if isinstance(error, dict): + params["error"] = error + else: + params["error"] = { + "type": error.type if error.type is not None else "positionUnavailable" + } + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setGeolocationOverride", params) + result = self._conn.execute(cmd) + return result''', + ''' def set_timezone_override( + self, + timezone=None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setTimezoneOverride. + + Sets or clears the timezone override for specified browsing or user contexts. + Pass ``timezone=None`` (or omit it) to clear a previously-set override. + + Args: + timezone: IANA timezone string (e.g. ``"America/New_York"``) or ``None`` + to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + params = {"timezone": timezone} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setTimezoneOverride", params) + return self._conn.execute(cmd)''', + ''' def set_scripting_enabled( + self, + enabled=None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setScriptingEnabled. + + Enables or disables scripting for specified browsing or user contexts. + Pass ``enabled=None`` to restore the default behaviour. + + Args: + enabled: ``True`` to enable scripting, ``False`` to disable it, or + ``None`` to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + params = {"enabled": enabled} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setScriptingEnabled", params) + return self._conn.execute(cmd)''', + ''' def set_user_agent_override( + self, + user_agent=None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setUserAgentOverride. + + Overrides the User-Agent string for specified browsing or user contexts. + Pass ``user_agent=None`` to clear a previously-set override. + + Args: + user_agent: Custom User-Agent string, or ``None`` to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + params = {"userAgent": user_agent} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setUserAgentOverride", params) + return self._conn.execute(cmd)''', + ''' def set_screen_orientation_override( + self, + screen_orientation=None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setScreenOrientationOverride. + + Sets or clears the screen orientation override for specified browsing or + user contexts. + + Args: + screen_orientation: A :class:`ScreenOrientation` instance (or dict with + ``natural`` and ``type`` keys) to lock the orientation, or ``None`` + to clear a previously-set override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + if screen_orientation is None: + so_value = None + elif isinstance(screen_orientation, dict): + so_value = screen_orientation + else: + natural = screen_orientation.natural + orientation_type = screen_orientation.type + so_value = { + "natural": natural.lower() if isinstance(natural, str) else natural, + "type": orientation_type.lower() if isinstance(orientation_type, str) else orientation_type, + } + params = {"screenOrientation": so_value} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setScreenOrientationOverride", params) + return self._conn.execute(cmd)''', + ''' def set_network_conditions( + self, + network_conditions=None, + offline: bool | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setNetworkConditions. + + Sets or clears network condition emulation for specified browsing or user + contexts. + + Args: + network_conditions: A dict with the raw ``networkConditions`` value + (e.g. ``{"type": "offline"}``), or ``None`` to clear the override. + Mutually exclusive with *offline*. + offline: Convenience bool — ``True`` sets offline conditions, + ``False`` clears them (sends ``null``). When provided, this takes + precedence over *network_conditions*. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + if offline is not None: + nc_value = {"type": "offline"} if offline else None + else: + nc_value = network_conditions + params = {"networkConditions": nc_value} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setNetworkConditions", params) + return self._conn.execute(cmd)''', + ], + }, + "script": { + "extra_methods": [ + ''' def execute(self, function_declaration: str, *args, context_id: str | None = None) -> Any: + """Execute a function declaration in the browser context. + + Args: + function_declaration: The function as a string, e.g. ``"() => document.title"``. + *args: Optional Python values to pass as arguments to the function. + Each value is serialised to a BiDi ``LocalValue`` automatically. + Supported types: ``None``, ``bool``, ``int``, ``float`` + (including ``NaN`` and ``Infinity``), ``str``, ``list``, + ``dict``, and ``datetime.datetime``. + context_id: The browsing context ID to run in. Defaults to the + driver\'s current window handle when a driver was provided. + + Returns: + The inner RemoteValue result dict, or raises WebDriverException on exception. + """ + import math as _math + import datetime as _datetime + from selenium.common.exceptions import WebDriverException as _WebDriverException + + def _serialize_arg(value): + """Serialise a Python value to a BiDi LocalValue dict.""" + if value is None: + return {"type": "null"} + if isinstance(value, bool): + return {"type": "boolean", "value": value} + if isinstance(value, _datetime.datetime): + return {"type": "date", "value": value.isoformat()} + if isinstance(value, float): + if _math.isnan(value): + return {"type": "number", "value": "NaN"} + if _math.isinf(value): + return {"type": "number", "value": "Infinity" if value > 0 else "-Infinity"} + return {"type": "number", "value": value} + if isinstance(value, int): + _MAX_SAFE_INT = 9007199254740991 + if abs(value) > _MAX_SAFE_INT: + return {"type": "bigint", "value": str(value)} + return {"type": "number", "value": value} + if isinstance(value, str): + return {"type": "string", "value": value} + if isinstance(value, list): + return {"type": "array", "value": [_serialize_arg(v) for v in value]} + if isinstance(value, dict): + return {"type": "object", "value": [[str(k), _serialize_arg(v)] for k, v in value.items()]} + return value + + if context_id is None and self._driver is not None: + try: + context_id = self._driver.current_window_handle + except Exception: + pass + target = {"context": context_id} if context_id else {} + serialized_args = [_serialize_arg(a) for a in args] if args else None + raw = self.call_function( + function_declaration=function_declaration, + await_promise=True, + target=target, + arguments=serialized_args, + ) + if isinstance(raw, dict): + if raw.get("type") == "exception": + exc = raw.get("exceptionDetails", {}) + msg = exc.get("text", str(exc)) if isinstance(exc, dict) else str(exc) + raise _WebDriverException(msg) + if raw.get("type") == "success": + return raw.get("result") + return raw''', + ''' def _add_preload_script(self, function_declaration, arguments=None, contexts=None, user_contexts=None, sandbox=None): + """Add a preload script with validation. + + Args: + function_declaration: The JS function to run on page load. + arguments: Optional list of BiDi arguments. + contexts: Optional list of browsing context IDs. + user_contexts: Optional list of user context IDs. + sandbox: Optional sandbox name. + + Returns: + script_id: The ID of the added preload script (str). + + Raises: + ValueError: If both contexts and user_contexts are specified. + """ + if contexts is not None and user_contexts is not None: + raise ValueError("Cannot specify both contexts and user_contexts") + result = self.add_preload_script( + function_declaration=function_declaration, + arguments=arguments, + contexts=contexts, + user_contexts=user_contexts, + sandbox=sandbox, + ) + if isinstance(result, dict): + return result.get("script") + return result''', + ''' def _remove_preload_script(self, script_id): + """Remove a preload script by ID. + + Args: + script_id: The ID of the preload script to remove. + """ + return self.remove_preload_script(script=script_id)''', + ''' def pin(self, function_declaration): + """Pin (add) a preload script that runs on every page load. + + Args: + function_declaration: The JS function to execute on page load. + + Returns: + script_id: The ID of the pinned script (str). + """ + return self._add_preload_script(function_declaration)''', + ''' def unpin(self, script_id): + """Unpin (remove) a previously pinned preload script. + + Args: + script_id: The ID returned by pin(). + """ + return self._remove_preload_script(script_id=script_id)''', + ''' def _evaluate(self, expression, target, await_promise, result_ownership=None, serialization_options=None, user_activation=None): + """Evaluate a script expression and return a structured result. + + Args: + expression: The JavaScript expression to evaluate. + target: A dict like {"context": } or {"realm": }. + await_promise: Whether to await a returned promise. + result_ownership: Optional result ownership setting. + serialization_options: Optional serialization options dict. + user_activation: Optional user activation flag. + + Returns: + An object with .realm, .result (dict or None), and .exception_details (or None). + """ + class _EvalResult: + def __init__(self2, realm, result, exception_details): + self2.realm = realm + self2.result = result + self2.exception_details = exception_details + + raw = self.evaluate( + expression=expression, + target=target, + await_promise=await_promise, + result_ownership=result_ownership, + serialization_options=serialization_options, + user_activation=user_activation, + ) + if isinstance(raw, dict): + realm = raw.get("realm") + if raw.get("type") == "exception": + exc = raw.get("exceptionDetails") + return _EvalResult(realm=realm, result=None, exception_details=exc) + return _EvalResult(realm=realm, result=raw.get("result"), exception_details=None) + return _EvalResult(realm=None, result=raw, exception_details=None)''', + ''' def _call_function(self, function_declaration, await_promise, target, arguments=None, result_ownership=None, this=None, user_activation=None, serialization_options=None): + """Call a function and return a structured result. + + Args: + function_declaration: The JS function string. + await_promise: Whether to await the return value. + target: A dict like {"context": }. + arguments: Optional list of BiDi arguments. + result_ownership: Optional result ownership. + this: Optional \'this\' binding. + user_activation: Optional user activation flag. + serialization_options: Optional serialization options dict. + + Returns: + An object with .result (dict or None) and .exception_details (or None). + """ + class _CallResult: + def __init__(self2, result, exception_details): + self2.result = result + self2.exception_details = exception_details + + raw = self.call_function( + function_declaration=function_declaration, + await_promise=await_promise, + target=target, + arguments=arguments, + result_ownership=result_ownership, + this=this, + user_activation=user_activation, + serialization_options=serialization_options, + ) + if isinstance(raw, dict): + if raw.get("type") == "exception": + exc = raw.get("exceptionDetails") + return _CallResult(result=None, exception_details=exc) + if raw.get("type") == "success": + return _CallResult(result=raw.get("result"), exception_details=None) + return _CallResult(result=raw, exception_details=None)''', + ''' def _get_realms(self, context=None, type=None): + """Get all realms, optionally filtered by context and type. + + Args: + context: Optional browsing context ID to filter by. + type: Optional realm type string to filter by (e.g. RealmType.WINDOW). + + Returns: + List of realm info objects with .realm, .origin, .type, .context attributes. + """ + class _RealmInfo: + def __init__(self2, realm, origin, type_, context): + self2.realm = realm + self2.origin = origin + self2.type = type_ + self2.context = context + + raw = self.get_realms(context=context, type=type) + realms_list = raw.get("realms", []) if isinstance(raw, dict) else [] + result = [] + for r in realms_list: + if isinstance(r, dict): + result.append(_RealmInfo( + realm=r.get("realm"), + origin=r.get("origin"), + type_=r.get("type"), + context=r.get("context"), + )) + return result''', + ''' def _disown(self, handles, target): + """Disown handles in a browsing context. + + Args: + handles: List of handle strings to disown. + target: A dict like {"context": }. + """ + return self.disown(handles=handles, target=target)''', + ''' def _subscribe_log_entry(self, callback, entry_type_filter=None): + """Subscribe to log.entryAdded BiDi events with optional type filtering.""" + import threading as _threading + from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod + + bidi_event = "log.entryAdded" + + if not hasattr(self, "_log_subscriptions"): + self._log_subscriptions = {} + self._log_lock = _threading.Lock() + + def _deserialize(params): + t = params.get("type") if isinstance(params, dict) else None + if t == "console": + cls = getattr(_log_mod, "ConsoleLogEntry", None) + if cls is not None and hasattr(cls, "from_json"): + try: + return cls.from_json(params) + except Exception: + pass + elif t == "javascript": + cls = getattr(_log_mod, "JavascriptLogEntry", None) + if cls is not None and hasattr(cls, "from_json"): + try: + return cls.from_json(params) + except Exception: + pass + return params + + def _wrapped(raw): + entry = _deserialize(raw) + if entry_type_filter is None: + callback(entry) + else: + t = getattr(entry, "type_", None) or ( + entry.get("type") if isinstance(entry, dict) else None + ) + if t == entry_type_filter: + callback(entry) + + class _BidiRef: + event_class = bidi_event + + def from_json(self2, p): + return p + + _wrapper = _BidiRef() + callback_id = self._conn.add_callback(_wrapper, _wrapped) + with self._log_lock: + if bidi_event not in self._log_subscriptions: + session = _Session(self._conn) + result = session.subscribe([bidi_event]) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self._log_subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + self._log_subscriptions[bidi_event]["callbacks"].append(callback_id) + return callback_id''', + ''' def _unsubscribe_log_entry(self, callback_id): + """Unsubscribe a log entry callback by ID.""" + from selenium.webdriver.common.bidi.session import Session as _Session + + bidi_event = "log.entryAdded" + if not hasattr(self, "_log_subscriptions"): + return + + class _BidiRef: + event_class = bidi_event + + def from_json(self2, p): + return p + + _wrapper = _BidiRef() + self._conn.remove_callback(_wrapper, callback_id) + with self._log_lock: + entry = self._log_subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + if entry is not None and not entry["callbacks"]: + session = _Session(self._conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self._log_subscriptions[bidi_event]''', + ''' def add_console_message_handler(self, callback: Callable) -> int: + """Add a handler for console log messages (log.entryAdded type=console). + + Args: + callback: Function called with a ConsoleLogEntry on each console message. + + Returns: + callback_id for use with remove_console_message_handler. + """ + return self._subscribe_log_entry(callback, entry_type_filter="console")''', + ''' def remove_console_message_handler(self, callback_id: int) -> None: + """Remove a console message handler by callback ID.""" + self._unsubscribe_log_entry(callback_id)''', + ''' def add_javascript_error_handler(self, callback: Callable) -> int: + """Add a handler for JavaScript error log messages (log.entryAdded type=javascript). + + Args: + callback: Function called with a JavascriptLogEntry on each JS error. + + Returns: + callback_id for use with remove_javascript_error_handler. + """ + return self._subscribe_log_entry(callback, entry_type_filter="javascript")''', + ''' def remove_javascript_error_handler(self, callback_id: int) -> None: + """Remove a JavaScript error handler by callback ID.""" + self._unsubscribe_log_entry(callback_id)''', + ], + }, + "network": { + # Initialize intercepts tracking list in __init__ + "extra_init_code": ["self.intercepts = []"], + # Request class wraps a beforeRequestSent event params and provides actions + "extra_dataclasses": [ + '''class BytesValue: + """A string or base64-encoded bytes value used in cookie operations. + + This corresponds to network.BytesValue in the WebDriver BiDi specification, + wrapping either a plain string or a base64-encoded binary value. + """ + + TYPE_STRING = "string" + TYPE_BASE64 = "base64" + + def __init__(self, type: str, value: str) -> None: + self.type = type + self.value = value + + def to_bidi_dict(self) -> dict: + return {"type": self.type, "value": self.value}''', + '''class Request: + """Wraps a BiDi network request event params and provides request action methods.""" + + def __init__(self, conn, params): + self._conn = conn + self._params = params if isinstance(params, dict) else {} + req = self._params.get("request", {}) or {} + self.url = req.get("url", "") + self._request_id = req.get("request") + + def continue_request(self, **kwargs): + """Continue the intercepted request.""" + from selenium.webdriver.common.bidi.common import command_builder as _cb + + params = {"request": self._request_id} + params.update(kwargs) + self._conn.execute(_cb("network.continueRequest", params))''', + ], + # Add before_request event (maps to network.beforeRequestSent) + "extra_events": [ + { + "event_key": "before_request", + "bidi_event": "network.beforeRequestSent", + "event_class": "dict", + }, + ], + "extra_methods": [ + ''' def _add_intercept(self, phases=None, url_patterns=None): + """Add a low-level network intercept. + + Args: + phases: list of intercept phases (default: ["beforeRequestSent"]) + url_patterns: optional URL patterns to filter + + Returns: + dict with "intercept" key containing the intercept ID + """ + from selenium.webdriver.common.bidi.common import command_builder as _cb + + if phases is None: + phases = ["beforeRequestSent"] + params = {"phases": phases} + if url_patterns: + params["urlPatterns"] = url_patterns + result = self._conn.execute(_cb("network.addIntercept", params)) + if result: + intercept_id = result.get("intercept") + if intercept_id and intercept_id not in self.intercepts: + self.intercepts.append(intercept_id) + return result''', + ''' def _remove_intercept(self, intercept_id): + """Remove a low-level network intercept.""" + from selenium.webdriver.common.bidi.common import command_builder as _cb + + self._conn.execute(_cb("network.removeIntercept", {"intercept": intercept_id})) + if intercept_id in self.intercepts: + self.intercepts.remove(intercept_id)''', + ''' def add_request_handler(self, event, callback, url_patterns=None): + """Add a handler for network requests at the specified phase. + + Args: + event: Event name, e.g. ``"before_request"``. + callback: Callable receiving a :class:`Request` instance. + url_patterns: optional list of URL pattern dicts to filter. + + Returns: + callback_id int for later removal via remove_request_handler. + """ + phase_map = { + "before_request": "beforeRequestSent", + "before_request_sent": "beforeRequestSent", + "response_started": "responseStarted", + "auth_required": "authRequired", + } + phase = phase_map.get(event, "beforeRequestSent") + self._add_intercept(phases=[phase], url_patterns=url_patterns) + + def _request_callback(params): + raw = ( + params + if isinstance(params, dict) + else (params.__dict__ if hasattr(params, "__dict__") else {}) + ) + request = Request(self._conn, raw) + callback(request) + + return self.add_event_handler(event, _request_callback)''', + ''' def remove_request_handler(self, event, callback_id): + """Remove a network request handler. + + Args: + event: The event name used when adding the handler. + callback_id: The int returned by add_request_handler. + """ + self.remove_event_handler(event, callback_id)''', + ''' def clear_request_handlers(self): + """Clear all request handlers and remove all tracked intercepts.""" + self.clear_event_handlers() + for intercept_id in list(self.intercepts): + self._remove_intercept(intercept_id)''', + ''' def add_auth_handler(self, username, password): + """Add an auth handler that automatically provides credentials. + + Args: + username: The username for basic authentication. + password: The password for basic authentication. + + Returns: + callback_id int for later removal via remove_auth_handler. + """ + from selenium.webdriver.common.bidi.common import command_builder as _cb + + def _auth_callback(params): + raw = ( + params + if isinstance(params, dict) + else (params.__dict__ if hasattr(params, "__dict__") else {}) + ) + request_id = ( + raw.get("request", {}).get("request") + if isinstance(raw, dict) + else None + ) + if request_id: + self._conn.execute( + _cb( + "network.continueWithAuth", + { + "request": request_id, + "action": "provideCredentials", + "credentials": { + "type": "password", + "username": username, + "password": password, + }, + }, + ) + ) + + return self.add_event_handler("auth_required", _auth_callback)''', + ''' def remove_auth_handler(self, callback_id): + """Remove an auth handler by callback ID.""" + self.remove_event_handler("auth_required", callback_id)''', + ], + }, + "storage": { + # Exclude auto-generated dataclasses that need custom to_bidi_dict() + # for JSON-over-WebSocket serialization, or custom constructors. + "exclude_types": [ + "CookieFilter", + "PartialCookie", + "BrowsingContextPartitionDescriptor", + "StorageKeyPartitionDescriptor", + ], + "extra_dataclasses": [ + # Re-export network types used in cookie operations so they can be + # imported from selenium.webdriver.common.bidi.storage alongside + # the storage-specific classes. + '''class BytesValue: + """A string or base64-encoded bytes value used in cookie operations. + + This corresponds to network.BytesValue in the WebDriver BiDi specification, + wrapping either a plain string or a base64-encoded binary value. + """ + + TYPE_STRING = "string" + TYPE_BASE64 = "base64" + + def __init__(self, type: str, value: str) -> None: + self.type = type + self.value = value + + def to_bidi_dict(self) -> dict: + return {"type": self.type, "value": self.value}''', + '''class SameSite: + """SameSite cookie attribute values.""" + + STRICT = "strict" + LAX = "lax" + NONE = "none" + DEFAULT = "default"''', + # Helper: cookie object returned inside a GetCookiesResult response + '''@dataclass +class StorageCookie: + """A cookie object returned by storage.getCookies.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + path: str | None = None + size: Any | None = None + http_only: bool | None = None + secure: bool | None = None + same_site: Any | None = None + expiry: Any | None = None + + @classmethod + def from_bidi_dict(cls, raw: dict) -> "StorageCookie": + """Deserialize a wire-level cookie dict to a StorageCookie.""" + value_raw = raw.get("value") + if isinstance(value_raw, dict): + value = BytesValue(value_raw.get("type"), value_raw.get("value")) + else: + value = value_raw + return cls( + name=raw.get("name"), + value=value, + domain=raw.get("domain"), + path=raw.get("path"), + size=raw.get("size"), + http_only=raw.get("httpOnly"), + secure=raw.get("secure"), + same_site=raw.get("sameSite"), + expiry=raw.get("expiry"), + )''', + # Custom CookieFilter with camelCase serialization + '''@dataclass +class CookieFilter: + """CookieFilter.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + path: str | None = None + size: Any | None = None + http_only: bool | None = None + secure: bool | None = None + same_site: Any | None = None + expiry: Any | None = None + + def to_bidi_dict(self) -> dict: + """Serialize to the BiDi wire-protocol dict.""" + result: dict = {} + if self.name is not None: + result["name"] = self.name + if self.value is not None: + result["value"] = self.value.to_bidi_dict() if hasattr(self.value, "to_bidi_dict") else self.value + if self.domain is not None: + result["domain"] = self.domain + if self.path is not None: + result["path"] = self.path + if self.size is not None: + result["size"] = self.size + if self.http_only is not None: + result["httpOnly"] = self.http_only + if self.secure is not None: + result["secure"] = self.secure + if self.same_site is not None: + result["sameSite"] = self.same_site + if self.expiry is not None: + result["expiry"] = self.expiry + return result''', + # Custom PartialCookie with camelCase serialization + '''@dataclass +class PartialCookie: + """PartialCookie.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + path: str | None = None + http_only: bool | None = None + secure: bool | None = None + same_site: Any | None = None + expiry: Any | None = None + + def to_bidi_dict(self) -> dict: + """Serialize to the BiDi wire-protocol dict.""" + result: dict = {} + if self.name is not None: + result["name"] = self.name + if self.value is not None: + result["value"] = self.value.to_bidi_dict() if hasattr(self.value, "to_bidi_dict") else self.value + if self.domain is not None: + result["domain"] = self.domain + if self.path is not None: + result["path"] = self.path + if self.http_only is not None: + result["httpOnly"] = self.http_only + if self.secure is not None: + result["secure"] = self.secure + if self.same_site is not None: + result["sameSite"] = self.same_site + if self.expiry is not None: + result["expiry"] = self.expiry + return result''', + # BrowsingContextPartitionDescriptor: first positional arg is *context* + # (the auto-generated dataclass had `type` first, breaking positional + # usage like BrowsingContextPartitionDescriptor(driver.current_window_handle)) + '''class BrowsingContextPartitionDescriptor: + """BrowsingContextPartitionDescriptor. + + The first positional argument is *context* (a browsing-context ID / window + handle), mirroring how the class is used throughout the test suite: + ``BrowsingContextPartitionDescriptor(driver.current_window_handle)``. + """ + + def __init__(self, context: Any = None, type: str = "context") -> None: + self.context = context + self.type = type + + def to_bidi_dict(self) -> dict: + return {"type": "context", "context": self.context}''', + # StorageKeyPartitionDescriptor with camelCase serialization + '''@dataclass +class StorageKeyPartitionDescriptor: + """StorageKeyPartitionDescriptor.""" + + type: Any | None = "storageKey" + user_context: str | None = None + source_origin: str | None = None + + def to_bidi_dict(self) -> dict: + """Serialize to the BiDi wire-protocol dict.""" + result: dict = {"type": "storageKey"} + if self.user_context is not None: + result["userContext"] = self.user_context + if self.source_origin is not None: + result["sourceOrigin"] = self.source_origin + return result''', + ], + # Override the generated Storage class methods (Python's last-definition- + # wins semantics means these extra_methods shadow the generated ones). + "extra_methods": [ + ''' def get_cookies(self, filter=None, partition=None): + """Execute storage.getCookies and return a GetCookiesResult.""" + if filter and hasattr(filter, "to_bidi_dict"): + filter = filter.to_bidi_dict() + if partition and hasattr(partition, "to_bidi_dict"): + partition = partition.to_bidi_dict() + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.getCookies", params) + result = self._conn.execute(cmd) + if result and "cookies" in result: + cookies = [ + StorageCookie.from_bidi_dict(c) + for c in result.get("cookies", []) + if isinstance(c, dict) + ] + pk_raw = result.get("partitionKey") + pk = ( + PartitionKey( + user_context=pk_raw.get("userContext"), + source_origin=pk_raw.get("sourceOrigin"), + ) + if isinstance(pk_raw, dict) + else None + ) + return GetCookiesResult(cookies=cookies, partition_key=pk) + return GetCookiesResult(cookies=[], partition_key=None)''', + ''' def set_cookie(self, cookie=None, partition=None): + """Execute storage.setCookie.""" + if cookie and hasattr(cookie, "to_bidi_dict"): + cookie = cookie.to_bidi_dict() + if partition and hasattr(partition, "to_bidi_dict"): + partition = partition.to_bidi_dict() + params = { + "cookie": cookie, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.setCookie", params) + result = self._conn.execute(cmd) + return result''', + ''' def delete_cookies(self, filter=None, partition=None): + """Execute storage.deleteCookies.""" + if filter and hasattr(filter, "to_bidi_dict"): + filter = filter.to_bidi_dict() + if partition and hasattr(partition, "to_bidi_dict"): + partition = partition.to_bidi_dict() + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.deleteCookies", params) + result = self._conn.execute(cmd) + return result''', + ], + }, + "session": { + # Override UserPromptHandler to add to_bidi_dict() for JSON serialization + "exclude_types": ["UserPromptHandler"], + "extra_dataclasses": [ + '''@dataclass +class UserPromptHandler: + """UserPromptHandler.""" + + alert: Any | None = None + before_unload: Any | None = None + confirm: Any | None = None + default: Any | None = None + file: Any | None = None + prompt: Any | None = None + + def to_bidi_dict(self) -> dict: + """Convert to BiDi protocol dict with camelCase keys.""" + result = {} + if self.alert is not None: + result["alert"] = self.alert + if self.before_unload is not None: + result["beforeUnload"] = self.before_unload + if self.confirm is not None: + result["confirm"] = self.confirm + if self.default is not None: + result["default"] = self.default + if self.file is not None: + result["file"] = self.file + if self.prompt is not None: + result["prompt"] = self.prompt + return result''', + ], + }, + "webExtension": { + # Suppress the raw generated stubs; hand-written versions follow below + "exclude_methods": ["install", "uninstall"], + "extra_methods": [ + ''' def install(self, path: str | None = None, archive_path: str | None = None, base64_value: str | None = None): + """Install a web extension. + + Exactly one of the three keyword arguments must be provided. + + Args: + path: Directory path to an unpacked extension (also accepted for + signed ``.xpi`` / ``.crx`` archive files on Firefox). + archive_path: File-system path to a packed extension archive. + base64_value: Base64-encoded extension archive string. + + Returns: + The raw result dict from the BiDi ``webExtension.install`` command + (contains at least an ``"extension"`` key with the extension ID). + + Raises: + ValueError: If more than one, or none, of the arguments is provided. + """ + provided = [k for k, v in {"path": path, "archive_path": archive_path, "base64_value": base64_value}.items() if v is not None] + if len(provided) != 1: + raise ValueError( + f"Exactly one of path, archive_path, or base64_value must be provided; got: {provided}" + ) + if path is not None: + extension_data = {"type": "path", "path": path} + elif archive_path is not None: + extension_data = {"type": "archivePath", "path": archive_path} + else: + extension_data = {"type": "base64", "value": base64_value} + params = {"extensionData": extension_data} + cmd = command_builder("webExtension.install", params) + return self._conn.execute(cmd)''', + ''' def uninstall(self, extension: Any | None = None): + """Uninstall a web extension. + + Args: + extension: Either the extension ID string returned by ``install``, + or the full result dict returned by ``install`` (the + ``"extension"`` value is extracted automatically). + """ + if isinstance(extension, dict): + extension = extension.get("extension") + params = {"extension": extension} + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("webExtension.uninstall", params) + return self._conn.execute(cmd)''', + ], + }, + "input": { + # FileDialogInfo needs from_json for event deserialization + "exclude_types": ["FileDialogInfo", "PointerMoveAction", "PointerDownAction"], + "extra_dataclasses": [ + '''@dataclass +class FileDialogInfo: + """FileDialogInfo - parameters for the input.fileDialogOpened event.""" + + context: Any | None = None + element: Any | None = None + multiple: bool | None = None + + @classmethod + def from_json(cls, params: dict) -> "FileDialogInfo": + """Deserialize event params into FileDialogInfo.""" + return cls( + context=params.get("context"), + element=params.get("element"), + multiple=params.get("multiple"), + )''', + '''@dataclass +class PointerMoveAction: + """PointerMoveAction.""" + + type: str = field(default="pointerMove", init=False) + x: Any | None = None + y: Any | None = None + duration: Any | None = None + origin: Any | None = None + properties: Any | None = None''', + '''@dataclass +class PointerDownAction: + """PointerDownAction.""" + + type: str = field(default="pointerDown", init=False) + button: Any | None = None + properties: Any | None = None''', + ], + "extra_methods": [ + ''' def add_file_dialog_handler(self, callback) -> int: + """Subscribe to the input.fileDialogOpened event. + + Args: + callback: Callable invoked with a FileDialogInfo when a file dialog opens. + + Returns: + A handler ID that can be passed to remove_file_dialog_handler. + """ + return self._event_manager.add_event_handler("file_dialog_opened", callback) + + def remove_file_dialog_handler(self, handler_id: int) -> None: + """Unsubscribe a previously registered file dialog event handler. + + Args: + handler_id: The handler ID returned by add_file_dialog_handler. + """ + return self._event_manager.remove_event_handler("file_dialog_opened", handler_id)''', + ], + }, +} + + +# ============================================================================ +# Pre-processing Functions +# ============================================================================ + + +def check_serialize_method(obj: Any) -> Any: + """Check if object has to_bidi_dict() method and use it for serialization.""" + if obj and hasattr(obj, "to_bidi_dict"): + return obj.to_bidi_dict() + return obj + + +# ============================================================================ +# Validation Functions +# ============================================================================ + + +def validate_download_behavior( + allowed: bool | None, + destination_folder: str | None, + user_contexts: Any | None = None, +) -> None: + """Validate download behavior parameters. + + Args: + allowed: Whether downloads are allowed + destination_folder: Destination folder for downloads + user_contexts: Optional list of user contexts (ignored for validation) + + Raises: + ValueError: If parameters are invalid + """ + if allowed is True and not destination_folder: + raise ValueError("destination_folder is required when allowed=True") + if allowed is False and destination_folder: + raise ValueError("destination_folder should not be provided when allowed=False") + + +# ============================================================================ +# Transformation Functions +# ============================================================================ + + +def transform_download_params( + allowed: bool | None, + destination_folder: str | None, +) -> dict[str, Any]: + """Transform download parameters into download_behavior object. + + Args: + allowed: Whether downloads are allowed + destination_folder: Destination folder for downloads + + Returns: + Dictionary representing the download_behavior object, or None if allowed is None + """ + if allowed is True: + return { + "type": "allowed", + # Convert pathlib.Path (or any path-like) to str so the BiDi + # protocol always receives a plain JSON string. + "destinationFolder": ( + str(destination_folder) if destination_folder is not None else None + ), + } + elif allowed is False: + return {"type": "denied"} + else: # None — reset to browser default (sent as JSON null) + return None + + +# ============================================================================ +# Dataclass Method Templates +# ============================================================================ + +DATACLASS_METHOD_TEMPLATES: dict[str, dict[str, str]] = { + "ClientWindowInfo": { + "get_client_window": "return self.client_window", + "get_state": "return self.state", + "get_width": "return self.width", + "get_height": "return self.height", + "is_active": "return self.active", + "get_x": "return self.x", + "get_y": "return self.y", + }, + "BrowsingContext": { + "add_event_handler": "_add_event_handler_impl", + "remove_event_handler": "_remove_event_handler_impl", + }, +} + +DATACLASS_METHOD_DOCSTRINGS: dict[str, dict[str, str]] = { + "ClientWindowInfo": { + "get_client_window": "Get the client window ID.", + "get_state": "Get the client window state.", + "get_width": "Get the client window width.", + "get_height": "Get the client window height.", + "is_active": "Check if the client window is active.", + "get_x": "Get the client window X position.", + "get_y": "Get the client window Y position.", + }, + "BrowsingContext": { + "add_event_handler": "Add an event handler for browsing context events.", + "remove_event_handler": "Remove an event handler by callback ID.", + }, +} + +# ============================================================================ +# Event Handler Support for BrowsingContext +# ============================================================================ + + +def _add_event_handler( + self, + event_name: str, + callback: callable, + contexts: list[str] | None = None, +) -> str: + """Add an event handler for a browsing context event. + + Supported events: + - 'context_created' + - 'context_destroyed' + - 'navigation_started' + - 'navigation_committed' + - 'navigation_failed' + - 'dom_content_loaded' + - 'load' + - 'fragment_navigated' + - 'user_prompt_opened' + - 'user_prompt_closed' + - 'download_will_begin' + - 'download_end' + - 'history_updated' + + Args: + event_name: The name of the event to subscribe to + callback: Callback function to invoke when event occurs + contexts: Optional list of context IDs to limit event subscription + + Returns: + A callback ID that can be used to unsubscribe the handler + """ + if not hasattr(self, "_event_handlers"): + self._event_handlers = {} + self._event_callback_id_counter = 0 + + # Generate unique callback ID + self._event_callback_id_counter += 1 + callback_id = f"callback_{self._event_callback_id_counter}" + + # Store the handler + self._event_handlers[callback_id] = { + "event": event_name, + "callback": callback, + "contexts": contexts, + } + + # Subscribe via the driver's event listening mechanism + if hasattr(self._driver, "_subscribe_event"): + self._driver._subscribe_event(event_name, callback, contexts) + + return callback_id + + +def _remove_event_handler( + self, + callback_id: str, +) -> None: + """Remove an event handler by its callback ID. + + Args: + callback_id: The callback ID returned from add_event_handler + """ + if not hasattr(self, "_event_handlers"): + return + + if callback_id in self._event_handlers: + handler_info = self._event_handlers[callback_id] + + # Unsubscribe from the driver + if hasattr(self._driver, "_unsubscribe_event"): + self._driver._unsubscribe_event( + handler_info["event"], + handler_info["callback"], + handler_info["contexts"], + ) + + del self._event_handlers[callback_id] diff --git a/py/private/cdp.py b/py/private/cdp.py new file mode 100644 index 0000000000000..b097762fe50cd --- /dev/null +++ b/py/private/cdp.py @@ -0,0 +1,515 @@ +# The MIT License(MIT) +# +# Copyright(c) 2018 Hyperion Gray +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files(the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# This code comes from https://github.com/HyperionGray/trio-chrome-devtools-protocol/tree/master/trio_cdp + +import contextvars +import importlib +import itertools +import json +import logging +import pathlib +from collections import defaultdict +from collections.abc import AsyncGenerator, AsyncIterator, Generator +from contextlib import asynccontextmanager, contextmanager +from dataclasses import dataclass +from typing import Any, TypeVar + +import trio +from trio_websocket import ConnectionClosed as WsConnectionClosed +from trio_websocket import connect_websocket_url + +logger = logging.getLogger("trio_cdp") +T = TypeVar("T") +MAX_WS_MESSAGE_SIZE = 2**24 + +devtools = None +version = None + + +def import_devtools(ver): + """Attempt to load the current latest available devtools into the module cache for use later.""" + global devtools + global version + version = ver + base = "selenium.webdriver.common.devtools.v" + try: + devtools = importlib.import_module(f"{base}{ver}") + return devtools + except ModuleNotFoundError: + # Attempt to parse and load the 'most recent' devtools module. This is likely + # because cdp has been updated but selenium python has not been released yet. + devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") + versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir()) + latest = max(int(x[1:]) for x in versions) + selenium_logger = logging.getLogger(__name__) + selenium_logger.debug("Falling back to loading `devtools`: v%s", latest) + devtools = importlib.import_module(f"{base}{latest}") + return devtools + + +_connection_context: contextvars.ContextVar = contextvars.ContextVar("connection_context") +_session_context: contextvars.ContextVar = contextvars.ContextVar("session_context") + + +def get_connection_context(fn_name): + """Look up the current connection. + + If there is no current connection, raise a ``RuntimeError`` with a + helpful message. + """ + try: + return _connection_context.get() + except LookupError: + raise RuntimeError(f"{fn_name}() must be called in a connection context.") + + +def get_session_context(fn_name): + """Look up the current session. + + If there is no current session, raise a ``RuntimeError`` with a + helpful message. + """ + try: + return _session_context.get() + except LookupError: + raise RuntimeError(f"{fn_name}() must be called in a session context.") + + +@contextmanager +def connection_context(connection): + """Context manager installs ``connection`` as the session context for the current Trio task.""" + token = _connection_context.set(connection) + try: + yield + finally: + _connection_context.reset(token) + + +@contextmanager +def session_context(session): + """Context manager installs ``session`` as the session context for the current Trio task.""" + token = _session_context.set(session) + try: + yield + finally: + _session_context.reset(token) + + +def set_global_connection(connection): + """Install ``connection`` in the root context so that it will become the default connection for all tasks. + + This is generally not recommended, except it may be necessary in + certain use cases such as running inside Jupyter notebook. + """ + global _connection_context + _connection_context = contextvars.ContextVar("_connection_context", default=connection) + + +def set_global_session(session): + """Install ``session`` in the root context so that it will become the default session for all tasks. + + This is generally not recommended, except it may be necessary in + certain use cases such as running inside Jupyter notebook. + """ + global _session_context + _session_context = contextvars.ContextVar("_session_context", default=session) + + +class BrowserError(Exception): + """This exception is raised when the browser's response to a command indicates that an error occurred.""" + + def __init__(self, obj): + self.code = obj.get("code") + self.message = obj.get("message") + self.detail = obj.get("data") + + def __str__(self): + return f"BrowserError {self.detail}" + + +class CdpConnectionClosed(WsConnectionClosed): + """Raised when a public method is called on a closed CDP connection.""" + + def __init__(self, reason): + """Constructor. + + Args: + reason: wsproto.frame_protocol.CloseReason + """ + self.reason = reason + + def __repr__(self): + """Return representation.""" + return f"{self.__class__.__name__}<{self.reason}>" + + +class InternalError(Exception): + """This exception is only raised when there is faulty logic in TrioCDP or the integration with PyCDP.""" + + pass + + +@dataclass +class CmEventProxy: + """A proxy object returned by :meth:`CdpBase.wait_for()``. + + After the context manager executes, this proxy object will have a + value set that contains the returned event. + """ + + value: Any = None + + +class CdpBase: + def __init__(self, ws, session_id, target_id): + self.ws = ws + self.session_id = session_id + self.target_id = target_id + self.channels = defaultdict(set) + self.id_iter = itertools.count() + self.inflight_cmd = {} + self.inflight_result = {} + + async def execute(self, cmd: Generator[dict, T, Any]) -> T: + """Execute a command on the server and wait for the result. + + Args: + cmd: any CDP command + + Returns: + a CDP result + """ + cmd_id = next(self.id_iter) + cmd_event = trio.Event() + self.inflight_cmd[cmd_id] = cmd, cmd_event + request = next(cmd) + request["id"] = cmd_id + if self.session_id: + request["sessionId"] = self.session_id + request_str = json.dumps(request) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"Sending CDP message: {cmd_id} {cmd_event}: {request_str}") + try: + await self.ws.send_message(request_str) + except WsConnectionClosed as wcc: + raise CdpConnectionClosed(wcc.reason) from None + await cmd_event.wait() + response = self.inflight_result.pop(cmd_id) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"Received CDP message: {response}") + if isinstance(response, Exception): + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"Exception raised by {cmd_event} message: {type(response).__name__}") + raise response + return response + + def listen(self, *event_types, buffer_size=10): + """Listen for events. + + Returns: + An async iterator that iterates over events matching the indicated types. + """ + sender, receiver = trio.open_memory_channel(buffer_size) + for event_type in event_types: + self.channels[event_type].add(sender) + return receiver + + @asynccontextmanager + async def wait_for(self, event_type: type[T], buffer_size=10) -> AsyncGenerator[CmEventProxy, None]: + """Wait for an event of the given type and return it. + + This is an async context manager, so you should open it inside + an async with block. The block will not exit until the indicated + event is received. + """ + sender: trio.MemorySendChannel + receiver: trio.MemoryReceiveChannel + sender, receiver = trio.open_memory_channel(buffer_size) + self.channels[event_type].add(sender) + proxy = CmEventProxy() + yield proxy + async with receiver: + event = await receiver.receive() + proxy.value = event + + def _handle_data(self, data): + """Handle incoming WebSocket data. + + Args: + data: a JSON dictionary + """ + if "id" in data: + self._handle_cmd_response(data) + else: + self._handle_event(data) + + def _handle_cmd_response(self, data: dict): + """Handle a response to a command. + + This will set an event flag that will return control to the + task that called the command. + + Args: + data: response as a JSON dictionary + """ + cmd_id = data["id"] + try: + cmd, event = self.inflight_cmd.pop(cmd_id) + except KeyError: + logger.warning("Got a message with a command ID that does not exist: %s", data) + return + if "error" in data: + # If the server reported an error, convert it to an exception and do + # not process the response any further. + self.inflight_result[cmd_id] = BrowserError(data["error"]) + else: + # Otherwise, continue the generator to parse the JSON result + # into a CDP object. + try: + _ = cmd.send(data["result"]) + raise InternalError("The command's generator function did not exit when expected!") + except StopIteration as exit: + return_ = exit.value + self.inflight_result[cmd_id] = return_ + event.set() + + def _handle_event(self, data: dict): + """Handle an event. + + Args: + data: event as a JSON dictionary + """ + global devtools + if devtools is None: + raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + event = devtools.util.parse_json_event(data) + logger.debug("Received event: %s", event) + to_remove = set() + for sender in self.channels[type(event)]: + try: + sender.send_nowait(event) + except trio.WouldBlock: + logger.error('Unable to send event "%r" due to full channel %s', event, sender) + except trio.BrokenResourceError: + to_remove.add(sender) + if to_remove: + self.channels[type(event)] -= to_remove + + +class CdpSession(CdpBase): + """Contains the state for a CDP session. + + Generally you should not instantiate this object yourself; you should call + :meth:`CdpConnection.open_session`. + """ + + def __init__(self, ws, session_id, target_id): + """Constructor. + + Args: + ws: trio_websocket.WebSocketConnection + session_id: devtools.target.SessionID + target_id: devtools.target.TargetID + """ + super().__init__(ws, session_id, target_id) + + self._dom_enable_count = 0 + self._dom_enable_lock = trio.Lock() + self._page_enable_count = 0 + self._page_enable_lock = trio.Lock() + + @asynccontextmanager + async def dom_enable(self): + """Context manager that executes ``dom.enable()`` when it enters and then calls ``dom.disable()``. + + This keeps track of concurrent callers and only disables DOM + events when all callers have exited. + """ + global devtools + async with self._dom_enable_lock: + self._dom_enable_count += 1 + if self._dom_enable_count == 1: + await self.execute(devtools.dom.enable()) + + yield + + async with self._dom_enable_lock: + self._dom_enable_count -= 1 + if self._dom_enable_count == 0: + await self.execute(devtools.dom.disable()) + + @asynccontextmanager + async def page_enable(self): + """Context manager executes ``page.enable()`` when it enters and then calls ``page.disable()`` when it exits. + + This keeps track of concurrent callers and only disables page + events when all callers have exited. + """ + global devtools + async with self._page_enable_lock: + self._page_enable_count += 1 + if self._page_enable_count == 1: + await self.execute(devtools.page.enable()) + + yield + + async with self._page_enable_lock: + self._page_enable_count -= 1 + if self._page_enable_count == 0: + await self.execute(devtools.page.disable()) + + +class CdpConnection(CdpBase, trio.abc.AsyncResource): + """Contains the connection state for a Chrome DevTools Protocol server. + + CDP can multiplex multiple "sessions" over a single connection. This + class corresponds to the "root" session, i.e. the implicitly created + session that has no session ID. This class is responsible for + reading incoming WebSocket messages and forwarding them to the + corresponding session, as well as handling messages targeted at the + root session itself. You should generally call the + :func:`open_cdp()` instead of instantiating this class directly. + """ + + def __init__(self, ws): + """Constructor. + + Args: + ws: trio_websocket.WebSocketConnection + """ + super().__init__(ws, session_id=None, target_id=None) + self.sessions = {} + + async def aclose(self): + """Close the underlying WebSocket connection. + + This will cause the reader task to gracefully exit when it tries + to read the next message from the WebSocket. All of the public + APIs (``execute()``, ``listen()``, etc.) will raise + ``CdpConnectionClosed`` after the CDP connection is closed. It + is safe to call this multiple times. + """ + await self.ws.aclose() + + @asynccontextmanager + async def open_session(self, target_id) -> AsyncIterator[CdpSession]: + """Context manager opens a session and enables the "simple" style of calling CDP APIs. + + For example, inside a session context, you can call ``await + dom.get_document()`` and it will execute on the current session + automatically. + """ + session = await self.connect_session(target_id) + with session_context(session): + yield session + + async def connect_session(self, target_id) -> "CdpSession": + """Returns a new :class:`CdpSession` connected to the specified target.""" + global devtools + if devtools is None: + raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + session_id = await self.execute(devtools.target.attach_to_target(target_id, True)) + session = CdpSession(self.ws, session_id, target_id) + self.sessions[session_id] = session + return session + + async def _reader_task(self): + """Runs in the background and handles incoming messages. + + Dispatches responses to commands and events to listeners. + """ + global devtools + if devtools is None: + raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + while True: + try: + message = await self.ws.get_message() + except WsConnectionClosed: + # If the WebSocket is closed, we don't want to throw an + # exception from the reader task. Instead we will throw + # exceptions from the public API methods, and we can quietly + # exit the reader task here. + break + try: + data = json.loads(message) + except json.JSONDecodeError: + raise BrowserError({"code": -32700, "message": "Client received invalid JSON", "data": message}) + logger.debug("Received message %r", data) + if "sessionId" in data: + session_id = devtools.target.SessionID(data["sessionId"]) + try: + session = self.sessions[session_id] + except KeyError: + raise BrowserError( + { + "code": -32700, + "message": "Browser sent a message for an invalid session", + "data": f"{session_id!r}", + } + ) + session._handle_data(data) + else: + self._handle_data(data) + + for _, session in self.sessions.items(): + for _, senders in session.channels.items(): + for sender in senders: + sender.close() + + +@asynccontextmanager +async def open_cdp(url) -> AsyncIterator[CdpConnection]: + """Async context manager opens a connection to the browser then closes the connection when the block exits. + + The context manager also sets the connection as the default + connection for the current task, so that commands like ``await + target.get_targets()`` will run on this connection automatically. If + you want to use multiple connections concurrently, it is recommended + to open each on in a separate task. + """ + async with trio.open_nursery() as nursery: + conn = await connect_cdp(nursery, url) + try: + with connection_context(conn): + yield conn + finally: + await conn.aclose() + + +async def connect_cdp(nursery, url) -> CdpConnection: + """Connect to the browser specified by ``url`` and spawn a background task in the specified nursery. + + The ``open_cdp()`` context manager is preferred in most situations. + You should only use this function if you need to specify a custom + nursery. This connection is not automatically closed! You can either + use the connection object as a context manager (``async with + conn:``) or else call ``await conn.aclose()`` on it when you are + done with it. If ``set_context`` is True, then the returned + connection will be installed as the default connection for the + current task. This argument is for unusual use cases, such as + running inside of a notebook. + """ + ws = await connect_websocket_url(nursery, url, max_message_size=MAX_WS_MESSAGE_SIZE) + cdp_conn = CdpConnection(ws) + nursery.start_soon(cdp_conn._reader_task) + return cdp_conn diff --git a/py/private/generate_bidi.bzl b/py/private/generate_bidi.bzl new file mode 100644 index 0000000000000..c11b6efe4735f --- /dev/null +++ b/py/private/generate_bidi.bzl @@ -0,0 +1,112 @@ +"""Bazel rule for generating WebDriver BiDi Python modules from CDDL specification.""" + +def _generate_bidi_impl(ctx): + """Implementation of the generate_bidi rule.""" + + cddl_file = ctx.file.cddl_file + manifest_file = ctx.file.enhancements_manifest + generator = ctx.executable.generator + output_dir = ctx.attr.module_name + spec_version = ctx.attr.spec_version + + # The generator creates BiDi modules from the CDDL spec + # Using snake_case naming convention for Python files + module_names = [ + "browser", + "browsing_context", + "common", + "console", + "emulation", + "input", + "log", + "network", + "permissions", + "script", + "session", + "storage", + "webextension", + ] + + # Declare all output files + module_files = [ + ctx.actions.declare_file(output_dir + "/" + name + ".py") + for name in module_names + ] + init_file = ctx.actions.declare_file(output_dir + "/__init__.py") + py_typed = ctx.actions.declare_file(output_dir + "/py.typed") + + gen_outputs = module_files + [init_file, py_typed] + + # Copy static extra_srcs into the output directory + extra_outputs = [] + for src in ctx.files.extra_srcs: + out = ctx.actions.declare_file(output_dir + "/" + src.basename) + ctx.actions.symlink(output = out, target_file = src) + extra_outputs.append(out) + + outputs = gen_outputs + extra_outputs + + # Output directory for the generator + output_base = init_file.dirname + + # Build the command to run the generator + args = [ + cddl_file.path, + output_base, + "--version", + spec_version, + ] + + # Add enhancement manifest if provided + inputs = [cddl_file] + if manifest_file: + args.extend(["--enhancements-manifest", manifest_file.path]) + inputs.append(manifest_file) + + ctx.actions.run( + inputs = inputs, + outputs = gen_outputs, + executable = generator, + arguments = args, + use_default_shell_env = True, + ) + + return [DefaultInfo(files = depset(outputs))] + + +generate_bidi = rule( + implementation = _generate_bidi_impl, + attrs = { + "cddl_file": attr.label( + allow_single_file = [".cddl"], + mandatory = True, + doc = "CDDL specification file", + ), + "enhancements_manifest": attr.label( + allow_single_file = [".py"], + mandatory = False, + doc = "Enhancement manifest Python file (optional)", + ), + "extra_srcs": attr.label_list( + allow_files = [".py"], + mandatory = False, + default = [], + doc = "Additional static Python files to copy verbatim into the output directory", + ), + "generator": attr.label( + executable = True, + cfg = "exec", + mandatory = True, + doc = "Generator script (e.g., generate_bidi.py)", + ), + "module_name": attr.string( + mandatory = True, + doc = "Name of the module being generated (e.g., 'selenium/webdriver/common/bidi')", + ), + "spec_version": attr.string( + default = "1.0", + doc = "WebDriver BiDi specification version", + ), + }, + doc = "Generates Python WebDriver BiDi modules from CDDL specification", +) diff --git a/py/requirements_lock.txt b/py/requirements_lock.txt index 68f8d858bb6f4..c58f4b1c76fe6 100644 --- a/py/requirements_lock.txt +++ b/py/requirements_lock.txt @@ -461,7 +461,6 @@ jeepney==0.9.0 \ --hash=sha256:cf0e9e845622b81e4a28df94c40345400256ec608d0e55bb8a3feaa9163f5732 # via # -r py/requirements.txt - # keyring # secretstorage jinja2==3.1.6 \ --hash=sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d \ @@ -1038,9 +1037,7 @@ rich==14.3.3 \ secretstorage==3.5.0 \ --hash=sha256:0ce65888c0725fcb2c5bc0fdb8e5438eece02c523557ea40ce0703c266248137 \ --hash=sha256:f04b8e4689cbce351744d5537bf6b1329c6fc68f91fa666f60a380edddcd11be - # via - # -r py/requirements.txt - # keyring + # via -r py/requirements.txt sniffio==1.3.1 \ --hash=sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2 \ --hash=sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc diff --git a/py/selenium/common/exceptions.py b/py/selenium/common/exceptions.py index c45f530002a8f..7ec809eb20b18 100644 --- a/py/selenium/common/exceptions.py +++ b/py/selenium/common/exceptions.py @@ -27,7 +27,10 @@ class WebDriverException(Exception): """Base webdriver exception.""" def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: super().__init__() self.msg = msg @@ -73,7 +76,10 @@ class NoSuchElementException(WebDriverException): """ def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#nosuchelementexception" @@ -111,9 +117,14 @@ class StaleElementReferenceException(WebDriverException): """ def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: - with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#staleelementreferenceexception" + with_support = ( + f"{msg}; {SUPPORT_MSG} {ERROR_URL}#staleelementreferenceexception" + ) super().__init__(with_support, screen, stacktrace) @@ -161,7 +172,10 @@ class ElementNotVisibleException(InvalidElementStateException): """ def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementnotvisibleexception" @@ -172,9 +186,14 @@ class ElementNotInteractableException(InvalidElementStateException): """Thrown when element interactions will hit another element due to paint order.""" def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: - with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementnotinteractableexception" + with_support = ( + f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementnotinteractableexception" + ) super().__init__(with_support, screen, stacktrace) @@ -213,7 +232,10 @@ class InvalidSelectorException(WebDriverException): """ def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#invalidselectorexception" @@ -252,9 +274,14 @@ class ElementClickInterceptedException(WebDriverException): """Thrown when element click fails because another element obscures it.""" def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: - with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementclickinterceptedexception" + with_support = ( + f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementclickinterceptedexception" + ) super().__init__(with_support, screen, stacktrace) @@ -271,7 +298,10 @@ class InvalidSessionIdException(WebDriverException): """Thrown when the given session id is not in the list of active sessions.""" def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#invalidsessionidexception" @@ -282,7 +312,10 @@ class SessionNotCreatedException(WebDriverException): """A new session could not be created.""" def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#sessionnotcreatedexception" @@ -297,7 +330,10 @@ class NoSuchDriverException(WebDriverException): """Raised when driver is not specified and cannot be located.""" def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}/driver_location" diff --git a/py/selenium/webdriver/common/bidi/__init__.py b/py/selenium/webdriver/common/bidi/__init__.py index a5b1e6f85a09e..ab96f2d81e292 100644 --- a/py/selenium/webdriver/common/bidi/__init__.py +++ b/py/selenium/webdriver/common/bidi/__init__.py @@ -1,16 +1,7 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. + +from __future__ import annotations + diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 5b449ae69276a..ed6a4d8f33bc5 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -1,280 +1,330 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import os -from typing import Any - -from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi.session import UserPromptHandler -from selenium.webdriver.common.proxy import Proxy - - -class ClientWindowState: - """Represents a window state.""" +# WebDriver BiDi module: browser +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass + + +def transform_download_params( + allowed: bool | None, + destination_folder: str | None, +) -> dict[str, Any] | None: + """Transform download parameters into download_behavior object. + + Args: + allowed: Whether downloads are allowed + destination_folder: Destination folder for downloads (accepts str or + pathlib.Path; will be coerced to str) + + Returns: + Dictionary representing the download_behavior object, or None if allowed is None + """ + if allowed is True: + return { + "type": "allowed", + # Coerce pathlib.Path (or any path-like) to str so the BiDi + # protocol always receives a plain JSON string. + "destinationFolder": str(destination_folder) if destination_folder is not None else None, + } + elif allowed is False: + return {"type": "denied"} + else: # None — reset to browser default (sent as JSON null) + return None + + +def validate_download_behavior( + allowed: bool | None, + destination_folder: str | None, + user_contexts: Any | None = None, +) -> None: + """Validate download behavior parameters. + + Args: + allowed: Whether downloads are allowed + destination_folder: Destination folder for downloads + user_contexts: Optional list of user contexts + + Raises: + ValueError: If parameters are invalid + """ + if allowed is True and not destination_folder: + raise ValueError("destination_folder is required when allowed=True") + if allowed is False and destination_folder: + raise ValueError("destination_folder should not be provided when allowed=False") + + +class ClientWindowNamedState: + """ClientWindowNamedState.""" FULLSCREEN = "fullscreen" MAXIMIZED = "maximized" MINIMIZED = "minimized" - NORMAL = "normal" - - VALID_STATES = {FULLSCREEN, MAXIMIZED, MINIMIZED, NORMAL} +@dataclass class ClientWindowInfo: - """Represents a client window information.""" - - def __init__( - self, - client_window: str, - state: str, - width: int, - height: int, - x: int, - y: int, - active: bool, - ): - self.client_window = client_window - self.state = state - self.width = width - self.height = height - self.x = x - self.y = y - self.active = active - - def get_state(self) -> str: - """Gets the state of the client window. - - Returns: - str: The state of the client window (one of the ClientWindowState constants). - """ - return self.state - - def get_client_window(self) -> str: - """Gets the client window identifier. - - Returns: - str: The client window identifier. - """ + """ClientWindowInfo.""" + + active: bool | None = None + client_window: Any | None = None + height: Any | None = None + state: Any | None = None + width: Any | None = None + x: Any | None = None + y: Any | None = None + + def get_client_window(self): + """Get the client window ID.""" return self.client_window - def get_width(self) -> int: - """Gets the width of the client window. + def get_state(self): + """Get the client window state.""" + return self.state - Returns: - int: The width of the client window. - """ + def get_width(self): + """Get the client window width.""" return self.width - def get_height(self) -> int: - """Gets the height of the client window. - - Returns: - int: The height of the client window. - """ + def get_height(self): + """Get the client window height.""" return self.height - def get_x(self) -> int: - """Gets the x coordinate of the client window. + def is_active(self): + """Check if the client window is active.""" + return self.active - Returns: - int: The x coordinate of the client window. - """ + def get_x(self): + """Get the client window X position.""" return self.x - def get_y(self) -> int: - """Gets the y coordinate of the client window. - - Returns: - int: The y coordinate of the client window. - """ + def get_y(self): + """Get the client window Y position.""" return self.y - def is_active(self) -> bool: - """Checks if the client window is active. - Returns: - bool: True if the client window is active, False otherwise. - """ - return self.active - @classmethod - def from_dict(cls, data: dict) -> "ClientWindowInfo": - """Creates a ClientWindowInfo instance from a dictionary. +@dataclass +class UserContextInfo: + """UserContextInfo.""" - Args: - data: A dictionary containing the client window information. + user_context: Any | None = None - Returns: - ClientWindowInfo: A new instance of ClientWindowInfo. - Raises: - ValueError: If required fields are missing or have invalid types. - """ - try: - client_window = data["clientWindow"] - if not isinstance(client_window, str): - raise ValueError("clientWindow must be a string") - - state = data["state"] - if not isinstance(state, str): - raise ValueError("state must be a string") - if state not in ClientWindowState.VALID_STATES: - raise ValueError(f"Invalid state: {state}. Must be one of {ClientWindowState.VALID_STATES}") - - width = data["width"] - if not isinstance(width, int) or width < 0: - raise ValueError(f"width must be a non-negative integer, got {width}") - - height = data["height"] - if not isinstance(height, int) or height < 0: - raise ValueError(f"height must be a non-negative integer, got {height}") - - x = data["x"] - if not isinstance(x, int): - raise ValueError(f"x must be an integer, got {type(x).__name__}") - - y = data["y"] - if not isinstance(y, int): - raise ValueError(f"y must be an integer, got {type(y).__name__}") - - active = data["active"] - if not isinstance(active, bool): - raise ValueError("active must be a boolean") - - return cls( - client_window=client_window, - state=state, - width=width, - height=height, - x=x, - y=y, - active=active, - ) - except (KeyError, TypeError) as e: - raise ValueError(f"Invalid data format for ClientWindowInfo: {e}") from e +@dataclass +class CreateUserContextParameters: + """CreateUserContextParameters.""" + accept_insecure_certs: bool | None = None + proxy: Any | None = None + unhandled_prompt_behavior: Any | None = None -class Browser: - """BiDi implementation of the browser module.""" - def __init__(self, conn): - self.conn = conn +@dataclass +class GetClientWindowsResult: + """GetClientWindowsResult.""" - def create_user_context( - self, - accept_insecure_certs: bool | None = None, - proxy: Proxy | None = None, - unhandled_prompt_behavior: UserPromptHandler | None = None, - ) -> str: - """Creates a new user context. + client_windows: list[Any | None] | None = None - Args: - accept_insecure_certs: Optional flag to accept insecure TLS certificates. - proxy: Optional proxy configuration for the user context. - unhandled_prompt_behavior: Optional configuration for handling user prompts. - Returns: - str: The ID of the created user context. - """ - params: dict[str, Any] = {} +@dataclass +class GetUserContextsResult: + """GetUserContextsResult.""" - if accept_insecure_certs is not None: - params["acceptInsecureCerts"] = accept_insecure_certs + user_contexts: list[Any | None] | None = None - if proxy is not None: - params["proxy"] = proxy.to_bidi_dict() - if unhandled_prompt_behavior is not None: - params["unhandledPromptBehavior"] = unhandled_prompt_behavior.to_dict() +@dataclass +class RemoveUserContextParameters: + """RemoveUserContextParameters.""" - result = self.conn.execute(command_builder("browser.createUserContext", params)) - return result["userContext"] + user_context: Any | None = None - def get_user_contexts(self) -> list[str]: - """Gets all user contexts. - Returns: - List[str]: A list of user context IDs. - """ - result = self.conn.execute(command_builder("browser.getUserContexts", {})) - return [context_info["userContext"] for context_info in result["userContexts"]] +@dataclass +class SetClientWindowStateParameters: + """SetClientWindowStateParameters.""" - def remove_user_context(self, user_context_id: str) -> None: - """Removes a user context. + client_window: Any | None = None - Args: - user_context_id: The ID of the user context to remove. - Raises: - ValueError: If the user context ID is "default" or does not exist. - """ - if user_context_id == "default": - raise ValueError("Cannot remove the default user context") +@dataclass +class ClientWindowRectState: + """ClientWindowRectState.""" - params = {"userContext": user_context_id} - self.conn.execute(command_builder("browser.removeUserContext", params)) + state: str = field(default="normal", init=False) + width: Any | None = None + height: Any | None = None + x: Any | None = None + y: Any | None = None - def get_client_windows(self) -> list[ClientWindowInfo]: - """Gets all client windows. - Returns: - List[ClientWindowInfo]: A list of client window information. - """ - result = self.conn.execute(command_builder("browser.getClientWindows", {})) - return [ClientWindowInfo.from_dict(window) for window in result["clientWindows"]] - - def set_download_behavior( - self, - *, - allowed: bool | None = None, - destination_folder: str | os.PathLike | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set the download behavior for the browser or specific user contexts. +@dataclass +class SetDownloadBehaviorParameters: + """SetDownloadBehaviorParameters.""" + + download_behavior: Any | None = None + user_contexts: list[Any | None] | None = None + + +@dataclass +class DownloadBehaviorAllowed: + """DownloadBehaviorAllowed.""" + + type: str = field(default="allowed", init=False) + destination_folder: str | None = None + + +@dataclass +class DownloadBehaviorDenied: + """DownloadBehaviorDenied.""" + + type: str = field(default="denied", init=False) + + +class Browser: + """WebDriver BiDi browser module.""" + + def __init__(self, conn) -> None: + self._conn = conn + + def close(self): + """Execute browser.close.""" + params = { + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.close", params) + result = self._conn.execute(cmd) + return result + + def create_user_context(self, accept_insecure_certs: bool | None = None, proxy: Any | None = None, unhandled_prompt_behavior: Any | None = None): + """Execute browser.createUserContext.""" + if proxy and hasattr(proxy, 'to_bidi_dict'): + proxy = proxy.to_bidi_dict() + + if unhandled_prompt_behavior and hasattr(unhandled_prompt_behavior, 'to_bidi_dict'): + unhandled_prompt_behavior = unhandled_prompt_behavior.to_bidi_dict() + + params = { + "acceptInsecureCerts": accept_insecure_certs, + "proxy": proxy, + "unhandledPromptBehavior": unhandled_prompt_behavior, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.createUserContext", params) + result = self._conn.execute(cmd) + if result and "userContext" in result: + extracted = result.get("userContext") + return extracted + return result + + def get_client_windows(self): + """Execute browser.getClientWindows.""" + params = { + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.getClientWindows", params) + result = self._conn.execute(cmd) + if result and "clientWindows" in result: + items = result.get("clientWindows", []) + return [ + ClientWindowInfo( + active=item.get("active"), + client_window=item.get("clientWindow"), + height=item.get("height"), + state=item.get("state"), + width=item.get("width"), + x=item.get("x"), + y=item.get("y") + ) + for item in items + if isinstance(item, dict) + ] + return [] + + def get_user_contexts(self): + """Execute browser.getUserContexts.""" + params = { + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.getUserContexts", params) + result = self._conn.execute(cmd) + if result and "userContexts" in result: + items = result.get("userContexts", []) + return [ + item.get("userContext") + for item in items + if isinstance(item, dict) + ] + return [] + + def remove_user_context(self, user_context: Any | None = None): + """Execute browser.removeUserContext.""" + params = { + "userContext": user_context, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.removeUserContext", params) + result = self._conn.execute(cmd) + return result + + def set_client_window_state(self, client_window: Any | None = None): + """Execute browser.setClientWindowState.""" + params = { + "clientWindow": client_window, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.setClientWindowState", params) + result = self._conn.execute(cmd) + return result + + def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): + """Execute browser.setDownloadBehavior.""" + validate_download_behavior(allowed=allowed, destination_folder=destination_folder, user_contexts=user_contexts) + + download_behavior = None + download_behavior = transform_download_params(allowed, destination_folder) + + params = { + "downloadBehavior": download_behavior, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.setDownloadBehavior", params) + result = self._conn.execute(cmd) + return result + + def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): + """Set the download behavior for the browser. Args: - allowed: True to allow downloads, False to deny downloads, or None to - clear download behavior (revert to default). - destination_folder: Required when allowed is True. Specifies the folder - to store downloads in. - user_contexts: Optional list of user context IDs to apply this - behavior to. If omitted, updates the default behavior. + allowed: ``True`` to allow downloads, ``False`` to deny, or ``None`` + to reset to browser default (sends ``null`` to the protocol). + destination_folder: Destination folder for downloads. Required when + ``allowed=True``. Accepts a string or :class:`pathlib.Path`. + user_contexts: Optional list of user context IDs. Raises: - ValueError: If allowed=True and destination_folder is missing, or if - allowed=False and destination_folder is provided. + ValueError: If *allowed* is ``True`` and *destination_folder* is + omitted, or ``False`` and *destination_folder* is provided. """ - params: dict[str, Any] = {} - - if allowed is None: - params["downloadBehavior"] = None - else: - if allowed: - if not destination_folder: - raise ValueError("destination_folder is required when allowed=True.") - params["downloadBehavior"] = { - "type": "allowed", - "destinationFolder": os.fspath(destination_folder), - } - else: - if destination_folder: - raise ValueError("destination_folder should not be provided when allowed=False.") - params["downloadBehavior"] = {"type": "denied"} - + validate_download_behavior( + allowed=allowed, + destination_folder=destination_folder, + user_contexts=user_contexts, + ) + download_behavior = transform_download_params(allowed, destination_folder) + # downloadBehavior is a REQUIRED field in the BiDi spec (can be null but + # must be present). Do NOT use a generic None-filter on it. + params: dict = {"downloadBehavior": download_behavior} if user_contexts is not None: params["userContexts"] = user_contexts - - self.conn.execute(command_builder("browser.setDownloadBehavior", params)) + cmd = command_builder("browser.setDownloadBehavior", params) + return self._conn.execute(cmd) diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index e8ae150342bda..35aea615d1780 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -1,35 +1,24 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# WebDriver BiDi module: browsingContext +from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass import threading from collections.abc import Callable from dataclasses import dataclass -from typing import Any - -from typing_extensions import Sentinel - -from selenium.webdriver.common.bidi.common import command_builder from selenium.webdriver.common.bidi.session import Session -UNDEFINED = Sentinel("UNDEFINED") - class ReadinessState: - """Represents the stage of document loading at which a navigation command will return.""" + """ReadinessState.""" NONE = "none" INTERACTIVE = "interactive" @@ -37,576 +26,509 @@ class ReadinessState: class UserPromptType: - """Represents the possible user prompt types.""" + """UserPromptType.""" ALERT = "alert" - BEFORE_UNLOAD = "beforeunload" + BEFOREUNLOAD = "beforeunload" CONFIRM = "confirm" PROMPT = "prompt" -class NavigationInfo: - """Provides details of an ongoing navigation.""" +class CreateType: + """CreateType.""" - def __init__( - self, - context: str, - navigation: str | None, - timestamp: int, - url: str, - ): - self.context = context - self.navigation = navigation - self.timestamp = timestamp - self.url = url + TAB = "tab" + WINDOW = "window" - @classmethod - def from_json(cls, json: dict) -> "NavigationInfo": - """Creates a NavigationInfo instance from a dictionary. - Args: - json: A dictionary containing the navigation information. +class DownloadCompleteParams: + """DownloadCompleteParams.""" - Returns: - A new instance of NavigationInfo. - """ - context = json.get("context") - if context is None or not isinstance(context, str): - raise ValueError("context is required and must be a string") - - navigation = json.get("navigation") - if navigation is not None and not isinstance(navigation, str): - raise ValueError("navigation must be a string") - - timestamp = json.get("timestamp") - if timestamp is None or not isinstance(timestamp, int) or timestamp < 0: - raise ValueError("timestamp is required and must be a non-negative integer") - - url = json.get("url") - if url is None or not isinstance(url, str): - raise ValueError("url is required and must be a string") - - return cls(context, navigation, timestamp, url) - - -class BrowsingContextInfo: - """Represents the properties of a navigable.""" - - def __init__( - self, - context: str, - url: str, - children: list["BrowsingContextInfo"] | None, - client_window: str, - user_context: str, - parent: str | None = None, - original_opener: str | None = None, - ): - self.context = context - self.url = url - self.children = children - self.parent = parent - self.user_context = user_context - self.original_opener = original_opener - self.client_window = client_window + COMPLETE = "complete" - @classmethod - def from_json(cls, json: dict) -> "BrowsingContextInfo": - """Creates a BrowsingContextInfo instance from a dictionary. - Args: - json: A dictionary containing the browsing context information. +@dataclass +class Info: + """Info.""" - Returns: - A new instance of BrowsingContextInfo. - """ - children = None - raw_children = json.get("children") - if raw_children is not None: - if not isinstance(raw_children, list): - raise ValueError("children must be a list if provided") - - children = [] - for child in raw_children: - if not isinstance(child, dict): - raise ValueError(f"Each child must be a dictionary, got {type(child)}") - children.append(BrowsingContextInfo.from_json(child)) - - context = json.get("context") - if context is None or not isinstance(context, str): - raise ValueError("context is required and must be a string") - - url = json.get("url") - if url is None or not isinstance(url, str): - raise ValueError("url is required and must be a string") - - parent = json.get("parent") - if parent is not None and not isinstance(parent, str): - raise ValueError("parent must be a string if provided") - - user_context = json.get("userContext") - if user_context is None or not isinstance(user_context, str): - raise ValueError("userContext is required and must be a string") - - original_opener = json.get("originalOpener") - if original_opener is not None and not isinstance(original_opener, str): - raise ValueError("originalOpener must be a string if provided") - - client_window = json.get("clientWindow") - if client_window is None or not isinstance(client_window, str): - raise ValueError("clientWindow is required and must be a string") - - return cls( - context=context, - url=url, - children=children, - client_window=client_window, - user_context=user_context, - parent=parent, - original_opener=original_opener, - ) + children: Any | None = None + client_window: Any | None = None + context: Any | None = None + original_opener: Any | None = None + url: str | None = None + user_context: Any | None = None + parent: Any | None = None -class DownloadWillBeginParams(NavigationInfo): - """Parameters for the downloadWillBegin event.""" +@dataclass +class AccessibilityLocator: + """AccessibilityLocator.""" - def __init__( - self, - context: str, - navigation: str | None, - timestamp: int, - url: str, - suggested_filename: str, - ): - super().__init__(context, navigation, timestamp, url) - self.suggested_filename = suggested_filename + type: str = field(default="accessibility", init=False) + name: str | None = None + role: str | None = None - @classmethod - def from_json(cls, json: dict) -> "DownloadWillBeginParams": - nav_info = NavigationInfo.from_json(json) - - suggested_filename = json.get("suggestedFilename") - if suggested_filename is None or not isinstance(suggested_filename, str): - raise ValueError("suggestedFilename is required and must be a string") - - return cls( - context=nav_info.context, - navigation=nav_info.navigation, - timestamp=nav_info.timestamp, - url=nav_info.url, - suggested_filename=suggested_filename, - ) +@dataclass +class CssLocator: + """CssLocator.""" -class UserPromptOpenedParams: - """Parameters for the userPromptOpened event.""" + type: str = field(default="css", init=False) + value: str | None = None - def __init__( - self, - context: str, - handler: str, - message: str, - type: str, - default_value: str | None = None, - ): - self.context = context - self.handler = handler - self.message = message - self.type = type - self.default_value = default_value - @classmethod - def from_json(cls, json: dict) -> "UserPromptOpenedParams": - """Creates a UserPromptOpenedParams instance from a dictionary. +@dataclass +class ContextLocator: + """ContextLocator.""" - Args: - json: A dictionary containing the user prompt parameters. + type: str = field(default="context", init=False) + context: Any | None = None - Returns: - A new instance of UserPromptOpenedParams. - """ - context = json.get("context") - if context is None or not isinstance(context, str): - raise ValueError("context is required and must be a string") - - handler = json.get("handler") - if handler is None or not isinstance(handler, str): - raise ValueError("handler is required and must be a string") - - message = json.get("message") - if message is None or not isinstance(message, str): - raise ValueError("message is required and must be a string") - - type_value = json.get("type") - if type_value is None or not isinstance(type_value, str): - raise ValueError("type is required and must be a string") - - default_value = json.get("defaultValue") - if default_value is not None and not isinstance(default_value, str): - raise ValueError("defaultValue must be a string if provided") - - return cls( - context=context, - handler=handler, - message=message, - type=type_value, - default_value=default_value, - ) +@dataclass +class InnerTextLocator: + """InnerTextLocator.""" -class UserPromptClosedParams: - """Parameters for the userPromptClosed event.""" + type: str = field(default="innerText", init=False) + value: str | None = None + ignore_case: bool | None = None + match_type: Any | None = None + max_depth: Any | None = None - def __init__( - self, - context: str, - accepted: bool, - type: str, - user_text: str | None = None, - ): - self.context = context - self.accepted = accepted - self.type = type - self.user_text = user_text - @classmethod - def from_json(cls, json: dict) -> "UserPromptClosedParams": - """Creates a UserPromptClosedParams instance from a dictionary. +@dataclass +class XPathLocator: + """XPathLocator.""" - Args: - json: A dictionary containing the user prompt closed parameters. + type: str = field(default="xpath", init=False) + value: str | None = None - Returns: - A new instance of UserPromptClosedParams. - """ - context = json.get("context") - if context is None or not isinstance(context, str): - raise ValueError("context is required and must be a string") - - accepted = json.get("accepted") - if accepted is None or not isinstance(accepted, bool): - raise ValueError("accepted is required and must be a boolean") - - type_value = json.get("type") - if type_value is None or not isinstance(type_value, str): - raise ValueError("type is required and must be a string") - - user_text = json.get("userText") - if user_text is not None and not isinstance(user_text, str): - raise ValueError("userText must be a string if provided") - - return cls( - context=context, - accepted=accepted, - type=type_value, - user_text=user_text, - ) +@dataclass +class BaseNavigationInfo: + """BaseNavigationInfo.""" + + context: Any | None = None + navigation: Any | None = None + timestamp: Any | None = None + url: str | None = None -class HistoryUpdatedParams: - """Parameters for the historyUpdated event.""" - def __init__( - self, - context: str, - timestamp: int, - url: str, - ): - self.context = context - self.timestamp = timestamp - self.url = url +@dataclass +class ActivateParameters: + """ActivateParameters.""" - @classmethod - def from_json(cls, json: dict) -> "HistoryUpdatedParams": - """Creates a HistoryUpdatedParams instance from a dictionary. + context: Any | None = None - Args: - json: A dictionary containing the history updated parameters. - Returns: - A new instance of HistoryUpdatedParams. - """ - context = json.get("context") - if context is None or not isinstance(context, str): - raise ValueError("context is required and must be a string") - - timestamp = json.get("timestamp") - if timestamp is None or not isinstance(timestamp, int) or timestamp < 0: - raise ValueError("timestamp is required and must be a non-negative integer") - - url = json.get("url") - if url is None or not isinstance(url, str): - raise ValueError("url is required and must be a string") - - return cls( - context=context, - timestamp=timestamp, - url=url, - ) +@dataclass +class CaptureScreenshotParameters: + """CaptureScreenshotParameters.""" + context: Any | None = None + format: Any | None = None + clip: Any | None = None -class DownloadCanceledParams(NavigationInfo): - def __init__( - self, - context: str, - navigation: str | None, - timestamp: int, - url: str, - status: str = "canceled", - ): - super().__init__(context, navigation, timestamp, url) - self.status = status - @classmethod - def from_json(cls, json: dict) -> "DownloadCanceledParams": - nav_info = NavigationInfo.from_json(json) - - status = json.get("status") - if status is None or status != "canceled": - raise ValueError("status is required and must be 'canceled'") - - return cls( - context=nav_info.context, - navigation=nav_info.navigation, - timestamp=nav_info.timestamp, - url=nav_info.url, - status=status, - ) +@dataclass +class ImageFormat: + """ImageFormat.""" + type: str | None = None + quality: Any | None = None -class DownloadCompleteParams(NavigationInfo): - def __init__( - self, - context: str, - navigation: str | None, - timestamp: int, - url: str, - status: str = "complete", - filepath: str | None = None, - ): - super().__init__(context, navigation, timestamp, url) - self.status = status - self.filepath = filepath - @classmethod - def from_json(cls, json: dict) -> "DownloadCompleteParams": - nav_info = NavigationInfo.from_json(json) - - status = json.get("status") - if status is None or status != "complete": - raise ValueError("status is required and must be 'complete'") - - filepath = json.get("filepath") - if filepath is not None and not isinstance(filepath, str): - raise ValueError("filepath must be a string if provided") - - return cls( - context=nav_info.context, - navigation=nav_info.navigation, - timestamp=nav_info.timestamp, - url=nav_info.url, - status=status, - filepath=filepath, - ) +@dataclass +class ElementClipRectangle: + """ElementClipRectangle.""" + type: str = field(default="element", init=False) + element: Any | None = None -class DownloadEndParams: - """Parameters for the downloadEnd event.""" - def __init__( - self, - download_params: DownloadCanceledParams | DownloadCompleteParams, - ): - self.download_params = download_params +@dataclass +class BoxClipRectangle: + """BoxClipRectangle.""" - @classmethod - def from_json(cls, json: dict) -> "DownloadEndParams": - status = json.get("status") - if status == "canceled": - return cls(DownloadCanceledParams.from_json(json)) - elif status == "complete": - return cls(DownloadCompleteParams.from_json(json)) - else: - raise ValueError("status must be either 'canceled' or 'complete'") + type: str = field(default="box", init=False) + x: Any | None = None + y: Any | None = None + width: Any | None = None + height: Any | None = None -class ContextCreated: - """Event class for browsingContext.contextCreated event.""" +@dataclass +class CaptureScreenshotResult: + """CaptureScreenshotResult.""" - event_class = "browsingContext.contextCreated" + data: str | None = None - @classmethod - def from_json(cls, json: dict): - if isinstance(json, BrowsingContextInfo): - return json - return BrowsingContextInfo.from_json(json) +@dataclass +class CloseParameters: + """CloseParameters.""" -class ContextDestroyed: - """Event class for browsingContext.contextDestroyed event.""" + context: Any | None = None + prompt_unload: bool | None = None - event_class = "browsingContext.contextDestroyed" - @classmethod - def from_json(cls, json: dict): - if isinstance(json, BrowsingContextInfo): - return json - return BrowsingContextInfo.from_json(json) +@dataclass +class CreateParameters: + """CreateParameters.""" + type: Any | None = None + reference_context: Any | None = None + background: bool | None = None + user_context: Any | None = None -class NavigationStarted: - """Event class for browsingContext.navigationStarted event.""" - event_class = "browsingContext.navigationStarted" +@dataclass +class CreateResult: + """CreateResult.""" - @classmethod - def from_json(cls, json: dict): - if isinstance(json, NavigationInfo): - return json - return NavigationInfo.from_json(json) + context: Any | None = None -class NavigationCommitted: - """Event class for browsingContext.navigationCommitted event.""" +@dataclass +class GetTreeParameters: + """GetTreeParameters.""" - event_class = "browsingContext.navigationCommitted" + max_depth: Any | None = None + root: Any | None = None - @classmethod - def from_json(cls, json: dict): - if isinstance(json, NavigationInfo): - return json - return NavigationInfo.from_json(json) +@dataclass +class GetTreeResult: + """GetTreeResult.""" -class NavigationFailed: - """Event class for browsingContext.navigationFailed event.""" + contexts: Any | None = None - event_class = "browsingContext.navigationFailed" - @classmethod - def from_json(cls, json: dict): - if isinstance(json, NavigationInfo): - return json - return NavigationInfo.from_json(json) +@dataclass +class HandleUserPromptParameters: + """HandleUserPromptParameters.""" + context: Any | None = None + accept: bool | None = None + user_text: str | None = None -class NavigationAborted: - """Event class for browsingContext.navigationAborted event.""" - event_class = "browsingContext.navigationAborted" +@dataclass +class LocateNodesParameters: + """LocateNodesParameters.""" - @classmethod - def from_json(cls, json: dict): - if isinstance(json, NavigationInfo): - return json - return NavigationInfo.from_json(json) + context: Any | None = None + locator: Any | None = None + serialization_options: Any | None = None + start_nodes: list[Any | None] | None = None -class DomContentLoaded: - """Event class for browsingContext.domContentLoaded event.""" +@dataclass +class LocateNodesResult: + """LocateNodesResult.""" - event_class = "browsingContext.domContentLoaded" + nodes: list[Any | None] | None = None - @classmethod - def from_json(cls, json: dict): - if isinstance(json, NavigationInfo): - return json - return NavigationInfo.from_json(json) +@dataclass +class NavigateParameters: + """NavigateParameters.""" -class Load: - """Event class for browsingContext.load event.""" + context: Any | None = None + url: str | None = None + wait: Any | None = None - event_class = "browsingContext.load" - @classmethod - def from_json(cls, json: dict): - if isinstance(json, NavigationInfo): - return json - return NavigationInfo.from_json(json) +@dataclass +class NavigateResult: + """NavigateResult.""" + navigation: Any | None = None + url: str | None = None -class FragmentNavigated: - """Event class for browsingContext.fragmentNavigated event.""" - event_class = "browsingContext.fragmentNavigated" +@dataclass +class PrintParameters: + """PrintParameters.""" - @classmethod - def from_json(cls, json: dict): - if isinstance(json, NavigationInfo): - return json - return NavigationInfo.from_json(json) + context: Any | None = None + background: bool | None = None + margin: Any | None = None + page: Any | None = None + scale: Any | None = None + shrink_to_fit: bool | None = None -class DownloadWillBegin: - """Event class for browsingContext.downloadWillBegin event.""" +@dataclass +class PrintMarginParameters: + """PrintMarginParameters.""" - event_class = "browsingContext.downloadWillBegin" + bottom: Any | None = None + left: Any | None = None + right: Any | None = None + top: Any | None = None - @classmethod - def from_json(cls, json: dict): - return DownloadWillBeginParams.from_json(json) +@dataclass +class PrintPageParameters: + """PrintPageParameters.""" -class UserPromptOpened: - """Event class for browsingContext.userPromptOpened event.""" + height: Any | None = None + width: Any | None = None - event_class = "browsingContext.userPromptOpened" - @classmethod - def from_json(cls, json: dict): - return UserPromptOpenedParams.from_json(json) +@dataclass +class PrintResult: + """PrintResult.""" + + data: str | None = None -class UserPromptClosed: - """Event class for browsingContext.userPromptClosed event.""" +@dataclass +class ReloadParameters: + """ReloadParameters.""" - event_class = "browsingContext.userPromptClosed" + context: Any | None = None + ignore_cache: bool | None = None + wait: Any | None = None - @classmethod - def from_json(cls, json: dict): - return UserPromptClosedParams.from_json(json) +@dataclass +class SetViewportParameters: + """SetViewportParameters.""" -class HistoryUpdated: - """Event class for browsingContext.historyUpdated event.""" + context: Any | None = None + viewport: Any | None = None + device_pixel_ratio: Any | None = None + user_contexts: list[Any | None] | None = None - event_class = "browsingContext.historyUpdated" - @classmethod - def from_json(cls, json: dict): - return HistoryUpdatedParams.from_json(json) +@dataclass +class Viewport: + """Viewport.""" + + width: Any | None = None + height: Any | None = None + + +@dataclass +class TraverseHistoryParameters: + """TraverseHistoryParameters.""" + + context: Any | None = None + delta: Any | None = None + + +@dataclass +class HistoryUpdatedParameters: + """HistoryUpdatedParameters.""" + + context: Any | None = None + timestamp: Any | None = None + url: str | None = None + + +@dataclass +class DownloadWillBeginParams: + """DownloadWillBeginParams.""" + + suggested_filename: str | None = None + + +@dataclass +class DownloadCanceledParams: + """DownloadCanceledParams.""" + + status: str = field(default="canceled", init=False) -class DownloadEnd: - """Event class for browsingContext.downloadEnd event.""" +@dataclass +class UserPromptClosedParameters: + """UserPromptClosedParameters.""" + + context: Any | None = None + accepted: bool | None = None + type: Any | None = None + user_text: str | None = None + + +@dataclass +class UserPromptOpenedParameters: + """UserPromptOpenedParameters.""" + + context: Any | None = None + handler: Any | None = None + message: str | None = None + type: Any | None = None + default_value: str | None = None + + +@dataclass +class DownloadWillBeginParams: + """DownloadWillBeginParams.""" - event_class = "browsingContext.downloadEnd" + suggested_filename: str | None = None + +@dataclass +class DownloadCanceledParams: + """DownloadCanceledParams.""" + + status: Any | None = None + +@dataclass +class DownloadParams: + """DownloadParams - fields shared by all download end event variants.""" + + status: str | None = None + context: Any | None = None + navigation: Any | None = None + timestamp: Any | None = None + url: str | None = None + filepath: str | None = None + +@dataclass +class DownloadEndParams: + """DownloadEndParams - params for browsingContext.downloadEnd event.""" + + download_params: "DownloadParams | None" = None @classmethod - def from_json(cls, json: dict): - return DownloadEndParams.from_json(json) + def from_json(cls, params: dict) -> "DownloadEndParams": + """Deserialize from BiDi wire-level params dict.""" + dp = DownloadParams( + status=params.get("status"), + context=params.get("context"), + navigation=params.get("navigation"), + timestamp=params.get("timestamp"), + url=params.get("url"), + filepath=params.get("filepath"), + ) + return cls(download_params=dp) + +# BiDi Event Name to Parameter Type Mapping +EVENT_NAME_MAPPING = { + "context_created": "browsingContext.contextCreated", + "context_destroyed": "browsingContext.contextDestroyed", + "navigation_started": "browsingContext.navigationStarted", + "fragment_navigated": "browsingContext.fragmentNavigated", + "history_updated": "browsingContext.historyUpdated", + "dom_content_loaded": "browsingContext.domContentLoaded", + "load": "browsingContext.load", + "download_will_begin": "browsingContext.downloadWillBegin", + "download_end": "browsingContext.downloadEnd", + "navigation_aborted": "browsingContext.navigationAborted", + "navigation_committed": "browsingContext.navigationCommitted", + "navigation_failed": "browsingContext.navigationFailed", + "user_prompt_closed": "browsingContext.userPromptClosed", + "user_prompt_opened": "browsingContext.userPromptOpened", + "download_will_begin": "browsingContext.downloadWillBegin", + "download_end": "browsingContext.downloadEnd", +} + +def _deserialize_info_list(items: list) -> list | None: + """Recursively deserialize a list of dicts to Info objects. + + Args: + items: List of dicts from the API response + + Returns: + List of Info objects with properly nested children, or None if empty + """ + if not items or not isinstance(items, list): + return None + + result = [] + for item in items: + if isinstance(item, dict): + # Recursively deserialize children only if the key exists in response + children_list = None + if "children" in item: + children_list = _deserialize_info_list(item.get("children", [])) + info = Info( + children=children_list, + client_window=item.get("clientWindow"), + context=item.get("context"), + original_opener=item.get("originalOpener"), + url=item.get("url"), + user_context=item.get("userContext"), + parent=item.get("parent"), + ) + result.append(info) + return result if result else None + + @dataclass class EventConfig: + """Configuration for a BiDi event.""" event_key: str bidi_event: str event_class: type +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + """ + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) + + class _EventManager: - """Class to manage event subscriptions and callbacks for BrowsingContext.""" + """Manages event subscriptions and callbacks.""" def __init__(self, conn, event_configs: dict[str, EventConfig]): self.conn = conn self.event_configs = event_configs self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} self._available_events = ", ".join(sorted(event_configs.keys())) - # Thread safety lock for subscription operations self._subscription_lock = threading.Lock() + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + def validate_event(self, event: str) -> EventConfig: event_config = self.event_configs.get(event) if not event_config: @@ -614,447 +536,352 @@ def validate_event(self, event: str) -> EventConfig: return event_config def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: - """Subscribe to a BiDi event if not already subscribed. - - Args: - bidi_event: The BiDi event name. - contexts: Optional browsing context IDs to subscribe to. - """ + """Subscribe to a BiDi event if not already subscribed.""" with self._subscription_lock: if bidi_event not in self.subscriptions: session = Session(self.conn) - self.conn.execute(session.subscribe(bidi_event, browsing_contexts=contexts)) - self.subscriptions[bidi_event] = [] + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } def unsubscribe_from_event(self, bidi_event: str) -> None: - """Unsubscribe from a BiDi event if no more callbacks exist. - - Args: - bidi_event: The BiDi event name. - """ + """Unsubscribe from a BiDi event if no more callbacks exist.""" with self._subscription_lock: - callback_list = self.subscriptions.get(bidi_event) - if callback_list is not None and not callback_list: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: session = Session(self.conn) - self.conn.execute(session.unsubscribe(bidi_event)) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) del self.subscriptions[bidi_event] def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: with self._subscription_lock: - self.subscriptions[bidi_event].append(callback_id) + self.subscriptions[bidi_event]["callbacks"].append(callback_id) def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: with self._subscription_lock: - callback_list = self.subscriptions.get(bidi_event) - if callback_list and callback_id in callback_list: - callback_list.remove(callback_id) + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: event_config = self.validate_event(event) - - callback_id = self.conn.add_callback(event_config.event_class, callback) - - # Subscribe to the event if needed + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) self.subscribe_to_event(event_config.bidi_event, contexts) - - # Track the callback self.add_callback_to_tracking(event_config.bidi_event, callback_id) - return callback_id def remove_event_handler(self, event: str, callback_id: int) -> None: event_config = self.validate_event(event) - - # Remove the callback from the connection - self.conn.remove_callback(event_config.event_class, callback_id) - - # Remove from tracking collections + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) self.remove_callback_from_tracking(event_config.bidi_event, callback_id) - - # Unsubscribe if no more callbacks exist self.unsubscribe_from_event(event_config.bidi_event) def clear_event_handlers(self) -> None: - """Clear all event handlers from the browsing context.""" + """Clear all event handlers.""" with self._subscription_lock: if not self.subscriptions: return - session = Session(self.conn) - - for bidi_event, callback_ids in list(self.subscriptions.items()): - event_class = self._bidi_to_class.get(bidi_event) - if event_class: - # Remove all callbacks for this event - for callback_id in callback_ids: - self.conn.remove_callback(event_class, callback_id) - - self.conn.execute(session.unsubscribe(bidi_event)) - + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) self.subscriptions.clear() -class BrowsingContext: - """BiDi implementation of the browsingContext module.""" - - EVENT_CONFIGS = { - "context_created": EventConfig("context_created", "browsingContext.contextCreated", ContextCreated), - "context_destroyed": EventConfig("context_destroyed", "browsingContext.contextDestroyed", ContextDestroyed), - "dom_content_loaded": EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", DomContentLoaded), - "download_end": EventConfig("download_end", "browsingContext.downloadEnd", DownloadEnd), - "download_will_begin": EventConfig( - "download_will_begin", "browsingContext.downloadWillBegin", DownloadWillBegin - ), - "fragment_navigated": EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", FragmentNavigated), - "history_updated": EventConfig("history_updated", "browsingContext.historyUpdated", HistoryUpdated), - "load": EventConfig("load", "browsingContext.load", Load), - "navigation_aborted": EventConfig("navigation_aborted", "browsingContext.navigationAborted", NavigationAborted), - "navigation_committed": EventConfig( - "navigation_committed", "browsingContext.navigationCommitted", NavigationCommitted - ), - "navigation_failed": EventConfig("navigation_failed", "browsingContext.navigationFailed", NavigationFailed), - "navigation_started": EventConfig("navigation_started", "browsingContext.navigationStarted", NavigationStarted), - "user_prompt_closed": EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", UserPromptClosed), - "user_prompt_opened": EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", UserPromptOpened), - } - - def __init__(self, conn): - self.conn = conn - self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - - @classmethod - def get_event_names(cls) -> list[str]: - """Get a list of all available event names. - - Returns: - A list of event names that can be used with event handlers. - """ - return list(cls.EVENT_CONFIGS.keys()) - def activate(self, context: str) -> None: - """Activates and focuses the given top-level traversable. - Args: - context: The browsing context ID to activate. +class BrowsingContext: + """WebDriver BiDi browsingContext module.""" - Raises: - Exception: If the browsing context is not a top-level traversable. - """ - params = {"context": context} - self.conn.execute(command_builder("browsingContext.activate", params)) - - def capture_screenshot( - self, - context: str, - origin: str = "viewport", - format: dict | None = None, - clip: dict | None = None, - ) -> str: - """Captures an image of the given navigable, and returns it as a Base64-encoded string. + EVENT_CONFIGS = {} + def __init__(self, conn) -> None: + self._conn = conn + self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - Args: - context: The browsing context ID to capture. - origin: The origin of the screenshot, either "viewport" or "document". - format: The format of the screenshot. - clip: The clip rectangle of the screenshot. + def activate(self, context: Any | None = None): + """Execute browsingContext.activate.""" + params = { + "context": context, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.activate", params) + result = self._conn.execute(cmd) + return result - Returns: - The Base64-encoded screenshot. - """ - params: dict[str, Any] = {"context": context, "origin": origin} - if format is not None: - params["format"] = format - if clip is not None: - params["clip"] = clip + def capture_screenshot(self, context: str | None = None, format: Any | None = None, clip: Any | None = None, origin: str | None = None): + """Execute browsingContext.captureScreenshot.""" + params = { + "context": context, + "format": format, + "clip": clip, + "origin": origin, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.captureScreenshot", params) + result = self._conn.execute(cmd) + if result and "data" in result: + extracted = result.get("data") + return extracted + return result - result = self.conn.execute(command_builder("browsingContext.captureScreenshot", params)) - return result["data"] + def close(self, context: Any | None = None, prompt_unload: bool | None = None): + """Execute browsingContext.close.""" + params = { + "context": context, + "promptUnload": prompt_unload, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.close", params) + result = self._conn.execute(cmd) + return result - def close(self, context: str, prompt_unload: bool = False) -> None: - """Closes a top-level traversable. + def create(self, type: Any | None = None, reference_context: Any | None = None, background: bool | None = None, user_context: Any | None = None): + """Execute browsingContext.create.""" + params = { + "type": type, + "referenceContext": reference_context, + "background": background, + "userContext": user_context, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.create", params) + result = self._conn.execute(cmd) + if result and "context" in result: + extracted = result.get("context") + return extracted + return result - Args: - context: The browsing context ID to close. - prompt_unload: Whether to prompt to unload. + def get_tree(self, max_depth: Any | None = None, root: Any | None = None): + """Execute browsingContext.getTree.""" + params = { + "maxDepth": max_depth, + "root": root, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.getTree", params) + result = self._conn.execute(cmd) + if result and "contexts" in result: + items = result.get("contexts", []) + return [ + Info( + children=_deserialize_info_list(item.get("children", [])), + client_window=item.get("clientWindow"), + context=item.get("context"), + original_opener=item.get("originalOpener"), + url=item.get("url"), + user_context=item.get("userContext"), + parent=item.get("parent") + ) + for item in items + if isinstance(item, dict) + ] + return [] + + def handle_user_prompt(self, context: Any | None = None, accept: bool | None = None, user_text: Any | None = None): + """Execute browsingContext.handleUserPrompt.""" + params = { + "context": context, + "accept": accept, + "userText": user_text, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.handleUserPrompt", params) + result = self._conn.execute(cmd) + return result - Raises: - Exception: If the browsing context is not a top-level traversable. - """ - params = {"context": context, "promptUnload": prompt_unload} - self.conn.execute(command_builder("browsingContext.close", params)) - - def create( - self, - type: str, - reference_context: str | None = None, - background: bool = False, - user_context: str | None = None, - ) -> str: - """Creates a new navigable, either in a new tab or in a new window, and returns its navigable id. + def locate_nodes(self, context: str | None = None, locator: Any | None = None, serialization_options: Any | None = None, start_nodes: Any | None = None, max_node_count: int | None = None): + """Execute browsingContext.locateNodes.""" + params = { + "context": context, + "locator": locator, + "serializationOptions": serialization_options, + "startNodes": start_nodes, + "maxNodeCount": max_node_count, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.locateNodes", params) + result = self._conn.execute(cmd) + if result and "nodes" in result: + extracted = result.get("nodes") + return extracted + return result - Args: - type: The type of the new navigable, either "tab" or "window". - reference_context: The reference browsing context ID. - background: Whether to create the new navigable in the background. - user_context: The user context ID. + def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any | None = None): + """Execute browsingContext.navigate.""" + params = { + "context": context, + "url": url, + "wait": wait, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.navigate", params) + result = self._conn.execute(cmd) + return result - Returns: - The browsing context ID of the created navigable. - """ - params: dict[str, Any] = {"type": type} - if reference_context is not None: - params["referenceContext"] = reference_context - if background is not None: - params["background"] = background - if user_context is not None: - params["userContext"] = user_context - - result = self.conn.execute(command_builder("browsingContext.create", params)) - return result["context"] - - def get_tree( - self, - max_depth: int | None = None, - root: str | None = None, - ) -> list[BrowsingContextInfo]: - """Get a tree of all descendent navigables including the given parent itself. - - Returns a tree of all descendent navigables including the given parent itself, or all top-level contexts - when no parent is provided. + def print(self, context: Any | None = None, background: bool | None = None, margin: Any | None = None, page: Any | None = None, scale: Any | None = None, shrink_to_fit: bool | None = None): + """Execute browsingContext.print.""" + params = { + "context": context, + "background": background, + "margin": margin, + "page": page, + "scale": scale, + "shrinkToFit": shrink_to_fit, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.print", params) + result = self._conn.execute(cmd) + if result and "data" in result: + extracted = result.get("data") + return extracted + return result - Args: - max_depth: The maximum depth of the tree. - root: The root browsing context ID. + def reload(self, context: Any | None = None, ignore_cache: bool | None = None, wait: Any | None = None): + """Execute browsingContext.reload.""" + params = { + "context": context, + "ignoreCache": ignore_cache, + "wait": wait, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.reload", params) + result = self._conn.execute(cmd) + return result - Returns: - A list of browsing context information. - """ - params: dict[str, Any] = {} - if max_depth is not None: - params["maxDepth"] = max_depth - if root is not None: - params["root"] = root - - result = self.conn.execute(command_builder("browsingContext.getTree", params)) - return [BrowsingContextInfo.from_json(context) for context in result["contexts"]] - - def handle_user_prompt( - self, - context: str, - accept: bool | None = None, - user_text: str | None = None, - ) -> None: - """Allows closing an open prompt. + def set_viewport(self, context: str | None = None, viewport: Any | None = None, user_contexts: Any | None = None, device_pixel_ratio: Any | None = None): + """Execute browsingContext.setViewport.""" + params = { + "context": context, + "viewport": viewport, + "userContexts": user_contexts, + "devicePixelRatio": device_pixel_ratio, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.setViewport", params) + result = self._conn.execute(cmd) + return result - Args: - context: The browsing context ID. - accept: Whether to accept the prompt. - user_text: The text to enter in the prompt. - """ - params: dict[str, Any] = {"context": context} - if accept is not None: - params["accept"] = accept - if user_text is not None: - params["userText"] = user_text - - self.conn.execute(command_builder("browsingContext.handleUserPrompt", params)) - - def locate_nodes( - self, - context: str, - locator: dict, - max_node_count: int | None = None, - serialization_options: dict | None = None, - start_nodes: list[dict] | None = None, - ) -> list[dict]: - """Returns a list of all nodes matching the specified locator. + def traverse_history(self, context: Any | None = None, delta: Any | None = None): + """Execute browsingContext.traverseHistory.""" + params = { + "context": context, + "delta": delta, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.traverseHistory", params) + result = self._conn.execute(cmd) + return result - Args: - context: The browsing context ID. - locator: The locator to use. - max_node_count: The maximum number of nodes to return. - serialization_options: The serialization options. - start_nodes: The start nodes. - Returns: - A list of nodes. - """ - params: dict[str, Any] = {"context": context, "locator": locator} - if max_node_count is not None: - params["maxNodeCount"] = max_node_count - if serialization_options is not None: - params["serializationOptions"] = serialization_options - if start_nodes is not None: - params["startNodes"] = start_nodes - - result = self.conn.execute(command_builder("browsingContext.locateNodes", params)) - return result["nodes"] - - def navigate( - self, - context: str, - url: str, - wait: str | None = None, - ) -> dict: - """Navigates a navigable to the given URL. + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + """Add an event handler. Args: - context: The browsing context ID. - url: The URL to navigate to. - wait: The readiness state to wait for. + event: The event to subscribe to. + callback: The callback function to execute on event. + contexts: The context IDs to subscribe to (optional). Returns: - A dictionary containing the navigation result. + The callback ID. """ - params = {"context": context, "url": url} - if wait is not None: - params["wait"] = wait - - result = self.conn.execute(command_builder("browsingContext.navigate", params)) - return result + return self._event_manager.add_event_handler(event, callback, contexts) - def print( - self, - context: str, - background: bool = False, - margin: dict | None = None, - orientation: str = "portrait", - page: dict | None = None, - page_ranges: list[int | str] | None = None, - scale: float = 1.0, - shrink_to_fit: bool = True, - ) -> str: - """Create a paginated PDF representation of the document as a Base64-encoded string. + def remove_event_handler(self, event: str, callback_id: int) -> None: + """Remove an event handler. Args: - context: The browsing context ID. - background: Whether to include the background. - margin: The margin parameters. - orientation: The orientation, either "portrait" or "landscape". - page: The page parameters. - page_ranges: The page ranges. - scale: The scale. - shrink_to_fit: Whether to shrink to fit. - - Returns: - The Base64-encoded PDF document. + event: The event to unsubscribe from. + callback_id: The callback ID. """ - params = { - "context": context, - "background": background, - "orientation": orientation, - "scale": scale, - "shrinkToFit": shrink_to_fit, - } - if margin is not None: - params["margin"] = margin - if page is not None: - params["page"] = page - if page_ranges is not None: - params["pageRanges"] = page_ranges - - result = self.conn.execute(command_builder("browsingContext.print", params)) - return result["data"] - - def reload( - self, - context: str, - ignore_cache: bool | None = None, - wait: str | None = None, - ) -> dict: - """Reloads a navigable. + return self._event_manager.remove_event_handler(event, callback_id) - Args: - context: The browsing context ID. - ignore_cache: Whether to ignore the cache. - wait: The readiness state to wait for. + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + return self._event_manager.clear_event_handlers() - Returns: - A dictionary containing the navigation result. - """ - params: dict[str, Any] = {"context": context} - if ignore_cache is not None: - params["ignoreCache"] = ignore_cache - if wait is not None: - params["wait"] = wait +# Event Info Type Aliases +# Event: browsingContext.contextCreated +ContextCreated = globals().get('Info', dict) # Fallback to dict if type not defined - result = self.conn.execute(command_builder("browsingContext.reload", params)) - return result +# Event: browsingContext.contextDestroyed +ContextDestroyed = globals().get('Info', dict) # Fallback to dict if type not defined - def set_viewport( - self, - context: str | None = None, - viewport: dict | None | Sentinel = UNDEFINED, - device_pixel_ratio: float | None | Sentinel = UNDEFINED, - user_contexts: list[str] | None = None, - ) -> None: - """Modifies specific viewport characteristics on the given top-level traversable. +# Event: browsingContext.navigationStarted +NavigationStarted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined - Args: - context: The browsing context ID. - viewport: The viewport parameters - {"width": , "height": } (`None` resets to default). - device_pixel_ratio: The device pixel ratio (`None` resets to default). - user_contexts: The user context IDs. - - Raises: - Exception: If the browsing context is not a top-level traversable - ValueError: If neither `context` nor `user_contexts` is provided - ValueError: If both `context` and `user_contexts` are provided - """ - if context is not None and user_contexts is not None: - raise ValueError("Cannot specify both context and user_contexts") +# Event: browsingContext.fragmentNavigated +FragmentNavigated = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined - if context is None and user_contexts is None: - raise ValueError("Must specify either context or user_contexts") +# Event: browsingContext.historyUpdated +HistoryUpdated = globals().get('HistoryUpdatedParameters', dict) # Fallback to dict if type not defined - params: dict[str, Any] = {} - if context is not None: - params["context"] = context - elif user_contexts is not None: - params["userContexts"] = user_contexts - if viewport is not UNDEFINED: - params["viewport"] = viewport - if device_pixel_ratio is not UNDEFINED: - params["devicePixelRatio"] = device_pixel_ratio +# Event: browsingContext.domContentLoaded +DomContentLoaded = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined - self.conn.execute(command_builder("browsingContext.setViewport", params)) +# Event: browsingContext.load +Load = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined - def traverse_history(self, context: str, delta: int) -> dict: - """Traverses the history of a given navigable by a delta. +# Event: browsingContext.downloadWillBegin +DownloadWillBegin = globals().get('DownloadWillBeginParams', dict) # Fallback to dict if type not defined - Args: - context: The browsing context ID. - delta: The delta to traverse by. +# Event: browsingContext.downloadEnd +DownloadEnd = globals().get('DownloadEndParams', dict) # Fallback to dict if type not defined - Returns: - A dictionary containing the traverse history result. - """ - params = {"context": context, "delta": delta} - result = self.conn.execute(command_builder("browsingContext.traverseHistory", params)) - return result +# Event: browsingContext.navigationAborted +NavigationAborted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - """Add an event handler to the browsing context. +# Event: browsingContext.navigationCommitted +NavigationCommitted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined - Args: - event: The event to subscribe to. - callback: The callback function to execute on event. - contexts: The browsing context IDs to subscribe to. +# Event: browsingContext.navigationFailed +NavigationFailed = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined - Returns: - Callback id. - """ - return self._event_manager.add_event_handler(event, callback, contexts) +# Event: browsingContext.userPromptClosed +UserPromptClosed = globals().get('UserPromptClosedParameters', dict) # Fallback to dict if type not defined - def remove_event_handler(self, event: str, callback_id: int) -> None: - """Remove an event handler from the browsing context. +# Event: browsingContext.userPromptOpened +UserPromptOpened = globals().get('UserPromptOpenedParameters', dict) # Fallback to dict if type not defined - Args: - event: The event to unsubscribe from. - callback_id: The callback id to remove. - """ - self._event_manager.remove_event_handler(event, callback_id) - def clear_event_handlers(self) -> None: - """Clear all event handlers from the browsing context.""" - self._event_manager.clear_event_handlers() +# Populate EVENT_CONFIGS with event configuration mappings +_globals = globals() +BrowsingContext.EVENT_CONFIGS = { + "context_created": (EventConfig("context_created", "browsingContext.contextCreated", _globals.get("ContextCreated", dict)) if _globals.get("ContextCreated") else EventConfig("context_created", "browsingContext.contextCreated", dict)), + "context_destroyed": (EventConfig("context_destroyed", "browsingContext.contextDestroyed", _globals.get("ContextDestroyed", dict)) if _globals.get("ContextDestroyed") else EventConfig("context_destroyed", "browsingContext.contextDestroyed", dict)), + "navigation_started": (EventConfig("navigation_started", "browsingContext.navigationStarted", _globals.get("NavigationStarted", dict)) if _globals.get("NavigationStarted") else EventConfig("navigation_started", "browsingContext.navigationStarted", dict)), + "fragment_navigated": (EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", _globals.get("FragmentNavigated", dict)) if _globals.get("FragmentNavigated") else EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", dict)), + "history_updated": (EventConfig("history_updated", "browsingContext.historyUpdated", _globals.get("HistoryUpdated", dict)) if _globals.get("HistoryUpdated") else EventConfig("history_updated", "browsingContext.historyUpdated", dict)), + "dom_content_loaded": (EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", _globals.get("DomContentLoaded", dict)) if _globals.get("DomContentLoaded") else EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", dict)), + "load": (EventConfig("load", "browsingContext.load", _globals.get("Load", dict)) if _globals.get("Load") else EventConfig("load", "browsingContext.load", dict)), + "download_will_begin": (EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBegin", dict)) if _globals.get("DownloadWillBegin") else EventConfig("download_will_begin", "browsingContext.downloadWillBegin", dict)), + "download_end": (EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEnd", dict)) if _globals.get("DownloadEnd") else EventConfig("download_end", "browsingContext.downloadEnd", dict)), + "navigation_aborted": (EventConfig("navigation_aborted", "browsingContext.navigationAborted", _globals.get("NavigationAborted", dict)) if _globals.get("NavigationAborted") else EventConfig("navigation_aborted", "browsingContext.navigationAborted", dict)), + "navigation_committed": (EventConfig("navigation_committed", "browsingContext.navigationCommitted", _globals.get("NavigationCommitted", dict)) if _globals.get("NavigationCommitted") else EventConfig("navigation_committed", "browsingContext.navigationCommitted", dict)), + "navigation_failed": (EventConfig("navigation_failed", "browsingContext.navigationFailed", _globals.get("NavigationFailed", dict)) if _globals.get("NavigationFailed") else EventConfig("navigation_failed", "browsingContext.navigationFailed", dict)), + "user_prompt_closed": (EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", _globals.get("UserPromptClosed", dict)) if _globals.get("UserPromptClosed") else EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", dict)), + "user_prompt_opened": (EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", _globals.get("UserPromptOpened", dict)) if _globals.get("UserPromptOpened") else EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", dict)), + "download_will_begin": EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBeginParams", dict)), + "download_end": EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEndParams", dict)), +} diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index 0f57d07e5f0d4..d90d8c770263a 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -15,22 +15,25 @@ # specific language governing permissions and limitations # under the License. -from collections.abc import Generator +"""Common utilities for BiDi command construction.""" +from typing import Any, Dict, Generator -def command_builder(method: str, params: dict | None = None) -> Generator[dict, dict, dict]: - """Build a command iterator to send to the BiDi protocol. + +def command_builder( + method: str, params: Dict[str, Any] +) -> Generator[Dict[str, Any], Any, Any]: + """Build a BiDi command generator. Args: - method: The method to execute. - params: The parameters to pass to the method. Default is None. + method: The BiDi method name (e.g., "session.status", "browser.close") + params: The parameters for the command + + Yields: + A dictionary representing the BiDi command Returns: - The response from the command execution. + The result from the BiDi command execution """ - if params is None: - params = {} - - command = {"method": method, "params": params} - cmd = yield command - return cmd + result = yield {"method": method, "params": params} + return result diff --git a/py/selenium/webdriver/common/bidi/console.py b/py/selenium/webdriver/common/bidi/console.py old mode 100644 new mode 100755 diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index a6acaefe89b83..4cd6ae2e3c712 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -1,39 +1,34 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# WebDriver BiDi module: emulation from __future__ import annotations -from enum import Enum -from typing import TYPE_CHECKING, Any, TypeVar +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass -from selenium.webdriver.common.bidi.common import command_builder -if TYPE_CHECKING: - from selenium.webdriver.remote.websocket_connection import WebSocketConnection +class ForcedColorsModeTheme: + """ForcedColorsModeTheme.""" + LIGHT = "light" + DARK = "dark" -class ScreenOrientationNatural(Enum): - """Natural screen orientation.""" + +class ScreenOrientationNatural: + """ScreenOrientationNatural.""" PORTRAIT = "portrait" LANDSCAPE = "landscape" -class ScreenOrientationType(Enum): - """Screen orientation type.""" +class ScreenOrientationType: + """ScreenOrientationType.""" PORTRAIT_PRIMARY = "portrait-primary" PORTRAIT_SECONDARY = "portrait-secondary" @@ -41,484 +36,494 @@ class ScreenOrientationType(Enum): LANDSCAPE_SECONDARY = "landscape-secondary" -E = TypeVar("E", ScreenOrientationNatural, ScreenOrientationType) +@dataclass +class SetForcedColorsModeThemeOverrideParameters: + """SetForcedColorsModeThemeOverrideParameters.""" + theme: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None -def _convert_to_enum(value: E | str, enum_class: type[E]) -> E: - if isinstance(value, enum_class): - return value - assert isinstance(value, str) - try: - return enum_class(value.lower()) - except ValueError: - raise ValueError(f"Invalid orientation: {value}") +@dataclass +class SetGeolocationOverrideParameters: + """SetGeolocationOverrideParameters.""" -class ScreenOrientation: - """Represents screen orientation configuration.""" + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None - def __init__( - self, - natural: ScreenOrientationNatural | str, - type: ScreenOrientationType | str, - ): - """Initialize ScreenOrientation. - Args: - natural: Natural screen orientation ("portrait" or "landscape"). - type: Screen orientation type ("portrait-primary", "portrait-secondary", - "landscape-primary", or "landscape-secondary"). +@dataclass +class GeolocationCoordinates: + """GeolocationCoordinates.""" - Raises: - ValueError: If natural or type values are invalid. - """ - # handle string values - self.natural = _convert_to_enum(natural, ScreenOrientationNatural) - self.type = _convert_to_enum(type, ScreenOrientationType) - - def to_dict(self) -> dict[str, str]: - return { - "natural": self.natural.value, - "type": self.type.value, - } + latitude: Any | None = None + longitude: Any | None = None + accuracy: Any | None = None + altitude: Any | None = None + altitude_accuracy: Any | None = None + heading: Any | None = None + speed: Any | None = None -class GeolocationCoordinates: - """Represents geolocation coordinates.""" +@dataclass +class GeolocationPositionError: + """GeolocationPositionError.""" - def __init__( - self, - latitude: float, - longitude: float, - accuracy: float = 1.0, - altitude: float | None = None, - altitude_accuracy: float | None = None, - heading: float | None = None, - speed: float | None = None, - ): - """Initialize GeolocationCoordinates. + type: str = field(default="positionUnavailable", init=False) - Args: - latitude: Latitude coordinate (-90.0 to 90.0). - longitude: Longitude coordinate (-180.0 to 180.0). - accuracy: Accuracy in meters (>= 0.0), defaults to 1.0. - altitude: Altitude in meters or None, defaults to None. - altitude_accuracy: Altitude accuracy in meters (>= 0.0) or None, defaults to None. - heading: Heading in degrees (0.0 to 360.0) or None, defaults to None. - speed: Speed in meters per second (>= 0.0) or None, defaults to None. - - Raises: - ValueError: If coordinates are out of valid range or if altitude_accuracy is provided without altitude. - """ - self.latitude = latitude - self.longitude = longitude - self.accuracy = accuracy - self.altitude = altitude - self.altitude_accuracy = altitude_accuracy - self.heading = heading - self.speed = speed - - @property - def latitude(self) -> float: - return self._latitude - - @latitude.setter - def latitude(self, value: float) -> None: - if not (-90.0 <= value <= 90.0): - raise ValueError("latitude must be between -90.0 and 90.0") - self._latitude = value - - @property - def longitude(self) -> float: - return self._longitude - - @longitude.setter - def longitude(self, value: float) -> None: - if not (-180.0 <= value <= 180.0): - raise ValueError("longitude must be between -180.0 and 180.0") - self._longitude = value - - @property - def accuracy(self) -> float: - return self._accuracy - - @accuracy.setter - def accuracy(self, value: float) -> None: - if value < 0.0: - raise ValueError("accuracy must be >= 0.0") - self._accuracy = value - - @property - def altitude(self) -> float | None: - return self._altitude - - @altitude.setter - def altitude(self, value: float | None) -> None: - self._altitude = value - - @property - def altitude_accuracy(self) -> float | None: - return self._altitude_accuracy - - @altitude_accuracy.setter - def altitude_accuracy(self, value: float | None) -> None: - if value is not None and self.altitude is None: - raise ValueError("altitude_accuracy cannot be set without altitude") - if value is not None and value < 0.0: - raise ValueError("altitude_accuracy must be >= 0.0") - self._altitude_accuracy = value - - @property - def heading(self) -> float | None: - return self._heading - - @heading.setter - def heading(self, value: float | None) -> None: - if value is not None and not (0.0 <= value < 360.0): - raise ValueError("heading must be between 0.0 and 360.0") - self._heading = value - - @property - def speed(self) -> float | None: - return self._speed - - @speed.setter - def speed(self, value: float | None) -> None: - if value is not None and value < 0.0: - raise ValueError("speed must be >= 0.0") - self._speed = value - - def to_dict(self) -> dict[str, float | None]: - result: dict[str, float | None] = { - "latitude": self.latitude, - "longitude": self.longitude, - "accuracy": self.accuracy, - } - if self.altitude is not None: - result["altitude"] = self.altitude +@dataclass +class SetLocaleOverrideParameters: + """SetLocaleOverrideParameters.""" - if self.altitude_accuracy is not None: - result["altitudeAccuracy"] = self.altitude_accuracy + locale: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None - if self.heading is not None: - result["heading"] = self.heading - if self.speed is not None: - result["speed"] = self.speed +@dataclass +class setNetworkConditionsParameters: + """setNetworkConditionsParameters.""" - return result + network_conditions: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None -class GeolocationPositionError: - """Represents a geolocation position error.""" +@dataclass +class NetworkConditionsOffline: + """NetworkConditionsOffline.""" - TYPE_POSITION_UNAVAILABLE = "positionUnavailable" + type: str = field(default="offline", init=False) - def __init__(self, type: str = TYPE_POSITION_UNAVAILABLE): - if type != self.TYPE_POSITION_UNAVAILABLE: - raise ValueError(f'type must be "{self.TYPE_POSITION_UNAVAILABLE}"') - self.type = type - def to_dict(self) -> dict[str, str]: - return {"type": self.type} +@dataclass +class ScreenArea: + """ScreenArea.""" + width: Any | None = None + height: Any | None = None -class Emulation: - """BiDi implementation of the emulation module.""" - def __init__(self, conn: WebSocketConnection) -> None: - self.conn = conn +@dataclass +class SetScreenSettingsOverrideParameters: + """SetScreenSettingsOverrideParameters.""" - def set_geolocation_override( - self, - coordinates: GeolocationCoordinates | None = None, - error: GeolocationPositionError | None = None, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set geolocation override for the given contexts or user contexts. + screen_area: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None - Args: - coordinates: Geolocation coordinates to emulate, or None. - error: Geolocation error to emulate, or None. - contexts: List of browsing context IDs to apply the override to. - user_contexts: List of user context IDs to apply the override to. - - Raises: - ValueError: If both coordinates and error are provided, or if both contexts - and user_contexts are provided, or if neither contexts nor - user_contexts are provided. - """ - if coordinates is not None and error is not None: - raise ValueError("Cannot specify both coordinates and error") - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and userContexts") +@dataclass +class ScreenOrientation: + """ScreenOrientation.""" - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or userContexts") + natural: Any | None = None + type: Any | None = None - params: dict[str, Any] = {} - if coordinates is not None: - params["coordinates"] = coordinates.to_dict() - elif error is not None: - params["error"] = error.to_dict() +@dataclass +class SetScreenOrientationOverrideParameters: + """SetScreenOrientationOverrideParameters.""" - if contexts is not None: - params["contexts"] = contexts - elif user_contexts is not None: - params["userContexts"] = user_contexts + screen_orientation: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None - self.conn.execute(command_builder("emulation.setGeolocationOverride", params)) - def set_timezone_override( - self, - timezone: str | None = None, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set timezone override for the given contexts or user contexts. +@dataclass +class SetUserAgentOverrideParameters: + """SetUserAgentOverrideParameters.""" - Args: - timezone: Timezone identifier (IANA timezone name or offset string like '+01:00'), - or None to clear the override. - contexts: List of browsing context IDs to apply the override to. - user_contexts: List of user context IDs to apply the override to. - - Raises: - ValueError: If both contexts and user_contexts are provided, or if neither - contexts nor user_contexts are provided. - """ - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and user_contexts") + user_agent: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or user_contexts") - params: dict[str, Any] = {"timezone": timezone} +@dataclass +class SetViewportMetaOverrideParameters: + """SetViewportMetaOverrideParameters.""" - if contexts is not None: - params["contexts"] = contexts - elif user_contexts is not None: - params["userContexts"] = user_contexts + viewport_meta: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None - self.conn.execute(command_builder("emulation.setTimezoneOverride", params)) - def set_locale_override( - self, - locale: str | None = None, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set locale override for the given contexts or user contexts. +@dataclass +class SetScriptingEnabledParameters: + """SetScriptingEnabledParameters.""" - Args: - locale: Locale string as per BCP 47, or None to clear override. - contexts: List of browsing context IDs to apply the override to. - user_contexts: List of user context IDs to apply the override to. + enabled: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None - Raises: - ValueError: If both contexts and user_contexts are provided, or if neither - contexts nor user_contexts are provided, or if locale is invalid. - """ - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and userContexts") - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or userContexts") +@dataclass +class SetScrollbarTypeOverrideParameters: + """SetScrollbarTypeOverrideParameters.""" - params: dict[str, Any] = {"locale": locale} + scrollbar_type: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None - if contexts is not None: - params["contexts"] = contexts - elif user_contexts is not None: - params["userContexts"] = user_contexts - self.conn.execute(command_builder("emulation.setLocaleOverride", params)) +@dataclass +class SetTimezoneOverrideParameters: + """SetTimezoneOverrideParameters.""" - def set_scripting_enabled( - self, - enabled: bool | None = False, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set scripting enabled override for the given contexts or user contexts. + timezone: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None - Args: - enabled: False to disable scripting, None to clear the override. - Note: Only emulation of disabled JavaScript is supported. - contexts: List of browsing context IDs to apply the override to. - user_contexts: List of user context IDs to apply the override to. - - Raises: - ValueError: If both contexts and user_contexts are provided, or if neither - contexts nor user_contexts are provided, or if enabled is True. - """ - if enabled: - raise ValueError("Only emulation of disabled JavaScript is supported (enabled must be False or None)") - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and userContexts") +@dataclass +class SetTouchOverrideParameters: + """SetTouchOverrideParameters.""" - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or userContexts") + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None - params: dict[str, Any] = {"enabled": enabled} - if contexts is not None: - params["contexts"] = contexts - elif user_contexts is not None: - params["userContexts"] = user_contexts +class Emulation: + """WebDriver BiDi emulation module.""" - self.conn.execute(command_builder("emulation.setScriptingEnabled", params)) + def __init__(self, conn) -> None: + self._conn = conn - def set_screen_orientation_override( - self, - screen_orientation: ScreenOrientation | None = None, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set screen orientation override for the given contexts or user contexts. + def set_forced_colors_mode_theme_override(self, theme: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setForcedColorsModeThemeOverride.""" + params = { + "theme": theme, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setForcedColorsModeThemeOverride", params) + result = self._conn.execute(cmd) + return result - Args: - screen_orientation: ScreenOrientation object to emulate, or None to clear the override. - contexts: List of browsing context IDs to apply the override to. - user_contexts: List of user context IDs to apply the override to. + def set_geolocation_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setGeolocationOverride.""" + params = { + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setGeolocationOverride", params) + result = self._conn.execute(cmd) + return result - Raises: - ValueError: If both contexts and user_contexts are provided, or if neither - contexts nor user_contexts are provided. - """ - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and userContexts") + def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setLocaleOverride.""" + params = { + "locale": locale, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setLocaleOverride", params) + result = self._conn.execute(cmd) + return result + + def set_network_conditions(self, network_conditions: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setNetworkConditions.""" + params = { + "networkConditions": network_conditions, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setNetworkConditions", params) + result = self._conn.execute(cmd) + return result - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or userContexts") + def set_screen_settings_override(self, screen_area: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setScreenSettingsOverride.""" + params = { + "screenArea": screen_area, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setScreenSettingsOverride", params) + result = self._conn.execute(cmd) + return result - params: dict[str, Any] = { - "screenOrientation": screen_orientation.to_dict() if screen_orientation is not None else None + def set_screen_orientation_override(self, screen_orientation: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setScreenOrientationOverride.""" + params = { + "screenOrientation": screen_orientation, + "contexts": contexts, + "userContexts": user_contexts, } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setScreenOrientationOverride", params) + result = self._conn.execute(cmd) + return result - if contexts is not None: - params["contexts"] = contexts - elif user_contexts is not None: - params["userContexts"] = user_contexts + def set_user_agent_override(self, user_agent: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setUserAgentOverride.""" + params = { + "userAgent": user_agent, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setUserAgentOverride", params) + result = self._conn.execute(cmd) + return result - self.conn.execute(command_builder("emulation.setScreenOrientationOverride", params)) + def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setViewportMetaOverride.""" + params = { + "viewportMeta": viewport_meta, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setViewportMetaOverride", params) + result = self._conn.execute(cmd) + return result - def set_user_agent_override( - self, - user_agent: str | None = None, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set user agent override for the given contexts or user contexts. + def set_scripting_enabled(self, enabled: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setScriptingEnabled.""" + params = { + "enabled": enabled, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setScriptingEnabled", params) + result = self._conn.execute(cmd) + return result - Args: - user_agent: User agent string to emulate, or None to clear the override. - contexts: List of browsing context IDs to apply the override to. - user_contexts: List of user context IDs to apply the override to. + def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setScrollbarTypeOverride.""" + params = { + "scrollbarType": scrollbar_type, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setScrollbarTypeOverride", params) + result = self._conn.execute(cmd) + return result - Raises: - ValueError: If both contexts and user_contexts are provided, or if neither - contexts nor user_contexts are provided. - """ - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and user_contexts") + def set_timezone_override(self, timezone: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setTimezoneOverride.""" + params = { + "timezone": timezone, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setTimezoneOverride", params) + result = self._conn.execute(cmd) + return result + + def set_touch_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setTouchOverride.""" + params = { + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setTouchOverride", params) + result = self._conn.execute(cmd) + return result - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or user_contexts") + def set_geolocation_override( + self, + coordinates=None, + error=None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setGeolocationOverride. - params: dict[str, Any] = {"userAgent": user_agent} + Sets or clears the geolocation override for specified browsing or user contexts. + Args: + coordinates: A GeolocationCoordinates instance (or dict) to override the + position, or ``None`` to clear a previously-set override. + error: A GeolocationPositionError instance (or dict) to simulate a + position-unavailable error. Mutually exclusive with *coordinates*. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + params = {} + if coordinates is not None: + if isinstance(coordinates, dict): + coords_dict = coordinates + else: + coords_dict = {} + if coordinates.latitude is not None: + coords_dict["latitude"] = coordinates.latitude + if coordinates.longitude is not None: + coords_dict["longitude"] = coordinates.longitude + if coordinates.accuracy is not None: + coords_dict["accuracy"] = coordinates.accuracy + if coordinates.altitude is not None: + coords_dict["altitude"] = coordinates.altitude + if coordinates.altitude_accuracy is not None: + coords_dict["altitudeAccuracy"] = coordinates.altitude_accuracy + if coordinates.heading is not None: + coords_dict["heading"] = coordinates.heading + if coordinates.speed is not None: + coords_dict["speed"] = coordinates.speed + params["coordinates"] = coords_dict + if error is not None: + if isinstance(error, dict): + params["error"] = error + else: + params["error"] = { + "type": error.type if error.type is not None else "positionUnavailable" + } if contexts is not None: params["contexts"] = contexts - elif user_contexts is not None: + if user_contexts is not None: params["userContexts"] = user_contexts - - self.conn.execute(command_builder("emulation.setUserAgentOverride", params)) - - def set_network_conditions( + cmd = command_builder("emulation.setGeolocationOverride", params) + result = self._conn.execute(cmd) + return result + def set_timezone_override( self, - offline: bool = False, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set network conditions for the given contexts or user contexts. + timezone=None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setTimezoneOverride. - Args: - offline: True to emulate offline network conditions, False to clear the override. - contexts: List of browsing context IDs to apply the conditions to. - user_contexts: List of user context IDs to apply the conditions to. + Sets or clears the timezone override for specified browsing or user contexts. + Pass ``timezone=None`` (or omit it) to clear a previously-set override. - Raises: - ValueError: If both contexts and user_contexts are provided, or if neither - contexts nor user_contexts are provided. + Args: + timezone: IANA timezone string (e.g. ``"America/New_York"``) or ``None`` + to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. """ - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and user_contexts") - - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or user_contexts") - - params: dict[str, Any] = {} - - if offline: - params["networkConditions"] = {"type": "offline"} - else: - # if offline is False or None, then clear the override - params["networkConditions"] = None - + params = {"timezone": timezone} if contexts is not None: params["contexts"] = contexts - elif user_contexts is not None: + if user_contexts is not None: params["userContexts"] = user_contexts + cmd = command_builder("emulation.setTimezoneOverride", params) + return self._conn.execute(cmd) + def set_scripting_enabled( + self, + enabled=None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setScriptingEnabled. - self.conn.execute(command_builder("emulation.setNetworkConditions", params)) + Enables or disables scripting for specified browsing or user contexts. + Pass ``enabled=None`` to restore the default behaviour. - def set_screen_settings_override( + Args: + enabled: ``True`` to enable scripting, ``False`` to disable it, or + ``None`` to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + params = {"enabled": enabled} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setScriptingEnabled", params) + return self._conn.execute(cmd) + def set_user_agent_override( self, - width: int | None = None, - height: int | None = None, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set screen settings override for the given contexts or user contexts. + user_agent=None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setUserAgentOverride. + + Overrides the User-Agent string for specified browsing or user contexts. + Pass ``user_agent=None`` to clear a previously-set override. Args: - width: Screen width in pixels (>= 0). None to clear the override. - height: Screen height in pixels (>= 0). None to clear the override. - contexts: List of browsing context IDs to apply the override to. - user_contexts: List of user context IDs to apply the override to. - - Raises: - ValueError: If only one of width/height is provided, or if both contexts - and user_contexts are provided, or if neither is provided. + user_agent: Custom User-Agent string, or ``None`` to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. """ - if (width is None) != (height is None): - raise ValueError("Must provide both width and height, or neither to clear the override") - - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and user_contexts") + params = {"userAgent": user_agent} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setUserAgentOverride", params) + return self._conn.execute(cmd) + def set_screen_orientation_override( + self, + screen_orientation=None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setScreenOrientationOverride. - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or user_contexts") + Sets or clears the screen orientation override for specified browsing or + user contexts. - screen_area = None - if width is not None and height is not None: - if not isinstance(width, int) or not isinstance(height, int): - raise ValueError("width and height must be integers") - if width < 0 or height < 0: - raise ValueError("width and height must be >= 0") - screen_area = {"width": width, "height": height} + Args: + screen_orientation: A :class:`ScreenOrientation` instance (or dict with + ``natural`` and ``type`` keys) to lock the orientation, or ``None`` + to clear a previously-set override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + if screen_orientation is None: + so_value = None + elif isinstance(screen_orientation, dict): + so_value = screen_orientation + else: + natural = screen_orientation.natural + orientation_type = screen_orientation.type + so_value = { + "natural": natural.lower() if isinstance(natural, str) else natural, + "type": orientation_type.lower() if isinstance(orientation_type, str) else orientation_type, + } + params = {"screenOrientation": so_value} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setScreenOrientationOverride", params) + return self._conn.execute(cmd) + def set_network_conditions( + self, + network_conditions=None, + offline: bool | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): + """Execute emulation.setNetworkConditions. - params: dict[str, Any] = {"screenArea": screen_area} + Sets or clears network condition emulation for specified browsing or user + contexts. + Args: + network_conditions: A dict with the raw ``networkConditions`` value + (e.g. ``{"type": "offline"}``), or ``None`` to clear the override. + Mutually exclusive with *offline*. + offline: Convenience bool — ``True`` sets offline conditions, + ``False`` clears them (sends ``null``). When provided, this takes + precedence over *network_conditions*. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + if offline is not None: + nc_value = {"type": "offline"} if offline else None + else: + nc_value = network_conditions + params = {"networkConditions": nc_value} if contexts is not None: params["contexts"] = contexts - elif user_contexts is not None: + if user_contexts is not None: params["userContexts"] = user_contexts - - self.conn.execute(command_builder("emulation.setScreenSettingsOverride", params)) + cmd = command_builder("emulation.setNetworkConditions", params) + return self._conn.execute(cmd) diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 270ececaf41a1..5dbe71dbd3886 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -1,40 +1,32 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import math -from dataclasses import dataclass, field -from typing import Any - -from selenium.webdriver.common.bidi.common import command_builder +# WebDriver BiDi module: input +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass +import threading +from collections.abc import Callable +from dataclasses import dataclass from selenium.webdriver.common.bidi.session import Session class PointerType: - """Represents the possible pointer types.""" + """PointerType.""" MOUSE = "mouse" PEN = "pen" TOUCH = "touch" - VALID_TYPES = {MOUSE, PEN, TOUCH} - class Origin: - """Represents the possible origin types.""" + """Origin.""" VIEWPORT = "viewport" POINTER = "pointer" @@ -42,421 +34,425 @@ class Origin: @dataclass class ElementOrigin: - """Represents an element origin for input actions.""" - - type: str - element: dict - - def __init__(self, element_reference: dict): - self.type = "element" - self.element = element_reference + """ElementOrigin.""" - def to_dict(self) -> dict: - """Convert the ElementOrigin to a dictionary.""" - return {"type": self.type, "element": self.element} + type: str = field(default="element", init=False) + element: Any | None = None @dataclass -class PointerParameters: - """Represents pointer parameters for pointer actions.""" - - pointer_type: str = PointerType.MOUSE - - def __post_init__(self): - if self.pointer_type not in PointerType.VALID_TYPES: - raise ValueError(f"Invalid pointer type: {self.pointer_type}. Must be one of {PointerType.VALID_TYPES}") +class PerformActionsParameters: + """PerformActionsParameters.""" - def to_dict(self) -> dict: - """Convert the PointerParameters to a dictionary.""" - return {"pointerType": self.pointer_type} + context: Any | None = None + actions: list[Any | None] | None = None @dataclass -class PointerCommonProperties: - """Common properties for pointer actions.""" - - width: int = 1 - height: int = 1 - pressure: float = 0.0 - tangential_pressure: float = 0.0 - twist: int = 0 - altitude_angle: float = 0.0 - azimuth_angle: float = 0.0 - - def __post_init__(self): - if self.width < 1: - raise ValueError("width must be at least 1") - if self.height < 1: - raise ValueError("height must be at least 1") - if not (0.0 <= self.pressure <= 1.0): - raise ValueError("pressure must be between 0.0 and 1.0") - if not (0.0 <= self.tangential_pressure <= 1.0): - raise ValueError("tangential_pressure must be between 0.0 and 1.0") - if not (0 <= self.twist <= 359): - raise ValueError("twist must be between 0 and 359") - if not (0.0 <= self.altitude_angle <= math.pi / 2): - raise ValueError("altitude_angle must be between 0.0 and π/2") - if not (0.0 <= self.azimuth_angle <= 2 * math.pi): - raise ValueError("azimuth_angle must be between 0.0 and 2π") - - def to_dict(self) -> dict: - """Convert the PointerCommonProperties to a dictionary.""" - result: dict[str, Any] = {} - if self.width != 1: - result["width"] = self.width - if self.height != 1: - result["height"] = self.height - if self.pressure != 0.0: - result["pressure"] = self.pressure - if self.tangential_pressure != 0.0: - result["tangentialPressure"] = self.tangential_pressure - if self.twist != 0: - result["twist"] = self.twist - if self.altitude_angle != 0.0: - result["altitudeAngle"] = self.altitude_angle - if self.azimuth_angle != 0.0: - result["azimuthAngle"] = self.azimuth_angle - return result - +class NoneSourceActions: + """NoneSourceActions.""" -# Action classes -@dataclass -class PauseAction: - """Represents a pause action.""" + type: str = field(default="none", init=False) + id: str | None = None + actions: list[Any | None] | None = None - duration: int | None = None - @property - def type(self) -> str: - return "pause" +@dataclass +class KeySourceActions: + """KeySourceActions.""" - def to_dict(self) -> dict: - """Convert the PauseAction to a dictionary.""" - result: dict[str, Any] = {"type": self.type} - if self.duration is not None: - result["duration"] = self.duration - return result + type: str = field(default="key", init=False) + id: str | None = None + actions: list[Any | None] | None = None @dataclass -class KeyDownAction: - """Represents a key down action.""" - - value: str = "" - - @property - def type(self) -> str: - return "keyDown" +class PointerSourceActions: + """PointerSourceActions.""" - def to_dict(self) -> dict: - """Convert the KeyDownAction to a dictionary.""" - return {"type": self.type, "value": self.value} + type: str = field(default="pointer", init=False) + id: str | None = None + parameters: Any | None = None + actions: list[Any | None] | None = None @dataclass -class KeyUpAction: - """Represents a key up action.""" - - value: str = "" - - @property - def type(self) -> str: - return "keyUp" +class PointerParameters: + """PointerParameters.""" - def to_dict(self) -> dict: - """Convert the KeyUpAction to a dictionary.""" - return {"type": self.type, "value": self.value} + pointer_type: Any | None = None @dataclass -class PointerDownAction: - """Represents a pointer down action.""" +class WheelSourceActions: + """WheelSourceActions.""" - button: int = 0 - properties: PointerCommonProperties | None = None + type: str = field(default="wheel", init=False) + id: str | None = None + actions: list[Any | None] | None = None - @property - def type(self) -> str: - return "pointerDown" - def to_dict(self) -> dict: - """Convert the PointerDownAction to a dictionary.""" - result: dict[str, Any] = {"type": self.type, "button": self.button} - if self.properties: - result.update(self.properties.to_dict()) - return result +@dataclass +class PauseAction: + """PauseAction.""" + + type: str = field(default="pause", init=False) + duration: Any | None = None @dataclass -class PointerUpAction: - """Represents a pointer up action.""" +class KeyDownAction: + """KeyDownAction.""" + + type: str = field(default="keyDown", init=False) + value: str | None = None - button: int = 0 - @property - def type(self) -> str: - return "pointerUp" +@dataclass +class KeyUpAction: + """KeyUpAction.""" - def to_dict(self) -> dict: - """Convert the PointerUpAction to a dictionary.""" - return {"type": self.type, "button": self.button} + type: str = field(default="keyUp", init=False) + value: str | None = None @dataclass -class PointerMoveAction: - """Represents a pointer move action.""" - - x: float = 0 - y: float = 0 - duration: int | None = None - origin: str | ElementOrigin | None = None - properties: PointerCommonProperties | None = None - - @property - def type(self) -> str: - return "pointerMove" - - def to_dict(self) -> dict: - """Convert the PointerMoveAction to a dictionary.""" - result: dict[str, Any] = {"type": self.type, "x": self.x, "y": self.y} - if self.duration is not None: - result["duration"] = self.duration - if self.origin is not None: - if isinstance(self.origin, ElementOrigin): - result["origin"] = self.origin.to_dict() - else: - result["origin"] = self.origin - if self.properties: - result.update(self.properties.to_dict()) - return result +class PointerUpAction: + """PointerUpAction.""" + + type: str = field(default="pointerUp", init=False) + button: Any | None = None @dataclass class WheelScrollAction: - """Represents a wheel scroll action.""" - - x: int = 0 - y: int = 0 - delta_x: int = 0 - delta_y: int = 0 - duration: int | None = None - origin: str | ElementOrigin | None = Origin.VIEWPORT - - @property - def type(self) -> str: - return "scroll" - - def to_dict(self) -> dict: - """Convert the WheelScrollAction to a dictionary.""" - result: dict[str, Any] = { - "type": self.type, - "x": self.x, - "y": self.y, - "deltaX": self.delta_x, - "deltaY": self.delta_y, - } - if self.duration is not None: - result["duration"] = self.duration - if self.origin is not None: - if isinstance(self.origin, ElementOrigin): - result["origin"] = self.origin.to_dict() - else: - result["origin"] = self.origin - return result + """WheelScrollAction.""" + type: str = field(default="scroll", init=False) + x: Any | None = None + y: Any | None = None + delta_x: Any | None = None + delta_y: Any | None = None + duration: Any | None = None + origin: Any | None = None -# Source Actions -@dataclass -class NoneSourceActions: - """Represents a sequence of none actions.""" - id: str = "" - actions: list[PauseAction] = field(default_factory=list) - - @property - def type(self) -> str: - return "none" +@dataclass +class PointerCommonProperties: + """PointerCommonProperties.""" - def to_dict(self) -> dict: - """Convert the NoneSourceActions to a dictionary.""" - return {"type": self.type, "id": self.id, "actions": [action.to_dict() for action in self.actions]} + width: Any | None = None + height: Any | None = None + pressure: Any | None = None + tangential_pressure: Any | None = None + twist: Any | None = None + altitude_angle: Any | None = None + azimuth_angle: Any | None = None @dataclass -class KeySourceActions: - """Represents a sequence of key actions.""" +class ReleaseActionsParameters: + """ReleaseActionsParameters.""" + + context: Any | None = None - id: str = "" - actions: list[PauseAction | KeyDownAction | KeyUpAction] = field(default_factory=list) - @property - def type(self) -> str: - return "key" +@dataclass +class SetFilesParameters: + """SetFilesParameters.""" - def to_dict(self) -> dict: - """Convert the KeySourceActions to a dictionary.""" - return {"type": self.type, "id": self.id, "actions": [action.to_dict() for action in self.actions]} + context: Any | None = None + element: Any | None = None + files: list[Any | None] | None = None @dataclass -class PointerSourceActions: - """Represents a sequence of pointer actions.""" - - id: str = "" - parameters: PointerParameters | None = None - actions: list[PauseAction | PointerDownAction | PointerUpAction | PointerMoveAction] = field(default_factory=list) - - def __post_init__(self): - if self.parameters is None: - self.parameters = PointerParameters() - - @property - def type(self) -> str: - return "pointer" - - def to_dict(self) -> dict: - """Convert the PointerSourceActions to a dictionary.""" - result: dict[str, Any] = { - "type": self.type, - "id": self.id, - "actions": [action.to_dict() for action in self.actions], - } - if self.parameters: - result["parameters"] = self.parameters.to_dict() - return result +class FileDialogInfo: + """FileDialogInfo - parameters for the input.fileDialogOpened event.""" + + context: Any | None = None + element: Any | None = None + multiple: bool | None = None + @classmethod + def from_json(cls, params: dict) -> "FileDialogInfo": + """Deserialize event params into FileDialogInfo.""" + return cls( + context=params.get("context"), + element=params.get("element"), + multiple=params.get("multiple"), + ) @dataclass -class WheelSourceActions: - """Represents a sequence of wheel actions.""" +class PointerMoveAction: + """PointerMoveAction.""" - id: str = "" - actions: list[PauseAction | WheelScrollAction] = field(default_factory=list) + type: str = field(default="pointerMove", init=False) + x: Any | None = None + y: Any | None = None + duration: Any | None = None + origin: Any | None = None + properties: Any | None = None - @property - def type(self) -> str: - return "wheel" +@dataclass +class PointerDownAction: + """PointerDownAction.""" - def to_dict(self) -> dict: - """Convert the WheelSourceActions to a dictionary.""" - return {"type": self.type, "id": self.id, "actions": [action.to_dict() for action in self.actions]} + type: str = field(default="pointerDown", init=False) + button: Any | None = None + properties: Any | None = None +# BiDi Event Name to Parameter Type Mapping +EVENT_NAME_MAPPING = { + "file_dialog_opened": "input.fileDialogOpened", +} @dataclass -class FileDialogInfo: - """Represents file dialog information from input.fileDialogOpened event.""" +class EventConfig: + """Configuration for a BiDi event.""" + event_key: str + bidi_event: str + event_class: type - context: str - multiple: bool - element: dict | None = None - @classmethod - def from_dict(cls, data: dict) -> "FileDialogInfo": - """Creates a FileDialogInfo instance from a dictionary. +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. Args: - data: A dictionary containing the file dialog information. + params: Raw BiDi event params with camelCase keys. Returns: - FileDialogInfo: A new instance of FileDialogInfo. + An instance of the dataclass, or the raw dict on failure. """ - return cls(context=data["context"], multiple=data["multiple"], element=data.get("element")) + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) -# Event Class -class FileDialogOpened: - """Event class for input.fileDialogOpened event.""" +class _EventManager: + """Manages event subscriptions and callbacks.""" - event_class = "input.fileDialogOpened" + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + """Subscribe to a BiDi event if not already subscribed.""" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist.""" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: + session = Session(self.conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event]["callbacks"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id - @classmethod - def from_json(cls, json): - """Create FileDialogInfo from JSON data.""" - return FileDialogInfo.from_dict(json) + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() -class Input: - """BiDi implementation of the input module.""" - def __init__(self, conn): - self.conn = conn - self.subscriptions = {} - self.callbacks = {} - def perform_actions( - self, - context: str, - actions: list[NoneSourceActions | KeySourceActions | PointerSourceActions | WheelSourceActions], - ) -> None: - """Performs a sequence of user input actions. +class Input: + """WebDriver BiDi input module.""" + + EVENT_CONFIGS = {} + def __init__(self, conn) -> None: + self._conn = conn + self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) + + def perform_actions(self, context: Any | None = None, actions: List[Any] | None = None): + """Execute input.performActions.""" + params = { + "context": context, + "actions": actions, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("input.performActions", params) + result = self._conn.execute(cmd) + return result - Args: - context: The browsing context ID where actions should be performed. - actions: A list of source actions to perform. - """ - params = {"context": context, "actions": [action.to_dict() for action in actions]} - self.conn.execute(command_builder("input.performActions", params)) + def release_actions(self, context: Any | None = None): + """Execute input.releaseActions.""" + params = { + "context": context, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("input.releaseActions", params) + result = self._conn.execute(cmd) + return result + + def set_files(self, context: Any | None = None, element: Any | None = None, files: List[Any] | None = None): + """Execute input.setFiles.""" + params = { + "context": context, + "element": element, + "files": files, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("input.setFiles", params) + result = self._conn.execute(cmd) + return result - def release_actions(self, context: str) -> None: - """Releases all input state for the given context. + def add_file_dialog_handler(self, callback) -> int: + """Subscribe to the input.fileDialogOpened event. Args: - context: The browsing context ID to release actions for. + callback: Callable invoked with a FileDialogInfo when a file dialog opens. + + Returns: + A handler ID that can be passed to remove_file_dialog_handler. """ - params = {"context": context} - self.conn.execute(command_builder("input.releaseActions", params)) + return self._event_manager.add_event_handler("file_dialog_opened", callback) - def set_files(self, context: str, element: dict, files: list[str]) -> None: - """Sets files for a file input element. + def remove_file_dialog_handler(self, handler_id: int) -> None: + """Unsubscribe a previously registered file dialog event handler. Args: - context: The browsing context ID. - element: The element reference (script.SharedReference). - files: A list of file paths to set. + handler_id: The handler ID returned by add_file_dialog_handler. """ - params = {"context": context, "element": element, "files": files} - self.conn.execute(command_builder("input.setFiles", params)) + return self._event_manager.remove_event_handler("file_dialog_opened", handler_id) - def add_file_dialog_handler(self, handler) -> int: - """Add a handler for file dialog opened events. + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + """Add an event handler. Args: - handler: Callback function that takes a FileDialogInfo object. + event: The event to subscribe to. + callback: The callback function to execute on event. + contexts: The context IDs to subscribe to (optional). Returns: - int: Callback ID for removing the handler later. + The callback ID. """ - # Subscribe to the event if not already subscribed - if FileDialogOpened.event_class not in self.subscriptions: - session = Session(self.conn) - self.conn.execute(session.subscribe(FileDialogOpened.event_class)) - self.subscriptions[FileDialogOpened.event_class] = [] - - # Add callback - the callback receives the parsed FileDialogInfo directly - callback_id = self.conn.add_callback(FileDialogOpened, handler) + return self._event_manager.add_event_handler(event, callback, contexts) - self.subscriptions[FileDialogOpened.event_class].append(callback_id) - self.callbacks[callback_id] = handler - - return callback_id - - def remove_file_dialog_handler(self, callback_id: int) -> None: - """Remove a file dialog handler. + def remove_event_handler(self, event: str, callback_id: int) -> None: + """Remove an event handler. Args: - callback_id: The callback ID returned by add_file_dialog_handler. + event: The event to unsubscribe from. + callback_id: The callback ID. """ - if callback_id in self.callbacks: - del self.callbacks[callback_id] + return self._event_manager.remove_event_handler(event, callback_id) - if FileDialogOpened.event_class in self.subscriptions: - if callback_id in self.subscriptions[FileDialogOpened.event_class]: - self.subscriptions[FileDialogOpened.event_class].remove(callback_id) + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + return self._event_manager.clear_event_handlers() + +# Event Info Type Aliases +# Event: input.fileDialogOpened +FileDialogOpened = globals().get('FileDialogInfo', dict) # Fallback to dict if type not defined - # If no more callbacks for this event, unsubscribe - if not self.subscriptions[FileDialogOpened.event_class]: - session = Session(self.conn) - self.conn.execute(session.unsubscribe(FileDialogOpened.event_class)) - del self.subscriptions[FileDialogOpened.event_class] - self.conn.remove_callback(FileDialogOpened, callback_id) +# Populate EVENT_CONFIGS with event configuration mappings +_globals = globals() +Input.EVENT_CONFIGS = { + "file_dialog_opened": (EventConfig("file_dialog_opened", "input.fileDialogOpened", _globals.get("FileDialogOpened", dict)) if _globals.get("FileDialogOpened") else EventConfig("file_dialog_opened", "input.fileDialogOpened", dict)), +} diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 575545776bda8..faf6c85ae2b6c 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -1,81 +1,109 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# WebDriver BiDi module: log from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator from dataclasses import dataclass -from typing import Any -class LogEntryAdded: - event_class = "log.entryAdded" +class Level: + """Level.""" + + DEBUG = "debug" + INFO = "info" + WARN = "warn" + ERROR = "error" + + +LogLevel = Level + +@dataclass +class BaseLogEntry: + """BaseLogEntry.""" + + level: Any | None = None + source: Any | None = None + text: Any | None = None + timestamp: Any | None = None + stack_trace: Any | None = None - @classmethod - def from_json(cls, json: dict[str, Any]) -> ConsoleLogEntry | JavaScriptLogEntry | None: - if json["type"] == "console": - return ConsoleLogEntry.from_json(json) - elif json["type"] == "javascript": - return JavaScriptLogEntry.from_json(json) - return None + +@dataclass +class GenericLogEntry: + """GenericLogEntry.""" + + type: str | None = None @dataclass class ConsoleLogEntry: - level: str - text: str - timestamp: str - method: str - args: list[dict[str, Any]] - type_: str + """ConsoleLogEntry - a console log entry from the browser.""" + + type_: str | None = None + method: str | None = None + args: list | None = None + level: Any | None = None + text: Any | None = None + source: Any | None = None + timestamp: Any | None = None + stack_trace: Any | None = None @classmethod - def from_json(cls, json: dict[str, Any]) -> ConsoleLogEntry: + def from_json(cls, params: dict) -> "ConsoleLogEntry": + """Deserialize from BiDi params dict.""" return cls( - level=json["level"], - text=json["text"], - timestamp=json["timestamp"], - method=json["method"], - args=json["args"], - type_=json["type"], + type_=params.get("type"), + method=params.get("method"), + args=params.get("args"), + level=params.get("level"), + text=params.get("text"), + source=params.get("source"), + timestamp=params.get("timestamp"), + stack_trace=params.get("stackTrace"), ) - @dataclass -class JavaScriptLogEntry: - level: str - text: str - timestamp: str - stacktrace: dict[str, Any] - type_: str +class JavascriptLogEntry: + """JavascriptLogEntry - a JavaScript error log entry from the browser.""" + + type_: str | None = None + level: Any | None = None + text: Any | None = None + source: Any | None = None + timestamp: Any | None = None + stacktrace: Any | None = None @classmethod - def from_json(cls, json: dict[str, Any]) -> JavaScriptLogEntry: + def from_json(cls, params: dict) -> "JavascriptLogEntry": + """Deserialize from BiDi params dict.""" return cls( - level=json["level"], - text=json["text"], - timestamp=json["timestamp"], - stacktrace=json["stackTrace"], - type_=json["type"], + type_=params.get("type"), + level=params.get("level"), + text=params.get("text"), + source=params.get("source"), + timestamp=params.get("timestamp"), + stacktrace=params.get("stackTrace"), ) +class Log: + """WebDriver BiDi log module.""" -class LogLevel: - """Represents log level.""" + def __init__(self, conn) -> None: + self._conn = conn + + def entry_added(self): + """Execute log.entryAdded.""" + params = { + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("log.entryAdded", params) + result = self._conn.execute(cmd) + return result - DEBUG = "debug" - INFO = "info" - WARN = "warn" - ERROR = "error" diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 82472838dccde..4f44e309bffbb 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -1,338 +1,923 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - +# WebDriver BiDi module: network from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass +import threading from collections.abc import Callable -from typing import Any +from dataclasses import dataclass +from selenium.webdriver.common.bidi.session import Session -from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.remote.websocket_connection import WebSocketConnection +class SameSite: + """SameSite.""" -class NetworkEvent: - """Represents a network event.""" + STRICT = "strict" + LAX = "lax" + NONE = "none" + DEFAULT = "default" - def __init__(self, event_class: str, **kwargs: Any) -> None: - self.event_class = event_class - self.params = kwargs - @classmethod - def from_json(cls, json: dict[str, Any]) -> NetworkEvent: - return cls(event_class=json.get("event_class", ""), **json) +class DataType: + """DataType.""" + REQUEST = "request" + RESPONSE = "response" -class Network: - EVENTS = { - "before_request": "network.beforeRequestSent", - "response_started": "network.responseStarted", - "response_completed": "network.responseCompleted", - "auth_required": "network.authRequired", - "fetch_error": "network.fetchError", - "continue_request": "network.continueRequest", - "continue_auth": "network.continueWithAuth", - } - - PHASES = { - "before_request": "beforeRequestSent", - "response_started": "responseStarted", - "auth_required": "authRequired", - } - - def __init__(self, conn: WebSocketConnection) -> None: - self.conn = conn - self.intercepts: list[str] = [] - self.callbacks: dict[str | int, Any] = {} - self.subscriptions: dict[str, list[int]] = {} - - def _add_intercept( - self, - phases: list[str] | None = None, - contexts: list[str] | None = None, - url_patterns: list[Any] | None = None, - ) -> dict[str, Any]: - """Add an intercept to the network. + +class InterceptPhase: + """InterceptPhase.""" + + BEFOREREQUESTSENT = "beforeRequestSent" + RESPONSESTARTED = "responseStarted" + AUTHREQUIRED = "authRequired" + + +class ContinueWithAuthNoCredentials: + """ContinueWithAuthNoCredentials.""" + + DEFAULT = "default" + CANCEL = "cancel" + + +@dataclass +class AuthChallenge: + """AuthChallenge.""" + + scheme: str | None = None + realm: str | None = None + + +@dataclass +class AuthCredentials: + """AuthCredentials.""" + + type: str = field(default="password", init=False) + username: str | None = None + password: str | None = None + + +@dataclass +class BaseParameters: + """BaseParameters.""" + + context: Any | None = None + is_blocked: bool | None = None + navigation: Any | None = None + redirect_count: Any | None = None + request: Any | None = None + timestamp: Any | None = None + intercepts: list[Any | None] | None = None + + +@dataclass +class StringValue: + """StringValue.""" + + type: str = field(default="string", init=False) + value: str | None = None + + +@dataclass +class Base64Value: + """Base64Value.""" + + type: str = field(default="base64", init=False) + value: str | None = None + + +@dataclass +class Cookie: + """Cookie.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + path: str | None = None + size: Any | None = None + http_only: bool | None = None + secure: bool | None = None + same_site: Any | None = None + expiry: Any | None = None + + +@dataclass +class CookieHeader: + """CookieHeader.""" + + name: str | None = None + value: Any | None = None + + +@dataclass +class FetchTimingInfo: + """FetchTimingInfo.""" + + time_origin: Any | None = None + request_time: Any | None = None + redirect_start: Any | None = None + redirect_end: Any | None = None + fetch_start: Any | None = None + dns_start: Any | None = None + dns_end: Any | None = None + connect_start: Any | None = None + connect_end: Any | None = None + tls_start: Any | None = None + request_start: Any | None = None + response_start: Any | None = None + response_end: Any | None = None + + +@dataclass +class Header: + """Header.""" + + name: str | None = None + value: Any | None = None + + +@dataclass +class Initiator: + """Initiator.""" + + column_number: Any | None = None + line_number: Any | None = None + request: Any | None = None + stack_trace: Any | None = None + type: Any | None = None + + +@dataclass +class ResponseContent: + """ResponseContent.""" + + size: Any | None = None + + +@dataclass +class ResponseData: + """ResponseData.""" + + url: str | None = None + protocol: str | None = None + status: Any | None = None + status_text: str | None = None + from_cache: bool | None = None + headers: list[Any | None] | None = None + mime_type: str | None = None + bytes_received: Any | None = None + headers_size: Any | None = None + body_size: Any | None = None + content: Any | None = None + auth_challenges: list[Any | None] | None = None + + +@dataclass +class SetCookieHeader: + """SetCookieHeader.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + http_only: bool | None = None + expiry: str | None = None + max_age: Any | None = None + path: str | None = None + same_site: Any | None = None + secure: bool | None = None + + +@dataclass +class UrlPatternPattern: + """UrlPatternPattern.""" + + type: str = field(default="pattern", init=False) + protocol: str | None = None + hostname: str | None = None + port: str | None = None + pathname: str | None = None + search: str | None = None + + +@dataclass +class UrlPatternString: + """UrlPatternString.""" + + type: str = field(default="string", init=False) + pattern: str | None = None + + +@dataclass +class AddDataCollectorParameters: + """AddDataCollectorParameters.""" + + data_types: list[Any | None] | None = None + max_encoded_data_size: Any | None = None + collector_type: Any | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None + + +@dataclass +class AddDataCollectorResult: + """AddDataCollectorResult.""" + + collector: Any | None = None + + +@dataclass +class AddInterceptParameters: + """AddInterceptParameters.""" + + phases: list[Any | None] | None = None + contexts: list[Any | None] | None = None + url_patterns: list[Any | None] | None = None + + +@dataclass +class AddInterceptResult: + """AddInterceptResult.""" + + intercept: Any | None = None + + +@dataclass +class ContinueResponseParameters: + """ContinueResponseParameters.""" + + request: Any | None = None + cookies: list[Any | None] | None = None + credentials: Any | None = None + headers: list[Any | None] | None = None + reason_phrase: str | None = None + status_code: Any | None = None + + +@dataclass +class ContinueWithAuthParameters: + """ContinueWithAuthParameters.""" + + request: Any | None = None + + +@dataclass +class ContinueWithAuthCredentials: + """ContinueWithAuthCredentials.""" + + action: str = field(default="provideCredentials", init=False) + credentials: Any | None = None + + +@dataclass +class disownDataParameters: + """disownDataParameters.""" + + data_type: Any | None = None + collector: Any | None = None + request: Any | None = None + + +@dataclass +class FailRequestParameters: + """FailRequestParameters.""" + + request: Any | None = None + + +@dataclass +class GetDataParameters: + """GetDataParameters.""" + + data_type: Any | None = None + collector: Any | None = None + disown: bool | None = None + request: Any | None = None + + +@dataclass +class GetDataResult: + """GetDataResult.""" + + bytes: Any | None = None + + +@dataclass +class ProvideResponseParameters: + """ProvideResponseParameters.""" + + request: Any | None = None + body: Any | None = None + cookies: list[Any | None] | None = None + headers: list[Any | None] | None = None + reason_phrase: str | None = None + status_code: Any | None = None + + +@dataclass +class RemoveDataCollectorParameters: + """RemoveDataCollectorParameters.""" + + collector: Any | None = None + + +@dataclass +class RemoveInterceptParameters: + """RemoveInterceptParameters.""" + + intercept: Any | None = None + + +@dataclass +class SetCacheBehaviorParameters: + """SetCacheBehaviorParameters.""" + + cache_behavior: Any | None = None + contexts: list[Any | None] | None = None + + +@dataclass +class SetExtraHeadersParameters: + """SetExtraHeadersParameters.""" + + headers: list[Any | None] | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None + + +@dataclass +class ResponseStartedParameters: + """ResponseStartedParameters.""" + + response: Any | None = None + + +class BytesValue: + """A string or base64-encoded bytes value used in cookie operations. + + This corresponds to network.BytesValue in the WebDriver BiDi specification, + wrapping either a plain string or a base64-encoded binary value. + """ + + TYPE_STRING = "string" + TYPE_BASE64 = "base64" + + def __init__(self, type: str, value: str) -> None: + self.type = type + self.value = value + + def to_bidi_dict(self) -> dict: + return {"type": self.type, "value": self.value} + +class Request: + """Wraps a BiDi network request event params and provides request action methods.""" + + def __init__(self, conn, params): + self._conn = conn + self._params = params if isinstance(params, dict) else {} + req = self._params.get("request", {}) or {} + self.url = req.get("url", "") + self._request_id = req.get("request") + + def continue_request(self, **kwargs): + """Continue the intercepted request.""" + from selenium.webdriver.common.bidi.common import command_builder as _cb + + params = {"request": self._request_id} + params.update(kwargs) + self._conn.execute(_cb("network.continueRequest", params)) + +# BiDi Event Name to Parameter Type Mapping +EVENT_NAME_MAPPING = { + "auth_required": "network.authRequired", + "before_request": "network.beforeRequestSent", +} + +@dataclass +class EventConfig: + """Configuration for a BiDi event.""" + event_key: str + bidi_event: str + event_class: type + + +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. Args: - phases: A list of phases to intercept. Default is None (empty list). - contexts: A list of contexts to intercept. Default is None. - url_patterns: A list of URL patterns to intercept. Default is None. + params: Raw BiDi event params with camelCase keys. Returns: - str: intercept id + An instance of the dataclass, or the raw dict on failure. """ - if phases is None: - phases = [] - params = {} - if contexts is not None: - params["contexts"] = contexts - if url_patterns is not None: - params["urlPatterns"] = url_patterns - if len(phases) > 0: - params["phases"] = phases - else: - params["phases"] = ["beforeRequestSent"] + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) + + +class _EventManager: + """Manages event subscriptions and callbacks.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + """Subscribe to a BiDi event if not already subscribed.""" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist.""" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: + session = Session(self.conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event]["callbacks"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() + + + + +class Network: + """WebDriver BiDi network module.""" + + EVENT_CONFIGS = {} + def __init__(self, conn) -> None: + self._conn = conn + self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) + self.intercepts = [] + + def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_data_size: Any | None = None, collector_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute network.addDataCollector.""" + params = { + "dataTypes": data_types, + "maxEncodedDataSize": max_encoded_data_size, + "collectorType": collector_type, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.addDataCollector", params) + result = self._conn.execute(cmd) + return result + + def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | None = None, url_patterns: List[Any] | None = None): + """Execute network.addIntercept.""" + params = { + "phases": phases, + "contexts": contexts, + "urlPatterns": url_patterns, + } + params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("network.addIntercept", params) + result = self._conn.execute(cmd) + return result - result: dict[str, Any] = self.conn.execute(cmd) - self.intercepts.append(result["intercept"]) + def continue_request(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, method: Any | None = None, url: Any | None = None): + """Execute network.continueRequest.""" + params = { + "request": request, + "body": body, + "cookies": cookies, + "headers": headers, + "method": method, + "url": url, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.continueRequest", params) + result = self._conn.execute(cmd) return result - def _remove_intercept(self, intercept: str | None = None) -> None: - """Remove a specific intercept, or all intercepts. + def continue_response(self, request: Any | None = None, cookies: List[Any] | None = None, credentials: Any | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): + """Execute network.continueResponse.""" + params = { + "request": request, + "cookies": cookies, + "credentials": credentials, + "headers": headers, + "reasonPhrase": reason_phrase, + "statusCode": status_code, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.continueResponse", params) + result = self._conn.execute(cmd) + return result - Args: - intercept: The intercept to remove. Default is None. + def continue_with_auth(self, request: Any | None = None): + """Execute network.continueWithAuth.""" + params = { + "request": request, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.continueWithAuth", params) + result = self._conn.execute(cmd) + return result + + def disown_data(self, data_type: Any | None = None, collector: Any | None = None, request: Any | None = None): + """Execute network.disownData.""" + params = { + "dataType": data_type, + "collector": collector, + "request": request, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.disownData", params) + result = self._conn.execute(cmd) + return result - Raises: - ValueError: If intercept is not found. + def fail_request(self, request: Any | None = None): + """Execute network.failRequest.""" + params = { + "request": request, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.failRequest", params) + result = self._conn.execute(cmd) + return result - Note: - If intercept is None, all intercepts will be removed. - """ - if intercept is None: - intercepts_to_remove = self.intercepts.copy() # create a copy before iterating - for intercept_id in intercepts_to_remove: # remove all intercepts - self.conn.execute(command_builder("network.removeIntercept", {"intercept": intercept_id})) - self.intercepts.remove(intercept_id) - else: - try: - self.conn.execute(command_builder("network.removeIntercept", {"intercept": intercept})) - self.intercepts.remove(intercept) - except Exception as e: - raise Exception(f"Exception: {e}") - - def _on_request(self, event_name: str, callback: Callable[[Request], Any]) -> int: - """Set a callback function to subscribe to a network event. + def get_data(self, data_type: Any | None = None, collector: Any | None = None, disown: bool | None = None, request: Any | None = None): + """Execute network.getData.""" + params = { + "dataType": data_type, + "collector": collector, + "disown": disown, + "request": request, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.getData", params) + result = self._conn.execute(cmd) + return result + + def provide_response(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): + """Execute network.provideResponse.""" + params = { + "request": request, + "body": body, + "cookies": cookies, + "headers": headers, + "reasonPhrase": reason_phrase, + "statusCode": status_code, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.provideResponse", params) + result = self._conn.execute(cmd) + return result + + def remove_data_collector(self, collector: Any | None = None): + """Execute network.removeDataCollector.""" + params = { + "collector": collector, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.removeDataCollector", params) + result = self._conn.execute(cmd) + return result + + def remove_intercept(self, intercept: Any | None = None): + """Execute network.removeIntercept.""" + params = { + "intercept": intercept, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.removeIntercept", params) + result = self._conn.execute(cmd) + return result + + def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[Any] | None = None): + """Execute network.setCacheBehavior.""" + params = { + "cacheBehavior": cache_behavior, + "contexts": contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.setCacheBehavior", params) + result = self._conn.execute(cmd) + return result + + def set_extra_headers(self, headers: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute network.setExtraHeaders.""" + params = { + "headers": headers, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.setExtraHeaders", params) + result = self._conn.execute(cmd) + return result + + def before_request_sent(self, initiator: Any | None = None, method: Any | None = None, params: Any | None = None): + """Execute network.beforeRequestSent.""" + params = { + "initiator": initiator, + "method": method, + "params": params, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.beforeRequestSent", params) + result = self._conn.execute(cmd) + return result + + def fetch_error(self, error_text: Any | None = None, method: Any | None = None, params: Any | None = None): + """Execute network.fetchError.""" + params = { + "errorText": error_text, + "method": method, + "params": params, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.fetchError", params) + result = self._conn.execute(cmd) + return result + + def response_completed(self, response: Any | None = None, method: Any | None = None, params: Any | None = None): + """Execute network.responseCompleted.""" + params = { + "response": response, + "method": method, + "params": params, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.responseCompleted", params) + result = self._conn.execute(cmd) + return result + + def response_started(self, response: Any | None = None): + """Execute network.responseStarted.""" + params = { + "response": response, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.responseStarted", params) + result = self._conn.execute(cmd) + return result + + def _add_intercept(self, phases=None, url_patterns=None): + """Add a low-level network intercept. Args: - event_name: The event to subscribe to. - callback: The callback function to execute on event. - Takes Request object as argument. + phases: list of intercept phases (default: ["beforeRequestSent"]) + url_patterns: optional URL patterns to filter Returns: - int: callback id + dict with "intercept" key containing the intercept ID """ - event = NetworkEvent(event_name) - - def _callback(event_data: NetworkEvent) -> None: - request = Request( - network=self, - request_id=event_data.params["request"].get("request", None), - body_size=event_data.params["request"].get("bodySize", None), - cookies=event_data.params["request"].get("cookies", None), - resource_type=event_data.params["request"].get("goog:resourceType", None), - headers=event_data.params["request"].get("headers", None), - headers_size=event_data.params["request"].get("headersSize", None), - timings=event_data.params["request"].get("timings", None), - url=event_data.params["request"].get("url", None), - ) - callback(request) + from selenium.webdriver.common.bidi.common import command_builder as _cb - callback_id: int = self.conn.add_callback(event, _callback) - - if event_name in self.callbacks: - self.callbacks[event_name].append(callback_id) - else: - self.callbacks[event_name] = [callback_id] - - return callback_id + if phases is None: + phases = ["beforeRequestSent"] + params = {"phases": phases} + if url_patterns: + params["urlPatterns"] = url_patterns + result = self._conn.execute(_cb("network.addIntercept", params)) + if result: + intercept_id = result.get("intercept") + if intercept_id and intercept_id not in self.intercepts: + self.intercepts.append(intercept_id) + return result + def _remove_intercept(self, intercept_id): + """Remove a low-level network intercept.""" + from selenium.webdriver.common.bidi.common import command_builder as _cb - def add_request_handler( - self, - event: str, - callback: Callable[[Request], Any], - url_patterns: list[Any] | None = None, - contexts: list[str] | None = None, - ) -> int: - """Add a request handler to the network. + self._conn.execute(_cb("network.removeIntercept", {"intercept": intercept_id})) + if intercept_id in self.intercepts: + self.intercepts.remove(intercept_id) + def add_request_handler(self, event, callback, url_patterns=None): + """Add a handler for network requests at the specified phase. Args: - event: The event to subscribe to. - callback: The callback function to execute on request interception. - Takes Request object as argument. - url_patterns: A list of URL patterns to intercept. Default is None. - contexts: A list of contexts to intercept. Default is None. + event: Event name, e.g. ``"before_request"``. + callback: Callable receiving a :class:`Request` instance. + url_patterns: optional list of URL pattern dicts to filter. Returns: - int: callback id + callback_id int for later removal via remove_request_handler. """ - try: - event_name = self.EVENTS[event] - phase_name = self.PHASES[event] - except KeyError: - raise Exception(f"Event {event} not found") - - result = self._add_intercept(phases=[phase_name], url_patterns=url_patterns, contexts=contexts) - callback_id = self._on_request(event_name, callback) - - if event_name in self.subscriptions: - self.subscriptions[event_name].append(callback_id) - else: - params: dict[str, Any] = {} - params["events"] = [event_name] - self.conn.execute(command_builder("session.subscribe", params)) - self.subscriptions[event_name] = [callback_id] - - self.callbacks[callback_id] = result["intercept"] - return callback_id + phase_map = { + "before_request": "beforeRequestSent", + "before_request_sent": "beforeRequestSent", + "response_started": "responseStarted", + "auth_required": "authRequired", + } + phase = phase_map.get(event, "beforeRequestSent") + self._add_intercept(phases=[phase], url_patterns=url_patterns) + + def _request_callback(params): + raw = ( + params + if isinstance(params, dict) + else (params.__dict__ if hasattr(params, "__dict__") else {}) + ) + request = Request(self._conn, raw) + callback(request) - def remove_request_handler(self, event: str, callback_id: int) -> None: - """Remove a request handler from the network. + return self.add_event_handler(event, _request_callback) + def remove_request_handler(self, event, callback_id): + """Remove a network request handler. Args: - event: The event to unsubscribe from. - callback_id: The callback id to remove. + event: The event name used when adding the handler. + callback_id: The int returned by add_request_handler. """ - try: - event_name = self.EVENTS[event] - except KeyError: - raise Exception(f"Event {event} not found") - - net_event = NetworkEvent(event_name) - - self.conn.remove_callback(net_event, callback_id) - self._remove_intercept(self.callbacks[callback_id]) - del self.callbacks[callback_id] - self.subscriptions[event_name].remove(callback_id) - if len(self.subscriptions[event_name]) == 0: - params: dict[str, Any] = {} - params["events"] = [event_name] - self.conn.execute(command_builder("session.unsubscribe", params)) - del self.subscriptions[event_name] - - def clear_request_handlers(self) -> None: - """Clear all request handlers from the network.""" - for event_name in self.subscriptions: - net_event = NetworkEvent(event_name) - for callback_id in self.subscriptions[event_name]: - self.conn.remove_callback(net_event, callback_id) - self._remove_intercept(self.callbacks[callback_id]) - del self.callbacks[callback_id] - params: dict[str, Any] = {} - params["events"] = [event_name] - self.conn.execute(command_builder("session.unsubscribe", params)) - self.subscriptions = {} - - def add_auth_handler(self, username: str, password: str) -> int: - """Add an authentication handler to the network. + self.remove_event_handler(event, callback_id) + def clear_request_handlers(self): + """Clear all request handlers and remove all tracked intercepts.""" + self.clear_event_handlers() + for intercept_id in list(self.intercepts): + self._remove_intercept(intercept_id) + def add_auth_handler(self, username, password): + """Add an auth handler that automatically provides credentials. Args: - username: The username to authenticate with. - password: The password to authenticate with. + username: The username for basic authentication. + password: The password for basic authentication. Returns: - int: callback id + callback_id int for later removal via remove_auth_handler. """ - event = "auth_required" - - def _callback(request: Request) -> None: - request._continue_with_auth(username, password) + from selenium.webdriver.common.bidi.common import command_builder as _cb - return self.add_request_handler(event, _callback) - - def remove_auth_handler(self, callback_id: int) -> None: - """Remove an authentication handler from the network. + def _auth_callback(params): + raw = ( + params + if isinstance(params, dict) + else (params.__dict__ if hasattr(params, "__dict__") else {}) + ) + request_id = ( + raw.get("request", {}).get("request") + if isinstance(raw, dict) + else None + ) + if request_id: + self._conn.execute( + _cb( + "network.continueWithAuth", + { + "request": request_id, + "action": "provideCredentials", + "credentials": { + "type": "password", + "username": username, + "password": password, + }, + }, + ) + ) + + return self.add_event_handler("auth_required", _auth_callback) + def remove_auth_handler(self, callback_id): + """Remove an auth handler by callback ID.""" + self.remove_event_handler("auth_required", callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + """Add an event handler. Args: - callback_id: The callback id to remove. + event: The event to subscribe to. + callback: The callback function to execute on event. + contexts: The context IDs to subscribe to (optional). + + Returns: + The callback ID. """ - event = "auth_required" - self.remove_request_handler(event, callback_id) + return self._event_manager.add_event_handler(event, callback, contexts) - -class Request: - """Represents an intercepted network request.""" - - def __init__( - self, - network: Network, - request_id: Any, - body_size: int | None = None, - cookies: Any = None, - resource_type: str | None = None, - headers: Any = None, - headers_size: int | None = None, - method: str | None = None, - timings: Any = None, - url: str | None = None, - ) -> None: - self.network = network - self.request_id = request_id - self.body_size = body_size - self.cookies = cookies - self.resource_type = resource_type - self.headers = headers - self.headers_size = headers_size - self.method = method - self.timings = timings - self.url = url - - def fail_request(self) -> None: - """Fail this request.""" - if not self.request_id: - raise ValueError("Request not found.") - - params: dict[str, Any] = {"request": self.request_id} - self.network.conn.execute(command_builder("network.failRequest", params)) - - def continue_request( - self, - body: Any = None, - method: str | None = None, - headers: Any = None, - cookies: Any = None, - url: str | None = None, - ) -> None: - """Continue after intercepting this request.""" - if not self.request_id: - raise ValueError("Request not found.") - - params: dict[str, Any] = {"request": self.request_id} - if body is not None: - params["body"] = body - if method is not None: - params["method"] = method - if headers is not None: - params["headers"] = headers - if cookies is not None: - params["cookies"] = cookies - if url is not None: - params["url"] = url - - self.network.conn.execute(command_builder("network.continueRequest", params)) - - def _continue_with_auth(self, username: str | None = None, password: str | None = None) -> None: - """Continue with authentication. + def remove_event_handler(self, event: str, callback_id: int) -> None: + """Remove an event handler. Args: - username: The username to authenticate with. - password: The password to authenticate with. - - Note: - If username or password is None, it attempts auth with no credentials. + event: The event to unsubscribe from. + callback_id: The callback ID. """ - params: dict[str, Any] = {} - params["request"] = self.request_id + return self._event_manager.remove_event_handler(event, callback_id) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + return self._event_manager.clear_event_handlers() + +# Event Info Type Aliases +# Event: network.authRequired +AuthRequired = globals().get('AuthRequiredParameters', dict) # Fallback to dict if type not defined - if not username or not password: # no credentials is valid option - params["action"] = "default" - else: - params["action"] = "provideCredentials" - params["credentials"] = {"type": "password", "username": username, "password": password} - self.network.conn.execute(command_builder("network.continueWithAuth", params)) +# Populate EVENT_CONFIGS with event configuration mappings +_globals = globals() +Network.EVENT_CONFIGS = { + "auth_required": (EventConfig("auth_required", "network.authRequired", _globals.get("AuthRequired", dict)) if _globals.get("AuthRequired") else EventConfig("auth_required", "network.authRequired", dict)), + "before_request": EventConfig("before_request", "network.beforeRequestSent", _globals.get("dict", dict)), +} diff --git a/py/selenium/webdriver/common/bidi/permissions.py b/py/selenium/webdriver/common/bidi/permissions.py index 17faa1ff5454f..f00e765c62e3b 100644 --- a/py/selenium/webdriver/common/bidi/permissions.py +++ b/py/selenium/webdriver/common/bidi/permissions.py @@ -15,12 +15,20 @@ # specific language governing permissions and limitations # under the License. +"""WebDriver BiDi Permissions module.""" -from selenium.webdriver.common.bidi.common import command_builder +from __future__ import annotations +from enum import Enum +from typing import Any, Optional, Union -class PermissionState: - """Represents the possible permission states.""" +from .common import command_builder + +_VALID_PERMISSION_STATES = {"granted", "denied", "prompt"} + + +class PermissionState(str, Enum): + """Permission state enumeration.""" GRANTED = "granted" DENIED = "denied" @@ -28,56 +36,69 @@ class PermissionState: class PermissionDescriptor: - """Represents a permission descriptor.""" + """Descriptor for a permission.""" - def __init__(self, name: str): + def __init__(self, name: str) -> None: + """Initialize a PermissionDescriptor. + + Args: + name: The name of the permission (e.g., 'geolocation', 'microphone', 'camera') + """ self.name = name - def to_dict(self) -> dict: - return {"name": self.name} + def __repr__(self) -> str: + return f"PermissionDescriptor('{self.name}')" class Permissions: - """BiDi implementation of the permissions module.""" + """WebDriver BiDi Permissions module.""" - def __init__(self, conn): - self.conn = conn + def __init__(self, websocket_connection: Any) -> None: + """Initialize the Permissions module. + + Args: + websocket_connection: The WebSocket connection for sending BiDi commands + """ + self._conn = websocket_connection def set_permission( self, - descriptor: str | PermissionDescriptor, - state: str, - origin: str, - user_context: str | None = None, + descriptor: Union[PermissionDescriptor, str], + state: Union[PermissionState, str], + origin: Optional[str] = None, + user_context: Optional[str] = None, ) -> None: - """Sets a permission state for a given permission descriptor. + """Set a permission for a given origin. Args: - descriptor: The permission name (str) or PermissionDescriptor object. - Examples: "geolocation", "camera", "microphone". - state: The permission state (granted, denied, prompt). - origin: The origin for which the permission is set. - user_context: The user context id (optional). + descriptor: The permission descriptor or permission name as a string + state: The desired permission state + origin: The origin for which to set the permission + user_context: Optional user context ID to scope the permission Raises: - ValueError: If the permission state is invalid. + ValueError: If the state is not a valid permission state """ - if state not in [PermissionState.GRANTED, PermissionState.DENIED, PermissionState.PROMPT]: - valid_states = f"{PermissionState.GRANTED}, {PermissionState.DENIED}, {PermissionState.PROMPT}" - raise ValueError(f"Invalid permission state. Must be one of: {valid_states}") + state_value = state.value if isinstance(state, PermissionState) else state + if state_value not in _VALID_PERMISSION_STATES: + raise ValueError( + f"Invalid permission state: {state_value!r}. " + f"Must be one of {sorted(_VALID_PERMISSION_STATES)}" + ) if isinstance(descriptor, str): - permission_descriptor = PermissionDescriptor(descriptor) + descriptor_dict = {"name": descriptor} else: - permission_descriptor = descriptor + descriptor_dict = {"name": descriptor.name} - params = { - "descriptor": permission_descriptor.to_dict(), - "state": state, - "origin": origin, + params: dict[str, Any] = { + "descriptor": descriptor_dict, + "state": state_value, } - + if origin is not None: + params["origin"] = origin if user_context is not None: params["userContext"] = user_context - self.conn.execute(command_builder("permissions.setPermission", params)) + cmd = command_builder("permissions.setPermission", params) + self._conn.execute(cmd) diff --git a/py/selenium/webdriver/common/bidi/py.typed b/py/selenium/webdriver/common/bidi/py.typed new file mode 100755 index 0000000000000..e69de29bb2d1d diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index e37b3269a4ade..e13c11f71a5cb 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -1,40 +1,33 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import datetime -import math -from dataclasses import dataclass -from typing import Any +# WebDriver BiDi module: script +from __future__ import annotations -from selenium.common.exceptions import WebDriverException -from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi.log import LogEntryAdded +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass +import threading +from collections.abc import Callable +from dataclasses import dataclass from selenium.webdriver.common.bidi.session import Session -class ResultOwnership: - """Represents the possible result ownership types.""" +class SpecialNumber: + """SpecialNumber.""" - NONE = "none" - ROOT = "root" + NAN = "NaN" + _0 = "-0" + INFINITY = "Infinity" + _INFINITY = "-Infinity" class RealmType: - """Represents the possible realm types.""" + """RealmType.""" WINDOW = "window" DEDICATED_WORKER = "dedicated-worker" @@ -46,502 +39,1224 @@ class RealmType: WORKLET = "worklet" +class ResultOwnership: + """ResultOwnership.""" + + ROOT = "root" + NONE = "none" + + +@dataclass +class ChannelValue: + """ChannelValue.""" + + type: str = field(default="channel", init=False) + value: Any | None = None + + +@dataclass +class ChannelProperties: + """ChannelProperties.""" + + channel: Any | None = None + serialization_options: Any | None = None + ownership: Any | None = None + + +@dataclass +class EvaluateResultSuccess: + """EvaluateResultSuccess.""" + + type: str = field(default="success", init=False) + result: Any | None = None + realm: Any | None = None + + +@dataclass +class EvaluateResultException: + """EvaluateResultException.""" + + type: str = field(default="exception", init=False) + exception_details: Any | None = None + realm: Any | None = None + + +@dataclass +class ExceptionDetails: + """ExceptionDetails.""" + + column_number: Any | None = None + exception: Any | None = None + line_number: Any | None = None + stack_trace: Any | None = None + text: str | None = None + + +@dataclass +class ArrayLocalValue: + """ArrayLocalValue.""" + + type: str = field(default="array", init=False) + value: Any | None = None + + +@dataclass +class DateLocalValue: + """DateLocalValue.""" + + type: str = field(default="date", init=False) + value: str | None = None + + +@dataclass +class MapLocalValue: + """MapLocalValue.""" + + type: str = field(default="map", init=False) + value: Any | None = None + + +@dataclass +class ObjectLocalValue: + """ObjectLocalValue.""" + + type: str = field(default="object", init=False) + value: Any | None = None + + +@dataclass +class RegExpValue: + """RegExpValue.""" + + pattern: str | None = None + flags: str | None = None + + +@dataclass +class RegExpLocalValue: + """RegExpLocalValue.""" + + type: str = field(default="regexp", init=False) + value: Any | None = None + + +@dataclass +class SetLocalValue: + """SetLocalValue.""" + + type: str = field(default="set", init=False) + value: Any | None = None + + @dataclass -class RealmInfo: - """Represents information about a realm.""" +class UndefinedValue: + """UndefinedValue.""" + + type: str = field(default="undefined", init=False) + + +@dataclass +class NullValue: + """NullValue.""" + + type: str = field(default="null", init=False) + + +@dataclass +class StringValue: + """StringValue.""" + + type: str = field(default="string", init=False) + value: str | None = None + + +@dataclass +class NumberValue: + """NumberValue.""" + + type: str = field(default="number", init=False) + value: Any | None = None + + +@dataclass +class BooleanValue: + """BooleanValue.""" + + type: str = field(default="boolean", init=False) + value: bool | None = None + + +@dataclass +class BigIntValue: + """BigIntValue.""" + + type: str = field(default="bigint", init=False) + value: str | None = None + + +@dataclass +class BaseRealmInfo: + """BaseRealmInfo.""" + + realm: Any | None = None + origin: str | None = None + - realm: str - origin: str - type: str - context: str | None = None +@dataclass +class WindowRealmInfo: + """WindowRealmInfo.""" + + type: str = field(default="window", init=False) + context: Any | None = None sandbox: str | None = None - @classmethod - def from_json(cls, json: dict[str, Any]) -> "RealmInfo": - """Creates a RealmInfo instance from a dictionary. - Args: - json: A dictionary containing the realm information. +@dataclass +class DedicatedWorkerRealmInfo: + """DedicatedWorkerRealmInfo.""" - Returns: - RealmInfo: A new instance of RealmInfo. - """ - if "realm" not in json: - raise ValueError("Missing required field 'realm' in RealmInfo") - if "origin" not in json: - raise ValueError("Missing required field 'origin' in RealmInfo") - if "type" not in json: - raise ValueError("Missing required field 'type' in RealmInfo") - - return cls( - realm=json["realm"], - origin=json["origin"], - type=json["type"], - context=json.get("context"), - sandbox=json.get("sandbox"), - ) + type: str = field(default="dedicated-worker", init=False) + owners: list[Any | None] | None = None @dataclass -class Source: - """Represents the source of a script message.""" +class SharedWorkerRealmInfo: + """SharedWorkerRealmInfo.""" - realm: str - context: str | None = None + type: str = field(default="shared-worker", init=False) - @classmethod - def from_json(cls, json: dict[str, Any]) -> "Source": - """Creates a Source instance from a dictionary. - Args: - json: A dictionary containing the source information. +@dataclass +class ServiceWorkerRealmInfo: + """ServiceWorkerRealmInfo.""" - Returns: - Source: A new instance of Source. - """ - if "realm" not in json: - raise ValueError("Missing required field 'realm' in Source") + type: str = field(default="service-worker", init=False) - return cls( - realm=json["realm"], - context=json.get("context"), - ) + +@dataclass +class WorkerRealmInfo: + """WorkerRealmInfo.""" + + type: str = field(default="worker", init=False) @dataclass -class EvaluateResult: - """Represents the result of script evaluation.""" +class PaintWorkletRealmInfo: + """PaintWorkletRealmInfo.""" - type: str - realm: str - result: dict | None = None - exception_details: dict | None = None + type: str = field(default="paint-worklet", init=False) - @classmethod - def from_json(cls, json: dict[str, Any]) -> "EvaluateResult": - """Creates an EvaluateResult instance from a dictionary. - Args: - json: A dictionary containing the evaluation result. +@dataclass +class AudioWorkletRealmInfo: + """AudioWorkletRealmInfo.""" - Returns: - EvaluateResult: A new instance of EvaluateResult. - """ - if "realm" not in json: - raise ValueError("Missing required field 'realm' in EvaluateResult") - if "type" not in json: - raise ValueError("Missing required field 'type' in EvaluateResult") - - return cls( - type=json["type"], - realm=json["realm"], - result=json.get("result"), - exception_details=json.get("exceptionDetails"), - ) + type: str = field(default="audio-worklet", init=False) -class ScriptMessage: - """Represents a script message event.""" +@dataclass +class WorkletRealmInfo: + """WorkletRealmInfo.""" - event_class = "script.message" + type: str = field(default="worklet", init=False) - def __init__(self, channel: str, data: dict, source: Source): - self.channel = channel - self.data = data - self.source = source - @classmethod - def from_json(cls, json: dict[str, Any]) -> "ScriptMessage": - """Creates a ScriptMessage instance from a dictionary. +@dataclass +class SharedReference: + """SharedReference.""" - Args: - json: A dictionary containing the script message. + shared_id: Any | None = None + handle: Any | None = None - Returns: - ScriptMessage: A new instance of ScriptMessage. - """ - if "channel" not in json: - raise ValueError("Missing required field 'channel' in ScriptMessage") - if "data" not in json: - raise ValueError("Missing required field 'data' in ScriptMessage") - if "source" not in json: - raise ValueError("Missing required field 'source' in ScriptMessage") - - return cls( - channel=json["channel"], - data=json["data"], - source=Source.from_json(json["source"]), - ) +@dataclass +class RemoteObjectReference: + """RemoteObjectReference.""" -class RealmCreated: - """Represents a realm created event.""" + handle: Any | None = None + shared_id: Any | None = None - event_class = "script.realmCreated" - def __init__(self, realm_info: RealmInfo): - self.realm_info = realm_info +@dataclass +class SymbolRemoteValue: + """SymbolRemoteValue.""" - @classmethod - def from_json(cls, json: dict[str, Any]) -> "RealmCreated": - """Creates a RealmCreated instance from a dictionary. + type: str = field(default="symbol", init=False) + handle: Any | None = None + internal_id: Any | None = None - Args: - json: A dictionary containing the realm created event. - Returns: - RealmCreated: A new instance of RealmCreated. - """ - return cls(realm_info=RealmInfo.from_json(json)) +@dataclass +class ArrayRemoteValue: + """ArrayRemoteValue.""" + type: str = field(default="array", init=False) + handle: Any | None = None + internal_id: Any | None = None + value: Any | None = None -class RealmDestroyed: - """Represents a realm destroyed event.""" - event_class = "script.realmDestroyed" +@dataclass +class ObjectRemoteValue: + """ObjectRemoteValue.""" - def __init__(self, realm: str): - self.realm = realm + type: str = field(default="object", init=False) + handle: Any | None = None + internal_id: Any | None = None + value: Any | None = None - @classmethod - def from_json(cls, json: dict[str, Any]) -> "RealmDestroyed": - """Creates a RealmDestroyed instance from a dictionary. - Args: - json: A dictionary containing the realm destroyed event. +@dataclass +class FunctionRemoteValue: + """FunctionRemoteValue.""" - Returns: - RealmDestroyed: A new instance of RealmDestroyed. - """ - if "realm" not in json: - raise ValueError("Missing required field 'realm' in RealmDestroyed") + type: str = field(default="function", init=False) + handle: Any | None = None + internal_id: Any | None = None - return cls(realm=json["realm"]) +@dataclass +class RegExpRemoteValue: + """RegExpRemoteValue.""" -class Script: - """BiDi implementation of the script module.""" + handle: Any | None = None + internal_id: Any | None = None - EVENTS = { - "message": "script.message", - "realm_created": "script.realmCreated", - "realm_destroyed": "script.realmDestroyed", - } - def __init__(self, conn, driver=None): - self.conn = conn - self.driver = driver - self.log_entry_subscribed = False - self.subscriptions = {} - self.callbacks = {} +@dataclass +class DateRemoteValue: + """DateRemoteValue.""" + + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class MapRemoteValue: + """MapRemoteValue.""" + + type: str = field(default="map", init=False) + handle: Any | None = None + internal_id: Any | None = None + value: Any | None = None + + +@dataclass +class SetRemoteValue: + """SetRemoteValue.""" + + type: str = field(default="set", init=False) + handle: Any | None = None + internal_id: Any | None = None + value: Any | None = None + + +@dataclass +class WeakMapRemoteValue: + """WeakMapRemoteValue.""" + + type: str = field(default="weakmap", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class WeakSetRemoteValue: + """WeakSetRemoteValue.""" + + type: str = field(default="weakset", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class GeneratorRemoteValue: + """GeneratorRemoteValue.""" + + type: str = field(default="generator", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class ErrorRemoteValue: + """ErrorRemoteValue.""" + + type: str = field(default="error", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class ProxyRemoteValue: + """ProxyRemoteValue.""" + + type: str = field(default="proxy", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class PromiseRemoteValue: + """PromiseRemoteValue.""" + + type: str = field(default="promise", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class TypedArrayRemoteValue: + """TypedArrayRemoteValue.""" + + type: str = field(default="typedarray", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class ArrayBufferRemoteValue: + """ArrayBufferRemoteValue.""" + + type: str = field(default="arraybuffer", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class NodeListRemoteValue: + """NodeListRemoteValue.""" + + type: str = field(default="nodelist", init=False) + handle: Any | None = None + internal_id: Any | None = None + value: Any | None = None + + +@dataclass +class HTMLCollectionRemoteValue: + """HTMLCollectionRemoteValue.""" + + type: str = field(default="htmlcollection", init=False) + handle: Any | None = None + internal_id: Any | None = None + value: Any | None = None + + +@dataclass +class NodeRemoteValue: + """NodeRemoteValue.""" + + type: str = field(default="node", init=False) + shared_id: Any | None = None + handle: Any | None = None + internal_id: Any | None = None + value: Any | None = None + + +@dataclass +class NodeProperties: + """NodeProperties.""" + + node_type: Any | None = None + child_node_count: Any | None = None + children: list[Any | None] | None = None + local_name: str | None = None + mode: Any | None = None + namespace_uri: str | None = None + node_value: str | None = None + shadow_root: Any | None = None + + +@dataclass +class WindowProxyRemoteValue: + """WindowProxyRemoteValue.""" + + type: str = field(default="window", init=False) + value: Any | None = None + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class WindowProxyProperties: + """WindowProxyProperties.""" + + context: Any | None = None + + +@dataclass +class StackFrame: + """StackFrame.""" + + column_number: Any | None = None + function_name: str | None = None + line_number: Any | None = None + url: str | None = None + + +@dataclass +class StackTrace: + """StackTrace.""" + + call_frames: list[Any | None] | None = None + + +@dataclass +class Source: + """Source.""" + + realm: Any | None = None + context: Any | None = None + + +@dataclass +class RealmTarget: + """RealmTarget.""" + + realm: Any | None = None + + +@dataclass +class ContextTarget: + """ContextTarget.""" + + context: Any | None = None + sandbox: str | None = None + + +@dataclass +class AddPreloadScriptParameters: + """AddPreloadScriptParameters.""" + + function_declaration: str | None = None + arguments: list[Any | None] | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None + sandbox: str | None = None + + +@dataclass +class AddPreloadScriptResult: + """AddPreloadScriptResult.""" + + script: Any | None = None + - # High-level APIs for SCRIPT module +@dataclass +class DisownParameters: + """DisownParameters.""" + + handles: list[Any | None] | None = None + target: Any | None = None + + +@dataclass +class CallFunctionParameters: + """CallFunctionParameters.""" + + function_declaration: str | None = None + await_promise: bool | None = None + target: Any | None = None + arguments: list[Any | None] | None = None + result_ownership: Any | None = None + serialization_options: Any | None = None + this: Any | None = None + user_activation: bool | None = None - def add_console_message_handler(self, handler): - self._subscribe_to_log_entries() - return self.conn.add_callback(LogEntryAdded, self._handle_log_entry("console", handler)) - def add_javascript_error_handler(self, handler): - self._subscribe_to_log_entries() - return self.conn.add_callback(LogEntryAdded, self._handle_log_entry("javascript", handler)) +@dataclass +class EvaluateParameters: + """EvaluateParameters.""" + + expression: str | None = None + target: Any | None = None + await_promise: bool | None = None + result_ownership: Any | None = None + serialization_options: Any | None = None + user_activation: bool | None = None + + +@dataclass +class GetRealmsParameters: + """GetRealmsParameters.""" + + context: Any | None = None + type: Any | None = None + + +@dataclass +class GetRealmsResult: + """GetRealmsResult.""" + + realms: list[Any | None] | None = None + + +@dataclass +class RemovePreloadScriptParameters: + """RemovePreloadScriptParameters.""" + + script: Any | None = None + + +@dataclass +class MessageParameters: + """MessageParameters.""" + + channel: Any | None = None + data: Any | None = None + source: Any | None = None + + +@dataclass +class RealmDestroyedParameters: + """RealmDestroyedParameters.""" + + realm: Any | None = None + + +# BiDi Event Name to Parameter Type Mapping +EVENT_NAME_MAPPING = { + "realm_created": "script.realmCreated", + "realm_destroyed": "script.realmDestroyed", +} + +@dataclass +class EventConfig: + """Configuration for a BiDi event.""" + event_key: str + bidi_event: str + event_class: type - def remove_console_message_handler(self, id): - self.conn.remove_callback(LogEntryAdded, id) - self._unsubscribe_from_log_entries() - remove_javascript_error_handler = remove_console_message_handler +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization - def pin(self, script: str) -> str: - """Pins a script to the current browsing context. + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. Args: - script: The script to pin. + params: Raw BiDi event params with camelCase keys. Returns: - str: The ID of the pinned script. + An instance of the dataclass, or the raw dict on failure. """ - return self._add_preload_script(script) + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) + + +class _EventManager: + """Manages event subscriptions and callbacks.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + """Subscribe to a BiDi event if not already subscribed.""" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist.""" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: + session = Session(self.conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event]["callbacks"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() - def unpin(self, script_id: str) -> None: - """Unpins a script from the current browsing context. - Args: - script_id: The ID of the pinned script to unpin. - """ - self._remove_preload_script(script_id) - def execute(self, script: str, *args) -> dict: - """Executes a script in the current browsing context. - Args: - script: The script function to execute. - *args: Arguments to pass to the script function. +class Script: + """WebDriver BiDi script module.""" - Returns: - dict: The result value from the script execution. + EVENT_CONFIGS = {} + def __init__(self, conn, driver=None) -> None: + self._conn = conn + self._driver = driver + self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - Raises: - WebDriverException: If the script execution fails. - """ - if self.driver is None: - raise WebDriverException("Driver reference is required for script execution") - browsing_context_id = self.driver.current_window_handle + def add_preload_script(self, function_declaration: Any | None = None, arguments: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None, sandbox: Any | None = None): + """Execute script.addPreloadScript.""" + params = { + "functionDeclaration": function_declaration, + "arguments": arguments, + "contexts": contexts, + "userContexts": user_contexts, + "sandbox": sandbox, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.addPreloadScript", params) + result = self._conn.execute(cmd) + return result - # Convert arguments to the format expected by BiDi call_function (LocalValue Type) - arguments = [] - for arg in args: - arguments.append(self.__convert_to_local_value(arg)) + def disown(self, handles: List[Any] | None = None, target: Any | None = None): + """Execute script.disown.""" + params = { + "handles": handles, + "target": target, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.disown", params) + result = self._conn.execute(cmd) + return result - target = {"context": browsing_context_id} + def call_function(self, function_declaration: Any | None = None, await_promise: bool | None = None, target: Any | None = None, arguments: List[Any] | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, this: Any | None = None, user_activation: bool | None = None): + """Execute script.callFunction.""" + params = { + "functionDeclaration": function_declaration, + "awaitPromise": await_promise, + "target": target, + "arguments": arguments, + "resultOwnership": result_ownership, + "serializationOptions": serialization_options, + "this": this, + "userActivation": user_activation, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.callFunction", params) + result = self._conn.execute(cmd) + return result - result = self._call_function( - function_declaration=script, await_promise=True, target=target, arguments=arguments if arguments else None - ) + def evaluate(self, expression: Any | None = None, target: Any | None = None, await_promise: bool | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, user_activation: bool | None = None): + """Execute script.evaluate.""" + params = { + "expression": expression, + "target": target, + "awaitPromise": await_promise, + "resultOwnership": result_ownership, + "serializationOptions": serialization_options, + "userActivation": user_activation, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.evaluate", params) + result = self._conn.execute(cmd) + return result - if result.type == "success": - return result.result if result.result is not None else {} - else: - error_message = "Error while executing script" - if result.exception_details: - if "text" in result.exception_details: - error_message += f": {result.exception_details['text']}" - elif "message" in result.exception_details: - error_message += f": {result.exception_details['message']}" - - raise WebDriverException(error_message) - - def __convert_to_local_value(self, value) -> dict: - """Converts a Python value to BiDi LocalValue format.""" - if value is None: - return {"type": "null"} - elif isinstance(value, bool): - return {"type": "boolean", "value": value} - elif isinstance(value, (int, float)): + def get_realms(self, context: Any | None = None, type: Any | None = None): + """Execute script.getRealms.""" + params = { + "context": context, + "type": type, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.getRealms", params) + result = self._conn.execute(cmd) + return result + + def remove_preload_script(self, script: Any | None = None): + """Execute script.removePreloadScript.""" + params = { + "script": script, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.removePreloadScript", params) + result = self._conn.execute(cmd) + return result + + def message(self, channel: Any | None = None, data: Any | None = None, source: Any | None = None): + """Execute script.message.""" + params = { + "channel": channel, + "data": data, + "source": source, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.message", params) + result = self._conn.execute(cmd) + return result + + def execute(self, function_declaration: str, *args, context_id: str | None = None) -> Any: + """Execute a function declaration in the browser context. + + Args: + function_declaration: The function as a string, e.g. ``"() => document.title"``. + *args: Optional Python values to pass as arguments to the function. + Each value is serialised to a BiDi ``LocalValue`` automatically. + Supported types: ``None``, ``bool``, ``int``, ``float`` + (including ``NaN`` and ``Infinity``), ``str``, ``list``, + ``dict``, and ``datetime.datetime``. + context_id: The browsing context ID to run in. Defaults to the + driver's current window handle when a driver was provided. + + Returns: + The inner RemoteValue result dict, or raises WebDriverException on exception. + """ + import math as _math + import datetime as _datetime + from selenium.common.exceptions import WebDriverException as _WebDriverException + + def _serialize_arg(value): + """Serialise a Python value to a BiDi LocalValue dict.""" + if value is None: + return {"type": "null"} + if isinstance(value, bool): + return {"type": "boolean", "value": value} + if isinstance(value, _datetime.datetime): + return {"type": "date", "value": value.isoformat()} if isinstance(value, float): - if math.isnan(value): + if _math.isnan(value): return {"type": "number", "value": "NaN"} - elif math.isinf(value): - if value > 0: - return {"type": "number", "value": "Infinity"} - else: - return {"type": "number", "value": "-Infinity"} - elif value == 0.0 and math.copysign(1.0, value) < 0: - return {"type": "number", "value": "-0"} - - JS_MAX_SAFE_INTEGER = 9007199254740991 - if isinstance(value, int) and (value > JS_MAX_SAFE_INTEGER or value < -JS_MAX_SAFE_INTEGER): - return {"type": "bigint", "value": str(value)} - - return {"type": "number", "value": value} - - elif isinstance(value, str): - return {"type": "string", "value": value} - elif isinstance(value, datetime.datetime): - # Convert Python datetime to JavaScript Date (ISO 8601 format) - return {"type": "date", "value": value.isoformat() + "Z" if value.tzinfo is None else value.isoformat()} - elif isinstance(value, datetime.date): - # Convert Python date to JavaScript Date - dt = datetime.datetime.combine(value, datetime.time.min).replace(tzinfo=datetime.timezone.utc) - return {"type": "date", "value": dt.isoformat()} - elif isinstance(value, set): - return {"type": "set", "value": [self.__convert_to_local_value(item) for item in value]} - elif isinstance(value, (list, tuple)): - return {"type": "array", "value": [self.__convert_to_local_value(item) for item in value]} - elif isinstance(value, dict): - return { - "type": "object", - "value": [ - [self.__convert_to_local_value(k), self.__convert_to_local_value(v)] for k, v in value.items() - ], - } - else: - # For other types, convert to string - return {"type": "string", "value": str(value)} - - # low-level APIs for script module - def _add_preload_script( - self, - function_declaration: str, - arguments: list[dict[str, Any]] | None = None, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - sandbox: str | None = None, - ) -> str: - """Adds a preload script. + if _math.isinf(value): + return {"type": "number", "value": "Infinity" if value > 0 else "-Infinity"} + return {"type": "number", "value": value} + if isinstance(value, int): + _MAX_SAFE_INT = 9007199254740991 + if abs(value) > _MAX_SAFE_INT: + return {"type": "bigint", "value": str(value)} + return {"type": "number", "value": value} + if isinstance(value, str): + return {"type": "string", "value": value} + if isinstance(value, list): + return {"type": "array", "value": [_serialize_arg(v) for v in value]} + if isinstance(value, dict): + return {"type": "object", "value": [[str(k), _serialize_arg(v)] for k, v in value.items()]} + return value + + if context_id is None and self._driver is not None: + try: + context_id = self._driver.current_window_handle + except Exception: + pass + target = {"context": context_id} if context_id else {} + serialized_args = [_serialize_arg(a) for a in args] if args else None + raw = self.call_function( + function_declaration=function_declaration, + await_promise=True, + target=target, + arguments=serialized_args, + ) + if isinstance(raw, dict): + if raw.get("type") == "exception": + exc = raw.get("exceptionDetails", {}) + msg = exc.get("text", str(exc)) if isinstance(exc, dict) else str(exc) + raise _WebDriverException(msg) + if raw.get("type") == "success": + return raw.get("result") + return raw + def _add_preload_script(self, function_declaration, arguments=None, contexts=None, user_contexts=None, sandbox=None): + """Add a preload script with validation. Args: - function_declaration: The function declaration to preload. - arguments: The arguments to pass to the function. - contexts: The browsing context IDs to apply the script to. - user_contexts: The user context IDs to apply the script to. - sandbox: The sandbox name to apply the script to. + function_declaration: The JS function to run on page load. + arguments: Optional list of BiDi arguments. + contexts: Optional list of browsing context IDs. + user_contexts: Optional list of user context IDs. + sandbox: Optional sandbox name. Returns: - str: The preload script ID. + script_id: The ID of the added preload script (str). Raises: - ValueError: If both contexts and user_contexts are provided. + ValueError: If both contexts and user_contexts are specified. """ if contexts is not None and user_contexts is not None: raise ValueError("Cannot specify both contexts and user_contexts") + result = self.add_preload_script( + function_declaration=function_declaration, + arguments=arguments, + contexts=contexts, + user_contexts=user_contexts, + sandbox=sandbox, + ) + if isinstance(result, dict): + return result.get("script") + return result + def _remove_preload_script(self, script_id): + """Remove a preload script by ID. - params: dict[str, Any] = {"functionDeclaration": function_declaration} - - if arguments is not None: - params["arguments"] = arguments - if contexts is not None: - params["contexts"] = contexts - if user_contexts is not None: - params["userContexts"] = user_contexts - if sandbox is not None: - params["sandbox"] = sandbox + Args: + script_id: The ID of the preload script to remove. + """ + return self.remove_preload_script(script=script_id) + def pin(self, function_declaration): + """Pin (add) a preload script that runs on every page load. - result = self.conn.execute(command_builder("script.addPreloadScript", params)) - return result["script"] + Args: + function_declaration: The JS function to execute on page load. - def _remove_preload_script(self, script_id: str) -> None: - """Removes a preload script. + Returns: + script_id: The ID of the pinned script (str). + """ + return self._add_preload_script(function_declaration) + def unpin(self, script_id): + """Unpin (remove) a previously pinned preload script. Args: - script_id: The preload script ID to remove. + script_id: The ID returned by pin(). """ - params = {"script": script_id} - self.conn.execute(command_builder("script.removePreloadScript", params)) + return self._remove_preload_script(script_id=script_id) + def _evaluate(self, expression, target, await_promise, result_ownership=None, serialization_options=None, user_activation=None): + """Evaluate a script expression and return a structured result. - def _disown(self, handles: list[str], target: dict) -> None: - """Disowns the given handles. + Args: + expression: The JavaScript expression to evaluate. + target: A dict like {"context": } or {"realm": }. + await_promise: Whether to await a returned promise. + result_ownership: Optional result ownership setting. + serialization_options: Optional serialization options dict. + user_activation: Optional user activation flag. + + Returns: + An object with .realm, .result (dict or None), and .exception_details (or None). + """ + class _EvalResult: + def __init__(self2, realm, result, exception_details): + self2.realm = realm + self2.result = result + self2.exception_details = exception_details + + raw = self.evaluate( + expression=expression, + target=target, + await_promise=await_promise, + result_ownership=result_ownership, + serialization_options=serialization_options, + user_activation=user_activation, + ) + if isinstance(raw, dict): + realm = raw.get("realm") + if raw.get("type") == "exception": + exc = raw.get("exceptionDetails") + return _EvalResult(realm=realm, result=None, exception_details=exc) + return _EvalResult(realm=realm, result=raw.get("result"), exception_details=None) + return _EvalResult(realm=None, result=raw, exception_details=None) + def _call_function(self, function_declaration, await_promise, target, arguments=None, result_ownership=None, this=None, user_activation=None, serialization_options=None): + """Call a function and return a structured result. Args: - handles: The handles to disown. - target: The target realm or context. + function_declaration: The JS function string. + await_promise: Whether to await the return value. + target: A dict like {"context": }. + arguments: Optional list of BiDi arguments. + result_ownership: Optional result ownership. + this: Optional 'this' binding. + user_activation: Optional user activation flag. + serialization_options: Optional serialization options dict. + + Returns: + An object with .result (dict or None) and .exception_details (or None). """ - params = { - "handles": handles, - "target": target, - } - self.conn.execute(command_builder("script.disown", params)) - - def _call_function( - self, - function_declaration: str, - await_promise: bool, - target: dict, - arguments: list[dict] | None = None, - result_ownership: str | None = None, - serialization_options: dict | None = None, - this: dict | None = None, - user_activation: bool = False, - ) -> EvaluateResult: - """Calls a provided function with given arguments in a given realm. + class _CallResult: + def __init__(self2, result, exception_details): + self2.result = result + self2.exception_details = exception_details + + raw = self.call_function( + function_declaration=function_declaration, + await_promise=await_promise, + target=target, + arguments=arguments, + result_ownership=result_ownership, + this=this, + user_activation=user_activation, + serialization_options=serialization_options, + ) + if isinstance(raw, dict): + if raw.get("type") == "exception": + exc = raw.get("exceptionDetails") + return _CallResult(result=None, exception_details=exc) + if raw.get("type") == "success": + return _CallResult(result=raw.get("result"), exception_details=None) + return _CallResult(result=raw, exception_details=None) + def _get_realms(self, context=None, type=None): + """Get all realms, optionally filtered by context and type. Args: - function_declaration: The function declaration to call. - await_promise: Whether to await promise resolution. - target: The target realm or context. - arguments: The arguments to pass to the function. - result_ownership: The result ownership type. - serialization_options: The serialization options. - this: The 'this' value for the function call. - user_activation: Whether to trigger user activation. + context: Optional browsing context ID to filter by. + type: Optional realm type string to filter by (e.g. RealmType.WINDOW). Returns: - EvaluateResult: The result of the function call. + List of realm info objects with .realm, .origin, .type, .context attributes. """ - params = { - "functionDeclaration": function_declaration, - "awaitPromise": await_promise, - "target": target, - "userActivation": user_activation, - } + class _RealmInfo: + def __init__(self2, realm, origin, type_, context): + self2.realm = realm + self2.origin = origin + self2.type = type_ + self2.context = context + + raw = self.get_realms(context=context, type=type) + realms_list = raw.get("realms", []) if isinstance(raw, dict) else [] + result = [] + for r in realms_list: + if isinstance(r, dict): + result.append(_RealmInfo( + realm=r.get("realm"), + origin=r.get("origin"), + type_=r.get("type"), + context=r.get("context"), + )) + return result + def _disown(self, handles, target): + """Disown handles in a browsing context. - if arguments is not None: - params["arguments"] = arguments - if result_ownership is not None: - params["resultOwnership"] = result_ownership - if serialization_options is not None: - params["serializationOptions"] = serialization_options - if this is not None: - params["this"] = this - - result = self.conn.execute(command_builder("script.callFunction", params)) - return EvaluateResult.from_json(result) - - def _evaluate( - self, - expression: str, - target: dict, - await_promise: bool, - result_ownership: str | None = None, - serialization_options: dict | None = None, - user_activation: bool = False, - ) -> EvaluateResult: - """Evaluates a provided script in a given realm. + Args: + handles: List of handle strings to disown. + target: A dict like {"context": }. + """ + return self.disown(handles=handles, target=target) + def _subscribe_log_entry(self, callback, entry_type_filter=None): + """Subscribe to log.entryAdded BiDi events with optional type filtering.""" + import threading as _threading + from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod + + bidi_event = "log.entryAdded" + + if not hasattr(self, "_log_subscriptions"): + self._log_subscriptions = {} + self._log_lock = _threading.Lock() + + def _deserialize(params): + t = params.get("type") if isinstance(params, dict) else None + if t == "console": + cls = getattr(_log_mod, "ConsoleLogEntry", None) + if cls is not None and hasattr(cls, "from_json"): + try: + return cls.from_json(params) + except Exception: + pass + elif t == "javascript": + cls = getattr(_log_mod, "JavascriptLogEntry", None) + if cls is not None and hasattr(cls, "from_json"): + try: + return cls.from_json(params) + except Exception: + pass + return params + + def _wrapped(raw): + entry = _deserialize(raw) + if entry_type_filter is None: + callback(entry) + else: + t = getattr(entry, "type_", None) or ( + entry.get("type") if isinstance(entry, dict) else None + ) + if t == entry_type_filter: + callback(entry) + + class _BidiRef: + event_class = bidi_event + + def from_json(self2, p): + return p + + _wrapper = _BidiRef() + callback_id = self._conn.add_callback(_wrapper, _wrapped) + with self._log_lock: + if bidi_event not in self._log_subscriptions: + session = _Session(self._conn) + result = session.subscribe([bidi_event]) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self._log_subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + self._log_subscriptions[bidi_event]["callbacks"].append(callback_id) + return callback_id + def _unsubscribe_log_entry(self, callback_id): + """Unsubscribe a log entry callback by ID.""" + from selenium.webdriver.common.bidi.session import Session as _Session + + bidi_event = "log.entryAdded" + if not hasattr(self, "_log_subscriptions"): + return + + class _BidiRef: + event_class = bidi_event + + def from_json(self2, p): + return p + + _wrapper = _BidiRef() + self._conn.remove_callback(_wrapper, callback_id) + with self._log_lock: + entry = self._log_subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + if entry is not None and not entry["callbacks"]: + session = _Session(self._conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self._log_subscriptions[bidi_event] + def add_console_message_handler(self, callback: Callable) -> int: + """Add a handler for console log messages (log.entryAdded type=console). Args: - expression: The script expression to evaluate. - target: The target realm or context. - await_promise: Whether to await promise resolution. - result_ownership: The result ownership type. - serialization_options: The serialization options. - user_activation: Whether to trigger user activation. + callback: Function called with a ConsoleLogEntry on each console message. Returns: - EvaluateResult: The result of the script evaluation. + callback_id for use with remove_console_message_handler. """ - params = { - "expression": expression, - "target": target, - "awaitPromise": await_promise, - "userActivation": user_activation, - } + return self._subscribe_log_entry(callback, entry_type_filter="console") + def remove_console_message_handler(self, callback_id: int) -> None: + """Remove a console message handler by callback ID.""" + self._unsubscribe_log_entry(callback_id) + def add_javascript_error_handler(self, callback: Callable) -> int: + """Add a handler for JavaScript error log messages (log.entryAdded type=javascript). - if result_ownership is not None: - params["resultOwnership"] = result_ownership - if serialization_options is not None: - params["serializationOptions"] = serialization_options + Args: + callback: Function called with a JavascriptLogEntry on each JS error. - result = self.conn.execute(command_builder("script.evaluate", params)) - return EvaluateResult.from_json(result) + Returns: + callback_id for use with remove_javascript_error_handler. + """ + return self._subscribe_log_entry(callback, entry_type_filter="javascript") + def remove_javascript_error_handler(self, callback_id: int) -> None: + """Remove a JavaScript error handler by callback ID.""" + self._unsubscribe_log_entry(callback_id) - def _get_realms( - self, - context: str | None = None, - type: str | None = None, - ) -> list[RealmInfo]: - """Returns a list of all realms, optionally filtered. + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + """Add an event handler. Args: - context: The browsing context ID to filter by. - type: The realm type to filter by. + event: The event to subscribe to. + callback: The callback function to execute on event. + contexts: The context IDs to subscribe to (optional). Returns: - List[RealmInfo]: A list of realm information. + The callback ID. """ - params = {} + return self._event_manager.add_event_handler(event, callback, contexts) - if context is not None: - params["context"] = context - if type is not None: - params["type"] = type + def remove_event_handler(self, event: str, callback_id: int) -> None: + """Remove an event handler. - result = self.conn.execute(command_builder("script.getRealms", params)) - return [RealmInfo.from_json(realm) for realm in result["realms"]] + Args: + event: The event to unsubscribe from. + callback_id: The callback ID. + """ + return self._event_manager.remove_event_handler(event, callback_id) - def _subscribe_to_log_entries(self): - if not self.log_entry_subscribed: - session = Session(self.conn) - self.conn.execute(session.subscribe(LogEntryAdded.event_class)) - self.log_entry_subscribed = True + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + return self._event_manager.clear_event_handlers() - def _unsubscribe_from_log_entries(self): - if self.log_entry_subscribed and LogEntryAdded.event_class not in self.conn.callbacks: - session = Session(self.conn) - self.conn.execute(session.unsubscribe(LogEntryAdded.event_class)) - self.log_entry_subscribed = False +# Event Info Type Aliases +# Event: script.realmCreated +RealmCreated = globals().get('RealmInfo', dict) # Fallback to dict if type not defined + +# Event: script.realmDestroyed +RealmDestroyed = globals().get('RealmDestroyedParameters', dict) # Fallback to dict if type not defined - def _handle_log_entry(self, type, handler): - def _handle_log_entry(log_entry): - if log_entry.type_ == type: - handler(log_entry) - return _handle_log_entry +# Populate EVENT_CONFIGS with event configuration mappings +_globals = globals() +Script.EVENT_CONFIGS = { + "realm_created": (EventConfig("realm_created", "script.realmCreated", _globals.get("RealmCreated", dict)) if _globals.get("RealmCreated") else EventConfig("realm_created", "script.realmCreated", dict)), + "realm_destroyed": (EventConfig("realm_destroyed", "script.realmDestroyed", _globals.get("RealmDestroyed", dict)) if _globals.get("RealmDestroyed") else EventConfig("realm_destroyed", "script.realmDestroyed", dict)), +} diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index 3481c2d77842d..9b1daaae557fa 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -1,134 +1,236 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# WebDriver BiDi module: session +from __future__ import annotations - -from selenium.webdriver.common.bidi.common import command_builder +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass class UserPromptHandlerType: - """Represents the behavior of the user prompt handler.""" + """UserPromptHandlerType.""" ACCEPT = "accept" DISMISS = "dismiss" IGNORE = "ignore" - VALID_TYPES = {ACCEPT, DISMISS, IGNORE} +@dataclass +class CapabilitiesRequest: + """CapabilitiesRequest.""" + + always_match: Any | None = None + first_match: list[Any | None] | None = None + + +@dataclass +class CapabilityRequest: + """CapabilityRequest.""" + + accept_insecure_certs: bool | None = None + browser_name: str | None = None + browser_version: str | None = None + platform_name: str | None = None + proxy: Any | None = None + unhandled_prompt_behavior: Any | None = None + + +@dataclass +class AutodetectProxyConfiguration: + """AutodetectProxyConfiguration.""" + + proxy_type: str = field(default="autodetect", init=False) + + +@dataclass +class DirectProxyConfiguration: + """DirectProxyConfiguration.""" + + proxy_type: str = field(default="direct", init=False) + + +@dataclass +class ManualProxyConfiguration: + """ManualProxyConfiguration.""" + + proxy_type: str = field(default="manual", init=False) + http_proxy: str | None = None + ssl_proxy: str | None = None + no_proxy: list[Any | None] | None = None + + +@dataclass +class SocksProxyConfiguration: + """SocksProxyConfiguration.""" + + socks_proxy: str | None = None + socks_version: Any | None = None + + +@dataclass +class PacProxyConfiguration: + """PacProxyConfiguration.""" + + proxy_type: str = field(default="pac", init=False) + proxy_autoconfig_url: str | None = None + + +@dataclass +class SystemProxyConfiguration: + """SystemProxyConfiguration.""" + + proxy_type: str = field(default="system", init=False) + + +@dataclass +class SubscribeParameters: + """SubscribeParameters.""" + + events: list[str | None] | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None + + +@dataclass +class UnsubscribeByIDRequest: + """UnsubscribeByIDRequest.""" + + subscriptions: list[Any | None] | None = None + + +@dataclass +class UnsubscribeByAttributesRequest: + """UnsubscribeByAttributesRequest.""" + + events: list[str | None] | None = None + + +@dataclass +class StatusResult: + """StatusResult.""" + ready: bool | None = None + message: str | None = None + + +@dataclass +class NewParameters: + """NewParameters.""" + + capabilities: Any | None = None + + +@dataclass +class NewResult: + """NewResult.""" + + session_id: str | None = None + accept_insecure_certs: bool | None = None + browser_name: str | None = None + browser_version: str | None = None + platform_name: str | None = None + set_window_rect: bool | None = None + user_agent: str | None = None + proxy: Any | None = None + unhandled_prompt_behavior: Any | None = None + web_socket_url: str | None = None + + +@dataclass +class SubscribeResult: + """SubscribeResult.""" + + subscription: Any | None = None + + +@dataclass class UserPromptHandler: - """Represents the configuration of the user prompt handler.""" - - def __init__( - self, - alert: str | None = None, - before_unload: str | None = None, - confirm: str | None = None, - default: str | None = None, - file: str | None = None, - prompt: str | None = None, - ): - """Initialize UserPromptHandler. - - Args: - alert: Handler type for alert prompts. - before_unload: Handler type for beforeUnload prompts. - confirm: Handler type for confirm prompts. - default: Default handler type for all prompts. - file: Handler type for file picker prompts. - prompt: Handler type for prompt dialogs. - - Raises: - ValueError: If any handler type is not valid. - """ - for field_name, value in [ - ("alert", alert), - ("before_unload", before_unload), - ("confirm", confirm), - ("default", default), - ("file", file), - ("prompt", prompt), - ]: - if value is not None and value not in UserPromptHandlerType.VALID_TYPES: - raise ValueError( - f"Invalid {field_name} handler type: {value}. Must be one of {UserPromptHandlerType.VALID_TYPES}" - ) - - self.alert = alert - self.before_unload = before_unload - self.confirm = confirm - self.default = default - self.file = file - self.prompt = prompt - - def to_dict(self) -> dict[str, str]: - """Convert the UserPromptHandler to a dictionary for BiDi protocol. - - Returns: - Dictionary representation suitable for BiDi protocol. - """ - field_mapping = { - "alert": "alert", - "before_unload": "beforeUnload", - "confirm": "confirm", - "default": "default", - "file": "file", - "prompt": "prompt", - } + """UserPromptHandler.""" + + alert: Any | None = None + before_unload: Any | None = None + confirm: Any | None = None + default: Any | None = None + file: Any | None = None + prompt: Any | None = None + def to_bidi_dict(self) -> dict: + """Convert to BiDi protocol dict with camelCase keys.""" result = {} - for attr_name, dict_key in field_mapping.items(): - value = getattr(self, attr_name) - if value is not None: - result[dict_key] = value + if self.alert is not None: + result["alert"] = self.alert + if self.before_unload is not None: + result["beforeUnload"] = self.before_unload + if self.confirm is not None: + result["confirm"] = self.confirm + if self.default is not None: + result["default"] = self.default + if self.file is not None: + result["file"] = self.file + if self.prompt is not None: + result["prompt"] = self.prompt return result - class Session: - def __init__(self, conn): - self.conn = conn + """WebDriver BiDi session module.""" + + def __init__(self, conn) -> None: + self._conn = conn - def subscribe(self, *events, browsing_contexts=None): + def status(self): + """Execute session.status.""" params = { - "events": events, } - if browsing_contexts is None: - browsing_contexts = [] - if browsing_contexts: - params["browsingContexts"] = browsing_contexts - return command_builder("session.subscribe", params) + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("session.status", params) + result = self._conn.execute(cmd) + return result - def unsubscribe(self, *events, browsing_contexts=None): + def new(self, capabilities: Any | None = None): + """Execute session.new.""" params = { - "events": events, + "capabilities": capabilities, } - if browsing_contexts is None: - browsing_contexts = [] - if browsing_contexts: - params["browsingContexts"] = browsing_contexts - return command_builder("session.unsubscribe", params) + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("session.new", params) + result = self._conn.execute(cmd) + return result - def status(self): - """The session.status command returns information about the remote end's readiness. + def end(self): + """Execute session.end.""" + params = { + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("session.end", params) + result = self._conn.execute(cmd) + return result + + def subscribe(self, events: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute session.subscribe.""" + params = { + "events": events, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("session.subscribe", params) + result = self._conn.execute(cmd) + return result - Returns information about the remote end's readiness to create new sessions - and may include implementation-specific metadata. + def unsubscribe(self, events: List[Any] | None = None, subscriptions: List[Any] | None = None): + """Execute session.unsubscribe.""" + params = { + "events": events, + "subscriptions": subscriptions, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("session.unsubscribe", params) + result = self._conn.execute(cmd) + return result - Returns: - Dictionary containing the ready state (bool), message (str) and metadata. - """ - cmd = command_builder("session.status", {}) - return self.conn.execute(cmd) diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 992ede07f4100..7e4c9c6dee459 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -1,150 +1,152 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# WebDriver BiDi module: storage from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass -from selenium.webdriver.common.bidi.common import command_builder -if TYPE_CHECKING: - from selenium.webdriver.remote.websocket_connection import WebSocketConnection +@dataclass +class PartitionKey: + """PartitionKey.""" + user_context: str | None = None + source_origin: str | None = None -class SameSite: - """Represents the possible same site values for cookies.""" - STRICT = "strict" - LAX = "lax" - NONE = "none" - DEFAULT = "default" +@dataclass +class GetCookiesParameters: + """GetCookiesParameters.""" + + filter: Any | None = None + partition: Any | None = None + + +@dataclass +class GetCookiesResult: + """GetCookiesResult.""" + + cookies: list[Any | None] | None = None + partition_key: Any | None = None + + +@dataclass +class SetCookieParameters: + """SetCookieParameters.""" + + cookie: Any | None = None + partition: Any | None = None + + +@dataclass +class SetCookieResult: + """SetCookieResult.""" + + partition_key: Any | None = None + + +@dataclass +class DeleteCookiesParameters: + """DeleteCookiesParameters.""" + + filter: Any | None = None + partition: Any | None = None + + +@dataclass +class DeleteCookiesResult: + """DeleteCookiesResult.""" + + partition_key: Any | None = None class BytesValue: - """Represents a bytes value.""" + """A string or base64-encoded bytes value used in cookie operations. + + This corresponds to network.BytesValue in the WebDriver BiDi specification, + wrapping either a plain string or a base64-encoded binary value. + """ - TYPE_BASE64 = "base64" TYPE_STRING = "string" + TYPE_BASE64 = "base64" - def __init__(self, type: str, value: str): + def __init__(self, type: str, value: str) -> None: self.type = type self.value = value - def to_dict(self) -> dict[str, str]: - """Converts the BytesValue to a dictionary. - - Returns: - A dictionary representation of the BytesValue. - """ + def to_bidi_dict(self) -> dict: return {"type": self.type, "value": self.value} +class SameSite: + """SameSite cookie attribute values.""" -class Cookie: - """Represents a cookie.""" - - def __init__( - self, - name: str, - value: BytesValue, - domain: str, - path: str | None = None, - size: int | None = None, - http_only: bool | None = None, - secure: bool | None = None, - same_site: str | None = None, - expiry: int | None = None, - ): - self.name = name - self.value = value - self.domain = domain - self.path = path - self.size = size - self.http_only = http_only - self.secure = secure - self.same_site = same_site - self.expiry = expiry + STRICT = "strict" + LAX = "lax" + NONE = "none" + DEFAULT = "default" + +@dataclass +class StorageCookie: + """A cookie object returned by storage.getCookies.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + path: str | None = None + size: Any | None = None + http_only: bool | None = None + secure: bool | None = None + same_site: Any | None = None + expiry: Any | None = None @classmethod - def from_dict(cls, data: dict[str, Any]) -> Cookie: - """Creates a Cookie instance from a dictionary. - - Args: - data: A dictionary containing the cookie information. - - Returns: - A new instance of Cookie. - """ - # Validation for empty strings - name = data.get("name") - if not name: - raise ValueError("name is required and cannot be empty") - domain = data.get("domain") - if not domain: - raise ValueError("domain is required and cannot be empty") - - value = BytesValue(data.get("value", {}).get("type"), data.get("value", {}).get("value")) + def from_bidi_dict(cls, raw: dict) -> "StorageCookie": + """Deserialize a wire-level cookie dict to a StorageCookie.""" + value_raw = raw.get("value") + if isinstance(value_raw, dict): + value = BytesValue(value_raw.get("type"), value_raw.get("value")) + else: + value = value_raw return cls( - name=str(name), + name=raw.get("name"), value=value, - domain=str(domain), - path=data.get("path"), - size=data.get("size"), - http_only=data.get("httpOnly"), - secure=data.get("secure"), - same_site=data.get("sameSite"), - expiry=data.get("expiry"), + domain=raw.get("domain"), + path=raw.get("path"), + size=raw.get("size"), + http_only=raw.get("httpOnly"), + secure=raw.get("secure"), + same_site=raw.get("sameSite"), + expiry=raw.get("expiry"), ) - +@dataclass class CookieFilter: - """Represents a filter for cookies.""" - - def __init__( - self, - name: str | None = None, - value: BytesValue | None = None, - domain: str | None = None, - path: str | None = None, - size: int | None = None, - http_only: bool | None = None, - secure: bool | None = None, - same_site: str | None = None, - expiry: int | None = None, - ): - self.name = name - self.value = value - self.domain = domain - self.path = path - self.size = size - self.http_only = http_only - self.secure = secure - self.same_site = same_site - self.expiry = expiry - - def to_dict(self) -> dict[str, Any]: - """Converts the CookieFilter to a dictionary. - - Returns: - A dictionary representation of the CookieFilter. - """ - result: dict[str, Any] = {} + """CookieFilter.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + path: str | None = None + size: Any | None = None + http_only: bool | None = None + secure: bool | None = None + same_site: Any | None = None + expiry: Any | None = None + + def to_bidi_dict(self) -> dict: + """Serialize to the BiDi wire-protocol dict.""" + result: dict = {} if self.name is not None: result["name"] = self.name if self.value is not None: - result["value"] = self.value.to_dict() + result["value"] = self.value.to_bidi_dict() if hasattr(self.value, "to_bidi_dict") else self.value if self.domain is not None: result["domain"] = self.domain if self.path is not None: @@ -161,103 +163,28 @@ def to_dict(self) -> dict[str, Any]: result["expiry"] = self.expiry return result - -class PartitionKey: - """Represents a storage partition key.""" - - def __init__(self, user_context: str | None = None, source_origin: str | None = None): - self.user_context = user_context - self.source_origin = source_origin - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> PartitionKey: - """Creates a PartitionKey instance from a dictionary. - - Args: - data: A dictionary containing the partition key information. - - Returns: - A new instance of PartitionKey. - """ - return cls( - user_context=data.get("userContext"), - source_origin=data.get("sourceOrigin"), - ) - - -class BrowsingContextPartitionDescriptor: - """Represents a browsing context partition descriptor.""" - - def __init__(self, context: str): - self.type = "context" - self.context = context - - def to_dict(self) -> dict[str, str]: - """Converts the BrowsingContextPartitionDescriptor to a dictionary. - - Returns: - Dict: A dictionary representation of the BrowsingContextPartitionDescriptor. - """ - return {"type": self.type, "context": self.context} - - -class StorageKeyPartitionDescriptor: - """Represents a storage key partition descriptor.""" - - def __init__(self, user_context: str | None = None, source_origin: str | None = None): - self.type = "storageKey" - self.user_context = user_context - self.source_origin = source_origin - - def to_dict(self) -> dict[str, str]: - """Converts the StorageKeyPartitionDescriptor to a dictionary. - - Returns: - Dict: A dictionary representation of the StorageKeyPartitionDescriptor. - """ - result = {"type": self.type} - if self.user_context is not None: - result["userContext"] = self.user_context - if self.source_origin is not None: - result["sourceOrigin"] = self.source_origin - return result - - +@dataclass class PartialCookie: - """Represents a partial cookie for setting.""" - - def __init__( - self, - name: str, - value: BytesValue, - domain: str, - path: str | None = None, - http_only: bool | None = None, - secure: bool | None = None, - same_site: str | None = None, - expiry: int | None = None, - ): - self.name = name - self.value = value - self.domain = domain - self.path = path - self.http_only = http_only - self.secure = secure - self.same_site = same_site - self.expiry = expiry - - def to_dict(self) -> dict[str, Any]: - """Converts the PartialCookie to a dictionary. - - Returns: - ------- - Dict: A dictionary representation of the PartialCookie. - """ - result: dict[str, Any] = { - "name": self.name, - "value": self.value.to_dict(), - "domain": self.domain, - } + """PartialCookie.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + path: str | None = None + http_only: bool | None = None + secure: bool | None = None + same_site: Any | None = None + expiry: Any | None = None + + def to_bidi_dict(self) -> dict: + """Serialize to the BiDi wire-protocol dict.""" + result: dict = {} + if self.name is not None: + result["name"] = self.name + if self.value is not None: + result["value"] = self.value.to_bidi_dict() if hasattr(self.value, "to_bidi_dict") else self.value + if self.domain is not None: + result["domain"] = self.domain if self.path is not None: result["path"] = self.path if self.http_only is not None: @@ -270,144 +197,132 @@ def to_dict(self) -> dict[str, Any]: result["expiry"] = self.expiry return result +class BrowsingContextPartitionDescriptor: + """BrowsingContextPartitionDescriptor. -class GetCookiesResult: - """Represents the result of a getCookies command.""" - - def __init__(self, cookies: list[Cookie], partition_key: PartitionKey): - self.cookies = cookies - self.partition_key = partition_key - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> GetCookiesResult: - """Creates a GetCookiesResult instance from a dictionary. - - Args: - data: A dictionary containing the get cookies result information. - - Returns: - A new instance of GetCookiesResult. - """ - cookies = [Cookie.from_dict(cookie) for cookie in data.get("cookies", [])] - partition_key = PartitionKey.from_dict(data.get("partitionKey", {})) - return cls(cookies=cookies, partition_key=partition_key) - - -class SetCookieResult: - """Represents the result of a setCookie command.""" - - def __init__(self, partition_key: PartitionKey): - self.partition_key = partition_key - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> SetCookieResult: - """Creates a SetCookieResult instance from a dictionary. - - Args: - data: A dictionary containing the set cookie result information. - - Returns: - A new instance of SetCookieResult. - """ - partition_key = PartitionKey.from_dict(data.get("partitionKey", {})) - return cls(partition_key=partition_key) - - -class DeleteCookiesResult: - """Represents the result of a deleteCookies command.""" + The first positional argument is *context* (a browsing-context ID / window + handle), mirroring how the class is used throughout the test suite: + ``BrowsingContextPartitionDescriptor(driver.current_window_handle)``. + """ - def __init__(self, partition_key: PartitionKey): - self.partition_key = partition_key + def __init__(self, context: Any = None, type: str = "context") -> None: + self.context = context + self.type = type - @classmethod - def from_dict(cls, data: dict[str, Any]) -> DeleteCookiesResult: - """Creates a DeleteCookiesResult instance from a dictionary. + def to_bidi_dict(self) -> dict: + return {"type": "context", "context": self.context} - Args: - data: A dictionary containing the delete cookies result information. +@dataclass +class StorageKeyPartitionDescriptor: + """StorageKeyPartitionDescriptor.""" - Returns: - A new instance of DeleteCookiesResult. - """ - partition_key = PartitionKey.from_dict(data.get("partitionKey", {})) - return cls(partition_key=partition_key) + type: Any | None = "storageKey" + user_context: str | None = None + source_origin: str | None = None + def to_bidi_dict(self) -> dict: + """Serialize to the BiDi wire-protocol dict.""" + result: dict = {"type": "storageKey"} + if self.user_context is not None: + result["userContext"] = self.user_context + if self.source_origin is not None: + result["sourceOrigin"] = self.source_origin + return result class Storage: - """BiDi implementation of the storage module.""" + """WebDriver BiDi storage module.""" - def __init__(self, conn: WebSocketConnection) -> None: - self.conn = conn + def __init__(self, conn) -> None: + self._conn = conn - def get_cookies( - self, - filter: CookieFilter | None = None, - partition: BrowsingContextPartitionDescriptor | StorageKeyPartitionDescriptor | None = None, - ) -> GetCookiesResult: - """Gets cookies matching the specified filter. + def get_cookies(self, filter: Any | None = None, partition: Any | None = None): + """Execute storage.getCookies.""" + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.getCookies", params) + result = self._conn.execute(cmd) + return result - Args: - filter: Optional filter to specify which cookies to retrieve. - partition: Optional partition key to limit the scope of the operation. + def set_cookie(self, cookie: Any | None = None, partition: Any | None = None): + """Execute storage.setCookie.""" + params = { + "cookie": cookie, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.setCookie", params) + result = self._conn.execute(cmd) + return result - Returns: - A GetCookiesResult containing the cookies and partition key. + def delete_cookies(self, filter: Any | None = None, partition: Any | None = None): + """Execute storage.deleteCookies.""" + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.deleteCookies", params) + result = self._conn.execute(cmd) + return result - Example: - result = await storage.get_cookies( - filter=CookieFilter(name="sessionId"), - partition=PartitionKey(...) + def get_cookies(self, filter=None, partition=None): + """Execute storage.getCookies and return a GetCookiesResult.""" + if filter and hasattr(filter, "to_bidi_dict"): + filter = filter.to_bidi_dict() + if partition and hasattr(partition, "to_bidi_dict"): + partition = partition.to_bidi_dict() + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.getCookies", params) + result = self._conn.execute(cmd) + if result and "cookies" in result: + cookies = [ + StorageCookie.from_bidi_dict(c) + for c in result.get("cookies", []) + if isinstance(c, dict) + ] + pk_raw = result.get("partitionKey") + pk = ( + PartitionKey( + user_context=pk_raw.get("userContext"), + source_origin=pk_raw.get("sourceOrigin"), + ) + if isinstance(pk_raw, dict) + else None ) - """ - params = {} - if filter is not None: - params["filter"] = filter.to_dict() - if partition is not None: - params["partition"] = partition.to_dict() - - result = self.conn.execute(command_builder("storage.getCookies", params)) - return GetCookiesResult.from_dict(result) - - def set_cookie( - self, - cookie: PartialCookie, - partition: BrowsingContextPartitionDescriptor | StorageKeyPartitionDescriptor | None = None, - ) -> SetCookieResult: - """Sets a cookie in the browser. - - Args: - cookie: The cookie to set. - partition: Optional partition descriptor. - - Returns: - The result of the set cookie command. - """ - params = {"cookie": cookie.to_dict()} - if partition is not None: - params["partition"] = partition.to_dict() - - result = self.conn.execute(command_builder("storage.setCookie", params)) - return SetCookieResult.from_dict(result) - - def delete_cookies( - self, - filter: CookieFilter | None = None, - partition: BrowsingContextPartitionDescriptor | StorageKeyPartitionDescriptor | None = None, - ) -> DeleteCookiesResult: - """Deletes cookies that match the given parameters. - - Args: - filter: Optional filter to match cookies to delete. - partition: Optional partition descriptor. - - Returns: - The result of the delete cookies command. - """ - params = {} - if filter is not None: - params["filter"] = filter.to_dict() - if partition is not None: - params["partition"] = partition.to_dict() - - result = self.conn.execute(command_builder("storage.deleteCookies", params)) - return DeleteCookiesResult.from_dict(result) + return GetCookiesResult(cookies=cookies, partition_key=pk) + return GetCookiesResult(cookies=[], partition_key=None) + def set_cookie(self, cookie=None, partition=None): + """Execute storage.setCookie.""" + if cookie and hasattr(cookie, "to_bidi_dict"): + cookie = cookie.to_bidi_dict() + if partition and hasattr(partition, "to_bidi_dict"): + partition = partition.to_bidi_dict() + params = { + "cookie": cookie, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.setCookie", params) + result = self._conn.execute(cmd) + return result + def delete_cookies(self, filter=None, partition=None): + """Execute storage.deleteCookies.""" + if filter and hasattr(filter, "to_bidi_dict"): + filter = filter.to_bidi_dict() + if partition and hasattr(partition, "to_bidi_dict"): + partition = partition.to_bidi_dict() + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.deleteCookies", params) + result = self._conn.execute(cmd) + return result diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 7609a04f3b3a4..8a737efeeafde 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -1,78 +1,112 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# WebDriver BiDi module: webExtension +from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass -from selenium.common.exceptions import WebDriverException -from selenium.webdriver.common.bidi.common import command_builder + +@dataclass +class InstallParameters: + """InstallParameters.""" + + extension_data: Any | None = None + + +@dataclass +class ExtensionPath: + """ExtensionPath.""" + + type: str = field(default="path", init=False) + path: str | None = None + + +@dataclass +class ExtensionArchivePath: + """ExtensionArchivePath.""" + + type: str = field(default="archivePath", init=False) + path: str | None = None + + +@dataclass +class ExtensionBase64Encoded: + """ExtensionBase64Encoded.""" + + type: str = field(default="base64", init=False) + value: str | None = None + + +@dataclass +class InstallResult: + """InstallResult.""" + + extension: Any | None = None + + +@dataclass +class UninstallParameters: + """UninstallParameters.""" + + extension: Any | None = None class WebExtension: - """BiDi implementation of the webExtension module.""" + """WebDriver BiDi webExtension module.""" - def __init__(self, conn): - self.conn = conn + def __init__(self, conn) -> None: + self._conn = conn - def install(self, path=None, archive_path=None, base64_value=None) -> dict: - """Installs a web extension in the remote end. + def install(self, path: str | None = None, archive_path: str | None = None, base64_value: str | None = None): + """Install a web extension. - You must provide exactly one of the parameters. + Exactly one of the three keyword arguments must be provided. Args: - path: Path to an extension directory. - archive_path: Path to an extension archive file. - base64_value: Base64 encoded string of the extension archive. + path: Directory path to an unpacked extension (also accepted for + signed ``.xpi`` / ``.crx`` archive files on Firefox). + archive_path: File-system path to a packed extension archive. + base64_value: Base64-encoded extension archive string. Returns: - A dictionary containing the extension ID. - """ - if sum(x is not None for x in (path, archive_path, base64_value)) != 1: - raise ValueError("Exactly one of path, archive_path, or base64_value must be provided") + The raw result dict from the BiDi ``webExtension.install`` command + (contains at least an ``"extension"`` key with the extension ID). + Raises: + ValueError: If more than one, or none, of the arguments is provided. + """ + provided = [k for k, v in {"path": path, "archive_path": archive_path, "base64_value": base64_value}.items() if v is not None] + if len(provided) != 1: + raise ValueError( + f"Exactly one of path, archive_path, or base64_value must be provided; got: {provided}" + ) if path is not None: extension_data = {"type": "path", "path": path} elif archive_path is not None: extension_data = {"type": "archivePath", "path": archive_path} - elif base64_value is not None: + else: extension_data = {"type": "base64", "value": base64_value} - params = {"extensionData": extension_data} - - try: - result = self.conn.execute(command_builder("webExtension.install", params)) - return result - except WebDriverException as e: - if "Method not available" in str(e): - raise WebDriverException( - f"{e!s}. If you are using Chrome or Edge, add '--enable-unsafe-extension-debugging' " - "and '--remote-debugging-pipe' arguments or set options.enable_webextensions = True" - ) from e - raise - - def uninstall(self, extension_id_or_result: str | dict) -> None: - """Uninstalls a web extension from the remote end. + cmd = command_builder("webExtension.install", params) + return self._conn.execute(cmd) + def uninstall(self, extension: Any | None = None): + """Uninstall a web extension. Args: - extension_id_or_result: Either the extension ID as a string or the result dictionary - from a previous install() call containing the extension ID. + extension: Either the extension ID string returned by ``install``, + or the full result dict returned by ``install`` (the + ``"extension"`` value is extracted automatically). """ - if isinstance(extension_id_or_result, dict): - extension_id = extension_id_or_result.get("extension") - else: - extension_id = extension_id_or_result - - params = {"extension": extension_id} - self.conn.execute(command_builder("webExtension.uninstall", params)) + if isinstance(extension, dict): + extension = extension.get("extension") + params = {"extension": extension} + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("webExtension.uninstall", params) + return self._conn.execute(cmd) diff --git a/py/selenium/webdriver/common/by.py b/py/selenium/webdriver/common/by.py index d2a10ac70a7c6..9cc0ac1b1864a 100644 --- a/py/selenium/webdriver/common/by.py +++ b/py/selenium/webdriver/common/by.py @@ -16,9 +16,20 @@ # under the License. """The By implementation.""" +from __future__ import annotations + from typing import Literal -ByType = Literal["id", "xpath", "link text", "partial link text", "name", "tag name", "class name", "css selector"] +ByType = Literal[ + "id", + "xpath", + "link text", + "partial link text", + "name", + "tag name", + "class name", + "css selector", +] class By: diff --git a/py/selenium/webdriver/common/proxy.py b/py/selenium/webdriver/common/proxy.py index 89172d1122c36..28de19afa5742 100644 --- a/py/selenium/webdriver/common/proxy.py +++ b/py/selenium/webdriver/common/proxy.py @@ -17,6 +17,8 @@ """The Proxy implementation.""" +from __future__ import annotations + class ProxyTypeFactory: """Factory for proxy types.""" @@ -33,13 +35,23 @@ class ProxyType: profile preference, 'string' is id of proxy type. """ - DIRECT = ProxyTypeFactory.make(0, "DIRECT") # Direct connection, no proxy (default on Windows). - MANUAL = ProxyTypeFactory.make(1, "MANUAL") # Manual proxy settings (e.g., for httpProxy). + DIRECT = ProxyTypeFactory.make( + 0, "DIRECT" + ) # Direct connection, no proxy (default on Windows). + MANUAL = ProxyTypeFactory.make( + 1, "MANUAL" + ) # Manual proxy settings (e.g., for httpProxy). PAC = ProxyTypeFactory.make(2, "PAC") # Proxy autoconfiguration from URL. RESERVED_1 = ProxyTypeFactory.make(3, "RESERVED1") # Never used. - AUTODETECT = ProxyTypeFactory.make(4, "AUTODETECT") # Proxy autodetection (presumably with WPAD). - SYSTEM = ProxyTypeFactory.make(5, "SYSTEM") # Use system settings (default on Linux). - UNSPECIFIED = ProxyTypeFactory.make(6, "UNSPECIFIED") # Not initialized (for internal use). + AUTODETECT = ProxyTypeFactory.make( + 4, "AUTODETECT" + ) # Proxy autodetection (presumably with WPAD). + SYSTEM = ProxyTypeFactory.make( + 5, "SYSTEM" + ) # Use system settings (default on Linux). + UNSPECIFIED = ProxyTypeFactory.make( + 6, "UNSPECIFIED" + ) # Not initialized (for internal use). @classmethod def load(cls, value): @@ -48,7 +60,11 @@ def load(cls, value): value = str(value).upper() for attr in dir(cls): attr_value = getattr(cls, attr) - if isinstance(attr_value, dict) and "string" in attr_value and attr_value["string"] == value: + if ( + isinstance(attr_value, dict) + and "string" in attr_value + and attr_value["string"] == value + ): return attr_value raise Exception(f"No proxy type is found for {value}") @@ -203,13 +219,17 @@ def to_bidi_dict(self) -> dict: if self.noProxy: # Convert comma-separated string to list if isinstance(self.noProxy, str): - result["noProxy"] = [host.strip() for host in self.noProxy.split(",") if host.strip()] + result["noProxy"] = [ + host.strip() for host in self.noProxy.split(",") if host.strip() + ] elif isinstance(self.noProxy, list): if not all(isinstance(h, str) for h in self.noProxy): raise TypeError("no_proxy list must contain only strings") result["noProxy"] = self.noProxy else: - raise TypeError("no_proxy must be a comma-separated string or a list of strings") + raise TypeError( + "no_proxy must be a comma-separated string or a list of strings" + ) elif proxy_type == "pac": if self.proxyAutoconfigUrl: diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index b804b5a1b9900..dc64d77265b09 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -20,6 +20,7 @@ import base64 import contextlib import copy +import inspect import os import pkgutil import tempfile @@ -114,7 +115,9 @@ def get_remote_connection( client_config: ClientConfig | None = None, ) -> RemoteConnection: if isinstance(command_executor, str): - client_config = client_config or ClientConfig(remote_server_addr=command_executor) + client_config = client_config or ClientConfig( + remote_server_addr=command_executor + ) client_config.remote_server_addr = command_executor command_executor = RemoteConnection(client_config=client_config) @@ -396,9 +399,13 @@ def create_web_element(self, element_id: str) -> WebElement: def _unwrap_value(self, value): if isinstance(value, dict): if "element-6066-11e4-a52e-4f735466cecf" in value: - return self.create_web_element(value["element-6066-11e4-a52e-4f735466cecf"]) + return self.create_web_element( + value["element-6066-11e4-a52e-4f735466cecf"] + ) if "shadow-6066-11e4-a52e-4f735466cecf" in value: - return self._shadowroot_cls(self, value["shadow-6066-11e4-a52e-4f735466cecf"]) + return self._shadowroot_cls( + self, value["shadow-6066-11e4-a52e-4f735466cecf"] + ) for key, val in value.items(): value[key] = self._unwrap_value(val) return value @@ -424,18 +431,29 @@ def execute_cdp_cmd(self, cmd: str, cmd_args: dict): Example: `driver.execute_cdp_cmd("Network.getResponseBody", {"requestId": requestId})` """ - return self.execute("executeCdpCommand", {"cmd": cmd, "params": cmd_args})["value"] + return self.execute("executeCdpCommand", {"cmd": cmd, "params": cmd_args})[ + "value" + ] - def execute(self, driver_command: str, params: dict[str, Any] | None = None) -> dict[str, Any]: + def execute( + self, driver_command: str, params: dict[str, Any] | None = None + ) -> dict[str, Any]: """Sends a command to be executed by a command.CommandExecutor. Args: - driver_command: The name of the command to execute as a string. + driver_command: The name of the command to execute as a string. Can also be a generator + for BiDi protocol commands. params: A dictionary of named parameters to send with the command. Returns: The command's JSON response loaded into a dictionary object. """ + # Handle BiDi generator commands + if inspect.isgenerator(driver_command): + # BiDi command: use WebSocketConnection directly + return self.command_executor.execute(driver_command) + + # Legacy WebDriver command: handle normally params = self._wrap_value(params) if self.session_id: @@ -444,7 +462,9 @@ def execute(self, driver_command: str, params: dict[str, Any] | None = None) -> elif "sessionId" not in params: params["sessionId"] = self.session_id - response = cast(RemoteConnection, self.command_executor).execute(driver_command, params) + response = cast(RemoteConnection, self.command_executor).execute( + driver_command, params + ) if response: self.error_handler.check_response(response) @@ -500,7 +520,9 @@ def unpin(self, script_key: ScriptKey) -> None: try: self.pinned_scripts.pop(script_key.id) except KeyError: - raise KeyError(f"No script with key: {script_key} existed in {self.pinned_scripts}") from None + raise KeyError( + f"No script with key: {script_key} existed in {self.pinned_scripts}" + ) from None def get_pinned_scripts(self) -> list[str]: """Return a list of all pinned scripts. @@ -533,7 +555,9 @@ def execute_script(self, script: str, *args) -> Any: converted_args = list(args) command = Command.W3C_EXECUTE_SCRIPT - return self.execute(command, {"script": script, "args": converted_args})["value"] + return self.execute(command, {"script": script, "args": converted_args})[ + "value" + ] def execute_async_script(self, script: str, *args) -> Any: """Asynchronously Executes JavaScript in the current window/frame. @@ -552,7 +576,9 @@ def execute_async_script(self, script: str, *args) -> Any: converted_args = list(args) command = Command.W3C_EXECUTE_SCRIPT_ASYNC - return self.execute(command, {"script": script, "args": converted_args})["value"] + return self.execute(command, {"script": script, "args": converted_args})[ + "value" + ] @property def current_url(self) -> str: @@ -729,7 +755,9 @@ def implicitly_wait(self, time_to_wait: float) -> None: Example: `driver.implicitly_wait(30)` """ - self.execute(Command.SET_TIMEOUTS, {"implicit": int(float(time_to_wait) * 1000)}) + self.execute( + Command.SET_TIMEOUTS, {"implicit": int(float(time_to_wait) * 1000)} + ) def set_script_timeout(self, time_to_wait: float) -> None: """Set the timeout for asynchronous script execution. @@ -758,9 +786,14 @@ def set_page_load_timeout(self, time_to_wait: float) -> None: `driver.set_page_load_timeout(30)` """ try: - self.execute(Command.SET_TIMEOUTS, {"pageLoad": int(float(time_to_wait) * 1000)}) + self.execute( + Command.SET_TIMEOUTS, {"pageLoad": int(float(time_to_wait) * 1000)} + ) except WebDriverException: - self.execute(Command.SET_TIMEOUTS, {"ms": float(time_to_wait) * 1000, "type": "page load"}) + self.execute( + Command.SET_TIMEOUTS, + {"ms": float(time_to_wait) * 1000, "type": "page load"}, + ) @property def timeouts(self) -> Timeouts: @@ -796,7 +829,9 @@ def timeouts(self, timeouts) -> None: """ _ = self.execute(Command.SET_TIMEOUTS, timeouts._to_json())["value"] - def find_element(self, by: str | RelativeBy = By.ID, value: str | None = None) -> WebElement: + def find_element( + self, by: str | RelativeBy = By.ID, value: str | None = None + ) -> WebElement: """Find an element given a By strategy and locator. Args: @@ -817,12 +852,18 @@ def find_element(self, by: str | RelativeBy = By.ID, value: str | None = None) - if isinstance(by, RelativeBy): elements = self.find_elements(by=by, value=value) if not elements: - raise NoSuchElementException(f"Cannot locate relative element with: {by.root}") + raise NoSuchElementException( + f"Cannot locate relative element with: {by.root}" + ) return elements[0] - return self.execute(Command.FIND_ELEMENT, {"using": by, "value": value})["value"] + return self.execute(Command.FIND_ELEMENT, {"using": by, "value": value})[ + "value" + ] - def find_elements(self, by: str | RelativeBy = By.ID, value: str | None = None) -> list[WebElement]: + def find_elements( + self, by: str | RelativeBy = By.ID, value: str | None = None + ) -> list[WebElement]: """Find elements given a By strategy and locator. Args: @@ -844,14 +885,21 @@ def find_elements(self, by: str | RelativeBy = By.ID, value: str | None = None) _pkg = ".".join(__name__.split(".")[:-1]) raw_data = pkgutil.get_data(_pkg, "findElements.js") if raw_data is None: - raise FileNotFoundError(f"Could not find findElements.js in package {_pkg}") + raise FileNotFoundError( + f"Could not find findElements.js in package {_pkg}" + ) raw_function = raw_data.decode("utf8") - find_element_js = f"/* findElements */return ({raw_function}).apply(null, arguments);" + find_element_js = ( + f"/* findElements */return ({raw_function}).apply(null, arguments);" + ) return self.execute_script(find_element_js, by.to_dict()) # Return empty list if driver returns null # See https://github.com/SeleniumHQ/selenium/issues/4555 - return self.execute(Command.FIND_ELEMENTS, {"using": by, "value": value})["value"] or [] + return ( + self.execute(Command.FIND_ELEMENTS, {"using": by, "value": value})["value"] + or [] + ) @property def capabilities(self) -> dict: @@ -948,7 +996,9 @@ def get_window_size(self, windowHandle: str = "current") -> dict: return {k: size[k] for k in ("width", "height")} - def set_window_position(self, x: float, y: float, windowHandle: str = "current") -> dict: + def set_window_position( + self, x: float, y: float, windowHandle: str = "current" + ) -> dict: """Sets the x,y position of the current window. Args: @@ -976,7 +1026,10 @@ def get_window_position(self, windowHandle="current") -> dict: def _check_if_window_handle_is_current(self, windowHandle: str) -> None: """Warns if the window handle is not equal to `current`.""" if windowHandle != "current": - warnings.warn("Only 'current' window is supported for W3C compatible browsers.", stacklevel=2) + warnings.warn( + "Only 'current' window is supported for W3C compatible browsers.", + stacklevel=2, + ) def get_window_rect(self) -> dict: """Get the window's position and size. @@ -1004,7 +1057,9 @@ def set_window_rect(self, x=None, y=None, width=None, height=None) -> dict: if (x is None and y is None) and (not height and not width): raise InvalidArgumentException("x and y or height and width need values") - return self.execute(Command.SET_WINDOW_RECT, {"x": x, "y": y, "width": width, "height": height})["value"] + return self.execute( + Command.SET_WINDOW_RECT, {"x": x, "y": y, "width": width, "height": height} + )["value"] @property def file_detector(self) -> FileDetector: @@ -1049,7 +1104,9 @@ def orientation(self, value) -> None: if value.upper() in allowed_values: self.execute(Command.SET_SCREEN_ORIENTATION, {"orientation": value}) else: - raise WebDriverException("You can only set the orientation to 'LANDSCAPE' and 'PORTRAIT'") + raise WebDriverException( + "You can only set the orientation to 'LANDSCAPE' and 'PORTRAIT'" + ) def start_devtools(self) -> tuple[Any, WebSocketConnection]: global cdp @@ -1064,7 +1121,9 @@ def start_devtools(self) -> tuple[Any, WebSocketConnection]: version, ws_url = self._get_cdp_details() if not ws_url: - raise WebDriverException("Unable to find url to connect to from capabilities") + raise WebDriverException( + "Unable to find url to connect to from capabilities" + ) if cdp is None: raise WebDriverException("CDP module not loaded") @@ -1073,20 +1132,28 @@ def start_devtools(self) -> tuple[Any, WebSocketConnection]: if self._websocket_connection: return self._devtools, self._websocket_connection if self.caps["browserName"].lower() == "firefox": - raise RuntimeError("CDP support for Firefox has been removed. Please switch to WebDriver BiDi.") + raise RuntimeError( + "CDP support for Firefox has been removed. Please switch to WebDriver BiDi." + ) if not isinstance(self.command_executor, RemoteConnection): - raise WebDriverException("command_executor must be a RemoteConnection instance for CDP support") + raise WebDriverException( + "command_executor must be a RemoteConnection instance for CDP support" + ) self._websocket_connection = WebSocketConnection( ws_url, self.command_executor.client_config.websocket_timeout, self.command_executor.client_config.websocket_interval, ) - targets = self._websocket_connection.execute(self._devtools.target.get_targets()) + targets = self._websocket_connection.execute( + self._devtools.target.get_targets() + ) for target in targets: if target.target_id == self.current_window_handle: target_id = target.target_id break - session = self._websocket_connection.execute(self._devtools.target.attach_to_target(target_id, True)) + session = self._websocket_connection.execute( + self._devtools.target.attach_to_target(target_id, True) + ) self._websocket_connection.session_id = session return self._devtools, self._websocket_connection @@ -1101,7 +1168,9 @@ async def bidi_connection(self): version, ws_url = self._get_cdp_details() if not ws_url: - raise WebDriverException("Unable to find url to connect to from capabilities") + raise WebDriverException( + "Unable to find url to connect to from capabilities" + ) devtools = cdp.import_devtools(version) async with cdp.open_cdp(ws_url) as conn: @@ -1127,10 +1196,14 @@ def _start_bidi(self) -> None: if self.caps.get("webSocketUrl"): ws_url = self.caps.get("webSocketUrl") else: - raise WebDriverException("Unable to find url to connect to from capabilities") + raise WebDriverException( + "Unable to find url to connect to from capabilities" + ) if not isinstance(self.command_executor, RemoteConnection): - raise WebDriverException("command_executor must be a RemoteConnection instance for BiDi support") + raise WebDriverException( + "command_executor must be a RemoteConnection instance for BiDi support" + ) self._websocket_connection = WebSocketConnection( ws_url, @@ -1346,9 +1419,13 @@ def _get_cdp_details(self): http = urllib3.PoolManager() try: if self.caps.get("browserName") == "chrome": - debugger_address = self.caps.get("goog:chromeOptions").get("debuggerAddress") + debugger_address = self.caps.get("goog:chromeOptions").get( + "debuggerAddress" + ) elif self.caps.get("browserName") in ("MicrosoftEdge", "webview2"): - debugger_address = self.caps.get("ms:edgeOptions").get("debuggerAddress") + debugger_address = self.caps.get("ms:edgeOptions").get( + "debuggerAddress" + ) except AttributeError: raise WebDriverException("Can't get debugger address.") @@ -1376,7 +1453,9 @@ def add_virtual_authenticator(self, options: VirtualAuthenticatorOptions) -> Non driver.add_virtual_authenticator(options) ``` """ - self._authenticator_id = self.execute(Command.ADD_VIRTUAL_AUTHENTICATOR, options.to_dict())["value"] + self._authenticator_id = self.execute( + Command.ADD_VIRTUAL_AUTHENTICATOR, options.to_dict() + )["value"] @property def virtual_authenticator_id(self) -> str | None: @@ -1390,7 +1469,10 @@ def remove_virtual_authenticator(self) -> None: The authenticator is no longer valid after removal, so no methods may be called. """ - self.execute(Command.REMOVE_VIRTUAL_AUTHENTICATOR, {"authenticatorId": self._authenticator_id}) + self.execute( + Command.REMOVE_VIRTUAL_AUTHENTICATOR, + {"authenticatorId": self._authenticator_id}, + ) self._authenticator_id = None @required_virtual_authenticator @@ -1405,13 +1487,20 @@ def add_credential(self, credential: Credential) -> None: driver.add_credential(credential) ``` """ - self.execute(Command.ADD_CREDENTIAL, {**credential.to_dict(), "authenticatorId": self._authenticator_id}) + self.execute( + Command.ADD_CREDENTIAL, + {**credential.to_dict(), "authenticatorId": self._authenticator_id}, + ) @required_virtual_authenticator def get_credentials(self) -> list[Credential]: """Returns the list of credentials owned by the authenticator.""" - credential_data = self.execute(Command.GET_CREDENTIALS, {"authenticatorId": self._authenticator_id}) - return [Credential.from_dict(credential) for credential in credential_data["value"]] + credential_data = self.execute( + Command.GET_CREDENTIALS, {"authenticatorId": self._authenticator_id} + ) + return [ + Credential.from_dict(credential) for credential in credential_data["value"] + ] @required_virtual_authenticator def remove_credential(self, credential_id: str | bytearray) -> None: @@ -1426,13 +1515,16 @@ def remove_credential(self, credential_id: str | bytearray) -> None: credential_id = urlsafe_b64encode(credential_id).decode() self.execute( - Command.REMOVE_CREDENTIAL, {"credentialId": credential_id, "authenticatorId": self._authenticator_id} + Command.REMOVE_CREDENTIAL, + {"credentialId": credential_id, "authenticatorId": self._authenticator_id}, ) @required_virtual_authenticator def remove_all_credentials(self) -> None: """Removes all credentials from the authenticator.""" - self.execute(Command.REMOVE_ALL_CREDENTIALS, {"authenticatorId": self._authenticator_id}) + self.execute( + Command.REMOVE_ALL_CREDENTIALS, {"authenticatorId": self._authenticator_id} + ) @required_virtual_authenticator def set_user_verified(self, verified: bool) -> None: @@ -1445,12 +1537,17 @@ def set_user_verified(self, verified: bool) -> None: Example: `driver.set_user_verified(True)` """ - self.execute(Command.SET_USER_VERIFIED, {"authenticatorId": self._authenticator_id, "isUserVerified": verified}) + self.execute( + Command.SET_USER_VERIFIED, + {"authenticatorId": self._authenticator_id, "isUserVerified": verified}, + ) def get_downloadable_files(self) -> list: """Retrieves the downloadable files as a list of file names.""" if "se:downloadsEnabled" not in self.capabilities: - raise WebDriverException("You must enable downloads in order to work with downloadable files.") + raise WebDriverException( + "You must enable downloads in order to work with downloadable files." + ) return self.execute(Command.GET_DOWNLOADABLE_FILES)["value"]["names"] @@ -1465,12 +1562,16 @@ def download_file(self, file_name: str, target_directory: str) -> None: `driver.download_file("example.zip", "/path/to/directory")` """ if "se:downloadsEnabled" not in self.capabilities: - raise WebDriverException("You must enable downloads in order to work with downloadable files.") + raise WebDriverException( + "You must enable downloads in order to work with downloadable files." + ) if not os.path.exists(target_directory): os.makedirs(target_directory) - contents = self.execute(Command.DOWNLOAD_FILE, {"name": file_name})["value"]["contents"] + contents = self.execute(Command.DOWNLOAD_FILE, {"name": file_name})["value"][ + "contents" + ] with tempfile.TemporaryDirectory() as tmp_dir: zip_file = os.path.join(tmp_dir, file_name + ".zip") @@ -1483,7 +1584,9 @@ def download_file(self, file_name: str, target_directory: str) -> None: def delete_downloadable_files(self) -> None: """Deletes all downloadable files.""" if "se:downloadsEnabled" not in self.capabilities: - raise WebDriverException("You must enable downloads in order to work with downloadable files.") + raise WebDriverException( + "You must enable downloads in order to work with downloadable files." + ) self.execute(Command.DELETE_DOWNLOADABLE_FILES) @@ -1589,5 +1692,10 @@ def _check_fedcm() -> Dialog | None: except NoAlertPresentException: return None - wait = WebDriverWait(self, timeout, poll_frequency=poll_frequency, ignored_exceptions=ignored_exceptions) + wait = WebDriverWait( + self, + timeout, + poll_frequency=poll_frequency, + ignored_exceptions=ignored_exceptions, + ) return wait.until(lambda _: _check_fedcm()) diff --git a/py/selenium/webdriver/remote/websocket_connection.py b/py/selenium/webdriver/remote/websocket_connection.py index 98bf4f4b9057a..68358e4a09974 100644 --- a/py/selenium/webdriver/remote/websocket_connection.py +++ b/py/selenium/webdriver/remote/websocket_connection.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import dataclasses import json import logging from ssl import CERT_NONE @@ -25,6 +26,40 @@ from selenium.common import WebDriverException + +def _snake_to_camel(name: str) -> str: + """Convert snake_case field name to camelCase for BiDi protocol.""" + parts = name.split("_") + return parts[0] + "".join(p.title() for p in parts[1:]) + + +class _BiDiEncoder(json.JSONEncoder): + """JSON encoder for BiDi dataclass instances. + + Converts snake_case field names to camelCase, strips ``None`` values, + and flattens a ``properties`` field (e.g. ``PointerCommonProperties``) + directly into its parent action dict as required by the BiDi spec. + """ + + def default(self, o): + if dataclasses.is_dataclass(o) and not isinstance(o, type): + result = {} + for f in dataclasses.fields(o): + value = getattr(o, f.name) + if value is None: + continue + camel_key = _snake_to_camel(f.name) + # Flatten PointerCommonProperties fields inline into the parent + if camel_key == "properties" and dataclasses.is_dataclass(value): + for pf in dataclasses.fields(value): + pv = getattr(value, pf.name) + if pv is not None: + result[_snake_to_camel(pf.name)] = pv + else: + result[camel_key] = value + return result + return super().default(o) + logger = logging.getLogger(__name__) @@ -63,7 +98,7 @@ def execute(self, command): if self.session_id: payload["sessionId"] = self.session_id - data = json.dumps(payload) + data = json.dumps(payload, cls=_BiDiEncoder) logger.debug(f"-> {data}"[: self._max_log_message_size]) self._ws.send(data) diff --git a/py/test/selenium/webdriver/common/bidi_browser_tests.py b/py/test/selenium/webdriver/common/bidi_browser_tests.py index 7fd054b73627d..b9e042403dcba 100644 --- a/py/test/selenium/webdriver/common/bidi_browser_tests.py +++ b/py/test/selenium/webdriver/common/bidi_browser_tests.py @@ -20,7 +20,7 @@ import pytest from selenium.common.exceptions import TimeoutException -from selenium.webdriver.common.bidi.browser import ClientWindowInfo, ClientWindowState +from selenium.webdriver.common.bidi.browser import ClientWindowInfo, ClientWindowNamedState from selenium.webdriver.common.bidi.browsing_context import ReadinessState from selenium.webdriver.common.bidi.session import UserPromptHandler, UserPromptHandlerType from selenium.webdriver.common.by import By @@ -100,10 +100,9 @@ def test_raises_exception_when_removing_default_user_context(driver): def test_client_window_state_constants(driver): - assert ClientWindowState.FULLSCREEN == "fullscreen" - assert ClientWindowState.MAXIMIZED == "maximized" - assert ClientWindowState.MINIMIZED == "minimized" - assert ClientWindowState.NORMAL == "normal" + """Test ClientWindowNamedState constants.""" + assert ClientWindowNamedState.MAXIMIZED == "maximized" + assert ClientWindowNamedState.MINIMIZED == "minimized" def test_create_user_context_with_accept_insecure_certs(driver): @@ -177,7 +176,7 @@ def test_create_user_context_with_manual_proxy_all_params(driver, proxy_server): # Visit a site that should be proxied driver.get("http://example.com/") - body_text = driver.find_element("tag name", "body").text + body_text = driver.find_element(By.TAG_NAME, "body").text assert "proxied response" in body_text.lower() finally: From 7d03c32f5813996c76323e01e8ff166b6a3794cc Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Fri, 27 Feb 2026 14:07:04 +0000 Subject: [PATCH 02/37] fixup --- py/generate_bidi.py | 204 ++++--- py/private/bidi_enhancements_manifest.py | 77 ++- py/selenium/webdriver/common/bidi/__init__.py | 23 + py/selenium/webdriver/common/bidi/browser.py | 43 +- .../webdriver/common/bidi/browsing_context.py | 171 ++++-- py/selenium/webdriver/common/bidi/cdp.py | 515 ------------------ py/selenium/webdriver/common/bidi/common.py | 7 +- py/selenium/webdriver/common/bidi/console.py | 0 .../webdriver/common/bidi/emulation.py | 187 +++---- py/selenium/webdriver/common/bidi/input.py | 36 +- py/selenium/webdriver/common/bidi/log.py | 223 +++++++- py/selenium/webdriver/common/bidi/network.py | 115 ++-- .../webdriver/common/bidi/permissions.py | 10 +- py/selenium/webdriver/common/bidi/script.py | 113 +++- py/selenium/webdriver/common/bidi/session.py | 30 +- py/selenium/webdriver/common/bidi/storage.py | 44 +- .../webdriver/common/bidi/webextension.py | 20 +- 17 files changed, 873 insertions(+), 945 deletions(-) delete mode 100644 py/selenium/webdriver/common/bidi/cdp.py mode change 100755 => 100644 py/selenium/webdriver/common/bidi/console.py diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 1770cf436bef1..2db595ff37cd0 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -18,12 +18,11 @@ import logging import re import sys -from collections import defaultdict from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from textwrap import dedent, indent as tw_indent -from typing import Any, Dict, List, Optional, Set, Tuple +from textwrap import indent as tw_indent +from typing import Any __version__ = "1.0.0" @@ -43,8 +42,7 @@ # WebDriver BiDi module: {{}} from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder +from typing import Any """ @@ -53,7 +51,7 @@ def indent(s: str, n: int) -> str: return tw_indent(s, n * " ") -def load_enhancements_manifest(manifest_path: Optional[str]) -> Dict[str, Any]: +def load_enhancements_manifest(manifest_path: str | None) -> dict[str, Any]: """Load enhancement manifest from a Python file. Args: @@ -124,10 +122,10 @@ def get_annotation(cls, cddl_type: str) -> str: if cddl_type.startswith("["): # Array inner = cddl_type.strip("[]+ ") inner_type = cls.get_annotation(inner) - return f"List[{inner_type}]" + return f"list[{inner_type}]" if cddl_type.startswith("{"): # Map/Dict - return "Dict[str, Any]" + return "dict[str, Any]" # Default to Any for unknown types return "Any" @@ -139,11 +137,11 @@ class CddlCommand: module: str name: str - params: Dict[str, str] = field(default_factory=dict) - result: Optional[str] = None + params: dict[str, str] = field(default_factory=dict) + result: str | None = None description: str = "" - def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: """Generate Python method code for this command. Args: @@ -174,8 +172,15 @@ def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str else: param_list = "self" - # Build method body - body = f" def {method_name}({param_list}):\n" + # Build method body - wrap long signatures over multiple lines if needed + sig_line = f" def {method_name}({param_list}):" + if len(sig_line) > 120 and param_strs: + body = f" def {method_name}(\n self,\n" + for p in param_strs: + body += f" {p},\n" + body += " ):\n" + else: + body = sig_line + "\n" body += f' """{self.description or "Execute " + self.module + "." + self.name}."""\n' # Add validation if specified @@ -237,7 +242,6 @@ def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str if result_param == "download_behavior": body += ' "downloadBehavior": download_behavior,\n' # Add remaining parameters that weren't part of the transform - override_params = enhancements.get("params_override", {}) for cddl_param_name in self.params: if cddl_param_name not in ["downloadBehavior"]: snake_name = self._camel_to_snake(cddl_param_name) @@ -264,45 +268,45 @@ def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str # Extract property from list items body += f' if result and "{extract_field}" in result:\n' body += f' items = result.get("{extract_field}", [])\n' - body += f" return [\n" + body += " return [\n" body += f' item.get("{extract_property}")\n' - body += f" for item in items\n" - body += f" if isinstance(item, dict)\n" - body += f" ]\n" - body += f" return []\n" + body += " for item in items\n" + body += " if isinstance(item, dict)\n" + body += " ]\n" + body += " return []\n" elif extract_field in deserialize_rules: # Extract field and deserialize to typed objects type_name = deserialize_rules[extract_field] body += f' if result and "{extract_field}" in result:\n' body += f' items = result.get("{extract_field}", [])\n' - body += f" return [\n" + body += " return [\n" body += f" {type_name}(\n" body += self._generate_field_args(extract_field, type_name) - body += f" )\n" - body += f" for item in items\n" - body += f" if isinstance(item, dict)\n" - body += f" ]\n" - body += f" return []\n" + body += " )\n" + body += " for item in items\n" + body += " if isinstance(item, dict)\n" + body += " ]\n" + body += " return []\n" else: # Simple field extraction (return the value directly, not wrapped in result dict) body += f' if result and "{extract_field}" in result:\n' body += f' extracted = result.get("{extract_field}")\n' - body += f" return extracted\n" - body += f" return result\n" + body += " return extracted\n" + body += " return result\n" elif "deserialize" in enhancements: # Deserialize response to typed objects (legacy, without extract_field) deserialize_rules = enhancements["deserialize"] for response_field, type_name in deserialize_rules.items(): body += f' if result and "{response_field}" in result:\n' body += f' items = result.get("{response_field}", [])\n' - body += f" return [\n" + body += " return [\n" body += f" {type_name}(\n" body += self._generate_field_args(response_field, type_name) - body += f" )\n" - body += f" for item in items\n" - body += f" if isinstance(item, dict)\n" - body += f" ]\n" - body += f" return []\n" + body += " )\n" + body += " for item in items\n" + body += " if isinstance(item, dict)\n" + body += " ]\n" + body += " return []\n" else: # No special response handling, just return the result body += " return result\n" @@ -351,10 +355,10 @@ class CddlTypeDefinition: module: str name: str - fields: Dict[str, str] = field(default_factory=dict) + fields: dict[str, str] = field(default_factory=dict) description: str = "" - def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str: """Generate Python dataclass code for this type. Args: @@ -366,7 +370,7 @@ def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> # Generate class name from type name (keep it as-is, don't split on underscores) class_name = self.name - code = f"@dataclass\n" + code = "@dataclass\n" code += f"class {class_name}:\n" code += f' """{self.description or self.name}."""\n\n' @@ -386,7 +390,7 @@ def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> literal_value = literal_match.group(1) code += f' {snake_name}: str = field(default="{literal_value}", init=False)\n' # Check if this field is a list type - elif "List[" in python_type: + elif "list[" in python_type: code += f" {snake_name}: {python_type} = field(default_factory=list)\n" else: code += f" {snake_name}: {python_type} = None\n" @@ -453,7 +457,7 @@ class CddlEnum: module: str name: str - values: List[str] = field(default_factory=list) + values: list[str] = field(default_factory=list) description: str = "" def to_python_class(self) -> str: @@ -530,10 +534,10 @@ class CddlModule: """Represents a CDDL module (e.g., script, network, browsing_context).""" name: str - commands: List[CddlCommand] = field(default_factory=list) - types: List[CddlTypeDefinition] = field(default_factory=list) - enums: List[CddlEnum] = field(default_factory=list) - events: List[CddlEvent] = field(default_factory=list) + commands: list[CddlCommand] = field(default_factory=list) + types: list[CddlTypeDefinition] = field(default_factory=list) + enums: list[CddlEnum] = field(default_factory=list) + events: list[CddlEvent] = field(default_factory=list) @staticmethod def _convert_method_to_event_name(method_suffix: str) -> str: @@ -548,7 +552,33 @@ def _convert_method_to_event_name(method_suffix: str) -> str: s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", method_suffix) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() - def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + def _needs_field_import(self, enhancements: dict[str, Any] | None = None) -> bool: + """Check if any type definition in this module requires the 'field' import. + + Respects the same type exclusions applied during code generation. + """ + enhancements = enhancements or {} + extra_cls_names: set[str] = set() + for extra_cls in enhancements.get("extra_dataclasses", []): + m = re.search(r"^class\s+(\w+)", extra_cls, re.MULTILINE) + if m: + extra_cls_names.add(m.group(1)) + exclude_types = set(enhancements.get("exclude_types", [])) | extra_cls_names + + for type_def in self.types: + if type_def.name in exclude_types: + continue + for field_type in type_def.fields.values(): + # Literal string discriminants use field(default=..., init=False) + if re.match(r'^"', field_type.strip()): + return True + # List-typed fields use field(default_factory=list) + python_type = CddlTypeDefinition._get_python_type(field_type) + if python_type.startswith("list["): + return True + return False + + def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """Generate Python code for this module. Args: @@ -558,17 +588,21 @@ def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: code = MODULE_HEADER.format(self.name) # Add imports if needed - if self.types: - code += "from dataclasses import field\n" + if self.commands: + code += "from .common import command_builder\n" + dataclass_imported = False if self.commands or self.types: - code += "from typing import Generator\n" code += "from dataclasses import dataclass\n" + dataclass_imported = True + if self.types and self._needs_field_import(enhancements): + code += "from dataclasses import field\n" # Add imports for event handling if needed if self.events: code += "import threading\n" code += "from collections.abc import Callable\n" - code += "from dataclasses import dataclass\n" + if not dataclass_imported: + code += "from dataclasses import dataclass\n" code += "from selenium.webdriver.common.bidi.session import Session\n" code += "\n\n" @@ -660,7 +694,13 @@ def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: code += f"{alias} = {target}\n\n" # Generate type dataclasses, skipping any overridden by extra_dataclasses - exclude_types = set(enhancements.get("exclude_types", [])) + # Also auto-exclude types whose names appear in extra_dataclasses + extra_cls_names = set() + for extra_cls in enhancements.get("extra_dataclasses", []): + m = re.search(r"^class\s+(\w+)", extra_cls, re.MULTILINE) + if m: + extra_cls_names.add(m.group(1)) + exclude_types = set(enhancements.get("exclude_types", [])) | extra_cls_names for type_def in self.types: if type_def.name in exclude_types: continue @@ -680,13 +720,16 @@ def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: # Generate EVENT_NAME_MAPPING for the module code += "# BiDi Event Name to Parameter Type Mapping\n" code += "EVENT_NAME_MAPPING = {\n" + # Collect event keys from extra_events so we skip CDDL duplicates + extra_event_keys = {evt["event_key"] for evt in enhancements.get("extra_events", [])} for event_def in self.events: # Convert method name to user-friendly event name # e.g., "browsingContext.contextCreated" -> "context_created" method_parts = event_def.method.split(".") if len(method_parts) == 2: event_name = self._convert_method_to_event_name(method_parts[1]) - code += f' "{event_name}": "{event_def.method}",\n' + if event_name not in extra_event_keys: + code += f' "{event_name}": "{event_def.method}",\n' # Extra events not in the CDDL spec (e.g. Chromium-specific events) for extra_evt in enhancements.get("extra_events", []): code += ( @@ -923,7 +966,13 @@ def clear_event_handlers(self) -> None: code += "\n" # Generate command methods - exclude_methods = enhancements.get("exclude_methods", []) + # Auto-exclude methods whose names appear in extra_methods to prevent duplicates + extra_method_names = set() + for extra_meth in enhancements.get("extra_methods", []): + m = re.search(r"def\s+(\w+)\s*\(", extra_meth) + if m: + extra_method_names.add(m.group(1)) + exclude_methods = set(enhancements.get("exclude_methods", [])) | extra_method_names if self.commands: for command in self.commands: # Get method-specific enhancements @@ -981,24 +1030,44 @@ def clear_event_handlers(self) -> None: code += "\n" # Now populate EVENT_CONFIGS after the aliases are defined - code += f"\n# Populate EVENT_CONFIGS with event configuration mappings\n" + code += "\n# Populate EVENT_CONFIGS with event configuration mappings\n" # Use globals() to look up types dynamically to handle missing types gracefully - code += f"_globals = globals()\n" + code += "_globals = globals()\n" code += f"{class_name}.EVENT_CONFIGS = {{\n" + # Collect extra event keys to skip CDDL duplicates + extra_event_keys_cfg = {evt["event_key"] for evt in enhancements.get("extra_events", [])} for event_def in self.events: # Convert method name to user-friendly event name method_parts = event_def.method.split(".") if len(method_parts) == 2: event_name = self._convert_method_to_event_name(method_parts[1]) + if event_name in extra_event_keys_cfg: + continue # The event class is the event name (e.g., ContextCreated) # Try to get it from globals, default to dict if not found - code += f' "{event_name}": (EventConfig("{event_name}", "{event_def.method}", _globals.get("{event_def.name}", dict)) if _globals.get("{event_def.name}") else EventConfig("{event_name}", "{event_def.method}", dict)),\n' + code += ( + f' "{event_name}": (\n' + f' EventConfig("{event_name}", "{event_def.method}",\n' + f' _globals.get("{event_def.name}", dict))\n' + f' if _globals.get("{event_def.name}")\n' + f' else EventConfig("{event_name}", "{event_def.method}", dict)\n' + f' ),\n' + ) # Extra events not in the CDDL spec for extra_evt in enhancements.get("extra_events", []): ek = extra_evt["event_key"] be = extra_evt["bidi_event"] ec = extra_evt["event_class"] - code += f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),\n' + single = f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),' + if len(single) > 120: + code += ( + f' "{ek}": EventConfig(\n' + f' "{ek}", "{be}",\n' + f' _globals.get("{ec}", dict),\n' + f' ),\n' + ) + else: + code += single + "\n" code += "}\n" return code @@ -1011,9 +1080,9 @@ def __init__(self, cddl_path: str): """Initialize parser with CDDL file path.""" self.cddl_path = Path(cddl_path) self.content = "" - self.modules: Dict[str, CddlModule] = {} - self.definitions: Dict[str, str] = {} - self.event_names: Set[str] = set() # Names of definitions that are events + self.modules: dict[str, CddlModule] = {} + self.definitions: dict[str, str] = {} + self.event_names: set[str] = set() # Names of definitions that are events self._read_file() def _read_file(self) -> None: @@ -1021,12 +1090,12 @@ def _read_file(self) -> None: if not self.cddl_path.exists(): raise FileNotFoundError(f"CDDL file not found: {self.cddl_path}") - with open(self.cddl_path, "r", encoding="utf-8") as f: + with open(self.cddl_path, encoding="utf-8") as f: self.content = f.read() logger.info(f"Loaded CDDL file: {self.cddl_path}") - def parse(self) -> Dict[str, CddlModule]: + def parse(self) -> dict[str, CddlModule]: """Parse CDDL content and return modules.""" # Remove comments content = self._remove_comments(self.content) @@ -1090,9 +1159,6 @@ def _extract_event_names(self) -> None: ... ) """ - # Look for definitions like "BrowsingContextEvent", "SessionEvent", etc. - event_union_pattern = re.compile(r"(\w+\.)?(\w+)Event") - for def_name, def_content in self.definitions.items(): # Check if this looks like an event union (name ends with "Event") and # contains a module-qualified reference like "module.EventName". @@ -1175,7 +1241,7 @@ def _is_enum_definition(self, definition: str) -> bool: # Pattern: "something" / "something_else" return " / " in clean_def and '"' in clean_def - def _extract_enum_values(self, enum_definition: str) -> List[str]: + def _extract_enum_values(self, enum_definition: str) -> list[str]: """Extract individual values from an enum definition. Enums are defined as: "value1" / "value2" / "value3" @@ -1225,7 +1291,7 @@ def _normalize_cddl_type(field_type: str) -> str: result = re.sub(r"-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?", "float", result) return result.strip() - def _extract_type_fields(self, type_definition: str) -> Dict[str, str]: + def _extract_type_fields(self, type_definition: str) -> dict[str, str]: """Extract fields from a type definition block.""" fields = {} @@ -1352,8 +1418,8 @@ def _extract_commands(self) -> None: ) def _extract_parameters( - self, params_type: str, _seen: Optional[Set[str]] = None - ) -> Dict[str, str]: + self, params_type: str, _seen: set[str] | None = None + ) -> dict[str, str]: """Extract parameters from a parameter type definition. Handles both struct types ({...}) and top-level union types (TypeA / TypeB), @@ -1466,7 +1532,7 @@ def module_name_to_filename(module_name: str) -> str: return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() -def generate_init_file(output_path: Path, modules: Dict[str, CddlModule]) -> None: +def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> None: """Generate __init__.py file for the module.""" init_path = output_path / "__init__.py" @@ -1481,7 +1547,7 @@ def generate_init_file(output_path: Path, modules: Dict[str, CddlModule]) -> Non filename = module_name_to_filename(module_name) code += f"from .{filename} import {class_name}\n" - code += f"\n__all__ = [\n" + code += "\n__all__ = [\n" for module_name in sorted(modules.keys()): class_name = module_name_to_class_name(module_name) code += f' "{class_name}",\n' @@ -1703,7 +1769,7 @@ def main( cddl_file: str, output_dir: str, spec_version: str = "1.0", - enhancements_manifest: Optional[str] = None, + enhancements_manifest: str | None = None, ) -> None: """Main entry point. diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index ae7229f6ddebd..39af67d4c635b 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -85,7 +85,12 @@ # downloadBehavior is never stripped by the generic None filter. # The BiDi spec marks it as required (can be null, but must be present). "extra_methods": [ - ''' def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): + ''' def set_download_behavior( + self, + allowed: bool | None = None, + destination_folder: str | None = None, + user_contexts: list[Any] | None = None, + ): """Set the download behavior for the browser. Args: @@ -272,8 +277,8 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": self, coordinates=None, error=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setGeolocationOverride. @@ -325,8 +330,8 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": ''' def set_timezone_override( self, timezone=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setTimezoneOverride. @@ -349,8 +354,8 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": ''' def set_scripting_enabled( self, enabled=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setScriptingEnabled. @@ -373,8 +378,8 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": ''' def set_user_agent_override( self, user_agent=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setUserAgentOverride. @@ -396,8 +401,8 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": ''' def set_screen_orientation_override( self, screen_orientation=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setScreenOrientationOverride. @@ -433,8 +438,8 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": self, network_conditions=None, offline: bool | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setNetworkConditions. @@ -534,7 +539,14 @@ def _serialize_arg(value): if raw.get("type") == "success": return raw.get("result") return raw''', - ''' def _add_preload_script(self, function_declaration, arguments=None, contexts=None, user_contexts=None, sandbox=None): + ''' def _add_preload_script( + self, + function_declaration, + arguments=None, + contexts=None, + user_contexts=None, + sandbox=None, + ): """Add a preload script with validation. Args: @@ -586,7 +598,15 @@ def _serialize_arg(value): script_id: The ID returned by pin(). """ return self._remove_preload_script(script_id=script_id)''', - ''' def _evaluate(self, expression, target, await_promise, result_ownership=None, serialization_options=None, user_activation=None): + ''' def _evaluate( + self, + expression, + target, + await_promise, + result_ownership=None, + serialization_options=None, + user_activation=None, + ): """Evaluate a script expression and return a structured result. Args: @@ -621,7 +641,17 @@ def __init__(self2, realm, result, exception_details): return _EvalResult(realm=realm, result=None, exception_details=exc) return _EvalResult(realm=realm, result=raw.get("result"), exception_details=None) return _EvalResult(realm=None, result=raw, exception_details=None)''', - ''' def _call_function(self, function_declaration, await_promise, target, arguments=None, result_ownership=None, this=None, user_activation=None, serialization_options=None): + ''' def _call_function( + self, + function_declaration, + await_promise, + target, + arguments=None, + result_ownership=None, + this=None, + user_activation=None, + serialization_options=None, + ): """Call a function and return a structured result. Args: @@ -1256,7 +1286,12 @@ def to_bidi_dict(self) -> dict: # Suppress the raw generated stubs; hand-written versions follow below "exclude_methods": ["install", "uninstall"], "extra_methods": [ - ''' def install(self, path: str | None = None, archive_path: str | None = None, base64_value: str | None = None): + ''' def install( + self, + path: str | None = None, + archive_path: str | None = None, + base64_value: str | None = None, + ): """Install a web extension. Exactly one of the three keyword arguments must be provided. @@ -1274,7 +1309,11 @@ def to_bidi_dict(self) -> dict: Raises: ValueError: If more than one, or none, of the arguments is provided. """ - provided = [k for k, v in {"path": path, "archive_path": archive_path, "base64_value": base64_value}.items() if v is not None] + provided = [ + k for k, v in { + "path": path, "archive_path": archive_path, "base64_value": base64_value, + }.items() if v is not None + ] if len(provided) != 1: raise ValueError( f"Exactly one of path, archive_path, or base64_value must be provided; got: {provided}" @@ -1502,6 +1541,7 @@ def _add_event_handler( - 'history_updated' Args: + self: The module instance this handler is bound to. event_name: The name of the event to subscribe to callback: Callback function to invoke when event occurs contexts: Optional list of context IDs to limit event subscription @@ -1538,6 +1578,7 @@ def _remove_event_handler( """Remove an event handler by its callback ID. Args: + self: The module instance this handler is bound to. callback_id: The callback ID returned from add_event_handler """ if not hasattr(self, "_event_handlers"): diff --git a/py/selenium/webdriver/common/bidi/__init__.py b/py/selenium/webdriver/common/bidi/__init__.py index ab96f2d81e292..7be7bd4f73856 100644 --- a/py/selenium/webdriver/common/bidi/__init__.py +++ b/py/selenium/webdriver/common/bidi/__init__.py @@ -5,3 +5,26 @@ from __future__ import annotations +from .browser import Browser +from .browsing_context import BrowsingContext +from .emulation import Emulation +from .input import Input +from .log import Log +from .network import Network +from .script import Script +from .session import Session +from .storage import Storage +from .webextension import WebExtension + +__all__ = [ + "Browser", + "BrowsingContext", + "Emulation", + "Input", + "Log", + "Network", + "Script", + "Session", + "Storage", + "WebExtension", +] diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index ed6a4d8f33bc5..acda63f71953e 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: browser from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass def transform_download_params( @@ -131,14 +130,14 @@ class CreateUserContextParameters: class GetClientWindowsResult: """GetClientWindowsResult.""" - client_windows: list[Any | None] | None = None + client_windows: list[Any | None] | None = field(default_factory=list) @dataclass class GetUserContextsResult: """GetUserContextsResult.""" - user_contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -171,7 +170,7 @@ class SetDownloadBehaviorParameters: """SetDownloadBehaviorParameters.""" download_behavior: Any | None = None - user_contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -204,7 +203,12 @@ def close(self): result = self._conn.execute(cmd) return result - def create_user_context(self, accept_insecure_certs: bool | None = None, proxy: Any | None = None, unhandled_prompt_behavior: Any | None = None): + def create_user_context( + self, + accept_insecure_certs: bool | None = None, + proxy: Any | None = None, + unhandled_prompt_behavior: Any | None = None, + ): """Execute browser.createUserContext.""" if proxy and hasattr(proxy, 'to_bidi_dict'): proxy = proxy.to_bidi_dict() @@ -285,23 +289,12 @@ def set_client_window_state(self, client_window: Any | None = None): result = self._conn.execute(cmd) return result - def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): - """Execute browser.setDownloadBehavior.""" - validate_download_behavior(allowed=allowed, destination_folder=destination_folder, user_contexts=user_contexts) - - download_behavior = None - download_behavior = transform_download_params(allowed, destination_folder) - - params = { - "downloadBehavior": download_behavior, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browser.setDownloadBehavior", params) - result = self._conn.execute(cmd) - return result - - def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): + def set_download_behavior( + self, + allowed: bool | None = None, + destination_folder: str | None = None, + user_contexts: list[Any] | None = None, + ): """Set the download behavior for the browser. Args: diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 35aea615d1780..5f128635df29d 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: browsingContext from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class ReadinessState: """ReadinessState.""" @@ -220,14 +219,14 @@ class LocateNodesParameters: context: Any | None = None locator: Any | None = None serialization_options: Any | None = None - start_nodes: list[Any | None] | None = None + start_nodes: list[Any | None] | None = field(default_factory=list) @dataclass class LocateNodesResult: """LocateNodesResult.""" - nodes: list[Any | None] | None = None + nodes: list[Any | None] | None = field(default_factory=list) @dataclass @@ -300,7 +299,7 @@ class SetViewportParameters: context: Any | None = None viewport: Any | None = None device_pixel_ratio: Any | None = None - user_contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -328,20 +327,6 @@ class HistoryUpdatedParameters: url: str | None = None -@dataclass -class DownloadWillBeginParams: - """DownloadWillBeginParams.""" - - suggested_filename: str | None = None - - -@dataclass -class DownloadCanceledParams: - """DownloadCanceledParams.""" - - status: str = field(default="canceled", init=False) - - @dataclass class UserPromptClosedParameters: """UserPromptClosedParameters.""" @@ -390,10 +375,10 @@ class DownloadParams: class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" - download_params: "DownloadParams | None" = None + download_params: DownloadParams | None = None @classmethod - def from_json(cls, params: dict) -> "DownloadEndParams": + def from_json(cls, params: dict) -> DownloadEndParams: """Deserialize from BiDi wire-level params dict.""" dp = DownloadParams( status=params.get("status"), @@ -414,8 +399,6 @@ def from_json(cls, params: dict) -> "DownloadEndParams": "history_updated": "browsingContext.historyUpdated", "dom_content_loaded": "browsingContext.domContentLoaded", "load": "browsingContext.load", - "download_will_begin": "browsingContext.downloadWillBegin", - "download_end": "browsingContext.downloadEnd", "navigation_aborted": "browsingContext.navigationAborted", "navigation_committed": "browsingContext.navigationCommitted", "navigation_failed": "browsingContext.navigationFailed", @@ -630,7 +613,13 @@ def activate(self, context: Any | None = None): result = self._conn.execute(cmd) return result - def capture_screenshot(self, context: str | None = None, format: Any | None = None, clip: Any | None = None, origin: str | None = None): + def capture_screenshot( + self, + context: str | None = None, + format: Any | None = None, + clip: Any | None = None, + origin: str | None = None, + ): """Execute browsingContext.captureScreenshot.""" params = { "context": context, @@ -657,7 +646,13 @@ def close(self, context: Any | None = None, prompt_unload: bool | None = None): result = self._conn.execute(cmd) return result - def create(self, type: Any | None = None, reference_context: Any | None = None, background: bool | None = None, user_context: Any | None = None): + def create( + self, + type: Any | None = None, + reference_context: Any | None = None, + background: bool | None = None, + user_context: Any | None = None, + ): """Execute browsingContext.create.""" params = { "type": type, @@ -711,7 +706,14 @@ def handle_user_prompt(self, context: Any | None = None, accept: bool | None = N result = self._conn.execute(cmd) return result - def locate_nodes(self, context: str | None = None, locator: Any | None = None, serialization_options: Any | None = None, start_nodes: Any | None = None, max_node_count: int | None = None): + def locate_nodes( + self, + context: str | None = None, + locator: Any | None = None, + serialization_options: Any | None = None, + start_nodes: Any | None = None, + max_node_count: int | None = None, + ): """Execute browsingContext.locateNodes.""" params = { "context": context, @@ -740,7 +742,15 @@ def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any result = self._conn.execute(cmd) return result - def print(self, context: Any | None = None, background: bool | None = None, margin: Any | None = None, page: Any | None = None, scale: Any | None = None, shrink_to_fit: bool | None = None): + def print( + self, + context: Any | None = None, + background: bool | None = None, + margin: Any | None = None, + page: Any | None = None, + scale: Any | None = None, + shrink_to_fit: bool | None = None, + ): """Execute browsingContext.print.""" params = { "context": context, @@ -770,7 +780,13 @@ def reload(self, context: Any | None = None, ignore_cache: bool | None = None, w result = self._conn.execute(cmd) return result - def set_viewport(self, context: str | None = None, viewport: Any | None = None, user_contexts: Any | None = None, device_pixel_ratio: Any | None = None): + def set_viewport( + self, + context: str | None = None, + viewport: Any | None = None, + user_contexts: Any | None = None, + device_pixel_ratio: Any | None = None, + ): """Execute browsingContext.setViewport.""" params = { "context": context, @@ -868,20 +884,81 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() BrowsingContext.EVENT_CONFIGS = { - "context_created": (EventConfig("context_created", "browsingContext.contextCreated", _globals.get("ContextCreated", dict)) if _globals.get("ContextCreated") else EventConfig("context_created", "browsingContext.contextCreated", dict)), - "context_destroyed": (EventConfig("context_destroyed", "browsingContext.contextDestroyed", _globals.get("ContextDestroyed", dict)) if _globals.get("ContextDestroyed") else EventConfig("context_destroyed", "browsingContext.contextDestroyed", dict)), - "navigation_started": (EventConfig("navigation_started", "browsingContext.navigationStarted", _globals.get("NavigationStarted", dict)) if _globals.get("NavigationStarted") else EventConfig("navigation_started", "browsingContext.navigationStarted", dict)), - "fragment_navigated": (EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", _globals.get("FragmentNavigated", dict)) if _globals.get("FragmentNavigated") else EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", dict)), - "history_updated": (EventConfig("history_updated", "browsingContext.historyUpdated", _globals.get("HistoryUpdated", dict)) if _globals.get("HistoryUpdated") else EventConfig("history_updated", "browsingContext.historyUpdated", dict)), - "dom_content_loaded": (EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", _globals.get("DomContentLoaded", dict)) if _globals.get("DomContentLoaded") else EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", dict)), - "load": (EventConfig("load", "browsingContext.load", _globals.get("Load", dict)) if _globals.get("Load") else EventConfig("load", "browsingContext.load", dict)), - "download_will_begin": (EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBegin", dict)) if _globals.get("DownloadWillBegin") else EventConfig("download_will_begin", "browsingContext.downloadWillBegin", dict)), - "download_end": (EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEnd", dict)) if _globals.get("DownloadEnd") else EventConfig("download_end", "browsingContext.downloadEnd", dict)), - "navigation_aborted": (EventConfig("navigation_aborted", "browsingContext.navigationAborted", _globals.get("NavigationAborted", dict)) if _globals.get("NavigationAborted") else EventConfig("navigation_aborted", "browsingContext.navigationAborted", dict)), - "navigation_committed": (EventConfig("navigation_committed", "browsingContext.navigationCommitted", _globals.get("NavigationCommitted", dict)) if _globals.get("NavigationCommitted") else EventConfig("navigation_committed", "browsingContext.navigationCommitted", dict)), - "navigation_failed": (EventConfig("navigation_failed", "browsingContext.navigationFailed", _globals.get("NavigationFailed", dict)) if _globals.get("NavigationFailed") else EventConfig("navigation_failed", "browsingContext.navigationFailed", dict)), - "user_prompt_closed": (EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", _globals.get("UserPromptClosed", dict)) if _globals.get("UserPromptClosed") else EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", dict)), - "user_prompt_opened": (EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", _globals.get("UserPromptOpened", dict)) if _globals.get("UserPromptOpened") else EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", dict)), - "download_will_begin": EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBeginParams", dict)), + "context_created": ( + EventConfig("context_created", "browsingContext.contextCreated", + _globals.get("ContextCreated", dict)) + if _globals.get("ContextCreated") + else EventConfig("context_created", "browsingContext.contextCreated", dict) + ), + "context_destroyed": ( + EventConfig("context_destroyed", "browsingContext.contextDestroyed", + _globals.get("ContextDestroyed", dict)) + if _globals.get("ContextDestroyed") + else EventConfig("context_destroyed", "browsingContext.contextDestroyed", dict) + ), + "navigation_started": ( + EventConfig("navigation_started", "browsingContext.navigationStarted", + _globals.get("NavigationStarted", dict)) + if _globals.get("NavigationStarted") + else EventConfig("navigation_started", "browsingContext.navigationStarted", dict) + ), + "fragment_navigated": ( + EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", + _globals.get("FragmentNavigated", dict)) + if _globals.get("FragmentNavigated") + else EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", dict) + ), + "history_updated": ( + EventConfig("history_updated", "browsingContext.historyUpdated", + _globals.get("HistoryUpdated", dict)) + if _globals.get("HistoryUpdated") + else EventConfig("history_updated", "browsingContext.historyUpdated", dict) + ), + "dom_content_loaded": ( + EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", + _globals.get("DomContentLoaded", dict)) + if _globals.get("DomContentLoaded") + else EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", dict) + ), + "load": ( + EventConfig("load", "browsingContext.load", + _globals.get("Load", dict)) + if _globals.get("Load") + else EventConfig("load", "browsingContext.load", dict) + ), + "navigation_aborted": ( + EventConfig("navigation_aborted", "browsingContext.navigationAborted", + _globals.get("NavigationAborted", dict)) + if _globals.get("NavigationAborted") + else EventConfig("navigation_aborted", "browsingContext.navigationAborted", dict) + ), + "navigation_committed": ( + EventConfig("navigation_committed", "browsingContext.navigationCommitted", + _globals.get("NavigationCommitted", dict)) + if _globals.get("NavigationCommitted") + else EventConfig("navigation_committed", "browsingContext.navigationCommitted", dict) + ), + "navigation_failed": ( + EventConfig("navigation_failed", "browsingContext.navigationFailed", + _globals.get("NavigationFailed", dict)) + if _globals.get("NavigationFailed") + else EventConfig("navigation_failed", "browsingContext.navigationFailed", dict) + ), + "user_prompt_closed": ( + EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", + _globals.get("UserPromptClosed", dict)) + if _globals.get("UserPromptClosed") + else EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", dict) + ), + "user_prompt_opened": ( + EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", + _globals.get("UserPromptOpened", dict)) + if _globals.get("UserPromptOpened") + else EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", dict) + ), + "download_will_begin": EventConfig( + "download_will_begin", "browsingContext.downloadWillBegin", + _globals.get("DownloadWillBeginParams", dict), + ), "download_end": EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEndParams", dict)), } diff --git a/py/selenium/webdriver/common/bidi/cdp.py b/py/selenium/webdriver/common/bidi/cdp.py deleted file mode 100644 index 38dcf8d803ea3..0000000000000 --- a/py/selenium/webdriver/common/bidi/cdp.py +++ /dev/null @@ -1,515 +0,0 @@ -# The MIT License(MIT) -# -# Copyright(c) 2018 Hyperion Gray -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files(the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -# -# This code comes from https://github.com/HyperionGray/trio-chrome-devtools-protocol/tree/master/trio_cdp - -import contextvars -import importlib -import itertools -import json -import logging -import pathlib -from collections import defaultdict -from collections.abc import AsyncGenerator, AsyncIterator, Generator -from contextlib import asynccontextmanager, contextmanager -from dataclasses import dataclass -from typing import Any, TypeVar - -import trio -from trio_websocket import ConnectionClosed as WsConnectionClosed -from trio_websocket import connect_websocket_url - -logger = logging.getLogger("trio_cdp") -T = TypeVar("T") -MAX_WS_MESSAGE_SIZE = 2**24 - -devtools = None -version = None - - -def import_devtools(ver): - """Attempt to load the current latest available devtools into the module cache for use later.""" - global devtools - global version - version = ver - base = "selenium.webdriver.common.devtools.v" - try: - devtools = importlib.import_module(f"{base}{ver}") - return devtools - except ModuleNotFoundError: - # Attempt to parse and load the 'most recent' devtools module. This is likely - # because cdp has been updated but selenium python has not been released yet. - devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") - versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir() and f.name != "latest") - latest = max(int(x[1:]) for x in versions) - selenium_logger = logging.getLogger(__name__) - selenium_logger.debug("Falling back to loading `devtools`: v%s", latest) - devtools = importlib.import_module(f"{base}{latest}") - return devtools - - -_connection_context: contextvars.ContextVar = contextvars.ContextVar("connection_context") -_session_context: contextvars.ContextVar = contextvars.ContextVar("session_context") - - -def get_connection_context(fn_name): - """Look up the current connection. - - If there is no current connection, raise a ``RuntimeError`` with a - helpful message. - """ - try: - return _connection_context.get() - except LookupError: - raise RuntimeError(f"{fn_name}() must be called in a connection context.") - - -def get_session_context(fn_name): - """Look up the current session. - - If there is no current session, raise a ``RuntimeError`` with a - helpful message. - """ - try: - return _session_context.get() - except LookupError: - raise RuntimeError(f"{fn_name}() must be called in a session context.") - - -@contextmanager -def connection_context(connection): - """Context manager installs ``connection`` as the session context for the current Trio task.""" - token = _connection_context.set(connection) - try: - yield - finally: - _connection_context.reset(token) - - -@contextmanager -def session_context(session): - """Context manager installs ``session`` as the session context for the current Trio task.""" - token = _session_context.set(session) - try: - yield - finally: - _session_context.reset(token) - - -def set_global_connection(connection): - """Install ``connection`` in the root context so that it will become the default connection for all tasks. - - This is generally not recommended, except it may be necessary in - certain use cases such as running inside Jupyter notebook. - """ - global _connection_context - _connection_context = contextvars.ContextVar("_connection_context", default=connection) - - -def set_global_session(session): - """Install ``session`` in the root context so that it will become the default session for all tasks. - - This is generally not recommended, except it may be necessary in - certain use cases such as running inside Jupyter notebook. - """ - global _session_context - _session_context = contextvars.ContextVar("_session_context", default=session) - - -class BrowserError(Exception): - """This exception is raised when the browser's response to a command indicates that an error occurred.""" - - def __init__(self, obj): - self.code = obj.get("code") - self.message = obj.get("message") - self.detail = obj.get("data") - - def __str__(self): - return f"BrowserError {self.detail}" - - -class CdpConnectionClosed(WsConnectionClosed): - """Raised when a public method is called on a closed CDP connection.""" - - def __init__(self, reason): - """Constructor. - - Args: - reason: wsproto.frame_protocol.CloseReason - """ - self.reason = reason - - def __repr__(self): - """Return representation.""" - return f"{self.__class__.__name__}<{self.reason}>" - - -class InternalError(Exception): - """This exception is only raised when there is faulty logic in TrioCDP or the integration with PyCDP.""" - - pass - - -@dataclass -class CmEventProxy: - """A proxy object returned by :meth:`CdpBase.wait_for()``. - - After the context manager executes, this proxy object will have a - value set that contains the returned event. - """ - - value: Any = None - - -class CdpBase: - def __init__(self, ws, session_id, target_id): - self.ws = ws - self.session_id = session_id - self.target_id = target_id - self.channels = defaultdict(set) - self.id_iter = itertools.count() - self.inflight_cmd = {} - self.inflight_result = {} - - async def execute(self, cmd: Generator[dict, T, Any]) -> T: - """Execute a command on the server and wait for the result. - - Args: - cmd: any CDP command - - Returns: - a CDP result - """ - cmd_id = next(self.id_iter) - cmd_event = trio.Event() - self.inflight_cmd[cmd_id] = cmd, cmd_event - request = next(cmd) - request["id"] = cmd_id - if self.session_id: - request["sessionId"] = self.session_id - request_str = json.dumps(request) - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f"Sending CDP message: {cmd_id} {cmd_event}: {request_str}") - try: - await self.ws.send_message(request_str) - except WsConnectionClosed as wcc: - raise CdpConnectionClosed(wcc.reason) from None - await cmd_event.wait() - response = self.inflight_result.pop(cmd_id) - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f"Received CDP message: {response}") - if isinstance(response, Exception): - if logger.isEnabledFor(logging.DEBUG): - logger.debug(f"Exception raised by {cmd_event} message: {type(response).__name__}") - raise response - return response - - def listen(self, *event_types, buffer_size=10): - """Listen for events. - - Returns: - An async iterator that iterates over events matching the indicated types. - """ - sender, receiver = trio.open_memory_channel(buffer_size) - for event_type in event_types: - self.channels[event_type].add(sender) - return receiver - - @asynccontextmanager - async def wait_for(self, event_type: type[T], buffer_size=10) -> AsyncGenerator[CmEventProxy, None]: - """Wait for an event of the given type and return it. - - This is an async context manager, so you should open it inside - an async with block. The block will not exit until the indicated - event is received. - """ - sender: trio.MemorySendChannel - receiver: trio.MemoryReceiveChannel - sender, receiver = trio.open_memory_channel(buffer_size) - self.channels[event_type].add(sender) - proxy = CmEventProxy() - yield proxy - async with receiver: - event = await receiver.receive() - proxy.value = event - - def _handle_data(self, data): - """Handle incoming WebSocket data. - - Args: - data: a JSON dictionary - """ - if "id" in data: - self._handle_cmd_response(data) - else: - self._handle_event(data) - - def _handle_cmd_response(self, data: dict): - """Handle a response to a command. - - This will set an event flag that will return control to the - task that called the command. - - Args: - data: response as a JSON dictionary - """ - cmd_id = data["id"] - try: - cmd, event = self.inflight_cmd.pop(cmd_id) - except KeyError: - logger.warning("Got a message with a command ID that does not exist: %s", data) - return - if "error" in data: - # If the server reported an error, convert it to an exception and do - # not process the response any further. - self.inflight_result[cmd_id] = BrowserError(data["error"]) - else: - # Otherwise, continue the generator to parse the JSON result - # into a CDP object. - try: - _ = cmd.send(data["result"]) - raise InternalError("The command's generator function did not exit when expected!") - except StopIteration as exit: - return_ = exit.value - self.inflight_result[cmd_id] = return_ - event.set() - - def _handle_event(self, data: dict): - """Handle an event. - - Args: - data: event as a JSON dictionary - """ - global devtools - if devtools is None: - raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") - event = devtools.util.parse_json_event(data) - logger.debug("Received event: %s", event) - to_remove = set() - for sender in self.channels[type(event)]: - try: - sender.send_nowait(event) - except trio.WouldBlock: - logger.error('Unable to send event "%r" due to full channel %s', event, sender) - except trio.BrokenResourceError: - to_remove.add(sender) - if to_remove: - self.channels[type(event)] -= to_remove - - -class CdpSession(CdpBase): - """Contains the state for a CDP session. - - Generally you should not instantiate this object yourself; you should call - :meth:`CdpConnection.open_session`. - """ - - def __init__(self, ws, session_id, target_id): - """Constructor. - - Args: - ws: trio_websocket.WebSocketConnection - session_id: devtools.target.SessionID - target_id: devtools.target.TargetID - """ - super().__init__(ws, session_id, target_id) - - self._dom_enable_count = 0 - self._dom_enable_lock = trio.Lock() - self._page_enable_count = 0 - self._page_enable_lock = trio.Lock() - - @asynccontextmanager - async def dom_enable(self): - """Context manager that executes ``dom.enable()`` when it enters and then calls ``dom.disable()``. - - This keeps track of concurrent callers and only disables DOM - events when all callers have exited. - """ - global devtools - async with self._dom_enable_lock: - self._dom_enable_count += 1 - if self._dom_enable_count == 1: - await self.execute(devtools.dom.enable()) - - yield - - async with self._dom_enable_lock: - self._dom_enable_count -= 1 - if self._dom_enable_count == 0: - await self.execute(devtools.dom.disable()) - - @asynccontextmanager - async def page_enable(self): - """Context manager executes ``page.enable()`` when it enters and then calls ``page.disable()`` when it exits. - - This keeps track of concurrent callers and only disables page - events when all callers have exited. - """ - global devtools - async with self._page_enable_lock: - self._page_enable_count += 1 - if self._page_enable_count == 1: - await self.execute(devtools.page.enable()) - - yield - - async with self._page_enable_lock: - self._page_enable_count -= 1 - if self._page_enable_count == 0: - await self.execute(devtools.page.disable()) - - -class CdpConnection(CdpBase, trio.abc.AsyncResource): - """Contains the connection state for a Chrome DevTools Protocol server. - - CDP can multiplex multiple "sessions" over a single connection. This - class corresponds to the "root" session, i.e. the implicitly created - session that has no session ID. This class is responsible for - reading incoming WebSocket messages and forwarding them to the - corresponding session, as well as handling messages targeted at the - root session itself. You should generally call the - :func:`open_cdp()` instead of instantiating this class directly. - """ - - def __init__(self, ws): - """Constructor. - - Args: - ws: trio_websocket.WebSocketConnection - """ - super().__init__(ws, session_id=None, target_id=None) - self.sessions = {} - - async def aclose(self): - """Close the underlying WebSocket connection. - - This will cause the reader task to gracefully exit when it tries - to read the next message from the WebSocket. All of the public - APIs (``execute()``, ``listen()``, etc.) will raise - ``CdpConnectionClosed`` after the CDP connection is closed. It - is safe to call this multiple times. - """ - await self.ws.aclose() - - @asynccontextmanager - async def open_session(self, target_id) -> AsyncIterator[CdpSession]: - """Context manager opens a session and enables the "simple" style of calling CDP APIs. - - For example, inside a session context, you can call ``await - dom.get_document()`` and it will execute on the current session - automatically. - """ - session = await self.connect_session(target_id) - with session_context(session): - yield session - - async def connect_session(self, target_id) -> "CdpSession": - """Returns a new :class:`CdpSession` connected to the specified target.""" - global devtools - if devtools is None: - raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") - session_id = await self.execute(devtools.target.attach_to_target(target_id, True)) - session = CdpSession(self.ws, session_id, target_id) - self.sessions[session_id] = session - return session - - async def _reader_task(self): - """Runs in the background and handles incoming messages. - - Dispatches responses to commands and events to listeners. - """ - global devtools - if devtools is None: - raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") - while True: - try: - message = await self.ws.get_message() - except WsConnectionClosed: - # If the WebSocket is closed, we don't want to throw an - # exception from the reader task. Instead we will throw - # exceptions from the public API methods, and we can quietly - # exit the reader task here. - break - try: - data = json.loads(message) - except json.JSONDecodeError: - raise BrowserError({"code": -32700, "message": "Client received invalid JSON", "data": message}) - logger.debug("Received message %r", data) - if "sessionId" in data: - session_id = devtools.target.SessionID(data["sessionId"]) - try: - session = self.sessions[session_id] - except KeyError: - raise BrowserError( - { - "code": -32700, - "message": "Browser sent a message for an invalid session", - "data": f"{session_id!r}", - } - ) - session._handle_data(data) - else: - self._handle_data(data) - - for _, session in self.sessions.items(): - for _, senders in session.channels.items(): - for sender in senders: - sender.close() - - -@asynccontextmanager -async def open_cdp(url) -> AsyncIterator[CdpConnection]: - """Async context manager opens a connection to the browser then closes the connection when the block exits. - - The context manager also sets the connection as the default - connection for the current task, so that commands like ``await - target.get_targets()`` will run on this connection automatically. If - you want to use multiple connections concurrently, it is recommended - to open each on in a separate task. - """ - async with trio.open_nursery() as nursery: - conn = await connect_cdp(nursery, url) - try: - with connection_context(conn): - yield conn - finally: - await conn.aclose() - - -async def connect_cdp(nursery, url) -> CdpConnection: - """Connect to the browser specified by ``url`` and spawn a background task in the specified nursery. - - The ``open_cdp()`` context manager is preferred in most situations. - You should only use this function if you need to specify a custom - nursery. This connection is not automatically closed! You can either - use the connection object as a context manager (``async with - conn:``) or else call ``await conn.aclose()`` on it when you are - done with it. If ``set_context`` is True, then the returned - connection will be installed as the default connection for the - current task. This argument is for unusual use cases, such as - running inside of a notebook. - """ - ws = await connect_websocket_url(nursery, url, max_message_size=MAX_WS_MESSAGE_SIZE) - cdp_conn = CdpConnection(ws) - nursery.start_soon(cdp_conn._reader_task) - return cdp_conn diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index d90d8c770263a..d7cb436a08471 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -17,12 +17,13 @@ """Common utilities for BiDi command construction.""" -from typing import Any, Dict, Generator +from collections.abc import Generator +from typing import Any def command_builder( - method: str, params: Dict[str, Any] -) -> Generator[Dict[str, Any], Any, Any]: + method: str, params: dict[str, Any] +) -> Generator[dict[str, Any], Any, Any]: """Build a BiDi command generator. Args: diff --git a/py/selenium/webdriver/common/bidi/console.py b/py/selenium/webdriver/common/bidi/console.py old mode 100755 new mode 100644 diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 4cd6ae2e3c712..cb575bbdc54dd 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: emulation from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass class ForcedColorsModeTheme: @@ -41,16 +40,16 @@ class SetForcedColorsModeThemeOverrideParameters: """SetForcedColorsModeThemeOverrideParameters.""" theme: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass class SetGeolocationOverrideParameters: """SetGeolocationOverrideParameters.""" - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -78,8 +77,8 @@ class SetLocaleOverrideParameters: """SetLocaleOverrideParameters.""" locale: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -87,8 +86,8 @@ class setNetworkConditionsParameters: """setNetworkConditionsParameters.""" network_conditions: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -111,8 +110,8 @@ class SetScreenSettingsOverrideParameters: """SetScreenSettingsOverrideParameters.""" screen_area: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -128,8 +127,8 @@ class SetScreenOrientationOverrideParameters: """SetScreenOrientationOverrideParameters.""" screen_orientation: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -137,8 +136,8 @@ class SetUserAgentOverrideParameters: """SetUserAgentOverrideParameters.""" user_agent: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -146,8 +145,8 @@ class SetViewportMetaOverrideParameters: """SetViewportMetaOverrideParameters.""" viewport_meta: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -155,8 +154,8 @@ class SetScriptingEnabledParameters: """SetScriptingEnabledParameters.""" enabled: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -164,8 +163,8 @@ class SetScrollbarTypeOverrideParameters: """SetScrollbarTypeOverrideParameters.""" scrollbar_type: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -173,16 +172,16 @@ class SetTimezoneOverrideParameters: """SetTimezoneOverrideParameters.""" timezone: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass class SetTouchOverrideParameters: """SetTouchOverrideParameters.""" - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) class Emulation: @@ -191,7 +190,12 @@ class Emulation: def __init__(self, conn) -> None: self._conn = conn - def set_forced_colors_mode_theme_override(self, theme: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_forced_colors_mode_theme_override( + self, + theme: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute emulation.setForcedColorsModeThemeOverride.""" params = { "theme": theme, @@ -203,18 +207,12 @@ def set_forced_colors_mode_theme_override(self, theme: Any | None = None, contex result = self._conn.execute(cmd) return result - def set_geolocation_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setGeolocationOverride.""" - params = { - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setGeolocationOverride", params) - result = self._conn.execute(cmd) - return result - - def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_locale_override( + self, + locale: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute emulation.setLocaleOverride.""" params = { "locale": locale, @@ -226,19 +224,12 @@ def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | N result = self._conn.execute(cmd) return result - def set_network_conditions(self, network_conditions: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setNetworkConditions.""" - params = { - "networkConditions": network_conditions, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setNetworkConditions", params) - result = self._conn.execute(cmd) - return result - - def set_screen_settings_override(self, screen_area: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_screen_settings_override( + self, + screen_area: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute emulation.setScreenSettingsOverride.""" params = { "screenArea": screen_area, @@ -250,31 +241,12 @@ def set_screen_settings_override(self, screen_area: Any | None = None, contexts: result = self._conn.execute(cmd) return result - def set_screen_orientation_override(self, screen_orientation: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setScreenOrientationOverride.""" - params = { - "screenOrientation": screen_orientation, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setScreenOrientationOverride", params) - result = self._conn.execute(cmd) - return result - - def set_user_agent_override(self, user_agent: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setUserAgentOverride.""" - params = { - "userAgent": user_agent, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setUserAgentOverride", params) - result = self._conn.execute(cmd) - return result - - def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_viewport_meta_override( + self, + viewport_meta: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute emulation.setViewportMetaOverride.""" params = { "viewportMeta": viewport_meta, @@ -286,19 +258,12 @@ def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: result = self._conn.execute(cmd) return result - def set_scripting_enabled(self, enabled: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setScriptingEnabled.""" - params = { - "enabled": enabled, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setScriptingEnabled", params) - result = self._conn.execute(cmd) - return result - - def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_scrollbar_type_override( + self, + scrollbar_type: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute emulation.setScrollbarTypeOverride.""" params = { "scrollbarType": scrollbar_type, @@ -310,19 +275,7 @@ def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, context result = self._conn.execute(cmd) return result - def set_timezone_override(self, timezone: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setTimezoneOverride.""" - params = { - "timezone": timezone, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setTimezoneOverride", params) - result = self._conn.execute(cmd) - return result - - def set_touch_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_touch_override(self, contexts: list[Any] | None = None, user_contexts: list[Any] | None = None): """Execute emulation.setTouchOverride.""" params = { "contexts": contexts, @@ -337,8 +290,8 @@ def set_geolocation_override( self, coordinates=None, error=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setGeolocationOverride. @@ -390,8 +343,8 @@ def set_geolocation_override( def set_timezone_override( self, timezone=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setTimezoneOverride. @@ -414,8 +367,8 @@ def set_timezone_override( def set_scripting_enabled( self, enabled=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setScriptingEnabled. @@ -438,8 +391,8 @@ def set_scripting_enabled( def set_user_agent_override( self, user_agent=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setUserAgentOverride. @@ -461,8 +414,8 @@ def set_user_agent_override( def set_screen_orientation_override( self, screen_orientation=None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setScreenOrientationOverride. @@ -498,8 +451,8 @@ def set_network_conditions( self, network_conditions=None, offline: bool | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setNetworkConditions. diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 5dbe71dbd3886..13f43361293f2 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: input from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class PointerType: """PointerType.""" @@ -45,7 +44,7 @@ class PerformActionsParameters: """PerformActionsParameters.""" context: Any | None = None - actions: list[Any | None] | None = None + actions: list[Any | None] | None = field(default_factory=list) @dataclass @@ -54,7 +53,7 @@ class NoneSourceActions: type: str = field(default="none", init=False) id: str | None = None - actions: list[Any | None] | None = None + actions: list[Any | None] | None = field(default_factory=list) @dataclass @@ -63,7 +62,7 @@ class KeySourceActions: type: str = field(default="key", init=False) id: str | None = None - actions: list[Any | None] | None = None + actions: list[Any | None] | None = field(default_factory=list) @dataclass @@ -73,7 +72,7 @@ class PointerSourceActions: type: str = field(default="pointer", init=False) id: str | None = None parameters: Any | None = None - actions: list[Any | None] | None = None + actions: list[Any | None] | None = field(default_factory=list) @dataclass @@ -89,7 +88,7 @@ class WheelSourceActions: type: str = field(default="wheel", init=False) id: str | None = None - actions: list[Any | None] | None = None + actions: list[Any | None] | None = field(default_factory=list) @dataclass @@ -163,7 +162,7 @@ class SetFilesParameters: context: Any | None = None element: Any | None = None - files: list[Any | None] | None = None + files: list[Any | None] | None = field(default_factory=list) @dataclass @@ -175,7 +174,7 @@ class FileDialogInfo: multiple: bool | None = None @classmethod - def from_json(cls, params: dict) -> "FileDialogInfo": + def from_json(cls, params: dict) -> FileDialogInfo: """Deserialize event params into FileDialogInfo.""" return cls( context=params.get("context"), @@ -368,7 +367,7 @@ def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def perform_actions(self, context: Any | None = None, actions: List[Any] | None = None): + def perform_actions(self, context: Any | None = None, actions: list[Any] | None = None): """Execute input.performActions.""" params = { "context": context, @@ -389,7 +388,7 @@ def release_actions(self, context: Any | None = None): result = self._conn.execute(cmd) return result - def set_files(self, context: Any | None = None, element: Any | None = None, files: List[Any] | None = None): + def set_files(self, context: Any | None = None, element: Any | None = None, files: list[Any] | None = None): """Execute input.setFiles.""" params = { "context": context, @@ -454,5 +453,10 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Input.EVENT_CONFIGS = { - "file_dialog_opened": (EventConfig("file_dialog_opened", "input.fileDialogOpened", _globals.get("FileDialogOpened", dict)) if _globals.get("FileDialogOpened") else EventConfig("file_dialog_opened", "input.fileDialogOpened", dict)), + "file_dialog_opened": ( + EventConfig("file_dialog_opened", "input.fileDialogOpened", + _globals.get("FileDialogOpened", dict)) + if _globals.get("FileDialogOpened") + else EventConfig("file_dialog_opened", "input.fileDialogOpened", dict) + ), } diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index faf6c85ae2b6c..7971b807e94a1 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -6,11 +6,12 @@ # WebDriver BiDi module: log from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator +import threading +from collections.abc import Callable from dataclasses import dataclass +from typing import Any + +from selenium.webdriver.common.bidi.session import Session class Level: @@ -56,7 +57,7 @@ class ConsoleLogEntry: stack_trace: Any | None = None @classmethod - def from_json(cls, params: dict) -> "ConsoleLogEntry": + def from_json(cls, params: dict) -> ConsoleLogEntry: """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -81,7 +82,7 @@ class JavascriptLogEntry: stacktrace: Any | None = None @classmethod - def from_json(cls, params: dict) -> "JavascriptLogEntry": + def from_json(cls, params: dict) -> JavascriptLogEntry: """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -92,18 +93,212 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": stacktrace=params.get("stackTrace"), ) +# BiDi Event Name to Parameter Type Mapping +EVENT_NAME_MAPPING = { + "entry_added": "log.entryAdded", +} + +@dataclass +class EventConfig: + """Configuration for a BiDi event.""" + event_key: str + bidi_event: str + event_class: type + + +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + """ + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) + + +class _EventManager: + """Manages event subscriptions and callbacks.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + """Subscribe to a BiDi event if not already subscribed.""" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist.""" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: + session = Session(self.conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event]["callbacks"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() + + + + class Log: """WebDriver BiDi log module.""" + EVENT_CONFIGS = {} def __init__(self, conn) -> None: self._conn = conn + self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) + + pass + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + """Add an event handler. + + Args: + event: The event to subscribe to. + callback: The callback function to execute on event. + contexts: The context IDs to subscribe to (optional). + + Returns: + The callback ID. + """ + return self._event_manager.add_event_handler(event, callback, contexts) + + def remove_event_handler(self, event: str, callback_id: int) -> None: + """Remove an event handler. + + Args: + event: The event to unsubscribe from. + callback_id: The callback ID. + """ + return self._event_manager.remove_event_handler(event, callback_id) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + return self._event_manager.clear_event_handlers() + +# Event Info Type Aliases +# Event: log.entryAdded +EntryAdded = globals().get('Entry', dict) # Fallback to dict if type not defined - def entry_added(self): - """Execute log.entryAdded.""" - params = { - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("log.entryAdded", params) - result = self._conn.execute(cmd) - return result +# Populate EVENT_CONFIGS with event configuration mappings +_globals = globals() +Log.EVENT_CONFIGS = { + "entry_added": ( + EventConfig("entry_added", "log.entryAdded", + _globals.get("EntryAdded", dict)) + if _globals.get("EntryAdded") + else EventConfig("entry_added", "log.entryAdded", dict) + ), +} diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 4f44e309bffbb..6e02eeabc4ed7 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: network from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class SameSite: """SameSite.""" @@ -75,7 +74,7 @@ class BaseParameters: redirect_count: Any | None = None request: Any | None = None timestamp: Any | None = None - intercepts: list[Any | None] | None = None + intercepts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -171,13 +170,13 @@ class ResponseData: status: Any | None = None status_text: str | None = None from_cache: bool | None = None - headers: list[Any | None] | None = None + headers: list[Any | None] | None = field(default_factory=list) mime_type: str | None = None bytes_received: Any | None = None headers_size: Any | None = None body_size: Any | None = None content: Any | None = None - auth_challenges: list[Any | None] | None = None + auth_challenges: list[Any | None] | None = field(default_factory=list) @dataclass @@ -219,11 +218,11 @@ class UrlPatternString: class AddDataCollectorParameters: """AddDataCollectorParameters.""" - data_types: list[Any | None] | None = None + data_types: list[Any | None] | None = field(default_factory=list) max_encoded_data_size: Any | None = None collector_type: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -237,9 +236,9 @@ class AddDataCollectorResult: class AddInterceptParameters: """AddInterceptParameters.""" - phases: list[Any | None] | None = None - contexts: list[Any | None] | None = None - url_patterns: list[Any | None] | None = None + phases: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = field(default_factory=list) + url_patterns: list[Any | None] | None = field(default_factory=list) @dataclass @@ -254,9 +253,9 @@ class ContinueResponseParameters: """ContinueResponseParameters.""" request: Any | None = None - cookies: list[Any | None] | None = None + cookies: list[Any | None] | None = field(default_factory=list) credentials: Any | None = None - headers: list[Any | None] | None = None + headers: list[Any | None] | None = field(default_factory=list) reason_phrase: str | None = None status_code: Any | None = None @@ -315,8 +314,8 @@ class ProvideResponseParameters: request: Any | None = None body: Any | None = None - cookies: list[Any | None] | None = None - headers: list[Any | None] | None = None + cookies: list[Any | None] | None = field(default_factory=list) + headers: list[Any | None] | None = field(default_factory=list) reason_phrase: str | None = None status_code: Any | None = None @@ -340,16 +339,16 @@ class SetCacheBehaviorParameters: """SetCacheBehaviorParameters.""" cache_behavior: Any | None = None - contexts: list[Any | None] | None = None + contexts: list[Any | None] | None = field(default_factory=list) @dataclass class SetExtraHeadersParameters: """SetExtraHeadersParameters.""" - headers: list[Any | None] | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + headers: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass @@ -562,7 +561,14 @@ def __init__(self, conn) -> None: self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) self.intercepts = [] - def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_data_size: Any | None = None, collector_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def add_data_collector( + self, + data_types: list[Any] | None = None, + max_encoded_data_size: Any | None = None, + collector_type: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute network.addDataCollector.""" params = { "dataTypes": data_types, @@ -576,7 +582,12 @@ def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_da result = self._conn.execute(cmd) return result - def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | None = None, url_patterns: List[Any] | None = None): + def add_intercept( + self, + phases: list[Any] | None = None, + contexts: list[Any] | None = None, + url_patterns: list[Any] | None = None, + ): """Execute network.addIntercept.""" params = { "phases": phases, @@ -588,7 +599,15 @@ def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | N result = self._conn.execute(cmd) return result - def continue_request(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, method: Any | None = None, url: Any | None = None): + def continue_request( + self, + request: Any | None = None, + body: Any | None = None, + cookies: list[Any] | None = None, + headers: list[Any] | None = None, + method: Any | None = None, + url: Any | None = None, + ): """Execute network.continueRequest.""" params = { "request": request, @@ -603,7 +622,15 @@ def continue_request(self, request: Any | None = None, body: Any | None = None, result = self._conn.execute(cmd) return result - def continue_response(self, request: Any | None = None, cookies: List[Any] | None = None, credentials: Any | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): + def continue_response( + self, + request: Any | None = None, + cookies: list[Any] | None = None, + credentials: Any | None = None, + headers: list[Any] | None = None, + reason_phrase: Any | None = None, + status_code: Any | None = None, + ): """Execute network.continueResponse.""" params = { "request": request, @@ -650,7 +677,13 @@ def fail_request(self, request: Any | None = None): result = self._conn.execute(cmd) return result - def get_data(self, data_type: Any | None = None, collector: Any | None = None, disown: bool | None = None, request: Any | None = None): + def get_data( + self, + data_type: Any | None = None, + collector: Any | None = None, + disown: bool | None = None, + request: Any | None = None, + ): """Execute network.getData.""" params = { "dataType": data_type, @@ -663,7 +696,15 @@ def get_data(self, data_type: Any | None = None, collector: Any | None = None, d result = self._conn.execute(cmd) return result - def provide_response(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): + def provide_response( + self, + request: Any | None = None, + body: Any | None = None, + cookies: list[Any] | None = None, + headers: list[Any] | None = None, + reason_phrase: Any | None = None, + status_code: Any | None = None, + ): """Execute network.provideResponse.""" params = { "request": request, @@ -698,7 +739,7 @@ def remove_intercept(self, intercept: Any | None = None): result = self._conn.execute(cmd) return result - def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[Any] | None = None): + def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: list[Any] | None = None): """Execute network.setCacheBehavior.""" params = { "cacheBehavior": cache_behavior, @@ -709,7 +750,12 @@ def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[A result = self._conn.execute(cmd) return result - def set_extra_headers(self, headers: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_extra_headers( + self, + headers: list[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute network.setExtraHeaders.""" params = { "headers": headers, @@ -918,6 +964,11 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Network.EVENT_CONFIGS = { - "auth_required": (EventConfig("auth_required", "network.authRequired", _globals.get("AuthRequired", dict)) if _globals.get("AuthRequired") else EventConfig("auth_required", "network.authRequired", dict)), + "auth_required": ( + EventConfig("auth_required", "network.authRequired", + _globals.get("AuthRequired", dict)) + if _globals.get("AuthRequired") + else EventConfig("auth_required", "network.authRequired", dict) + ), "before_request": EventConfig("before_request", "network.beforeRequestSent", _globals.get("dict", dict)), } diff --git a/py/selenium/webdriver/common/bidi/permissions.py b/py/selenium/webdriver/common/bidi/permissions.py index f00e765c62e3b..6dd138da17309 100644 --- a/py/selenium/webdriver/common/bidi/permissions.py +++ b/py/selenium/webdriver/common/bidi/permissions.py @@ -20,7 +20,7 @@ from __future__ import annotations from enum import Enum -from typing import Any, Optional, Union +from typing import Any from .common import command_builder @@ -63,10 +63,10 @@ def __init__(self, websocket_connection: Any) -> None: def set_permission( self, - descriptor: Union[PermissionDescriptor, str], - state: Union[PermissionState, str], - origin: Optional[str] = None, - user_context: Optional[str] = None, + descriptor: PermissionDescriptor | str, + state: PermissionState | str, + origin: str | None = None, + user_context: str | None = None, ) -> None: """Set a permission for a given origin. diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index e13c11f71a5cb..b29721db88503 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: script from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class SpecialNumber: """SpecialNumber.""" @@ -216,7 +215,7 @@ class DedicatedWorkerRealmInfo: """DedicatedWorkerRealmInfo.""" type: str = field(default="dedicated-worker", init=False) - owners: list[Any | None] | None = None + owners: list[Any | None] | None = field(default_factory=list) @dataclass @@ -460,7 +459,7 @@ class NodeProperties: node_type: Any | None = None child_node_count: Any | None = None - children: list[Any | None] | None = None + children: list[Any | None] | None = field(default_factory=list) local_name: str | None = None mode: Any | None = None namespace_uri: str | None = None @@ -499,7 +498,7 @@ class StackFrame: class StackTrace: """StackTrace.""" - call_frames: list[Any | None] | None = None + call_frames: list[Any | None] | None = field(default_factory=list) @dataclass @@ -530,9 +529,9 @@ class AddPreloadScriptParameters: """AddPreloadScriptParameters.""" function_declaration: str | None = None - arguments: list[Any | None] | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + arguments: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) sandbox: str | None = None @@ -547,7 +546,7 @@ class AddPreloadScriptResult: class DisownParameters: """DisownParameters.""" - handles: list[Any | None] | None = None + handles: list[Any | None] | None = field(default_factory=list) target: Any | None = None @@ -558,7 +557,7 @@ class CallFunctionParameters: function_declaration: str | None = None await_promise: bool | None = None target: Any | None = None - arguments: list[Any | None] | None = None + arguments: list[Any | None] | None = field(default_factory=list) result_ownership: Any | None = None serialization_options: Any | None = None this: Any | None = None @@ -589,7 +588,7 @@ class GetRealmsParameters: class GetRealmsResult: """GetRealmsResult.""" - realms: list[Any | None] | None = None + realms: list[Any | None] | None = field(default_factory=list) @dataclass @@ -783,7 +782,14 @@ def __init__(self, conn, driver=None) -> None: self._driver = driver self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def add_preload_script(self, function_declaration: Any | None = None, arguments: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None, sandbox: Any | None = None): + def add_preload_script( + self, + function_declaration: Any | None = None, + arguments: list[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + sandbox: Any | None = None, + ): """Execute script.addPreloadScript.""" params = { "functionDeclaration": function_declaration, @@ -797,7 +803,7 @@ def add_preload_script(self, function_declaration: Any | None = None, arguments: result = self._conn.execute(cmd) return result - def disown(self, handles: List[Any] | None = None, target: Any | None = None): + def disown(self, handles: list[Any] | None = None, target: Any | None = None): """Execute script.disown.""" params = { "handles": handles, @@ -808,7 +814,17 @@ def disown(self, handles: List[Any] | None = None, target: Any | None = None): result = self._conn.execute(cmd) return result - def call_function(self, function_declaration: Any | None = None, await_promise: bool | None = None, target: Any | None = None, arguments: List[Any] | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, this: Any | None = None, user_activation: bool | None = None): + def call_function( + self, + function_declaration: Any | None = None, + await_promise: bool | None = None, + target: Any | None = None, + arguments: list[Any] | None = None, + result_ownership: Any | None = None, + serialization_options: Any | None = None, + this: Any | None = None, + user_activation: bool | None = None, + ): """Execute script.callFunction.""" params = { "functionDeclaration": function_declaration, @@ -825,7 +841,15 @@ def call_function(self, function_declaration: Any | None = None, await_promise: result = self._conn.execute(cmd) return result - def evaluate(self, expression: Any | None = None, target: Any | None = None, await_promise: bool | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, user_activation: bool | None = None): + def evaluate( + self, + expression: Any | None = None, + target: Any | None = None, + await_promise: bool | None = None, + result_ownership: Any | None = None, + serialization_options: Any | None = None, + user_activation: bool | None = None, + ): """Execute script.evaluate.""" params = { "expression": expression, @@ -889,8 +913,9 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import math as _math import datetime as _datetime + import math as _math + from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -941,7 +966,14 @@ def _serialize_arg(value): if raw.get("type") == "success": return raw.get("result") return raw - def _add_preload_script(self, function_declaration, arguments=None, contexts=None, user_contexts=None, sandbox=None): + def _add_preload_script( + self, + function_declaration, + arguments=None, + contexts=None, + user_contexts=None, + sandbox=None, + ): """Add a preload script with validation. Args: @@ -993,7 +1025,15 @@ def unpin(self, script_id): script_id: The ID returned by pin(). """ return self._remove_preload_script(script_id=script_id) - def _evaluate(self, expression, target, await_promise, result_ownership=None, serialization_options=None, user_activation=None): + def _evaluate( + self, + expression, + target, + await_promise, + result_ownership=None, + serialization_options=None, + user_activation=None, + ): """Evaluate a script expression and return a structured result. Args: @@ -1028,7 +1068,17 @@ def __init__(self2, realm, result, exception_details): return _EvalResult(realm=realm, result=None, exception_details=exc) return _EvalResult(realm=realm, result=raw.get("result"), exception_details=None) return _EvalResult(realm=None, result=raw, exception_details=None) - def _call_function(self, function_declaration, await_promise, target, arguments=None, result_ownership=None, this=None, user_activation=None, serialization_options=None): + def _call_function( + self, + function_declaration, + await_promise, + target, + arguments=None, + result_ownership=None, + this=None, + user_activation=None, + serialization_options=None, + ): """Call a function and return a structured result. Args: @@ -1106,8 +1156,9 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod + from selenium.webdriver.common.bidi.session import Session as _Session bidi_event = "log.entryAdded" @@ -1257,6 +1308,16 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Script.EVENT_CONFIGS = { - "realm_created": (EventConfig("realm_created", "script.realmCreated", _globals.get("RealmCreated", dict)) if _globals.get("RealmCreated") else EventConfig("realm_created", "script.realmCreated", dict)), - "realm_destroyed": (EventConfig("realm_destroyed", "script.realmDestroyed", _globals.get("RealmDestroyed", dict)) if _globals.get("RealmDestroyed") else EventConfig("realm_destroyed", "script.realmDestroyed", dict)), + "realm_created": ( + EventConfig("realm_created", "script.realmCreated", + _globals.get("RealmCreated", dict)) + if _globals.get("RealmCreated") + else EventConfig("realm_created", "script.realmCreated", dict) + ), + "realm_destroyed": ( + EventConfig("realm_destroyed", "script.realmDestroyed", + _globals.get("RealmDestroyed", dict)) + if _globals.get("RealmDestroyed") + else EventConfig("realm_destroyed", "script.realmDestroyed", dict) + ), } diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index 9b1daaae557fa..c1b5be09ca024 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: session from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass class UserPromptHandlerType: @@ -26,7 +25,7 @@ class CapabilitiesRequest: """CapabilitiesRequest.""" always_match: Any | None = None - first_match: list[Any | None] | None = None + first_match: list[Any | None] | None = field(default_factory=list) @dataclass @@ -62,7 +61,7 @@ class ManualProxyConfiguration: proxy_type: str = field(default="manual", init=False) http_proxy: str | None = None ssl_proxy: str | None = None - no_proxy: list[Any | None] | None = None + no_proxy: list[Any | None] | None = field(default_factory=list) @dataclass @@ -92,23 +91,23 @@ class SystemProxyConfiguration: class SubscribeParameters: """SubscribeParameters.""" - events: list[str | None] | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + events: list[str | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass class UnsubscribeByIDRequest: """UnsubscribeByIDRequest.""" - subscriptions: list[Any | None] | None = None + subscriptions: list[Any | None] | None = field(default_factory=list) @dataclass class UnsubscribeByAttributesRequest: """UnsubscribeByAttributesRequest.""" - events: list[str | None] | None = None + events: list[str | None] | None = field(default_factory=list) @dataclass @@ -211,7 +210,12 @@ def end(self): result = self._conn.execute(cmd) return result - def subscribe(self, events: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def subscribe( + self, + events: list[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): """Execute session.subscribe.""" params = { "events": events, @@ -223,7 +227,7 @@ def subscribe(self, events: List[Any] | None = None, contexts: List[Any] | None result = self._conn.execute(cmd) return result - def unsubscribe(self, events: List[Any] | None = None, subscriptions: List[Any] | None = None): + def unsubscribe(self, events: list[Any] | None = None, subscriptions: list[Any] | None = None): """Execute session.unsubscribe.""" params = { "events": events, diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 7e4c9c6dee459..3f29b85d13a23 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: storage from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass @dataclass @@ -33,7 +32,7 @@ class GetCookiesParameters: class GetCookiesResult: """GetCookiesResult.""" - cookies: list[Any | None] | None = None + cookies: list[Any | None] | None = field(default_factory=list) partition_key: Any | None = None @@ -107,7 +106,7 @@ class StorageCookie: expiry: Any | None = None @classmethod - def from_bidi_dict(cls, raw: dict) -> "StorageCookie": + def from_bidi_dict(cls, raw: dict) -> StorageCookie: """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): @@ -235,39 +234,6 @@ class Storage: def __init__(self, conn) -> None: self._conn = conn - def get_cookies(self, filter: Any | None = None, partition: Any | None = None): - """Execute storage.getCookies.""" - params = { - "filter": filter, - "partition": partition, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("storage.getCookies", params) - result = self._conn.execute(cmd) - return result - - def set_cookie(self, cookie: Any | None = None, partition: Any | None = None): - """Execute storage.setCookie.""" - params = { - "cookie": cookie, - "partition": partition, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("storage.setCookie", params) - result = self._conn.execute(cmd) - return result - - def delete_cookies(self, filter: Any | None = None, partition: Any | None = None): - """Execute storage.deleteCookies.""" - params = { - "filter": filter, - "partition": partition, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("storage.deleteCookies", params) - result = self._conn.execute(cmd) - return result - def get_cookies(self, filter=None, partition=None): """Execute storage.getCookies and return a GetCookiesResult.""" if filter and hasattr(filter, "to_bidi_dict"): diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 8a737efeeafde..ebbe6729499b2 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: webExtension from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass @dataclass @@ -64,7 +63,12 @@ class WebExtension: def __init__(self, conn) -> None: self._conn = conn - def install(self, path: str | None = None, archive_path: str | None = None, base64_value: str | None = None): + def install( + self, + path: str | None = None, + archive_path: str | None = None, + base64_value: str | None = None, + ): """Install a web extension. Exactly one of the three keyword arguments must be provided. @@ -82,7 +86,11 @@ def install(self, path: str | None = None, archive_path: str | None = None, base Raises: ValueError: If more than one, or none, of the arguments is provided. """ - provided = [k for k, v in {"path": path, "archive_path": archive_path, "base64_value": base64_value}.items() if v is not None] + provided = [ + k for k, v in { + "path": path, "archive_path": archive_path, "base64_value": base64_value, + }.items() if v is not None + ] if len(provided) != 1: raise ValueError( f"Exactly one of path, archive_path, or base64_value must be provided; got: {provided}" From 71278c1f65b4db6ad719c0ba470ed5e8186be781 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 28 Feb 2026 08:54:27 +0000 Subject: [PATCH 03/37] fixup --- py/generate_bidi.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 2db595ff37cd0..4bf0d8b64514e 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -721,7 +721,9 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code += "# BiDi Event Name to Parameter Type Mapping\n" code += "EVENT_NAME_MAPPING = {\n" # Collect event keys from extra_events so we skip CDDL duplicates - extra_event_keys = {evt["event_key"] for evt in enhancements.get("extra_events", [])} + extra_event_keys = { + evt["event_key"] for evt in enhancements.get("extra_events", []) + } for event_def in self.events: # Convert method name to user-friendly event name # e.g., "browsingContext.contextCreated" -> "context_created" @@ -972,7 +974,9 @@ def clear_event_handlers(self) -> None: m = re.search(r"def\s+(\w+)\s*\(", extra_meth) if m: extra_method_names.add(m.group(1)) - exclude_methods = set(enhancements.get("exclude_methods", [])) | extra_method_names + exclude_methods = ( + set(enhancements.get("exclude_methods", [])) | extra_method_names + ) if self.commands: for command in self.commands: # Get method-specific enhancements @@ -1035,7 +1039,9 @@ def clear_event_handlers(self) -> None: code += "_globals = globals()\n" code += f"{class_name}.EVENT_CONFIGS = {{\n" # Collect extra event keys to skip CDDL duplicates - extra_event_keys_cfg = {evt["event_key"] for evt in enhancements.get("extra_events", [])} + extra_event_keys_cfg = { + evt["event_key"] for evt in enhancements.get("extra_events", []) + } for event_def in self.events: # Convert method name to user-friendly event name method_parts = event_def.method.split(".") @@ -1051,7 +1057,7 @@ def clear_event_handlers(self) -> None: f' _globals.get("{event_def.name}", dict))\n' f' if _globals.get("{event_def.name}")\n' f' else EventConfig("{event_name}", "{event_def.method}", dict)\n' - f' ),\n' + f" ),\n" ) # Extra events not in the CDDL spec for extra_evt in enhancements.get("extra_events", []): @@ -1064,7 +1070,7 @@ def clear_event_handlers(self) -> None: f' "{ek}": EventConfig(\n' f' "{ek}", "{be}",\n' f' _globals.get("{ec}", dict),\n' - f' ),\n' + f" ),\n" ) else: code += single + "\n" From f998e6d43d09dbba385fd19579261618508ac944 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 28 Feb 2026 08:55:57 +0000 Subject: [PATCH 04/37] fixup --- py/selenium/webdriver/common/bidi/cdp.py | 515 +++++++++++++++++++++++ 1 file changed, 515 insertions(+) create mode 100644 py/selenium/webdriver/common/bidi/cdp.py diff --git a/py/selenium/webdriver/common/bidi/cdp.py b/py/selenium/webdriver/common/bidi/cdp.py new file mode 100644 index 0000000000000..b097762fe50cd --- /dev/null +++ b/py/selenium/webdriver/common/bidi/cdp.py @@ -0,0 +1,515 @@ +# The MIT License(MIT) +# +# Copyright(c) 2018 Hyperion Gray +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files(the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# This code comes from https://github.com/HyperionGray/trio-chrome-devtools-protocol/tree/master/trio_cdp + +import contextvars +import importlib +import itertools +import json +import logging +import pathlib +from collections import defaultdict +from collections.abc import AsyncGenerator, AsyncIterator, Generator +from contextlib import asynccontextmanager, contextmanager +from dataclasses import dataclass +from typing import Any, TypeVar + +import trio +from trio_websocket import ConnectionClosed as WsConnectionClosed +from trio_websocket import connect_websocket_url + +logger = logging.getLogger("trio_cdp") +T = TypeVar("T") +MAX_WS_MESSAGE_SIZE = 2**24 + +devtools = None +version = None + + +def import_devtools(ver): + """Attempt to load the current latest available devtools into the module cache for use later.""" + global devtools + global version + version = ver + base = "selenium.webdriver.common.devtools.v" + try: + devtools = importlib.import_module(f"{base}{ver}") + return devtools + except ModuleNotFoundError: + # Attempt to parse and load the 'most recent' devtools module. This is likely + # because cdp has been updated but selenium python has not been released yet. + devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") + versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir()) + latest = max(int(x[1:]) for x in versions) + selenium_logger = logging.getLogger(__name__) + selenium_logger.debug("Falling back to loading `devtools`: v%s", latest) + devtools = importlib.import_module(f"{base}{latest}") + return devtools + + +_connection_context: contextvars.ContextVar = contextvars.ContextVar("connection_context") +_session_context: contextvars.ContextVar = contextvars.ContextVar("session_context") + + +def get_connection_context(fn_name): + """Look up the current connection. + + If there is no current connection, raise a ``RuntimeError`` with a + helpful message. + """ + try: + return _connection_context.get() + except LookupError: + raise RuntimeError(f"{fn_name}() must be called in a connection context.") + + +def get_session_context(fn_name): + """Look up the current session. + + If there is no current session, raise a ``RuntimeError`` with a + helpful message. + """ + try: + return _session_context.get() + except LookupError: + raise RuntimeError(f"{fn_name}() must be called in a session context.") + + +@contextmanager +def connection_context(connection): + """Context manager installs ``connection`` as the session context for the current Trio task.""" + token = _connection_context.set(connection) + try: + yield + finally: + _connection_context.reset(token) + + +@contextmanager +def session_context(session): + """Context manager installs ``session`` as the session context for the current Trio task.""" + token = _session_context.set(session) + try: + yield + finally: + _session_context.reset(token) + + +def set_global_connection(connection): + """Install ``connection`` in the root context so that it will become the default connection for all tasks. + + This is generally not recommended, except it may be necessary in + certain use cases such as running inside Jupyter notebook. + """ + global _connection_context + _connection_context = contextvars.ContextVar("_connection_context", default=connection) + + +def set_global_session(session): + """Install ``session`` in the root context so that it will become the default session for all tasks. + + This is generally not recommended, except it may be necessary in + certain use cases such as running inside Jupyter notebook. + """ + global _session_context + _session_context = contextvars.ContextVar("_session_context", default=session) + + +class BrowserError(Exception): + """This exception is raised when the browser's response to a command indicates that an error occurred.""" + + def __init__(self, obj): + self.code = obj.get("code") + self.message = obj.get("message") + self.detail = obj.get("data") + + def __str__(self): + return f"BrowserError {self.detail}" + + +class CdpConnectionClosed(WsConnectionClosed): + """Raised when a public method is called on a closed CDP connection.""" + + def __init__(self, reason): + """Constructor. + + Args: + reason: wsproto.frame_protocol.CloseReason + """ + self.reason = reason + + def __repr__(self): + """Return representation.""" + return f"{self.__class__.__name__}<{self.reason}>" + + +class InternalError(Exception): + """This exception is only raised when there is faulty logic in TrioCDP or the integration with PyCDP.""" + + pass + + +@dataclass +class CmEventProxy: + """A proxy object returned by :meth:`CdpBase.wait_for()``. + + After the context manager executes, this proxy object will have a + value set that contains the returned event. + """ + + value: Any = None + + +class CdpBase: + def __init__(self, ws, session_id, target_id): + self.ws = ws + self.session_id = session_id + self.target_id = target_id + self.channels = defaultdict(set) + self.id_iter = itertools.count() + self.inflight_cmd = {} + self.inflight_result = {} + + async def execute(self, cmd: Generator[dict, T, Any]) -> T: + """Execute a command on the server and wait for the result. + + Args: + cmd: any CDP command + + Returns: + a CDP result + """ + cmd_id = next(self.id_iter) + cmd_event = trio.Event() + self.inflight_cmd[cmd_id] = cmd, cmd_event + request = next(cmd) + request["id"] = cmd_id + if self.session_id: + request["sessionId"] = self.session_id + request_str = json.dumps(request) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"Sending CDP message: {cmd_id} {cmd_event}: {request_str}") + try: + await self.ws.send_message(request_str) + except WsConnectionClosed as wcc: + raise CdpConnectionClosed(wcc.reason) from None + await cmd_event.wait() + response = self.inflight_result.pop(cmd_id) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"Received CDP message: {response}") + if isinstance(response, Exception): + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"Exception raised by {cmd_event} message: {type(response).__name__}") + raise response + return response + + def listen(self, *event_types, buffer_size=10): + """Listen for events. + + Returns: + An async iterator that iterates over events matching the indicated types. + """ + sender, receiver = trio.open_memory_channel(buffer_size) + for event_type in event_types: + self.channels[event_type].add(sender) + return receiver + + @asynccontextmanager + async def wait_for(self, event_type: type[T], buffer_size=10) -> AsyncGenerator[CmEventProxy, None]: + """Wait for an event of the given type and return it. + + This is an async context manager, so you should open it inside + an async with block. The block will not exit until the indicated + event is received. + """ + sender: trio.MemorySendChannel + receiver: trio.MemoryReceiveChannel + sender, receiver = trio.open_memory_channel(buffer_size) + self.channels[event_type].add(sender) + proxy = CmEventProxy() + yield proxy + async with receiver: + event = await receiver.receive() + proxy.value = event + + def _handle_data(self, data): + """Handle incoming WebSocket data. + + Args: + data: a JSON dictionary + """ + if "id" in data: + self._handle_cmd_response(data) + else: + self._handle_event(data) + + def _handle_cmd_response(self, data: dict): + """Handle a response to a command. + + This will set an event flag that will return control to the + task that called the command. + + Args: + data: response as a JSON dictionary + """ + cmd_id = data["id"] + try: + cmd, event = self.inflight_cmd.pop(cmd_id) + except KeyError: + logger.warning("Got a message with a command ID that does not exist: %s", data) + return + if "error" in data: + # If the server reported an error, convert it to an exception and do + # not process the response any further. + self.inflight_result[cmd_id] = BrowserError(data["error"]) + else: + # Otherwise, continue the generator to parse the JSON result + # into a CDP object. + try: + _ = cmd.send(data["result"]) + raise InternalError("The command's generator function did not exit when expected!") + except StopIteration as exit: + return_ = exit.value + self.inflight_result[cmd_id] = return_ + event.set() + + def _handle_event(self, data: dict): + """Handle an event. + + Args: + data: event as a JSON dictionary + """ + global devtools + if devtools is None: + raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + event = devtools.util.parse_json_event(data) + logger.debug("Received event: %s", event) + to_remove = set() + for sender in self.channels[type(event)]: + try: + sender.send_nowait(event) + except trio.WouldBlock: + logger.error('Unable to send event "%r" due to full channel %s', event, sender) + except trio.BrokenResourceError: + to_remove.add(sender) + if to_remove: + self.channels[type(event)] -= to_remove + + +class CdpSession(CdpBase): + """Contains the state for a CDP session. + + Generally you should not instantiate this object yourself; you should call + :meth:`CdpConnection.open_session`. + """ + + def __init__(self, ws, session_id, target_id): + """Constructor. + + Args: + ws: trio_websocket.WebSocketConnection + session_id: devtools.target.SessionID + target_id: devtools.target.TargetID + """ + super().__init__(ws, session_id, target_id) + + self._dom_enable_count = 0 + self._dom_enable_lock = trio.Lock() + self._page_enable_count = 0 + self._page_enable_lock = trio.Lock() + + @asynccontextmanager + async def dom_enable(self): + """Context manager that executes ``dom.enable()`` when it enters and then calls ``dom.disable()``. + + This keeps track of concurrent callers and only disables DOM + events when all callers have exited. + """ + global devtools + async with self._dom_enable_lock: + self._dom_enable_count += 1 + if self._dom_enable_count == 1: + await self.execute(devtools.dom.enable()) + + yield + + async with self._dom_enable_lock: + self._dom_enable_count -= 1 + if self._dom_enable_count == 0: + await self.execute(devtools.dom.disable()) + + @asynccontextmanager + async def page_enable(self): + """Context manager executes ``page.enable()`` when it enters and then calls ``page.disable()`` when it exits. + + This keeps track of concurrent callers and only disables page + events when all callers have exited. + """ + global devtools + async with self._page_enable_lock: + self._page_enable_count += 1 + if self._page_enable_count == 1: + await self.execute(devtools.page.enable()) + + yield + + async with self._page_enable_lock: + self._page_enable_count -= 1 + if self._page_enable_count == 0: + await self.execute(devtools.page.disable()) + + +class CdpConnection(CdpBase, trio.abc.AsyncResource): + """Contains the connection state for a Chrome DevTools Protocol server. + + CDP can multiplex multiple "sessions" over a single connection. This + class corresponds to the "root" session, i.e. the implicitly created + session that has no session ID. This class is responsible for + reading incoming WebSocket messages and forwarding them to the + corresponding session, as well as handling messages targeted at the + root session itself. You should generally call the + :func:`open_cdp()` instead of instantiating this class directly. + """ + + def __init__(self, ws): + """Constructor. + + Args: + ws: trio_websocket.WebSocketConnection + """ + super().__init__(ws, session_id=None, target_id=None) + self.sessions = {} + + async def aclose(self): + """Close the underlying WebSocket connection. + + This will cause the reader task to gracefully exit when it tries + to read the next message from the WebSocket. All of the public + APIs (``execute()``, ``listen()``, etc.) will raise + ``CdpConnectionClosed`` after the CDP connection is closed. It + is safe to call this multiple times. + """ + await self.ws.aclose() + + @asynccontextmanager + async def open_session(self, target_id) -> AsyncIterator[CdpSession]: + """Context manager opens a session and enables the "simple" style of calling CDP APIs. + + For example, inside a session context, you can call ``await + dom.get_document()`` and it will execute on the current session + automatically. + """ + session = await self.connect_session(target_id) + with session_context(session): + yield session + + async def connect_session(self, target_id) -> "CdpSession": + """Returns a new :class:`CdpSession` connected to the specified target.""" + global devtools + if devtools is None: + raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + session_id = await self.execute(devtools.target.attach_to_target(target_id, True)) + session = CdpSession(self.ws, session_id, target_id) + self.sessions[session_id] = session + return session + + async def _reader_task(self): + """Runs in the background and handles incoming messages. + + Dispatches responses to commands and events to listeners. + """ + global devtools + if devtools is None: + raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + while True: + try: + message = await self.ws.get_message() + except WsConnectionClosed: + # If the WebSocket is closed, we don't want to throw an + # exception from the reader task. Instead we will throw + # exceptions from the public API methods, and we can quietly + # exit the reader task here. + break + try: + data = json.loads(message) + except json.JSONDecodeError: + raise BrowserError({"code": -32700, "message": "Client received invalid JSON", "data": message}) + logger.debug("Received message %r", data) + if "sessionId" in data: + session_id = devtools.target.SessionID(data["sessionId"]) + try: + session = self.sessions[session_id] + except KeyError: + raise BrowserError( + { + "code": -32700, + "message": "Browser sent a message for an invalid session", + "data": f"{session_id!r}", + } + ) + session._handle_data(data) + else: + self._handle_data(data) + + for _, session in self.sessions.items(): + for _, senders in session.channels.items(): + for sender in senders: + sender.close() + + +@asynccontextmanager +async def open_cdp(url) -> AsyncIterator[CdpConnection]: + """Async context manager opens a connection to the browser then closes the connection when the block exits. + + The context manager also sets the connection as the default + connection for the current task, so that commands like ``await + target.get_targets()`` will run on this connection automatically. If + you want to use multiple connections concurrently, it is recommended + to open each on in a separate task. + """ + async with trio.open_nursery() as nursery: + conn = await connect_cdp(nursery, url) + try: + with connection_context(conn): + yield conn + finally: + await conn.aclose() + + +async def connect_cdp(nursery, url) -> CdpConnection: + """Connect to the browser specified by ``url`` and spawn a background task in the specified nursery. + + The ``open_cdp()`` context manager is preferred in most situations. + You should only use this function if you need to specify a custom + nursery. This connection is not automatically closed! You can either + use the connection object as a context manager (``async with + conn:``) or else call ``await conn.aclose()`` on it when you are + done with it. If ``set_context`` is True, then the returned + connection will be installed as the default connection for the + current task. This argument is for unusual use cases, such as + running inside of a notebook. + """ + ws = await connect_websocket_url(nursery, url, max_message_size=MAX_WS_MESSAGE_SIZE) + cdp_conn = CdpConnection(ws) + nursery.start_soon(cdp_conn._reader_task) + return cdp_conn From 803b617249cfde9daebc2e97ab5390f31a666eb1 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Mon, 2 Mar 2026 11:26:58 +0000 Subject: [PATCH 05/37] handle comments --- py/generate_bidi.py | 64 +++--- py/private/bidi_enhancements_manifest.py | 36 +++- py/selenium/webdriver/common/bidi/browser.py | 38 ++-- .../webdriver/common/bidi/browsing_context.py | 68 +++--- py/selenium/webdriver/common/bidi/common.py | 6 +- .../webdriver/common/bidi/emulation.py | 36 ++-- py/selenium/webdriver/common/bidi/input.py | 30 +-- py/selenium/webdriver/common/bidi/log.py | 4 +- py/selenium/webdriver/common/bidi/network.py | 194 ++++++++++-------- py/selenium/webdriver/common/bidi/script.py | 154 +++++++------- py/selenium/webdriver/common/bidi/session.py | 30 +-- py/selenium/webdriver/common/bidi/storage.py | 14 +- .../webdriver/common/bidi/webextension.py | 12 +- py/selenium/webdriver/remote/webdriver.py | 8 +- 14 files changed, 386 insertions(+), 308 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 4bf0d8b64514e..5d7f39e53abfc 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -368,11 +368,14 @@ def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str dataclass_methods = enhancements.get("dataclass_methods", {}) method_docstrings = enhancements.get("method_docstrings", {}) - # Generate class name from type name (keep it as-is, don't split on underscores) - class_name = self.name + # Generate class name from type name. + # CDDL type names that start with a lowercase letter (e.g. camelCase + # command-parameter types like "setNetworkConditionsParameters") are + # capitalised so that the resulting Python class follows PascalCase. + class_name = self.name[0].upper() + self.name[1:] if self.name else self.name code = "@dataclass\n" code += f"class {class_name}:\n" - code += f' """{self.description or self.name}."""\n\n' + code += f' """{class_name} type definition."""\n\n' if not self.fields: code += " pass\n" @@ -466,9 +469,9 @@ def to_python_class(self) -> str: Generates a simple class with string constants to match the existing pattern in the codebase (e.g., ClientWindowState). """ - class_name = self.name + class_name = self.name[0].upper() + self.name[1:] if self.name else self.name code = f"class {class_name}:\n" - code += f' """{self.description or self.name}."""\n\n' + code += f' """{class_name}."""\n\n' for value in self.values: # Convert value to UPPER_SNAKE_CASE constant name @@ -684,8 +687,19 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ - # Generate enums first + # Collect names of extra_dataclasses so we can skip CDDL-generated + # enums and types that are overridden by manual definitions. + extra_cls_names = set() + for extra_cls in enhancements.get("extra_dataclasses", []): + m = re.search(r"^class\s+(\w+)", extra_cls, re.MULTILINE) + if m: + extra_cls_names.add(m.group(1)) + exclude_types = set(enhancements.get("exclude_types", [])) | extra_cls_names + + # Generate enums first, skipping any that are overridden via extra_dataclasses for enum_def in self.enums: + if enum_def.name in exclude_types: + continue code += enum_def.to_python_class() code += "\n\n" @@ -694,13 +708,6 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code += f"{alias} = {target}\n\n" # Generate type dataclasses, skipping any overridden by extra_dataclasses - # Also auto-exclude types whose names appear in extra_dataclasses - extra_cls_names = set() - for extra_cls in enhancements.get("extra_dataclasses", []): - m = re.search(r"^class\s+(\w+)", extra_cls, re.MULTILINE) - if m: - extra_cls_names.add(m.group(1)) - exclude_types = set(enhancements.get("exclude_types", [])) | extra_cls_names for type_def in self.types: if type_def.name in exclude_types: continue @@ -1146,8 +1153,12 @@ def _remove_comments(self, content: str) -> str: def _extract_definitions(self, content: str) -> None: """Extract CDDL definitions (type definitions, commands, etc.).""" # Match pattern: Name = Definition - # Handles multiline definitions properly - pattern = r"(\w+(?:\.\w+)*)\s*=\s*(.+?)(?=\n\w+(?:\.\w+)?\s*=|\Z)" + # Handles multiline definitions properly. + # The \s* after \n in the lookahead allows definitions that start with + # leading whitespace (e.g. " network.BeforeRequestSent = (") to be + # recognised as separate definitions instead of being swallowed into + # the body of the preceding definition. + pattern = r"(\w+(?:\.\w+)*)\s*=\s*(.+?)(?=\n\s*\w+(?:\.\w+)?\s*=|\Z)" for match in re.finditer(pattern, content, re.DOTALL): name = match.group(1).strip() @@ -1589,12 +1600,15 @@ def generate_common_file(output_path: Path) -> None: "\n" '"""Common utilities for BiDi command construction."""\n' "\n" - "from typing import Any, Dict, Generator\n" + "from __future__ import annotations\n" + "\n" + "from collections.abc import Generator\n" + "from typing import Any\n" "\n" "\n" "def command_builder(\n" - " method: str, params: Dict[str, Any]\n" - ") -> Generator[Dict[str, Any], Any, Any]:\n" + " method: str, params: dict[str, Any] | None = None\n" + ") -> Generator[dict[str, Any], Any, Any]:\n" ' """Build a BiDi command generator.\n' "\n" " Args:\n" @@ -1607,6 +1621,8 @@ def generate_common_file(output_path: Path) -> None: " Returns:\n" " The result from the BiDi command execution\n" ' """\n' + " if params is None:\n" + " params = {}\n" ' result = yield {"method": method, "params": params}\n' " return result\n" ) @@ -1680,8 +1696,10 @@ def generate_permissions_file(output_path: Path) -> None: "\n" "from __future__ import annotations\n" "\n" + "from __future__ import annotations\n" + "\n" "from enum import Enum\n" - "from typing import Any, Optional, Union\n" + "from typing import Any\n" "\n" "from .common import command_builder\n" "\n" @@ -1724,10 +1742,10 @@ def generate_permissions_file(output_path: Path) -> None: "\n" " def set_permission(\n" " self,\n" - " descriptor: Union[PermissionDescriptor, str],\n" - " state: Union[PermissionState, str],\n" - " origin: Optional[str] = None,\n" - " user_context: Optional[str] = None,\n" + " descriptor: PermissionDescriptor | str,\n" + " state: PermissionState | str,\n" + " origin: str | None = None,\n" + " user_context: str | None = None,\n" " ) -> None:\n" ' """Set a permission for a given origin.\n' "\n" diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 39af67d4c635b..adf0a17128af3 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -81,6 +81,20 @@ "result_param": "download_behavior", }, }, + # Replace the auto-generated ClientWindowNamedState so we can add the + # convenience NORMAL constant. In the BiDi spec "normal" is the state + # represented by ClientWindowRectState, but exposing it here keeps the + # Python API consistent with the old ClientWindowState enum. + "exclude_types": ["ClientWindowNamedState"], + "extra_dataclasses": [ + '''class ClientWindowNamedState: + """Named states for a browser client window.""" + + FULLSCREEN = "fullscreen" + MAXIMIZED = "maximized" + MINIMIZED = "minimized" + NORMAL = "normal"''', + ], # Override the generator-produced set_download_behavior so that # downloadBehavior is never stripped by the generic None filter. # The BiDi spec marks it as required (can be null, but must be present). @@ -845,8 +859,11 @@ def from_json(self2, p): ], }, "network": { - # Initialize intercepts tracking list in __init__ - "extra_init_code": ["self.intercepts = []"], + # Initialize intercepts tracking list and per-handler intercept map + "extra_init_code": [ + "self.intercepts = []", + "self._handler_intercepts: dict = {}", + ], # Request class wraps a beforeRequestSent event params and provides actions "extra_dataclasses": [ '''class BytesValue: @@ -940,7 +957,8 @@ def continue_request(self, **kwargs): "auth_required": "authRequired", } phase = phase_map.get(event, "beforeRequestSent") - self._add_intercept(phases=[phase], url_patterns=url_patterns) + intercept_result = self._add_intercept(phases=[phase], url_patterns=url_patterns) + intercept_id = intercept_result.get("intercept") if intercept_result else None def _request_callback(params): raw = ( @@ -951,15 +969,21 @@ def _request_callback(params): request = Request(self._conn, raw) callback(request) - return self.add_event_handler(event, _request_callback)''', + callback_id = self.add_event_handler(event, _request_callback) + if intercept_id: + self._handler_intercepts[callback_id] = intercept_id + return callback_id''', ''' def remove_request_handler(self, event, callback_id): - """Remove a network request handler. + """Remove a network request handler and its associated network intercept. Args: event: The event name used when adding the handler. callback_id: The int returned by add_request_handler. """ - self.remove_event_handler(event, callback_id)''', + self.remove_event_handler(event, callback_id) + intercept_id = self._handler_intercepts.pop(callback_id, None) + if intercept_id: + self._remove_intercept(intercept_id)''', ''' def clear_request_handlers(self): """Clear all request handlers and remove all tracked intercepts.""" self.clear_event_handlers() diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index acda63f71953e..71f917634304d 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -60,17 +60,9 @@ def validate_download_behavior( raise ValueError("destination_folder should not be provided when allowed=False") -class ClientWindowNamedState: - """ClientWindowNamedState.""" - - FULLSCREEN = "fullscreen" - MAXIMIZED = "maximized" - MINIMIZED = "minimized" - - @dataclass class ClientWindowInfo: - """ClientWindowInfo.""" + """ClientWindowInfo type definition.""" active: bool | None = None client_window: Any | None = None @@ -112,14 +104,14 @@ def get_y(self): @dataclass class UserContextInfo: - """UserContextInfo.""" + """UserContextInfo type definition.""" user_context: Any | None = None @dataclass class CreateUserContextParameters: - """CreateUserContextParameters.""" + """CreateUserContextParameters type definition.""" accept_insecure_certs: bool | None = None proxy: Any | None = None @@ -128,35 +120,35 @@ class CreateUserContextParameters: @dataclass class GetClientWindowsResult: - """GetClientWindowsResult.""" + """GetClientWindowsResult type definition.""" client_windows: list[Any | None] | None = field(default_factory=list) @dataclass class GetUserContextsResult: - """GetUserContextsResult.""" + """GetUserContextsResult type definition.""" user_contexts: list[Any | None] | None = field(default_factory=list) @dataclass class RemoveUserContextParameters: - """RemoveUserContextParameters.""" + """RemoveUserContextParameters type definition.""" user_context: Any | None = None @dataclass class SetClientWindowStateParameters: - """SetClientWindowStateParameters.""" + """SetClientWindowStateParameters type definition.""" client_window: Any | None = None @dataclass class ClientWindowRectState: - """ClientWindowRectState.""" + """ClientWindowRectState type definition.""" state: str = field(default="normal", init=False) width: Any | None = None @@ -167,7 +159,7 @@ class ClientWindowRectState: @dataclass class SetDownloadBehaviorParameters: - """SetDownloadBehaviorParameters.""" + """SetDownloadBehaviorParameters type definition.""" download_behavior: Any | None = None user_contexts: list[Any | None] | None = field(default_factory=list) @@ -175,7 +167,7 @@ class SetDownloadBehaviorParameters: @dataclass class DownloadBehaviorAllowed: - """DownloadBehaviorAllowed.""" + """DownloadBehaviorAllowed type definition.""" type: str = field(default="allowed", init=False) destination_folder: str | None = None @@ -183,11 +175,19 @@ class DownloadBehaviorAllowed: @dataclass class DownloadBehaviorDenied: - """DownloadBehaviorDenied.""" + """DownloadBehaviorDenied type definition.""" type: str = field(default="denied", init=False) +class ClientWindowNamedState: + """Named states for a browser client window.""" + + FULLSCREEN = "fullscreen" + MAXIMIZED = "maximized" + MINIMIZED = "minimized" + NORMAL = "normal" + class Browser: """WebDriver BiDi browser module.""" diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 5f128635df29d..ede96071778c3 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -48,7 +48,7 @@ class DownloadCompleteParams: @dataclass class Info: - """Info.""" + """Info type definition.""" children: Any | None = None client_window: Any | None = None @@ -61,7 +61,7 @@ class Info: @dataclass class AccessibilityLocator: - """AccessibilityLocator.""" + """AccessibilityLocator type definition.""" type: str = field(default="accessibility", init=False) name: str | None = None @@ -70,7 +70,7 @@ class AccessibilityLocator: @dataclass class CssLocator: - """CssLocator.""" + """CssLocator type definition.""" type: str = field(default="css", init=False) value: str | None = None @@ -78,7 +78,7 @@ class CssLocator: @dataclass class ContextLocator: - """ContextLocator.""" + """ContextLocator type definition.""" type: str = field(default="context", init=False) context: Any | None = None @@ -86,7 +86,7 @@ class ContextLocator: @dataclass class InnerTextLocator: - """InnerTextLocator.""" + """InnerTextLocator type definition.""" type: str = field(default="innerText", init=False) value: str | None = None @@ -97,7 +97,7 @@ class InnerTextLocator: @dataclass class XPathLocator: - """XPathLocator.""" + """XPathLocator type definition.""" type: str = field(default="xpath", init=False) value: str | None = None @@ -105,7 +105,7 @@ class XPathLocator: @dataclass class BaseNavigationInfo: - """BaseNavigationInfo.""" + """BaseNavigationInfo type definition.""" context: Any | None = None navigation: Any | None = None @@ -115,14 +115,14 @@ class BaseNavigationInfo: @dataclass class ActivateParameters: - """ActivateParameters.""" + """ActivateParameters type definition.""" context: Any | None = None @dataclass class CaptureScreenshotParameters: - """CaptureScreenshotParameters.""" + """CaptureScreenshotParameters type definition.""" context: Any | None = None format: Any | None = None @@ -131,7 +131,7 @@ class CaptureScreenshotParameters: @dataclass class ImageFormat: - """ImageFormat.""" + """ImageFormat type definition.""" type: str | None = None quality: Any | None = None @@ -139,7 +139,7 @@ class ImageFormat: @dataclass class ElementClipRectangle: - """ElementClipRectangle.""" + """ElementClipRectangle type definition.""" type: str = field(default="element", init=False) element: Any | None = None @@ -147,7 +147,7 @@ class ElementClipRectangle: @dataclass class BoxClipRectangle: - """BoxClipRectangle.""" + """BoxClipRectangle type definition.""" type: str = field(default="box", init=False) x: Any | None = None @@ -158,14 +158,14 @@ class BoxClipRectangle: @dataclass class CaptureScreenshotResult: - """CaptureScreenshotResult.""" + """CaptureScreenshotResult type definition.""" data: str | None = None @dataclass class CloseParameters: - """CloseParameters.""" + """CloseParameters type definition.""" context: Any | None = None prompt_unload: bool | None = None @@ -173,7 +173,7 @@ class CloseParameters: @dataclass class CreateParameters: - """CreateParameters.""" + """CreateParameters type definition.""" type: Any | None = None reference_context: Any | None = None @@ -183,14 +183,14 @@ class CreateParameters: @dataclass class CreateResult: - """CreateResult.""" + """CreateResult type definition.""" context: Any | None = None @dataclass class GetTreeParameters: - """GetTreeParameters.""" + """GetTreeParameters type definition.""" max_depth: Any | None = None root: Any | None = None @@ -198,14 +198,14 @@ class GetTreeParameters: @dataclass class GetTreeResult: - """GetTreeResult.""" + """GetTreeResult type definition.""" contexts: Any | None = None @dataclass class HandleUserPromptParameters: - """HandleUserPromptParameters.""" + """HandleUserPromptParameters type definition.""" context: Any | None = None accept: bool | None = None @@ -214,7 +214,7 @@ class HandleUserPromptParameters: @dataclass class LocateNodesParameters: - """LocateNodesParameters.""" + """LocateNodesParameters type definition.""" context: Any | None = None locator: Any | None = None @@ -224,14 +224,14 @@ class LocateNodesParameters: @dataclass class LocateNodesResult: - """LocateNodesResult.""" + """LocateNodesResult type definition.""" nodes: list[Any | None] | None = field(default_factory=list) @dataclass class NavigateParameters: - """NavigateParameters.""" + """NavigateParameters type definition.""" context: Any | None = None url: str | None = None @@ -240,7 +240,7 @@ class NavigateParameters: @dataclass class NavigateResult: - """NavigateResult.""" + """NavigateResult type definition.""" navigation: Any | None = None url: str | None = None @@ -248,7 +248,7 @@ class NavigateResult: @dataclass class PrintParameters: - """PrintParameters.""" + """PrintParameters type definition.""" context: Any | None = None background: bool | None = None @@ -260,7 +260,7 @@ class PrintParameters: @dataclass class PrintMarginParameters: - """PrintMarginParameters.""" + """PrintMarginParameters type definition.""" bottom: Any | None = None left: Any | None = None @@ -270,7 +270,7 @@ class PrintMarginParameters: @dataclass class PrintPageParameters: - """PrintPageParameters.""" + """PrintPageParameters type definition.""" height: Any | None = None width: Any | None = None @@ -278,14 +278,14 @@ class PrintPageParameters: @dataclass class PrintResult: - """PrintResult.""" + """PrintResult type definition.""" data: str | None = None @dataclass class ReloadParameters: - """ReloadParameters.""" + """ReloadParameters type definition.""" context: Any | None = None ignore_cache: bool | None = None @@ -294,7 +294,7 @@ class ReloadParameters: @dataclass class SetViewportParameters: - """SetViewportParameters.""" + """SetViewportParameters type definition.""" context: Any | None = None viewport: Any | None = None @@ -304,7 +304,7 @@ class SetViewportParameters: @dataclass class Viewport: - """Viewport.""" + """Viewport type definition.""" width: Any | None = None height: Any | None = None @@ -312,7 +312,7 @@ class Viewport: @dataclass class TraverseHistoryParameters: - """TraverseHistoryParameters.""" + """TraverseHistoryParameters type definition.""" context: Any | None = None delta: Any | None = None @@ -320,7 +320,7 @@ class TraverseHistoryParameters: @dataclass class HistoryUpdatedParameters: - """HistoryUpdatedParameters.""" + """HistoryUpdatedParameters type definition.""" context: Any | None = None timestamp: Any | None = None @@ -329,7 +329,7 @@ class HistoryUpdatedParameters: @dataclass class UserPromptClosedParameters: - """UserPromptClosedParameters.""" + """UserPromptClosedParameters type definition.""" context: Any | None = None accepted: bool | None = None @@ -339,7 +339,7 @@ class UserPromptClosedParameters: @dataclass class UserPromptOpenedParameters: - """UserPromptOpenedParameters.""" + """UserPromptOpenedParameters type definition.""" context: Any | None = None handler: Any | None = None diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index d7cb436a08471..dae051876833e 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -17,12 +17,14 @@ """Common utilities for BiDi command construction.""" +from __future__ import annotations + from collections.abc import Generator from typing import Any def command_builder( - method: str, params: dict[str, Any] + method: str, params: dict[str, Any] | None = None ) -> Generator[dict[str, Any], Any, Any]: """Build a BiDi command generator. @@ -36,5 +38,7 @@ def command_builder( Returns: The result from the BiDi command execution """ + if params is None: + params = {} result = yield {"method": method, "params": params} return result diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index cb575bbdc54dd..fbbe0966d8b3a 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -37,7 +37,7 @@ class ScreenOrientationType: @dataclass class SetForcedColorsModeThemeOverrideParameters: - """SetForcedColorsModeThemeOverrideParameters.""" + """SetForcedColorsModeThemeOverrideParameters type definition.""" theme: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -46,7 +46,7 @@ class SetForcedColorsModeThemeOverrideParameters: @dataclass class SetGeolocationOverrideParameters: - """SetGeolocationOverrideParameters.""" + """SetGeolocationOverrideParameters type definition.""" contexts: list[Any | None] | None = field(default_factory=list) user_contexts: list[Any | None] | None = field(default_factory=list) @@ -54,7 +54,7 @@ class SetGeolocationOverrideParameters: @dataclass class GeolocationCoordinates: - """GeolocationCoordinates.""" + """GeolocationCoordinates type definition.""" latitude: Any | None = None longitude: Any | None = None @@ -67,14 +67,14 @@ class GeolocationCoordinates: @dataclass class GeolocationPositionError: - """GeolocationPositionError.""" + """GeolocationPositionError type definition.""" type: str = field(default="positionUnavailable", init=False) @dataclass class SetLocaleOverrideParameters: - """SetLocaleOverrideParameters.""" + """SetLocaleOverrideParameters type definition.""" locale: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -82,8 +82,8 @@ class SetLocaleOverrideParameters: @dataclass -class setNetworkConditionsParameters: - """setNetworkConditionsParameters.""" +class SetNetworkConditionsParameters: + """SetNetworkConditionsParameters type definition.""" network_conditions: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -92,14 +92,14 @@ class setNetworkConditionsParameters: @dataclass class NetworkConditionsOffline: - """NetworkConditionsOffline.""" + """NetworkConditionsOffline type definition.""" type: str = field(default="offline", init=False) @dataclass class ScreenArea: - """ScreenArea.""" + """ScreenArea type definition.""" width: Any | None = None height: Any | None = None @@ -107,7 +107,7 @@ class ScreenArea: @dataclass class SetScreenSettingsOverrideParameters: - """SetScreenSettingsOverrideParameters.""" + """SetScreenSettingsOverrideParameters type definition.""" screen_area: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -116,7 +116,7 @@ class SetScreenSettingsOverrideParameters: @dataclass class ScreenOrientation: - """ScreenOrientation.""" + """ScreenOrientation type definition.""" natural: Any | None = None type: Any | None = None @@ -124,7 +124,7 @@ class ScreenOrientation: @dataclass class SetScreenOrientationOverrideParameters: - """SetScreenOrientationOverrideParameters.""" + """SetScreenOrientationOverrideParameters type definition.""" screen_orientation: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -133,7 +133,7 @@ class SetScreenOrientationOverrideParameters: @dataclass class SetUserAgentOverrideParameters: - """SetUserAgentOverrideParameters.""" + """SetUserAgentOverrideParameters type definition.""" user_agent: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -142,7 +142,7 @@ class SetUserAgentOverrideParameters: @dataclass class SetViewportMetaOverrideParameters: - """SetViewportMetaOverrideParameters.""" + """SetViewportMetaOverrideParameters type definition.""" viewport_meta: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -151,7 +151,7 @@ class SetViewportMetaOverrideParameters: @dataclass class SetScriptingEnabledParameters: - """SetScriptingEnabledParameters.""" + """SetScriptingEnabledParameters type definition.""" enabled: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -160,7 +160,7 @@ class SetScriptingEnabledParameters: @dataclass class SetScrollbarTypeOverrideParameters: - """SetScrollbarTypeOverrideParameters.""" + """SetScrollbarTypeOverrideParameters type definition.""" scrollbar_type: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -169,7 +169,7 @@ class SetScrollbarTypeOverrideParameters: @dataclass class SetTimezoneOverrideParameters: - """SetTimezoneOverrideParameters.""" + """SetTimezoneOverrideParameters type definition.""" timezone: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -178,7 +178,7 @@ class SetTimezoneOverrideParameters: @dataclass class SetTouchOverrideParameters: - """SetTouchOverrideParameters.""" + """SetTouchOverrideParameters type definition.""" contexts: list[Any | None] | None = field(default_factory=list) user_contexts: list[Any | None] | None = field(default_factory=list) diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 13f43361293f2..c8e58181b343e 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -33,7 +33,7 @@ class Origin: @dataclass class ElementOrigin: - """ElementOrigin.""" + """ElementOrigin type definition.""" type: str = field(default="element", init=False) element: Any | None = None @@ -41,7 +41,7 @@ class ElementOrigin: @dataclass class PerformActionsParameters: - """PerformActionsParameters.""" + """PerformActionsParameters type definition.""" context: Any | None = None actions: list[Any | None] | None = field(default_factory=list) @@ -49,7 +49,7 @@ class PerformActionsParameters: @dataclass class NoneSourceActions: - """NoneSourceActions.""" + """NoneSourceActions type definition.""" type: str = field(default="none", init=False) id: str | None = None @@ -58,7 +58,7 @@ class NoneSourceActions: @dataclass class KeySourceActions: - """KeySourceActions.""" + """KeySourceActions type definition.""" type: str = field(default="key", init=False) id: str | None = None @@ -67,7 +67,7 @@ class KeySourceActions: @dataclass class PointerSourceActions: - """PointerSourceActions.""" + """PointerSourceActions type definition.""" type: str = field(default="pointer", init=False) id: str | None = None @@ -77,14 +77,14 @@ class PointerSourceActions: @dataclass class PointerParameters: - """PointerParameters.""" + """PointerParameters type definition.""" pointer_type: Any | None = None @dataclass class WheelSourceActions: - """WheelSourceActions.""" + """WheelSourceActions type definition.""" type: str = field(default="wheel", init=False) id: str | None = None @@ -93,7 +93,7 @@ class WheelSourceActions: @dataclass class PauseAction: - """PauseAction.""" + """PauseAction type definition.""" type: str = field(default="pause", init=False) duration: Any | None = None @@ -101,7 +101,7 @@ class PauseAction: @dataclass class KeyDownAction: - """KeyDownAction.""" + """KeyDownAction type definition.""" type: str = field(default="keyDown", init=False) value: str | None = None @@ -109,7 +109,7 @@ class KeyDownAction: @dataclass class KeyUpAction: - """KeyUpAction.""" + """KeyUpAction type definition.""" type: str = field(default="keyUp", init=False) value: str | None = None @@ -117,7 +117,7 @@ class KeyUpAction: @dataclass class PointerUpAction: - """PointerUpAction.""" + """PointerUpAction type definition.""" type: str = field(default="pointerUp", init=False) button: Any | None = None @@ -125,7 +125,7 @@ class PointerUpAction: @dataclass class WheelScrollAction: - """WheelScrollAction.""" + """WheelScrollAction type definition.""" type: str = field(default="scroll", init=False) x: Any | None = None @@ -138,7 +138,7 @@ class WheelScrollAction: @dataclass class PointerCommonProperties: - """PointerCommonProperties.""" + """PointerCommonProperties type definition.""" width: Any | None = None height: Any | None = None @@ -151,14 +151,14 @@ class PointerCommonProperties: @dataclass class ReleaseActionsParameters: - """ReleaseActionsParameters.""" + """ReleaseActionsParameters type definition.""" context: Any | None = None @dataclass class SetFilesParameters: - """SetFilesParameters.""" + """SetFilesParameters type definition.""" context: Any | None = None element: Any | None = None diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 7971b807e94a1..eaf52a2ec08c2 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -27,7 +27,7 @@ class Level: @dataclass class BaseLogEntry: - """BaseLogEntry.""" + """BaseLogEntry type definition.""" level: Any | None = None source: Any | None = None @@ -38,7 +38,7 @@ class BaseLogEntry: @dataclass class GenericLogEntry: - """GenericLogEntry.""" + """GenericLogEntry type definition.""" type: str | None = None diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 6e02eeabc4ed7..c9737ac9131d0 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -49,7 +49,7 @@ class ContinueWithAuthNoCredentials: @dataclass class AuthChallenge: - """AuthChallenge.""" + """AuthChallenge type definition.""" scheme: str | None = None realm: str | None = None @@ -57,7 +57,7 @@ class AuthChallenge: @dataclass class AuthCredentials: - """AuthCredentials.""" + """AuthCredentials type definition.""" type: str = field(default="password", init=False) username: str | None = None @@ -66,7 +66,7 @@ class AuthCredentials: @dataclass class BaseParameters: - """BaseParameters.""" + """BaseParameters type definition.""" context: Any | None = None is_blocked: bool | None = None @@ -79,7 +79,7 @@ class BaseParameters: @dataclass class StringValue: - """StringValue.""" + """StringValue type definition.""" type: str = field(default="string", init=False) value: str | None = None @@ -87,7 +87,7 @@ class StringValue: @dataclass class Base64Value: - """Base64Value.""" + """Base64Value type definition.""" type: str = field(default="base64", init=False) value: str | None = None @@ -95,7 +95,7 @@ class Base64Value: @dataclass class Cookie: - """Cookie.""" + """Cookie type definition.""" name: str | None = None value: Any | None = None @@ -110,7 +110,7 @@ class Cookie: @dataclass class CookieHeader: - """CookieHeader.""" + """CookieHeader type definition.""" name: str | None = None value: Any | None = None @@ -118,7 +118,7 @@ class CookieHeader: @dataclass class FetchTimingInfo: - """FetchTimingInfo.""" + """FetchTimingInfo type definition.""" time_origin: Any | None = None request_time: Any | None = None @@ -137,7 +137,7 @@ class FetchTimingInfo: @dataclass class Header: - """Header.""" + """Header type definition.""" name: str | None = None value: Any | None = None @@ -145,7 +145,7 @@ class Header: @dataclass class Initiator: - """Initiator.""" + """Initiator type definition.""" column_number: Any | None = None line_number: Any | None = None @@ -156,14 +156,14 @@ class Initiator: @dataclass class ResponseContent: - """ResponseContent.""" + """ResponseContent type definition.""" size: Any | None = None @dataclass class ResponseData: - """ResponseData.""" + """ResponseData type definition.""" url: str | None = None protocol: str | None = None @@ -181,7 +181,7 @@ class ResponseData: @dataclass class SetCookieHeader: - """SetCookieHeader.""" + """SetCookieHeader type definition.""" name: str | None = None value: Any | None = None @@ -196,7 +196,7 @@ class SetCookieHeader: @dataclass class UrlPatternPattern: - """UrlPatternPattern.""" + """UrlPatternPattern type definition.""" type: str = field(default="pattern", init=False) protocol: str | None = None @@ -208,7 +208,7 @@ class UrlPatternPattern: @dataclass class UrlPatternString: - """UrlPatternString.""" + """UrlPatternString type definition.""" type: str = field(default="string", init=False) pattern: str | None = None @@ -216,7 +216,7 @@ class UrlPatternString: @dataclass class AddDataCollectorParameters: - """AddDataCollectorParameters.""" + """AddDataCollectorParameters type definition.""" data_types: list[Any | None] | None = field(default_factory=list) max_encoded_data_size: Any | None = None @@ -227,14 +227,14 @@ class AddDataCollectorParameters: @dataclass class AddDataCollectorResult: - """AddDataCollectorResult.""" + """AddDataCollectorResult type definition.""" collector: Any | None = None @dataclass class AddInterceptParameters: - """AddInterceptParameters.""" + """AddInterceptParameters type definition.""" phases: list[Any | None] | None = field(default_factory=list) contexts: list[Any | None] | None = field(default_factory=list) @@ -243,14 +243,14 @@ class AddInterceptParameters: @dataclass class AddInterceptResult: - """AddInterceptResult.""" + """AddInterceptResult type definition.""" intercept: Any | None = None @dataclass class ContinueResponseParameters: - """ContinueResponseParameters.""" + """ContinueResponseParameters type definition.""" request: Any | None = None cookies: list[Any | None] | None = field(default_factory=list) @@ -262,22 +262,22 @@ class ContinueResponseParameters: @dataclass class ContinueWithAuthParameters: - """ContinueWithAuthParameters.""" + """ContinueWithAuthParameters type definition.""" request: Any | None = None @dataclass class ContinueWithAuthCredentials: - """ContinueWithAuthCredentials.""" + """ContinueWithAuthCredentials type definition.""" action: str = field(default="provideCredentials", init=False) credentials: Any | None = None @dataclass -class disownDataParameters: - """disownDataParameters.""" +class DisownDataParameters: + """DisownDataParameters type definition.""" data_type: Any | None = None collector: Any | None = None @@ -286,14 +286,14 @@ class disownDataParameters: @dataclass class FailRequestParameters: - """FailRequestParameters.""" + """FailRequestParameters type definition.""" request: Any | None = None @dataclass class GetDataParameters: - """GetDataParameters.""" + """GetDataParameters type definition.""" data_type: Any | None = None collector: Any | None = None @@ -303,14 +303,14 @@ class GetDataParameters: @dataclass class GetDataResult: - """GetDataResult.""" + """GetDataResult type definition.""" bytes: Any | None = None @dataclass class ProvideResponseParameters: - """ProvideResponseParameters.""" + """ProvideResponseParameters type definition.""" request: Any | None = None body: Any | None = None @@ -322,21 +322,21 @@ class ProvideResponseParameters: @dataclass class RemoveDataCollectorParameters: - """RemoveDataCollectorParameters.""" + """RemoveDataCollectorParameters type definition.""" collector: Any | None = None @dataclass class RemoveInterceptParameters: - """RemoveInterceptParameters.""" + """RemoveInterceptParameters type definition.""" intercept: Any | None = None @dataclass class SetCacheBehaviorParameters: - """SetCacheBehaviorParameters.""" + """SetCacheBehaviorParameters type definition.""" cache_behavior: Any | None = None contexts: list[Any | None] | None = field(default_factory=list) @@ -344,16 +344,44 @@ class SetCacheBehaviorParameters: @dataclass class SetExtraHeadersParameters: - """SetExtraHeadersParameters.""" + """SetExtraHeadersParameters type definition.""" headers: list[Any | None] | None = field(default_factory=list) contexts: list[Any | None] | None = field(default_factory=list) user_contexts: list[Any | None] | None = field(default_factory=list) +@dataclass +class AuthRequiredParameters: + """AuthRequiredParameters type definition.""" + + response: Any | None = None + + +@dataclass +class BeforeRequestSentParameters: + """BeforeRequestSentParameters type definition.""" + + initiator: Any | None = None + + +@dataclass +class FetchErrorParameters: + """FetchErrorParameters type definition.""" + + error_text: str | None = None + + +@dataclass +class ResponseCompletedParameters: + """ResponseCompletedParameters type definition.""" + + response: Any | None = None + + @dataclass class ResponseStartedParameters: - """ResponseStartedParameters.""" + """ResponseStartedParameters type definition.""" response: Any | None = None @@ -396,6 +424,10 @@ def continue_request(self, **kwargs): # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "auth_required": "network.authRequired", + "before_request_sent": "network.beforeRequestSent", + "fetch_error": "network.fetchError", + "response_completed": "network.responseCompleted", + "response_started": "network.responseStarted", "before_request": "network.beforeRequestSent", } @@ -560,6 +592,7 @@ def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) self.intercepts = [] + self._handler_intercepts: dict = {} def add_data_collector( self, @@ -767,52 +800,6 @@ def set_extra_headers( result = self._conn.execute(cmd) return result - def before_request_sent(self, initiator: Any | None = None, method: Any | None = None, params: Any | None = None): - """Execute network.beforeRequestSent.""" - params = { - "initiator": initiator, - "method": method, - "params": params, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.beforeRequestSent", params) - result = self._conn.execute(cmd) - return result - - def fetch_error(self, error_text: Any | None = None, method: Any | None = None, params: Any | None = None): - """Execute network.fetchError.""" - params = { - "errorText": error_text, - "method": method, - "params": params, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.fetchError", params) - result = self._conn.execute(cmd) - return result - - def response_completed(self, response: Any | None = None, method: Any | None = None, params: Any | None = None): - """Execute network.responseCompleted.""" - params = { - "response": response, - "method": method, - "params": params, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.responseCompleted", params) - result = self._conn.execute(cmd) - return result - - def response_started(self, response: Any | None = None): - """Execute network.responseStarted.""" - params = { - "response": response, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("network.responseStarted", params) - result = self._conn.execute(cmd) - return result - def _add_intercept(self, phases=None, url_patterns=None): """Add a low-level network intercept. @@ -861,7 +848,8 @@ def add_request_handler(self, event, callback, url_patterns=None): "auth_required": "authRequired", } phase = phase_map.get(event, "beforeRequestSent") - self._add_intercept(phases=[phase], url_patterns=url_patterns) + intercept_result = self._add_intercept(phases=[phase], url_patterns=url_patterns) + intercept_id = intercept_result.get("intercept") if intercept_result else None def _request_callback(params): raw = ( @@ -872,15 +860,21 @@ def _request_callback(params): request = Request(self._conn, raw) callback(request) - return self.add_event_handler(event, _request_callback) + callback_id = self.add_event_handler(event, _request_callback) + if intercept_id: + self._handler_intercepts[callback_id] = intercept_id + return callback_id def remove_request_handler(self, event, callback_id): - """Remove a network request handler. + """Remove a network request handler and its associated network intercept. Args: event: The event name used when adding the handler. callback_id: The int returned by add_request_handler. """ self.remove_event_handler(event, callback_id) + intercept_id = self._handler_intercepts.pop(callback_id, None) + if intercept_id: + self._remove_intercept(intercept_id) def clear_request_handlers(self): """Clear all request handlers and remove all tracked intercepts.""" self.clear_event_handlers() @@ -960,6 +954,18 @@ def clear_event_handlers(self) -> None: # Event: network.authRequired AuthRequired = globals().get('AuthRequiredParameters', dict) # Fallback to dict if type not defined +# Event: network.beforeRequestSent +BeforeRequestSent = globals().get('BeforeRequestSentParameters', dict) # Fallback to dict if type not defined + +# Event: network.fetchError +FetchError = globals().get('FetchErrorParameters', dict) # Fallback to dict if type not defined + +# Event: network.responseCompleted +ResponseCompleted = globals().get('ResponseCompletedParameters', dict) # Fallback to dict if type not defined + +# Event: network.responseStarted +ResponseStarted = globals().get('ResponseStartedParameters', dict) # Fallback to dict if type not defined + # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() @@ -970,5 +976,29 @@ def clear_event_handlers(self) -> None: if _globals.get("AuthRequired") else EventConfig("auth_required", "network.authRequired", dict) ), + "before_request_sent": ( + EventConfig("before_request_sent", "network.beforeRequestSent", + _globals.get("BeforeRequestSent", dict)) + if _globals.get("BeforeRequestSent") + else EventConfig("before_request_sent", "network.beforeRequestSent", dict) + ), + "fetch_error": ( + EventConfig("fetch_error", "network.fetchError", + _globals.get("FetchError", dict)) + if _globals.get("FetchError") + else EventConfig("fetch_error", "network.fetchError", dict) + ), + "response_completed": ( + EventConfig("response_completed", "network.responseCompleted", + _globals.get("ResponseCompleted", dict)) + if _globals.get("ResponseCompleted") + else EventConfig("response_completed", "network.responseCompleted", dict) + ), + "response_started": ( + EventConfig("response_started", "network.responseStarted", + _globals.get("ResponseStarted", dict)) + if _globals.get("ResponseStarted") + else EventConfig("response_started", "network.responseStarted", dict) + ), "before_request": EventConfig("before_request", "network.beforeRequestSent", _globals.get("dict", dict)), } diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index b29721db88503..061bb17b0deec 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -47,7 +47,7 @@ class ResultOwnership: @dataclass class ChannelValue: - """ChannelValue.""" + """ChannelValue type definition.""" type: str = field(default="channel", init=False) value: Any | None = None @@ -55,7 +55,7 @@ class ChannelValue: @dataclass class ChannelProperties: - """ChannelProperties.""" + """ChannelProperties type definition.""" channel: Any | None = None serialization_options: Any | None = None @@ -64,7 +64,7 @@ class ChannelProperties: @dataclass class EvaluateResultSuccess: - """EvaluateResultSuccess.""" + """EvaluateResultSuccess type definition.""" type: str = field(default="success", init=False) result: Any | None = None @@ -73,7 +73,7 @@ class EvaluateResultSuccess: @dataclass class EvaluateResultException: - """EvaluateResultException.""" + """EvaluateResultException type definition.""" type: str = field(default="exception", init=False) exception_details: Any | None = None @@ -82,7 +82,7 @@ class EvaluateResultException: @dataclass class ExceptionDetails: - """ExceptionDetails.""" + """ExceptionDetails type definition.""" column_number: Any | None = None exception: Any | None = None @@ -93,7 +93,7 @@ class ExceptionDetails: @dataclass class ArrayLocalValue: - """ArrayLocalValue.""" + """ArrayLocalValue type definition.""" type: str = field(default="array", init=False) value: Any | None = None @@ -101,7 +101,7 @@ class ArrayLocalValue: @dataclass class DateLocalValue: - """DateLocalValue.""" + """DateLocalValue type definition.""" type: str = field(default="date", init=False) value: str | None = None @@ -109,7 +109,7 @@ class DateLocalValue: @dataclass class MapLocalValue: - """MapLocalValue.""" + """MapLocalValue type definition.""" type: str = field(default="map", init=False) value: Any | None = None @@ -117,7 +117,7 @@ class MapLocalValue: @dataclass class ObjectLocalValue: - """ObjectLocalValue.""" + """ObjectLocalValue type definition.""" type: str = field(default="object", init=False) value: Any | None = None @@ -125,7 +125,7 @@ class ObjectLocalValue: @dataclass class RegExpValue: - """RegExpValue.""" + """RegExpValue type definition.""" pattern: str | None = None flags: str | None = None @@ -133,7 +133,7 @@ class RegExpValue: @dataclass class RegExpLocalValue: - """RegExpLocalValue.""" + """RegExpLocalValue type definition.""" type: str = field(default="regexp", init=False) value: Any | None = None @@ -141,7 +141,7 @@ class RegExpLocalValue: @dataclass class SetLocalValue: - """SetLocalValue.""" + """SetLocalValue type definition.""" type: str = field(default="set", init=False) value: Any | None = None @@ -149,21 +149,21 @@ class SetLocalValue: @dataclass class UndefinedValue: - """UndefinedValue.""" + """UndefinedValue type definition.""" type: str = field(default="undefined", init=False) @dataclass class NullValue: - """NullValue.""" + """NullValue type definition.""" type: str = field(default="null", init=False) @dataclass class StringValue: - """StringValue.""" + """StringValue type definition.""" type: str = field(default="string", init=False) value: str | None = None @@ -171,7 +171,7 @@ class StringValue: @dataclass class NumberValue: - """NumberValue.""" + """NumberValue type definition.""" type: str = field(default="number", init=False) value: Any | None = None @@ -179,7 +179,7 @@ class NumberValue: @dataclass class BooleanValue: - """BooleanValue.""" + """BooleanValue type definition.""" type: str = field(default="boolean", init=False) value: bool | None = None @@ -187,7 +187,7 @@ class BooleanValue: @dataclass class BigIntValue: - """BigIntValue.""" + """BigIntValue type definition.""" type: str = field(default="bigint", init=False) value: str | None = None @@ -195,7 +195,7 @@ class BigIntValue: @dataclass class BaseRealmInfo: - """BaseRealmInfo.""" + """BaseRealmInfo type definition.""" realm: Any | None = None origin: str | None = None @@ -203,7 +203,7 @@ class BaseRealmInfo: @dataclass class WindowRealmInfo: - """WindowRealmInfo.""" + """WindowRealmInfo type definition.""" type: str = field(default="window", init=False) context: Any | None = None @@ -212,7 +212,7 @@ class WindowRealmInfo: @dataclass class DedicatedWorkerRealmInfo: - """DedicatedWorkerRealmInfo.""" + """DedicatedWorkerRealmInfo type definition.""" type: str = field(default="dedicated-worker", init=False) owners: list[Any | None] | None = field(default_factory=list) @@ -220,49 +220,49 @@ class DedicatedWorkerRealmInfo: @dataclass class SharedWorkerRealmInfo: - """SharedWorkerRealmInfo.""" + """SharedWorkerRealmInfo type definition.""" type: str = field(default="shared-worker", init=False) @dataclass class ServiceWorkerRealmInfo: - """ServiceWorkerRealmInfo.""" + """ServiceWorkerRealmInfo type definition.""" type: str = field(default="service-worker", init=False) @dataclass class WorkerRealmInfo: - """WorkerRealmInfo.""" + """WorkerRealmInfo type definition.""" type: str = field(default="worker", init=False) @dataclass class PaintWorkletRealmInfo: - """PaintWorkletRealmInfo.""" + """PaintWorkletRealmInfo type definition.""" type: str = field(default="paint-worklet", init=False) @dataclass class AudioWorkletRealmInfo: - """AudioWorkletRealmInfo.""" + """AudioWorkletRealmInfo type definition.""" type: str = field(default="audio-worklet", init=False) @dataclass class WorkletRealmInfo: - """WorkletRealmInfo.""" + """WorkletRealmInfo type definition.""" type: str = field(default="worklet", init=False) @dataclass class SharedReference: - """SharedReference.""" + """SharedReference type definition.""" shared_id: Any | None = None handle: Any | None = None @@ -270,7 +270,7 @@ class SharedReference: @dataclass class RemoteObjectReference: - """RemoteObjectReference.""" + """RemoteObjectReference type definition.""" handle: Any | None = None shared_id: Any | None = None @@ -278,7 +278,7 @@ class RemoteObjectReference: @dataclass class SymbolRemoteValue: - """SymbolRemoteValue.""" + """SymbolRemoteValue type definition.""" type: str = field(default="symbol", init=False) handle: Any | None = None @@ -287,7 +287,7 @@ class SymbolRemoteValue: @dataclass class ArrayRemoteValue: - """ArrayRemoteValue.""" + """ArrayRemoteValue type definition.""" type: str = field(default="array", init=False) handle: Any | None = None @@ -297,7 +297,7 @@ class ArrayRemoteValue: @dataclass class ObjectRemoteValue: - """ObjectRemoteValue.""" + """ObjectRemoteValue type definition.""" type: str = field(default="object", init=False) handle: Any | None = None @@ -307,7 +307,7 @@ class ObjectRemoteValue: @dataclass class FunctionRemoteValue: - """FunctionRemoteValue.""" + """FunctionRemoteValue type definition.""" type: str = field(default="function", init=False) handle: Any | None = None @@ -316,7 +316,7 @@ class FunctionRemoteValue: @dataclass class RegExpRemoteValue: - """RegExpRemoteValue.""" + """RegExpRemoteValue type definition.""" handle: Any | None = None internal_id: Any | None = None @@ -324,7 +324,7 @@ class RegExpRemoteValue: @dataclass class DateRemoteValue: - """DateRemoteValue.""" + """DateRemoteValue type definition.""" handle: Any | None = None internal_id: Any | None = None @@ -332,7 +332,7 @@ class DateRemoteValue: @dataclass class MapRemoteValue: - """MapRemoteValue.""" + """MapRemoteValue type definition.""" type: str = field(default="map", init=False) handle: Any | None = None @@ -342,7 +342,7 @@ class MapRemoteValue: @dataclass class SetRemoteValue: - """SetRemoteValue.""" + """SetRemoteValue type definition.""" type: str = field(default="set", init=False) handle: Any | None = None @@ -352,7 +352,7 @@ class SetRemoteValue: @dataclass class WeakMapRemoteValue: - """WeakMapRemoteValue.""" + """WeakMapRemoteValue type definition.""" type: str = field(default="weakmap", init=False) handle: Any | None = None @@ -361,7 +361,7 @@ class WeakMapRemoteValue: @dataclass class WeakSetRemoteValue: - """WeakSetRemoteValue.""" + """WeakSetRemoteValue type definition.""" type: str = field(default="weakset", init=False) handle: Any | None = None @@ -370,7 +370,7 @@ class WeakSetRemoteValue: @dataclass class GeneratorRemoteValue: - """GeneratorRemoteValue.""" + """GeneratorRemoteValue type definition.""" type: str = field(default="generator", init=False) handle: Any | None = None @@ -379,7 +379,7 @@ class GeneratorRemoteValue: @dataclass class ErrorRemoteValue: - """ErrorRemoteValue.""" + """ErrorRemoteValue type definition.""" type: str = field(default="error", init=False) handle: Any | None = None @@ -388,7 +388,7 @@ class ErrorRemoteValue: @dataclass class ProxyRemoteValue: - """ProxyRemoteValue.""" + """ProxyRemoteValue type definition.""" type: str = field(default="proxy", init=False) handle: Any | None = None @@ -397,7 +397,7 @@ class ProxyRemoteValue: @dataclass class PromiseRemoteValue: - """PromiseRemoteValue.""" + """PromiseRemoteValue type definition.""" type: str = field(default="promise", init=False) handle: Any | None = None @@ -406,7 +406,7 @@ class PromiseRemoteValue: @dataclass class TypedArrayRemoteValue: - """TypedArrayRemoteValue.""" + """TypedArrayRemoteValue type definition.""" type: str = field(default="typedarray", init=False) handle: Any | None = None @@ -415,7 +415,7 @@ class TypedArrayRemoteValue: @dataclass class ArrayBufferRemoteValue: - """ArrayBufferRemoteValue.""" + """ArrayBufferRemoteValue type definition.""" type: str = field(default="arraybuffer", init=False) handle: Any | None = None @@ -424,7 +424,7 @@ class ArrayBufferRemoteValue: @dataclass class NodeListRemoteValue: - """NodeListRemoteValue.""" + """NodeListRemoteValue type definition.""" type: str = field(default="nodelist", init=False) handle: Any | None = None @@ -434,7 +434,7 @@ class NodeListRemoteValue: @dataclass class HTMLCollectionRemoteValue: - """HTMLCollectionRemoteValue.""" + """HTMLCollectionRemoteValue type definition.""" type: str = field(default="htmlcollection", init=False) handle: Any | None = None @@ -444,7 +444,7 @@ class HTMLCollectionRemoteValue: @dataclass class NodeRemoteValue: - """NodeRemoteValue.""" + """NodeRemoteValue type definition.""" type: str = field(default="node", init=False) shared_id: Any | None = None @@ -455,7 +455,7 @@ class NodeRemoteValue: @dataclass class NodeProperties: - """NodeProperties.""" + """NodeProperties type definition.""" node_type: Any | None = None child_node_count: Any | None = None @@ -469,7 +469,7 @@ class NodeProperties: @dataclass class WindowProxyRemoteValue: - """WindowProxyRemoteValue.""" + """WindowProxyRemoteValue type definition.""" type: str = field(default="window", init=False) value: Any | None = None @@ -479,14 +479,14 @@ class WindowProxyRemoteValue: @dataclass class WindowProxyProperties: - """WindowProxyProperties.""" + """WindowProxyProperties type definition.""" context: Any | None = None @dataclass class StackFrame: - """StackFrame.""" + """StackFrame type definition.""" column_number: Any | None = None function_name: str | None = None @@ -496,14 +496,14 @@ class StackFrame: @dataclass class StackTrace: - """StackTrace.""" + """StackTrace type definition.""" call_frames: list[Any | None] | None = field(default_factory=list) @dataclass class Source: - """Source.""" + """Source type definition.""" realm: Any | None = None context: Any | None = None @@ -511,14 +511,14 @@ class Source: @dataclass class RealmTarget: - """RealmTarget.""" + """RealmTarget type definition.""" realm: Any | None = None @dataclass class ContextTarget: - """ContextTarget.""" + """ContextTarget type definition.""" context: Any | None = None sandbox: str | None = None @@ -526,7 +526,7 @@ class ContextTarget: @dataclass class AddPreloadScriptParameters: - """AddPreloadScriptParameters.""" + """AddPreloadScriptParameters type definition.""" function_declaration: str | None = None arguments: list[Any | None] | None = field(default_factory=list) @@ -537,14 +537,14 @@ class AddPreloadScriptParameters: @dataclass class AddPreloadScriptResult: - """AddPreloadScriptResult.""" + """AddPreloadScriptResult type definition.""" script: Any | None = None @dataclass class DisownParameters: - """DisownParameters.""" + """DisownParameters type definition.""" handles: list[Any | None] | None = field(default_factory=list) target: Any | None = None @@ -552,7 +552,7 @@ class DisownParameters: @dataclass class CallFunctionParameters: - """CallFunctionParameters.""" + """CallFunctionParameters type definition.""" function_declaration: str | None = None await_promise: bool | None = None @@ -566,7 +566,7 @@ class CallFunctionParameters: @dataclass class EvaluateParameters: - """EvaluateParameters.""" + """EvaluateParameters type definition.""" expression: str | None = None target: Any | None = None @@ -578,7 +578,7 @@ class EvaluateParameters: @dataclass class GetRealmsParameters: - """GetRealmsParameters.""" + """GetRealmsParameters type definition.""" context: Any | None = None type: Any | None = None @@ -586,21 +586,21 @@ class GetRealmsParameters: @dataclass class GetRealmsResult: - """GetRealmsResult.""" + """GetRealmsResult type definition.""" realms: list[Any | None] | None = field(default_factory=list) @dataclass class RemovePreloadScriptParameters: - """RemovePreloadScriptParameters.""" + """RemovePreloadScriptParameters type definition.""" script: Any | None = None @dataclass class MessageParameters: - """MessageParameters.""" + """MessageParameters type definition.""" channel: Any | None = None data: Any | None = None @@ -609,13 +609,14 @@ class MessageParameters: @dataclass class RealmDestroyedParameters: - """RealmDestroyedParameters.""" + """RealmDestroyedParameters type definition.""" realm: Any | None = None # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { + "message": "script.message", "realm_created": "script.realmCreated", "realm_destroyed": "script.realmDestroyed", } @@ -885,18 +886,6 @@ def remove_preload_script(self, script: Any | None = None): result = self._conn.execute(cmd) return result - def message(self, channel: Any | None = None, data: Any | None = None, source: Any | None = None): - """Execute script.message.""" - params = { - "channel": channel, - "data": data, - "source": source, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("script.message", params) - result = self._conn.execute(cmd) - return result - def execute(self, function_declaration: str, *args, context_id: str | None = None) -> Any: """Execute a function declaration in the browser context. @@ -1298,6 +1287,9 @@ def clear_event_handlers(self) -> None: return self._event_manager.clear_event_handlers() # Event Info Type Aliases +# Event: script.message +Message = globals().get('MessageParameters', dict) # Fallback to dict if type not defined + # Event: script.realmCreated RealmCreated = globals().get('RealmInfo', dict) # Fallback to dict if type not defined @@ -1308,6 +1300,12 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Script.EVENT_CONFIGS = { + "message": ( + EventConfig("message", "script.message", + _globals.get("Message", dict)) + if _globals.get("Message") + else EventConfig("message", "script.message", dict) + ), "realm_created": ( EventConfig("realm_created", "script.realmCreated", _globals.get("RealmCreated", dict)) diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index c1b5be09ca024..da12c1cd49792 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -22,7 +22,7 @@ class UserPromptHandlerType: @dataclass class CapabilitiesRequest: - """CapabilitiesRequest.""" + """CapabilitiesRequest type definition.""" always_match: Any | None = None first_match: list[Any | None] | None = field(default_factory=list) @@ -30,7 +30,7 @@ class CapabilitiesRequest: @dataclass class CapabilityRequest: - """CapabilityRequest.""" + """CapabilityRequest type definition.""" accept_insecure_certs: bool | None = None browser_name: str | None = None @@ -42,21 +42,21 @@ class CapabilityRequest: @dataclass class AutodetectProxyConfiguration: - """AutodetectProxyConfiguration.""" + """AutodetectProxyConfiguration type definition.""" proxy_type: str = field(default="autodetect", init=False) @dataclass class DirectProxyConfiguration: - """DirectProxyConfiguration.""" + """DirectProxyConfiguration type definition.""" proxy_type: str = field(default="direct", init=False) @dataclass class ManualProxyConfiguration: - """ManualProxyConfiguration.""" + """ManualProxyConfiguration type definition.""" proxy_type: str = field(default="manual", init=False) http_proxy: str | None = None @@ -66,7 +66,7 @@ class ManualProxyConfiguration: @dataclass class SocksProxyConfiguration: - """SocksProxyConfiguration.""" + """SocksProxyConfiguration type definition.""" socks_proxy: str | None = None socks_version: Any | None = None @@ -74,7 +74,7 @@ class SocksProxyConfiguration: @dataclass class PacProxyConfiguration: - """PacProxyConfiguration.""" + """PacProxyConfiguration type definition.""" proxy_type: str = field(default="pac", init=False) proxy_autoconfig_url: str | None = None @@ -82,14 +82,14 @@ class PacProxyConfiguration: @dataclass class SystemProxyConfiguration: - """SystemProxyConfiguration.""" + """SystemProxyConfiguration type definition.""" proxy_type: str = field(default="system", init=False) @dataclass class SubscribeParameters: - """SubscribeParameters.""" + """SubscribeParameters type definition.""" events: list[str | None] | None = field(default_factory=list) contexts: list[Any | None] | None = field(default_factory=list) @@ -98,21 +98,21 @@ class SubscribeParameters: @dataclass class UnsubscribeByIDRequest: - """UnsubscribeByIDRequest.""" + """UnsubscribeByIDRequest type definition.""" subscriptions: list[Any | None] | None = field(default_factory=list) @dataclass class UnsubscribeByAttributesRequest: - """UnsubscribeByAttributesRequest.""" + """UnsubscribeByAttributesRequest type definition.""" events: list[str | None] | None = field(default_factory=list) @dataclass class StatusResult: - """StatusResult.""" + """StatusResult type definition.""" ready: bool | None = None message: str | None = None @@ -120,14 +120,14 @@ class StatusResult: @dataclass class NewParameters: - """NewParameters.""" + """NewParameters type definition.""" capabilities: Any | None = None @dataclass class NewResult: - """NewResult.""" + """NewResult type definition.""" session_id: str | None = None accept_insecure_certs: bool | None = None @@ -143,7 +143,7 @@ class NewResult: @dataclass class SubscribeResult: - """SubscribeResult.""" + """SubscribeResult type definition.""" subscription: Any | None = None diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 3f29b85d13a23..c5a4666ebaf07 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -14,7 +14,7 @@ @dataclass class PartitionKey: - """PartitionKey.""" + """PartitionKey type definition.""" user_context: str | None = None source_origin: str | None = None @@ -22,7 +22,7 @@ class PartitionKey: @dataclass class GetCookiesParameters: - """GetCookiesParameters.""" + """GetCookiesParameters type definition.""" filter: Any | None = None partition: Any | None = None @@ -30,7 +30,7 @@ class GetCookiesParameters: @dataclass class GetCookiesResult: - """GetCookiesResult.""" + """GetCookiesResult type definition.""" cookies: list[Any | None] | None = field(default_factory=list) partition_key: Any | None = None @@ -38,7 +38,7 @@ class GetCookiesResult: @dataclass class SetCookieParameters: - """SetCookieParameters.""" + """SetCookieParameters type definition.""" cookie: Any | None = None partition: Any | None = None @@ -46,14 +46,14 @@ class SetCookieParameters: @dataclass class SetCookieResult: - """SetCookieResult.""" + """SetCookieResult type definition.""" partition_key: Any | None = None @dataclass class DeleteCookiesParameters: - """DeleteCookiesParameters.""" + """DeleteCookiesParameters type definition.""" filter: Any | None = None partition: Any | None = None @@ -61,7 +61,7 @@ class DeleteCookiesParameters: @dataclass class DeleteCookiesResult: - """DeleteCookiesResult.""" + """DeleteCookiesResult type definition.""" partition_key: Any | None = None diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index ebbe6729499b2..0a3998a611125 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -14,14 +14,14 @@ @dataclass class InstallParameters: - """InstallParameters.""" + """InstallParameters type definition.""" extension_data: Any | None = None @dataclass class ExtensionPath: - """ExtensionPath.""" + """ExtensionPath type definition.""" type: str = field(default="path", init=False) path: str | None = None @@ -29,7 +29,7 @@ class ExtensionPath: @dataclass class ExtensionArchivePath: - """ExtensionArchivePath.""" + """ExtensionArchivePath type definition.""" type: str = field(default="archivePath", init=False) path: str | None = None @@ -37,7 +37,7 @@ class ExtensionArchivePath: @dataclass class ExtensionBase64Encoded: - """ExtensionBase64Encoded.""" + """ExtensionBase64Encoded type definition.""" type: str = field(default="base64", init=False) value: str | None = None @@ -45,14 +45,14 @@ class ExtensionBase64Encoded: @dataclass class InstallResult: - """InstallResult.""" + """InstallResult type definition.""" extension: Any | None = None @dataclass class UninstallParameters: - """UninstallParameters.""" + """UninstallParameters type definition.""" extension: Any | None = None diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index dc64d77265b09..38013a56f7b7d 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -450,8 +450,12 @@ def execute( """ # Handle BiDi generator commands if inspect.isgenerator(driver_command): - # BiDi command: use WebSocketConnection directly - return self.command_executor.execute(driver_command) + # BiDi command: route through the WebSocket connection, not the + # HTTP RemoteConnection which only accepts (command, params) pairs. + if not self._websocket_connection: + self._start_bidi() + assert self._websocket_connection is not None + return self._websocket_connection.execute(driver_command) # Legacy WebDriver command: handle normally params = self._wrap_value(params) From c0503635fe21c255db80e1a0099a51d58c84d992 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Mon, 2 Mar 2026 13:55:49 +0000 Subject: [PATCH 06/37] [py] Fix Copilot review: license headers, _BiDiEncoder nested types, revert unrelated requirements changes --- py/generate_bidi.py | 19 ++++++++++++++++++- py/requirements_lock.txt | 5 ++++- py/selenium/webdriver/common/bidi/__init__.py | 17 +++++++++++++++++ py/selenium/webdriver/common/bidi/browser.py | 17 +++++++++++++++++ .../webdriver/common/bidi/browsing_context.py | 17 +++++++++++++++++ .../webdriver/common/bidi/emulation.py | 17 +++++++++++++++++ py/selenium/webdriver/common/bidi/input.py | 17 +++++++++++++++++ py/selenium/webdriver/common/bidi/log.py | 17 +++++++++++++++++ py/selenium/webdriver/common/bidi/network.py | 17 +++++++++++++++++ py/selenium/webdriver/common/bidi/script.py | 17 +++++++++++++++++ py/selenium/webdriver/common/bidi/session.py | 17 +++++++++++++++++ py/selenium/webdriver/common/bidi/storage.py | 17 +++++++++++++++++ .../webdriver/common/bidi/webextension.py | 17 +++++++++++++++++ .../webdriver/remote/websocket_connection.py | 14 ++++++++++++-- 14 files changed, 221 insertions(+), 4 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 5d7f39e53abfc..412494517772a 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -32,7 +32,24 @@ logger = logging.getLogger("generate_bidi") # File headers -SHARED_HEADER = """# DO NOT EDIT THIS FILE! +SHARED_HEADER = """# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make # changes, edit the generator and regenerate all of the modules.""" diff --git a/py/requirements_lock.txt b/py/requirements_lock.txt index c58f4b1c76fe6..68f8d858bb6f4 100644 --- a/py/requirements_lock.txt +++ b/py/requirements_lock.txt @@ -461,6 +461,7 @@ jeepney==0.9.0 \ --hash=sha256:cf0e9e845622b81e4a28df94c40345400256ec608d0e55bb8a3feaa9163f5732 # via # -r py/requirements.txt + # keyring # secretstorage jinja2==3.1.6 \ --hash=sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d \ @@ -1037,7 +1038,9 @@ rich==14.3.3 \ secretstorage==3.5.0 \ --hash=sha256:0ce65888c0725fcb2c5bc0fdb8e5438eece02c523557ea40ce0703c266248137 \ --hash=sha256:f04b8e4689cbce351744d5537bf6b1329c6fc68f91fa666f60a380edddcd11be - # via -r py/requirements.txt + # via + # -r py/requirements.txt + # keyring sniffio==1.3.1 \ --hash=sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2 \ --hash=sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc diff --git a/py/selenium/webdriver/common/bidi/__init__.py b/py/selenium/webdriver/common/bidi/__init__.py index 7be7bd4f73856..bb129d5f6a195 100644 --- a/py/selenium/webdriver/common/bidi/__init__.py +++ b/py/selenium/webdriver/common/bidi/__init__.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 71f917634304d..ff0c2d59b8cf2 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index ede96071778c3..7a0f8faf8687e 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index fbbe0966d8b3a..c58f6d5f78d6c 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index c8e58181b343e..e9c3f8345f05d 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index eaf52a2ec08c2..94f511d7185f8 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index c9737ac9131d0..9dc5fb94d8488 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 061bb17b0deec..0b2ec04101933 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index da12c1cd49792..771a5327151bf 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index c5a4666ebaf07..7623381706040 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 0a3998a611125..99250afca4c68 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -1,3 +1,20 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/remote/websocket_connection.py b/py/selenium/webdriver/remote/websocket_connection.py index 68358e4a09974..8d6f745d4ac5b 100644 --- a/py/selenium/webdriver/remote/websocket_connection.py +++ b/py/selenium/webdriver/remote/websocket_connection.py @@ -41,6 +41,16 @@ class _BiDiEncoder(json.JSONEncoder): directly into its parent action dict as required by the BiDi spec. """ + def _convert(self, value): + """Recursively convert a value, handling nested dataclasses, lists, and dicts.""" + if dataclasses.is_dataclass(value) and not isinstance(value, type): + return self.default(value) + if isinstance(value, list): + return [self._convert(item) for item in value] + if isinstance(value, dict): + return {k: self._convert(v) for k, v in value.items()} + return value + def default(self, o): if dataclasses.is_dataclass(o) and not isinstance(o, type): result = {} @@ -54,9 +64,9 @@ def default(self, o): for pf in dataclasses.fields(value): pv = getattr(value, pf.name) if pv is not None: - result[_snake_to_camel(pf.name)] = pv + result[_snake_to_camel(pf.name)] = self._convert(pv) else: - result[camel_key] = value + result[camel_key] = self._convert(value) return result return super().default(o) From 4bf0219e15aedec6f0ce80cc86ec84c3c22a0167 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 11:48:05 +0000 Subject: [PATCH 07/37] fixup --- py/generate_bidi.py | 1134 +---------------- .../webdriver/common/bidi/_event_manager.py | 186 +++ .../webdriver/remote/websocket_connection.py | 13 +- 3 files changed, 199 insertions(+), 1134 deletions(-) create mode 100644 py/selenium/webdriver/common/bidi/_event_manager.py diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 412494517772a..8103cafe40684 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -619,11 +619,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: # Add imports for event handling if needed if self.events: - code += "import threading\n" - code += "from collections.abc import Callable\n" - if not dataclass_imported: - code += "from dataclasses import dataclass\n" - code += "from selenium.webdriver.common.bidi.session import Session\n" + code += "from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager\n" code += "\n\n" @@ -801,1131 +797,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ code += "\n\n" - # Generate EventConfig and _EventManager for modules with events - if self.events: - # Generate EventConfig dataclass - code += """@dataclass -class EventConfig: - \"\"\"Configuration for a BiDi event.\"\"\" - event_key: str - bidi_event: str - event_class: type - - -""" - - # Generate _EventManager class - code += """class _EventWrapper: - \"\"\"Wrapper to provide event_class attribute for WebSocketConnection callbacks.\"\"\" - def __init__(self, bidi_event: str, event_class: type): - self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class - self._python_class = event_class # Keep reference to Python dataclass for deserialization - - def from_json(self, params: dict) -> Any: - \"\"\"Deserialize event params into the wrapped Python dataclass. - - Args: - params: Raw BiDi event params with camelCase keys. - - Returns: - An instance of the dataclass, or the raw dict on failure. - \"\"\" - if self._python_class is None or self._python_class is dict: - return params - try: - # Delegate to a classmethod from_json if the class defines one - if hasattr(self._python_class, \"from_json\") and callable( - self._python_class.from_json - ): - return self._python_class.from_json(params) - import dataclasses as dc - - snake_params = {self._camel_to_snake(k): v for k, v in params.items()} - if dc.is_dataclass(self._python_class): - valid_fields = {f.name for f in dc.fields(self._python_class)} - filtered = {k: v for k, v in snake_params.items() if k in valid_fields} - return self._python_class(**filtered) - return self._python_class(**snake_params) - except Exception: - return params - - @staticmethod - def _camel_to_snake(name: str) -> str: - result = [name[0].lower()] - for char in name[1:]: - if char.isupper(): - result.extend([\"_\", char.lower()]) - else: - result.append(char) - return \"\".join(result) - - -class _EventManager: - \"\"\"Manages event subscriptions and callbacks.\"\"\" - - def __init__(self, conn, event_configs: dict[str, EventConfig]): - self.conn = conn - self.event_configs = event_configs - self.subscriptions: dict = {} - self._event_wrappers = {} # Cache of _EventWrapper objects - self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} - self._available_events = ", ".join(sorted(event_configs.keys())) - self._subscription_lock = threading.Lock() - - # Create event wrappers for each event - for config in event_configs.values(): - wrapper = _EventWrapper(config.bidi_event, config.event_class) - self._event_wrappers[config.bidi_event] = wrapper - - def validate_event(self, event: str) -> EventConfig: - event_config = self.event_configs.get(event) - if not event_config: - raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") - return event_config - - def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: - \"\"\"Subscribe to a BiDi event if not already subscribed.\"\"\" - with self._subscription_lock: - if bidi_event not in self.subscriptions: - session = Session(self.conn) - result = session.subscribe([bidi_event], contexts=contexts) - sub_id = ( - result.get(\"subscription\") if isinstance(result, dict) else None - ) - self.subscriptions[bidi_event] = { - \"callbacks\": [], - \"subscription_id\": sub_id, - } - - def unsubscribe_from_event(self, bidi_event: str) -> None: - \"\"\"Unsubscribe from a BiDi event if no more callbacks exist.\"\"\" - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry is not None and not entry[\"callbacks\"]: - session = Session(self.conn) - sub_id = entry.get(\"subscription_id\") - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - del self.subscriptions[bidi_event] - - def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - self.subscriptions[bidi_event][\"callbacks\"].append(callback_id) - - def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry and callback_id in entry[\"callbacks\"]: - entry[\"callbacks\"].remove(callback_id) - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - event_config = self.validate_event(event) - # Use the event wrapper for add_callback - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - callback_id = self.conn.add_callback(event_wrapper, callback) - self.subscribe_to_event(event_config.bidi_event, contexts) - self.add_callback_to_tracking(event_config.bidi_event, callback_id) - return callback_id - - def remove_event_handler(self, event: str, callback_id: int) -> None: - event_config = self.validate_event(event) - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - self.conn.remove_callback(event_wrapper, callback_id) - self.remove_callback_from_tracking(event_config.bidi_event, callback_id) - self.unsubscribe_from_event(event_config.bidi_event) - - def clear_event_handlers(self) -> None: - \"\"\"Clear all event handlers.\"\"\" - with self._subscription_lock: - if not self.subscriptions: - return - session = Session(self.conn) - for bidi_event, entry in list(self.subscriptions.items()): - event_wrapper = self._event_wrappers.get(bidi_event) - callbacks = entry[\"callbacks\"] if isinstance(entry, dict) else entry - if event_wrapper: - for callback_id in callbacks: - self.conn.remove_callback(event_wrapper, callback_id) - sub_id = ( - entry.get(\"subscription_id\") if isinstance(entry, dict) else None - ) - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - self.subscriptions.clear() - - -""" - code += "\n\n" + # EventConfig, _EventWrapper, and _EventManager are imported from + # ._event_manager (see the import block above); nothing to emit here. # Generate class - # Convert module name (camelCase or snake_case) to proper class name (PascalCase) - class_name = module_name_to_class_name(self.name) - code += f"class {class_name}:\n" - code += f' """WebDriver BiDi {self.name} module."""\n\n' - - # Add EVENT_CONFIGS dict if there are events - if self.events: - code += ( - " EVENT_CONFIGS = {}\n" # Will be populated after types are defined - ) - - if self.name == "script": - code += " def __init__(self, conn, driver=None) -> None:\n" - code += " self._conn = conn\n" - code += " self._driver = driver\n" - else: - code += " def __init__(self, conn) -> None:\n" - code += " self._conn = conn\n" - - # Initialize _event_manager if there are events - if self.events: - code += " self._event_manager = _EventManager(conn, self.EVENT_CONFIGS)\n" - - # Append extra init code from enhancements (e.g. self.intercepts = []) - for init_line in enhancements.get("extra_init_code", []): - code += f" {init_line}\n" - - code += "\n" - - # Generate command methods - # Auto-exclude methods whose names appear in extra_methods to prevent duplicates - extra_method_names = set() - for extra_meth in enhancements.get("extra_methods", []): - m = re.search(r"def\s+(\w+)\s*\(", extra_meth) - if m: - extra_method_names.add(m.group(1)) - exclude_methods = ( - set(enhancements.get("exclude_methods", [])) | extra_method_names - ) - if self.commands: - for command in self.commands: - # Get method-specific enhancements - # Convert command name to snake_case to match enhancement manifest keys - method_name_snake = command._camel_to_snake(command.name) - if method_name_snake in exclude_methods: - continue - method_enhancements = enhancements.get(method_name_snake, {}) - code += command.to_python_method(method_enhancements) - code += "\n" - else: - code += " pass\n" - - # Emit extra methods from enhancement manifest - for extra_method in enhancements.get("extra_methods", []): - code += extra_method - code += "\n" - - # Add delegating event handler methods if events are present - if self.events: - code += """ - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - \"\"\"Add an event handler. - - Args: - event: The event to subscribe to. - callback: The callback function to execute on event. - contexts: The context IDs to subscribe to (optional). - - Returns: - The callback ID. - \"\"\" - return self._event_manager.add_event_handler(event, callback, contexts) - - def remove_event_handler(self, event: str, callback_id: int) -> None: - \"\"\"Remove an event handler. - - Args: - event: The event to unsubscribe from. - callback_id: The callback ID. - \"\"\" - return self._event_manager.remove_event_handler(event, callback_id) - - def clear_event_handlers(self) -> None: - \"\"\"Clear all event handlers.\"\"\" - return self._event_manager.clear_event_handlers() -""" - - # Generate event info type aliases AFTER the class definition - # This ensures all types are available when we create the aliases - if self.events: - code += "\n# Event Info Type Aliases\n" - for event_def in self.events: - code += event_def.to_python_dataclass() - code += "\n" - - # Now populate EVENT_CONFIGS after the aliases are defined - code += "\n# Populate EVENT_CONFIGS with event configuration mappings\n" - # Use globals() to look up types dynamically to handle missing types gracefully - code += "_globals = globals()\n" - code += f"{class_name}.EVENT_CONFIGS = {{\n" - # Collect extra event keys to skip CDDL duplicates - extra_event_keys_cfg = { - evt["event_key"] for evt in enhancements.get("extra_events", []) - } - for event_def in self.events: - # Convert method name to user-friendly event name - method_parts = event_def.method.split(".") - if len(method_parts) == 2: - event_name = self._convert_method_to_event_name(method_parts[1]) - if event_name in extra_event_keys_cfg: - continue - # The event class is the event name (e.g., ContextCreated) - # Try to get it from globals, default to dict if not found - code += ( - f' "{event_name}": (\n' - f' EventConfig("{event_name}", "{event_def.method}",\n' - f' _globals.get("{event_def.name}", dict))\n' - f' if _globals.get("{event_def.name}")\n' - f' else EventConfig("{event_name}", "{event_def.method}", dict)\n' - f" ),\n" - ) - # Extra events not in the CDDL spec - for extra_evt in enhancements.get("extra_events", []): - ek = extra_evt["event_key"] - be = extra_evt["bidi_event"] - ec = extra_evt["event_class"] - single = f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),' - if len(single) > 120: - code += ( - f' "{ek}": EventConfig(\n' - f' "{ek}", "{be}",\n' - f' _globals.get("{ec}", dict),\n' - f" ),\n" - ) - else: - code += single + "\n" - code += "}\n" - - return code - - -class CddlParser: - """Parse CDDL specification files.""" - - def __init__(self, cddl_path: str): - """Initialize parser with CDDL file path.""" - self.cddl_path = Path(cddl_path) - self.content = "" - self.modules: dict[str, CddlModule] = {} - self.definitions: dict[str, str] = {} - self.event_names: set[str] = set() # Names of definitions that are events - self._read_file() - - def _read_file(self) -> None: - """Read and preprocess CDDL file.""" - if not self.cddl_path.exists(): - raise FileNotFoundError(f"CDDL file not found: {self.cddl_path}") - - with open(self.cddl_path, encoding="utf-8") as f: - self.content = f.read() - - logger.info(f"Loaded CDDL file: {self.cddl_path}") - - def parse(self) -> dict[str, CddlModule]: - """Parse CDDL content and return modules.""" - # Remove comments - content = self._remove_comments(self.content) - - # Extract all definitions - self._extract_definitions(content) - - # Extract event names from event union definitions - self._extract_event_names() - - # Extract type definitions by module - self._extract_types() - - # Extract event definitions by module - self._extract_events() - - # Extract command definitions by module - self._extract_commands() - - # If no modules found, create a default one from the filename - if not self.modules: - module_name = self.cddl_path.stem - default_module = CddlModule(name=module_name) - self.modules[module_name] = default_module - logger.warning(f"No modules found in CDDL, creating default: {module_name}") - - return self.modules - - def _remove_comments(self, content: str) -> str: - """Remove comments from CDDL content.""" - # CDDL uses ; for comments to end of line - lines = content.split("\n") - cleaned = [] - for line in lines: - if ";" in line and not line.strip().startswith(";"): - line = line[: line.index(";")] - elif line.strip().startswith(";"): - continue - cleaned.append(line) - return "\n".join(cleaned) - - def _extract_definitions(self, content: str) -> None: - """Extract CDDL definitions (type definitions, commands, etc.).""" - # Match pattern: Name = Definition - # Handles multiline definitions properly. - # The \s* after \n in the lookahead allows definitions that start with - # leading whitespace (e.g. " network.BeforeRequestSent = (") to be - # recognised as separate definitions instead of being swallowed into - # the body of the preceding definition. - pattern = r"(\w+(?:\.\w+)*)\s*=\s*(.+?)(?=\n\s*\w+(?:\.\w+)?\s*=|\Z)" - - for match in re.finditer(pattern, content, re.DOTALL): - name = match.group(1).strip() - definition = match.group(2).strip() - self.definitions[name] = definition - logger.debug(f"Extracted definition: {name}") - - def _extract_event_names(self) -> None: - """Extract event names from event union definitions. - - Event union definitions follow pattern: - module.ModuleEvent = ( - module.EventName1 // - module.EventName2 // - ... - ) - """ - for def_name, def_content in self.definitions.items(): - # Check if this looks like an event union (name ends with "Event") and - # contains a module-qualified reference like "module.EventName". - # Handles both single-item (no //) and multi-item (// separated) unions. - if "Event" in def_name and re.search(r"\w+\.\w+", def_content): - # Extract event names from the union (works for single and multi-item) - event_refs = re.findall(r"(\w+\.\w+)", def_content) - for event_ref in event_refs: - self.event_names.add(event_ref) - logger.debug(f"Identified event: {event_ref} (from {def_name})") - - def _extract_types(self) -> None: - """Extract type definitions from parsed definitions.""" - # Type definitions follow pattern: module.TypeName = { field: type, ... } - # They have dots in the name and curly braces in the content - # But they DON'T have method: "..." pattern (which means it's not a command) - # Enums follow pattern: module.EnumName = "value1" / "value2" / ... - - for def_name, def_content in self.definitions.items(): - # Skip if not a namespaced name (e.g., skip "EmptyParams", "Extensible") - if "." not in def_name: - continue - - # Skip if it's a command (contains method: pattern) - if "method:" in def_content: - continue - - # Extract module.TypeName - if "." in def_name: - module_name, type_name = def_name.rsplit(".", 1) - - # Create module if not exists - if module_name not in self.modules: - self.modules[module_name] = CddlModule(name=module_name) - - # Check if this is an enum (string union with /) - if self._is_enum_definition(def_content): - # Extract enum values - values = self._extract_enum_values(def_content) - if values: - enum_def = CddlEnum( - module=module_name, - name=type_name, - values=values, - description=f"{type_name}", - ) - self.modules[module_name].enums.append(enum_def) - logger.debug( - f"Found enum: {def_name} with {len(values)} values" - ) - else: - # Extract fields from type definition - fields = self._extract_type_fields(def_content) - - if fields: # Only create type if it has fields - type_def = CddlTypeDefinition( - module=module_name, - name=type_name, - fields=fields, - description=f"{type_name}", - ) - self.modules[module_name].types.append(type_def) - logger.debug( - f"Found type: {def_name} with {len(fields)} fields" - ) - - def _is_enum_definition(self, definition: str) -> bool: - """Check if a definition is an enum (string union with /). - - Enums are defined as: "value1" / "value2" / "value3" - """ - # Clean whitespace - clean_def = definition.strip() - - # Must not have curly braces (that would be a type definition) - if "{" in clean_def or "}" in clean_def: - return False - - # Must contain the union operator / surrounded by quotes - # Pattern: "something" / "something_else" - return " / " in clean_def and '"' in clean_def - - def _extract_enum_values(self, enum_definition: str) -> list[str]: - """Extract individual values from an enum definition. - - Enums are defined as: "value1" / "value2" / "value3" - Can span multiple lines. - """ - values = [] - - # Clean the definition and extract quoted strings - # Split by / and extract quoted values - parts = enum_definition.split("/") - - for part in parts: - part = part.strip() - - # Extract quoted string - use search instead of match to find quotes anywhere - match = re.search(r'"([^"]*)"', part) - if match: - value = match.group(1) - values.append(value) - logger.debug(f"Extracted enum value: {value}") - - return values - - @staticmethod - def _normalize_cddl_type(field_type: str) -> str: - """Normalize a CDDL type expression to a simple Python-compatible form. - - Strips CDDL control operators (.ge, .le, .gt, .lt, .default, etc.) and - replaces interval/constraint expressions with their base types so that - the caller can safely check for nested struct syntax. - - Examples: - '(float .ge 0.0) .default 1.0' -> 'float' - '(float .ge 0.0) / null' -> 'float / null' - '(0.0...360.0) / null' -> 'float / null' - '-90.0..90.0' -> 'float' - 'float / null .default null' -> 'float / null' - """ - result = field_type - # Remove trailing .default annotations - result = re.sub(r"\s*\.default\s+\S+", "", result) - # Replace parenthesised constraint expressions: (baseType .operator ...) -> baseType - result = re.sub(r"\((\w+)\s+\.\w+[^)]*\)", r"\1", result) - # Replace parenthesised numeric interval types: (0.0...360.0) -> float - result = re.sub(r"\(-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?\)", "float", result) - # Replace bare numeric interval types: -90.0..90.0 -> float - result = re.sub(r"-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?", "float", result) - return result.strip() - - def _extract_type_fields(self, type_definition: str) -> dict[str, str]: - """Extract fields from a type definition block.""" - fields = {} - - # Remove outer braces - clean_def = type_definition.strip() - if clean_def.startswith("{"): - clean_def = clean_def[1:] - if clean_def.endswith("}"): - clean_def = clean_def[:-1] - - # Parse each line for field: type patterns - for line in clean_def.split("\n"): - line = line.strip() - if not line or "Extensible" in line or line.startswith("//"): - continue - - # Match pattern: [?] fieldName: type - match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) - if not match: - # Try without optional marker - match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) - - if match: - field_name = match.group(1).strip() - field_type = match.group(2).strip() - normalized_type = self._normalize_cddl_type(field_type) - - # Skip lines that are part of nested definitions - if "{" not in normalized_type and "(" not in normalized_type: - fields[field_name] = normalized_type - logger.debug(f"Extracted field {field_name}: {normalized_type}") - - return fields - - def _extract_events(self) -> None: - """Extract event definitions from parsed definitions. - - Events are definitions that: - 1. Are listed in an event union (e.g., BrowsingContextEvent) - 2. Have method: "..." and params: ... fields - - Event pattern: module.EventName = (method: "module.eventName", params: module.ParamType) - """ - # Find definitions that are in the event_names set - event_pattern = re.compile( - r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" - ) - - for def_name, def_content in self.definitions.items(): - # Skip if not identified as an event - if def_name not in self.event_names: - continue - - # Extract method and params - match = event_pattern.search(def_content) - if match: - method = match.group(1) # e.g., "browsingContext.contextCreated" - params_type = match.group(2) # e.g., "browsingContext.Info" - - # Extract module name from method - if "." in method: - module_name, _ = method.split(".", 1) - - # Create module if not exists - if module_name not in self.modules: - self.modules[module_name] = CddlModule(name=module_name) - - # Extract event name from definition name (e.g., browsingContext.ContextCreated) - _, event_name = def_name.rsplit(".", 1) - - # Create event - event = CddlEvent( - module=module_name, - name=event_name, - method=method, - params_type=params_type, - description=f"Event: {method}", - ) - - self.modules[module_name].events.append(event) - logger.debug( - f"Found event: {def_name} (method={method}, params={params_type})" - ) - - def _extract_commands(self) -> None: - """Extract command definitions from parsed definitions.""" - # Find command definitions that follow pattern: module.Command = (method: "...", params: ...) - command_pattern = re.compile( - r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" - ) - - for def_name, def_content in self.definitions.items(): - # Skip definitions that are events (they share the same pattern) - if def_name in self.event_names: - continue - matches = list(command_pattern.finditer(def_content)) - if matches: - for match in matches: - method = match.group(1) # e.g., "session.new" - params_type = match.group(2) # e.g., "session.NewParameters" - - # Extract module name from method - if "." in method: - module_name, command_name = method.split(".", 1) - - # Create module if not exists - if module_name not in self.modules: - self.modules[module_name] = CddlModule(name=module_name) - - # Extract parameters - params = self._extract_parameters(params_type) - - # Create command - cmd = CddlCommand( - module=module_name, - name=command_name, - params=params, - description=f"Execute {method}", - ) - - self.modules[module_name].commands.append(cmd) - logger.debug( - f"Found command: {method} with params {params_type}" - ) - - def _extract_parameters( - self, params_type: str, _seen: set[str] | None = None - ) -> dict[str, str]: - """Extract parameters from a parameter type definition. - - Handles both struct types ({...}) and top-level union types (TypeA / TypeB), - merging all fields from each alternative as optional parameters. - """ - params = {} - - if _seen is None: - _seen = set() - if params_type in _seen: - return params - _seen.add(params_type) - - if params_type not in self.definitions: - logger.debug(f"Parameter type not found: {params_type}") - return params - - definition = self.definitions[params_type] - - # Handle top-level type alias that is a union of other named types: - # e.g. session.UnsubscribeByAttributesRequest / session.UnsubscribeByIDRequest - # These definitions contain a single line with "/" separating type names - # (not the double-slash "//" used for command unions). - stripped = definition.strip() - if not stripped.startswith("{") and "/" in stripped and "//" not in stripped: - # Each token separated by "/" should be a named type reference - alternatives = [a.strip() for a in stripped.split("/") if a.strip()] - all_named = all(re.match(r"^[\w.]+$", a) for a in alternatives) - if all_named: - for alt_type in alternatives: - alt_params = self._extract_parameters(alt_type, _seen) - params.update(alt_params) - return params - - # Remove the outer curly braces and split by comma - # Then parse each line for key: type patterns - clean_def = stripped - if clean_def.startswith("{"): - clean_def = clean_def[1:] - if clean_def.endswith("}"): - clean_def = clean_def[:-1] - - # Split by newlines and process each line - for line in clean_def.split("\n"): - line = line.strip() - if not line or "Extensible" in line: - continue - - # Match pattern: [?] name: type - # Using a simple pattern that handles optional prefix - match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) - if not match: - # Try without optional marker - match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) - - if match: - param_name = match.group(1).strip() - param_type = match.group(2).strip() - normalized_type = self._normalize_cddl_type(param_type) - - # Skip lines that are part of nested definitions - if "{" not in normalized_type and "(" not in normalized_type: - params[param_name] = normalized_type - logger.debug( - f"Extracted param {param_name}: {normalized_type} from {params_type}" - ) - - return params - - -def module_name_to_class_name(module_name: str) -> str: - """Convert module name to class name (PascalCase). - - Handles both camelCase (browsingContext) and snake_case (browsing_context). - """ - if "_" in module_name: - # Snake_case: browsing_context -> BrowsingContext - return "".join(word.capitalize() for word in module_name.split("_")) - else: - # CamelCase: browsingContext -> BrowsingContext - return module_name[0].upper() + module_name[1:] if module_name else "" - - -def module_name_to_filename(module_name: str) -> str: - """Convert module name to Python filename (snake_case). - - Handles both camelCase (browsingContext) and snake_case (browsing_context). - Special cases: - - browsingContext -> browsing_context - - webExtension -> webextension - """ - # Handle explicit mappings for known camelCase names - camel_to_snake_map = { - "browsingContext": "browsing_context", - "webExtension": "webextension", - } - - if module_name in camel_to_snake_map: - return camel_to_snake_map[module_name] - - if "_" in module_name: - # Already snake_case - return module_name - else: - # Convert camelCase to snake_case for other cases - # This handles cases like "myModuleName" -> "my_module_name" - import re - - s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", module_name) - return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() - - -def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> None: - """Generate __init__.py file for the module.""" - init_path = output_path / "__init__.py" - - code = f"""{SHARED_HEADER} - -from __future__ import annotations - -""" - - for module_name in sorted(modules.keys()): - class_name = module_name_to_class_name(module_name) - filename = module_name_to_filename(module_name) - code += f"from .{filename} import {class_name}\n" - - code += "\n__all__ = [\n" - for module_name in sorted(modules.keys()): - class_name = module_name_to_class_name(module_name) - code += f' "{class_name}",\n' - code += "]\n" - - with open(init_path, "w", encoding="utf-8") as f: - f.write(code) - - logger.info(f"Generated: {init_path}") - - -def generate_common_file(output_path: Path) -> None: - """Generate common.py file with shared utilities.""" - common_path = output_path / "common.py" - - code = ( - "# Licensed to the Software Freedom Conservancy (SFC) under one\n" - "# or more contributor license agreements. See the NOTICE file\n" - "# distributed with this work for additional information\n" - "# regarding copyright ownership. The SFC licenses this file\n" - "# to you under the Apache License, Version 2.0 (the\n" - '# "License"); you may not use this file except in compliance\n' - "# with the License. You may obtain a copy of the License at\n" - "#\n" - "# http://www.apache.org/licenses/LICENSE-2.0\n" - "#\n" - "# Unless required by applicable law or agreed to in writing,\n" - "# software distributed under the License is distributed on an\n" - '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' - "# KIND, either express or implied. See the License for the\n" - "# specific language governing permissions and limitations\n" - "# under the License.\n" - "\n" - '"""Common utilities for BiDi command construction."""\n' - "\n" - "from __future__ import annotations\n" - "\n" - "from collections.abc import Generator\n" - "from typing import Any\n" - "\n" - "\n" - "def command_builder(\n" - " method: str, params: dict[str, Any] | None = None\n" - ") -> Generator[dict[str, Any], Any, Any]:\n" - ' """Build a BiDi command generator.\n' - "\n" - " Args:\n" - ' method: The BiDi method name (e.g., "session.status", "browser.close")\n' - " params: The parameters for the command\n" - "\n" - " Yields:\n" - " A dictionary representing the BiDi command\n" - "\n" - " Returns:\n" - " The result from the BiDi command execution\n" - ' """\n' - " if params is None:\n" - " params = {}\n" - ' result = yield {"method": method, "params": params}\n' - " return result\n" - ) - - with open(common_path, "w", encoding="utf-8") as f: - f.write(code) - - logger.info(f"Generated: {common_path}") - - -def generate_console_file(output_path: Path) -> None: - """Generate console.py file with Console enum helper.""" - console_path = output_path / "console.py" - - code = ( - "# Licensed to the Software Freedom Conservancy (SFC) under one\n" - "# or more contributor license agreements. See the NOTICE file\n" - "# distributed with this work for additional information\n" - "# regarding copyright ownership. The SFC licenses this file\n" - "# to you under the Apache License, Version 2.0 (the\n" - '# "License"); you may not use this file except in compliance\n' - "# with the License. You may obtain a copy of the License at\n" - "#\n" - "# http://www.apache.org/licenses/LICENSE-2.0\n" - "#\n" - "# Unless required by applicable law or agreed to in writing,\n" - "# software distributed under the License is distributed on an\n" - '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' - "# KIND, either express or implied. See the License for the\n" - "# specific language governing permissions and limitations\n" - "# under the License.\n" - "\n" - "from enum import Enum\n" - "\n" - "\n" - "class Console(Enum):\n" - ' ALL = "all"\n' - ' LOG = "log"\n' - ' ERROR = "error"\n' - ) - - with open(console_path, "w", encoding="utf-8") as f: - f.write(code) - - logger.info(f"Generated: {console_path}") - - -def generate_permissions_file(output_path: Path) -> None: - """Generate permissions.py file with permission-related classes.""" - permissions_path = output_path / "permissions.py" - - code = ( - "# Licensed to the Software Freedom Conservancy (SFC) under one\n" - "# or more contributor license agreements. See the NOTICE file\n" - "# distributed with this work for additional information\n" - "# regarding copyright ownership. The SFC licenses this file\n" - "# to you under the Apache License, Version 2.0 (the\n" - '# "License"); you may not use this file except in compliance\n' - "# with the License. You may obtain a copy of the License at\n" - "#\n" - "# http://www.apache.org/licenses/LICENSE-2.0\n" - "#\n" - "# Unless required by applicable law or agreed to in writing,\n" - "# software distributed under the License is distributed on an\n" - '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' - "# KIND, either express or implied. See the License for the\n" - "# specific language governing permissions and limitations\n" - "# under the License.\n" - "\n" - '"""WebDriver BiDi Permissions module."""\n' - "\n" - "from __future__ import annotations\n" - "\n" - "from __future__ import annotations\n" - "\n" - "from enum import Enum\n" - "from typing import Any\n" - "\n" - "from .common import command_builder\n" - "\n" - '_VALID_PERMISSION_STATES = {"granted", "denied", "prompt"}\n' - "\n" - "\n" - "class PermissionState(str, Enum):\n" - ' """Permission state enumeration."""\n' - "\n" - ' GRANTED = "granted"\n' - ' DENIED = "denied"\n' - ' PROMPT = "prompt"\n' - "\n" - "\n" - "class PermissionDescriptor:\n" - ' """Descriptor for a permission."""\n' - "\n" - " def __init__(self, name: str) -> None:\n" - ' """Initialize a PermissionDescriptor.\n' - "\n" - " Args:\n" - " name: The name of the permission (e.g., 'geolocation', 'microphone', 'camera')\n" - ' """\n' - " self.name = name\n" - "\n" - " def __repr__(self) -> str:\n" - " return f\"PermissionDescriptor('{self.name}')\"\n" - "\n" - "\n" - "class Permissions:\n" - ' """WebDriver BiDi Permissions module."""\n' - "\n" - " def __init__(self, websocket_connection: Any) -> None:\n" - ' """Initialize the Permissions module.\n' - "\n" - " Args:\n" - " websocket_connection: The WebSocket connection for sending BiDi commands\n" - ' """\n' - " self._conn = websocket_connection\n" - "\n" - " def set_permission(\n" - " self,\n" - " descriptor: PermissionDescriptor | str,\n" - " state: PermissionState | str,\n" - " origin: str | None = None,\n" - " user_context: str | None = None,\n" - " ) -> None:\n" - ' """Set a permission for a given origin.\n' - "\n" - " Args:\n" - " descriptor: The permission descriptor or permission name as a string\n" - " state: The desired permission state\n" - " origin: The origin for which to set the permission\n" - " user_context: Optional user context ID to scope the permission\n" - "\n" - " Raises:\n" - " ValueError: If the state is not a valid permission state\n" - ' """\n' - " state_value = state.value if isinstance(state, PermissionState) else state\n" - " if state_value not in _VALID_PERMISSION_STATES:\n" - " raise ValueError(\n" - ' f"Invalid permission state: {state_value!r}. "\n' - ' f"Must be one of {sorted(_VALID_PERMISSION_STATES)}"\n' - " )\n" - "\n" - " if isinstance(descriptor, str):\n" - ' descriptor_dict = {"name": descriptor}\n' - " else:\n" - ' descriptor_dict = {"name": descriptor.name}\n' - "\n" - " params: dict[str, Any] = {\n" - ' "descriptor": descriptor_dict,\n' - ' "state": state_value,\n' - " }\n" - " if origin is not None:\n" - ' params["origin"] = origin\n' - " if user_context is not None:\n" - ' params["userContext"] = user_context\n' - "\n" - ' cmd = command_builder("permissions.setPermission", params)\n' - " self._conn.execute(cmd)\n" - ) - - with open(permissions_path, "w", encoding="utf-8") as f: - f.write(code) - - logger.info(f"Generated: {permissions_path}") - - -def main( - cddl_file: str, - output_dir: str, - spec_version: str = "1.0", - enhancements_manifest: str | None = None, -) -> None: - """Main entry point. - - Args: - cddl_file: Path to CDDL specification file - output_dir: Output directory for generated modules - spec_version: BiDi spec version - enhancements_manifest: Path to enhancement manifest Python file - """ - output_path = Path(output_dir).resolve() - output_path.mkdir(parents=True, exist_ok=True) - - logger.info(f"WebDriver BiDi Code Generator v{__version__}") - logger.info(f"Input CDDL: {cddl_file}") - logger.info(f"Output directory: {output_path}") - logger.info(f"Spec version: {spec_version}") - - # Load enhancement manifest - manifest = load_enhancements_manifest(enhancements_manifest) - if manifest: - logger.info(f"Loaded enhancement manifest from: {enhancements_manifest}") - - # Parse CDDL - parser = CddlParser(cddl_file) - modules = parser.parse() - - logger.info(f"Parsed {len(modules)} modules") - - # Clean up existing generated files - for file_path in output_path.glob("*.py"): - if file_path.name != "py.typed" and not file_path.name.startswith("_"): - file_path.unlink() - logger.debug(f"Removed: {file_path}") - - # Generate module files using snake_case filenames - for module_name, module in sorted(modules.items()): - filename = module_name_to_filename(module_name) - module_path = output_path / f"{filename}.py" - - # Get module-specific enhancements (merge with dataclass templates) - module_enhancements = manifest.get("enhancements", {}).get(module_name, {}) - - # Add dataclass methods and docstrings to the enhancement data for this module - full_module_enhancements = { - **module_enhancements, - "dataclass_methods": manifest.get("dataclass_methods", {}), - "method_docstrings": manifest.get("method_docstrings", {}), - } - - with open(module_path, "w", encoding="utf-8") as f: - f.write(module.generate_code(full_module_enhancements)) - logger.info(f"Generated: {module_path}") - - # Generate __init__.py - generate_init_file(output_path, modules) - - # Generate common.py - generate_common_file(output_path) - - # Generate permissions.py - generate_permissions_file(output_path) - - # Generate console.py - generate_console_file(output_path) - - # Create py.typed marker - py_typed_path = output_path / "py.typed" - py_typed_path.touch() - logger.info(f"Generated type marker: {py_typed_path}") - - logger.info("Code generation complete!") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Generate Python WebDriver BiDi modules from CDDL specification" - ) - parser.add_argument( - "cddl_file", - help="Path to CDDL specification file", - ) - parser.add_argument( - "output_dir", - help="Output directory for generated Python modules", - ) - parser.add_argument( - "--version", - default="1.0", - help="BiDi spec version (default: 1.0)", - ) - parser.add_argument( - "--enhancements-manifest", - default=None, - help="Path to enhancement manifest Python file (optional)", - ) - parser.add_argument( - "-v", - "--verbose", - action="store_true", - help="Enable verbose logging", - ) - - args = parser.parse_args() - - if args.verbose: - logging.getLogger("generate_bidi").setLevel(logging.DEBUG) - - try: - main( - args.cddl_file, - args.output_dir, - args.version, - args.enhancements_manifest, - ) - sys.exit(0) - except Exception as e: - logger.error(f"Generation failed: {e}", exc_info=True) - sys.exit(1) diff --git a/py/selenium/webdriver/common/bidi/_event_manager.py b/py/selenium/webdriver/common/bidi/_event_manager.py new file mode 100644 index 0000000000000..216a5b8eccb70 --- /dev/null +++ b/py/selenium/webdriver/common/bidi/_event_manager.py @@ -0,0 +1,186 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Shared event management helpers for generated WebDriver BiDi modules. + +``EventConfig``, ``_EventWrapper``, and ``_EventManager`` are emitted +identically into every generated module that exposes events. Rather than +duplicating ~160 lines of code across all of those modules, they are defined +once here and imported by the generated files. +""" + +from __future__ import annotations + +import threading +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +from selenium.webdriver.common.bidi.session import Session + + +@dataclass +class EventConfig: + """Configuration for a BiDi event.""" + + event_key: str + bidi_event: str + event_class: type + + +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + """ + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) + + +class _EventManager: + """Manages event subscriptions and callbacks.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + """Subscribe to a BiDi event if not already subscribed.""" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist.""" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: + session = Session(self.conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event]["callbacks"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() diff --git a/py/selenium/webdriver/remote/websocket_connection.py b/py/selenium/webdriver/remote/websocket_connection.py index 8d6f745d4ac5b..b4cac118df033 100644 --- a/py/selenium/webdriver/remote/websocket_connection.py +++ b/py/selenium/webdriver/remote/websocket_connection.py @@ -70,6 +70,7 @@ def default(self, o): return result return super().default(o) + logger = logging.getLogger(__name__) @@ -154,7 +155,9 @@ def _serialize_command(self, command): def _deserialize_result(self, result, command): try: _ = command.send(result) - raise WebDriverException("The command's generator function did not exit when expected!") + raise WebDriverException( + "The command's generator function did not exit when expected!" + ) except StopIteration as exit: return exit.value @@ -171,11 +174,15 @@ def on_error(ws, error): def run_socket(): if self.url.startswith("wss://"): - self._ws.run_forever(sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True) + self._ws.run_forever( + sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True + ) else: self._ws.run_forever(suppress_origin=True) - self._ws = WebSocketApp(self.url, on_open=on_open, on_message=on_message, on_error=on_error) + self._ws = WebSocketApp( + self.url, on_open=on_open, on_message=on_message, on_error=on_error + ) self._ws_thread = Thread(target=run_socket, daemon=True) self._ws_thread.start() From 6bf1a33027c429bad14d5b6632ae293f27d3cd81 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 12:12:02 +0000 Subject: [PATCH 08/37] remove --version call --- py/private/generate_bidi.bzl | 1 - 1 file changed, 1 deletion(-) diff --git a/py/private/generate_bidi.bzl b/py/private/generate_bidi.bzl index c11b6efe4735f..e072279f85e94 100644 --- a/py/private/generate_bidi.bzl +++ b/py/private/generate_bidi.bzl @@ -53,7 +53,6 @@ def _generate_bidi_impl(ctx): args = [ cddl_file.path, output_base, - "--version", spec_version, ] From 4b314d66ca4d60d0b694fbbf72fecebb567b4d6b Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 12:37:46 +0000 Subject: [PATCH 09/37] correct web extensions --- py/generate_bidi.py | 1274 +++++++++++++++-- py/private/bidi_enhancements_manifest.py | 10 +- py/selenium/webdriver/common/bidi/__init__.py | 17 - py/selenium/webdriver/common/bidi/browser.py | 83 +- .../webdriver/common/bidi/browsing_context.py | 256 ++-- py/selenium/webdriver/common/bidi/common.py | 11 +- .../webdriver/common/bidi/emulation.py | 216 +-- py/selenium/webdriver/common/bidi/input.py | 83 +- py/selenium/webdriver/common/bidi/log.py | 39 +- py/selenium/webdriver/common/bidi/network.py | 312 ++-- .../webdriver/common/bidi/permissions.py | 10 +- py/selenium/webdriver/common/bidi/script.py | 253 ++-- py/selenium/webdriver/common/bidi/session.py | 77 +- py/selenium/webdriver/common/bidi/storage.py | 75 +- .../webdriver/common/bidi/webextension.py | 57 +- 15 files changed, 1764 insertions(+), 1009 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 8103cafe40684..d14e2575c8bfd 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -18,11 +18,12 @@ import logging import re import sys +from collections import defaultdict from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from textwrap import indent as tw_indent -from typing import Any +from textwrap import dedent, indent as tw_indent +from typing import Any, Dict, List, Optional, Set, Tuple __version__ = "1.0.0" @@ -32,24 +33,7 @@ logger = logging.getLogger("generate_bidi") # File headers -SHARED_HEADER = """# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# DO NOT EDIT THIS FILE! +SHARED_HEADER = """# DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make # changes, edit the generator and regenerate all of the modules.""" @@ -59,7 +43,8 @@ # WebDriver BiDi module: {{}} from __future__ import annotations -from typing import Any +from typing import Any, Dict, List, Optional, Union +from .common import command_builder """ @@ -68,7 +53,7 @@ def indent(s: str, n: int) -> str: return tw_indent(s, n * " ") -def load_enhancements_manifest(manifest_path: str | None) -> dict[str, Any]: +def load_enhancements_manifest(manifest_path: Optional[str]) -> Dict[str, Any]: """Load enhancement manifest from a Python file. Args: @@ -139,10 +124,10 @@ def get_annotation(cls, cddl_type: str) -> str: if cddl_type.startswith("["): # Array inner = cddl_type.strip("[]+ ") inner_type = cls.get_annotation(inner) - return f"list[{inner_type}]" + return f"List[{inner_type}]" if cddl_type.startswith("{"): # Map/Dict - return "dict[str, Any]" + return "Dict[str, Any]" # Default to Any for unknown types return "Any" @@ -154,11 +139,11 @@ class CddlCommand: module: str name: str - params: dict[str, str] = field(default_factory=dict) - result: str | None = None + params: Dict[str, str] = field(default_factory=dict) + result: Optional[str] = None description: str = "" - def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: + def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str: """Generate Python method code for this command. Args: @@ -189,15 +174,8 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: else: param_list = "self" - # Build method body - wrap long signatures over multiple lines if needed - sig_line = f" def {method_name}({param_list}):" - if len(sig_line) > 120 and param_strs: - body = f" def {method_name}(\n self,\n" - for p in param_strs: - body += f" {p},\n" - body += " ):\n" - else: - body = sig_line + "\n" + # Build method body + body = f" def {method_name}({param_list}):\n" body += f' """{self.description or "Execute " + self.module + "." + self.name}."""\n' # Add validation if specified @@ -259,6 +237,7 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: if result_param == "download_behavior": body += ' "downloadBehavior": download_behavior,\n' # Add remaining parameters that weren't part of the transform + override_params = enhancements.get("params_override", {}) for cddl_param_name in self.params: if cddl_param_name not in ["downloadBehavior"]: snake_name = self._camel_to_snake(cddl_param_name) @@ -285,45 +264,45 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: # Extract property from list items body += f' if result and "{extract_field}" in result:\n' body += f' items = result.get("{extract_field}", [])\n' - body += " return [\n" + body += f" return [\n" body += f' item.get("{extract_property}")\n' - body += " for item in items\n" - body += " if isinstance(item, dict)\n" - body += " ]\n" - body += " return []\n" + body += f" for item in items\n" + body += f" if isinstance(item, dict)\n" + body += f" ]\n" + body += f" return []\n" elif extract_field in deserialize_rules: # Extract field and deserialize to typed objects type_name = deserialize_rules[extract_field] body += f' if result and "{extract_field}" in result:\n' body += f' items = result.get("{extract_field}", [])\n' - body += " return [\n" + body += f" return [\n" body += f" {type_name}(\n" body += self._generate_field_args(extract_field, type_name) - body += " )\n" - body += " for item in items\n" - body += " if isinstance(item, dict)\n" - body += " ]\n" - body += " return []\n" + body += f" )\n" + body += f" for item in items\n" + body += f" if isinstance(item, dict)\n" + body += f" ]\n" + body += f" return []\n" else: # Simple field extraction (return the value directly, not wrapped in result dict) body += f' if result and "{extract_field}" in result:\n' body += f' extracted = result.get("{extract_field}")\n' - body += " return extracted\n" - body += " return result\n" + body += f" return extracted\n" + body += f" return result\n" elif "deserialize" in enhancements: # Deserialize response to typed objects (legacy, without extract_field) deserialize_rules = enhancements["deserialize"] for response_field, type_name in deserialize_rules.items(): body += f' if result and "{response_field}" in result:\n' body += f' items = result.get("{response_field}", [])\n' - body += " return [\n" + body += f" return [\n" body += f" {type_name}(\n" body += self._generate_field_args(response_field, type_name) - body += " )\n" - body += " for item in items\n" - body += " if isinstance(item, dict)\n" - body += " ]\n" - body += " return []\n" + body += f" )\n" + body += f" for item in items\n" + body += f" if isinstance(item, dict)\n" + body += f" ]\n" + body += f" return []\n" else: # No special response handling, just return the result body += " return result\n" @@ -372,10 +351,10 @@ class CddlTypeDefinition: module: str name: str - fields: dict[str, str] = field(default_factory=dict) + fields: Dict[str, str] = field(default_factory=dict) description: str = "" - def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str: + def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> str: """Generate Python dataclass code for this type. Args: @@ -385,14 +364,11 @@ def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str dataclass_methods = enhancements.get("dataclass_methods", {}) method_docstrings = enhancements.get("method_docstrings", {}) - # Generate class name from type name. - # CDDL type names that start with a lowercase letter (e.g. camelCase - # command-parameter types like "setNetworkConditionsParameters") are - # capitalised so that the resulting Python class follows PascalCase. - class_name = self.name[0].upper() + self.name[1:] if self.name else self.name - code = "@dataclass\n" + # Generate class name from type name (keep it as-is, don't split on underscores) + class_name = self.name + code = f"@dataclass\n" code += f"class {class_name}:\n" - code += f' """{class_name} type definition."""\n\n' + code += f' """{self.description or self.name}."""\n\n' if not self.fields: code += " pass\n" @@ -410,7 +386,7 @@ def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str literal_value = literal_match.group(1) code += f' {snake_name}: str = field(default="{literal_value}", init=False)\n' # Check if this field is a list type - elif "list[" in python_type: + elif "List[" in python_type: code += f" {snake_name}: {python_type} = field(default_factory=list)\n" else: code += f" {snake_name}: {python_type} = None\n" @@ -477,7 +453,7 @@ class CddlEnum: module: str name: str - values: list[str] = field(default_factory=list) + values: List[str] = field(default_factory=list) description: str = "" def to_python_class(self) -> str: @@ -486,9 +462,9 @@ def to_python_class(self) -> str: Generates a simple class with string constants to match the existing pattern in the codebase (e.g., ClientWindowState). """ - class_name = self.name[0].upper() + self.name[1:] if self.name else self.name + class_name = self.name code = f"class {class_name}:\n" - code += f' """{class_name}."""\n\n' + code += f' """{self.description or self.name}."""\n\n' for value in self.values: # Convert value to UPPER_SNAKE_CASE constant name @@ -554,10 +530,10 @@ class CddlModule: """Represents a CDDL module (e.g., script, network, browsing_context).""" name: str - commands: list[CddlCommand] = field(default_factory=list) - types: list[CddlTypeDefinition] = field(default_factory=list) - enums: list[CddlEnum] = field(default_factory=list) - events: list[CddlEvent] = field(default_factory=list) + commands: List[CddlCommand] = field(default_factory=list) + types: List[CddlTypeDefinition] = field(default_factory=list) + enums: List[CddlEnum] = field(default_factory=list) + events: List[CddlEvent] = field(default_factory=list) @staticmethod def _convert_method_to_event_name(method_suffix: str) -> str: @@ -572,33 +548,7 @@ def _convert_method_to_event_name(method_suffix: str) -> str: s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", method_suffix) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() - def _needs_field_import(self, enhancements: dict[str, Any] | None = None) -> bool: - """Check if any type definition in this module requires the 'field' import. - - Respects the same type exclusions applied during code generation. - """ - enhancements = enhancements or {} - extra_cls_names: set[str] = set() - for extra_cls in enhancements.get("extra_dataclasses", []): - m = re.search(r"^class\s+(\w+)", extra_cls, re.MULTILINE) - if m: - extra_cls_names.add(m.group(1)) - exclude_types = set(enhancements.get("exclude_types", [])) | extra_cls_names - - for type_def in self.types: - if type_def.name in exclude_types: - continue - for field_type in type_def.fields.values(): - # Literal string discriminants use field(default=..., init=False) - if re.match(r'^"', field_type.strip()): - return True - # List-typed fields use field(default_factory=list) - python_type = CddlTypeDefinition._get_python_type(field_type) - if python_type.startswith("list["): - return True - return False - - def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: + def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: """Generate Python code for this module. Args: @@ -608,18 +558,18 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code = MODULE_HEADER.format(self.name) # Add imports if needed - if self.commands: - code += "from .common import command_builder\n" - dataclass_imported = False + if self.types: + code += "from dataclasses import field\n" if self.commands or self.types: + code += "from typing import Generator\n" code += "from dataclasses import dataclass\n" - dataclass_imported = True - if self.types and self._needs_field_import(enhancements): - code += "from dataclasses import field\n" # Add imports for event handling if needed if self.events: - code += "from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager\n" + code += "import threading\n" + code += "from collections.abc import Callable\n" + code += "from dataclasses import dataclass\n" + code += "from selenium.webdriver.common.bidi.session import Session\n" code += "\n\n" @@ -700,19 +650,8 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ - # Collect names of extra_dataclasses so we can skip CDDL-generated - # enums and types that are overridden by manual definitions. - extra_cls_names = set() - for extra_cls in enhancements.get("extra_dataclasses", []): - m = re.search(r"^class\s+(\w+)", extra_cls, re.MULTILINE) - if m: - extra_cls_names.add(m.group(1)) - exclude_types = set(enhancements.get("exclude_types", [])) | extra_cls_names - - # Generate enums first, skipping any that are overridden via extra_dataclasses + # Generate enums first for enum_def in self.enums: - if enum_def.name in exclude_types: - continue code += enum_def.to_python_class() code += "\n\n" @@ -721,6 +660,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code += f"{alias} = {target}\n\n" # Generate type dataclasses, skipping any overridden by extra_dataclasses + exclude_types = set(enhancements.get("exclude_types", [])) for type_def in self.types: if type_def.name in exclude_types: continue @@ -740,18 +680,13 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: # Generate EVENT_NAME_MAPPING for the module code += "# BiDi Event Name to Parameter Type Mapping\n" code += "EVENT_NAME_MAPPING = {\n" - # Collect event keys from extra_events so we skip CDDL duplicates - extra_event_keys = { - evt["event_key"] for evt in enhancements.get("extra_events", []) - } for event_def in self.events: # Convert method name to user-friendly event name # e.g., "browsingContext.contextCreated" -> "context_created" method_parts = event_def.method.split(".") if len(method_parts) == 2: event_name = self._convert_method_to_event_name(method_parts[1]) - if event_name not in extra_event_keys: - code += f' "{event_name}": "{event_def.method}",\n' + code += f' "{event_name}": "{event_def.method}",\n' # Extra events not in the CDDL spec (e.g. Chromium-specific events) for extra_evt in enhancements.get("extra_events", []): code += ( @@ -797,7 +732,1094 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ code += "\n\n" - # EventConfig, _EventWrapper, and _EventManager are imported from - # ._event_manager (see the import block above); nothing to emit here. + # Generate EventConfig and _EventManager for modules with events + if self.events: + # Generate EventConfig dataclass + code += """@dataclass +class EventConfig: + \"\"\"Configuration for a BiDi event.\"\"\" + event_key: str + bidi_event: str + event_class: type + + +""" + + # Generate _EventManager class + code += """class _EventWrapper: + \"\"\"Wrapper to provide event_class attribute for WebSocketConnection callbacks.\"\"\" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + \"\"\"Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + \"\"\" + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, \"from_json\") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend([\"_\", char.lower()]) + else: + result.append(char) + return \"\".join(result) + + +class _EventManager: + \"\"\"Manages event subscriptions and callbacks.\"\"\" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + \"\"\"Subscribe to a BiDi event if not already subscribed.\"\"\" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get(\"subscription\") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + \"callbacks\": [], + \"subscription_id\": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + \"\"\"Unsubscribe from a BiDi event if no more callbacks exist.\"\"\" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry[\"callbacks\"]: + session = Session(self.conn) + sub_id = entry.get(\"subscription_id\") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event][\"callbacks\"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry[\"callbacks\"]: + entry[\"callbacks\"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + \"\"\"Clear all event handlers.\"\"\" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry[\"callbacks\"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get(\"subscription_id\") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() + + +""" + code += "\n\n" # Generate class + # Convert module name (camelCase or snake_case) to proper class name (PascalCase) + class_name = module_name_to_class_name(self.name) + code += f"class {class_name}:\n" + code += f' """WebDriver BiDi {self.name} module."""\n\n' + + # Add EVENT_CONFIGS dict if there are events + if self.events: + code += ( + " EVENT_CONFIGS = {}\n" # Will be populated after types are defined + ) + + if self.name == "script": + code += " def __init__(self, conn, driver=None) -> None:\n" + code += " self._conn = conn\n" + code += " self._driver = driver\n" + else: + code += " def __init__(self, conn) -> None:\n" + code += " self._conn = conn\n" + + # Initialize _event_manager if there are events + if self.events: + code += " self._event_manager = _EventManager(conn, self.EVENT_CONFIGS)\n" + + # Append extra init code from enhancements (e.g. self.intercepts = []) + for init_line in enhancements.get("extra_init_code", []): + code += f" {init_line}\n" + + code += "\n" + + # Generate command methods + exclude_methods = enhancements.get("exclude_methods", []) + if self.commands: + for command in self.commands: + # Get method-specific enhancements + # Convert command name to snake_case to match enhancement manifest keys + method_name_snake = command._camel_to_snake(command.name) + if method_name_snake in exclude_methods: + continue + method_enhancements = enhancements.get(method_name_snake, {}) + code += command.to_python_method(method_enhancements) + code += "\n" + else: + code += " pass\n" + + # Emit extra methods from enhancement manifest + for extra_method in enhancements.get("extra_methods", []): + code += extra_method + code += "\n" + + # Add delegating event handler methods if events are present + if self.events: + code += """ + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + \"\"\"Add an event handler. + + Args: + event: The event to subscribe to. + callback: The callback function to execute on event. + contexts: The context IDs to subscribe to (optional). + + Returns: + The callback ID. + \"\"\" + return self._event_manager.add_event_handler(event, callback, contexts) + + def remove_event_handler(self, event: str, callback_id: int) -> None: + \"\"\"Remove an event handler. + + Args: + event: The event to unsubscribe from. + callback_id: The callback ID. + \"\"\" + return self._event_manager.remove_event_handler(event, callback_id) + + def clear_event_handlers(self) -> None: + \"\"\"Clear all event handlers.\"\"\" + return self._event_manager.clear_event_handlers() +""" + + # Generate event info type aliases AFTER the class definition + # This ensures all types are available when we create the aliases + if self.events: + code += "\n# Event Info Type Aliases\n" + for event_def in self.events: + code += event_def.to_python_dataclass() + code += "\n" + + # Now populate EVENT_CONFIGS after the aliases are defined + code += f"\n# Populate EVENT_CONFIGS with event configuration mappings\n" + # Use globals() to look up types dynamically to handle missing types gracefully + code += f"_globals = globals()\n" + code += f"{class_name}.EVENT_CONFIGS = {{\n" + for event_def in self.events: + # Convert method name to user-friendly event name + method_parts = event_def.method.split(".") + if len(method_parts) == 2: + event_name = self._convert_method_to_event_name(method_parts[1]) + # The event class is the event name (e.g., ContextCreated) + # Try to get it from globals, default to dict if not found + code += f' "{event_name}": (EventConfig("{event_name}", "{event_def.method}", _globals.get("{event_def.name}", dict)) if _globals.get("{event_def.name}") else EventConfig("{event_name}", "{event_def.method}", dict)),\n' + # Extra events not in the CDDL spec + for extra_evt in enhancements.get("extra_events", []): + ek = extra_evt["event_key"] + be = extra_evt["bidi_event"] + ec = extra_evt["event_class"] + code += f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),\n' + code += "}\n" + + return code + + +class CddlParser: + """Parse CDDL specification files.""" + + def __init__(self, cddl_path: str): + """Initialize parser with CDDL file path.""" + self.cddl_path = Path(cddl_path) + self.content = "" + self.modules: Dict[str, CddlModule] = {} + self.definitions: Dict[str, str] = {} + self.event_names: Set[str] = set() # Names of definitions that are events + self._read_file() + + def _read_file(self) -> None: + """Read and preprocess CDDL file.""" + if not self.cddl_path.exists(): + raise FileNotFoundError(f"CDDL file not found: {self.cddl_path}") + + with open(self.cddl_path, "r", encoding="utf-8") as f: + self.content = f.read() + + logger.info(f"Loaded CDDL file: {self.cddl_path}") + + def parse(self) -> Dict[str, CddlModule]: + """Parse CDDL content and return modules.""" + # Remove comments + content = self._remove_comments(self.content) + + # Extract all definitions + self._extract_definitions(content) + + # Extract event names from event union definitions + self._extract_event_names() + + # Extract type definitions by module + self._extract_types() + + # Extract event definitions by module + self._extract_events() + + # Extract command definitions by module + self._extract_commands() + + # If no modules found, create a default one from the filename + if not self.modules: + module_name = self.cddl_path.stem + default_module = CddlModule(name=module_name) + self.modules[module_name] = default_module + logger.warning(f"No modules found in CDDL, creating default: {module_name}") + + return self.modules + + def _remove_comments(self, content: str) -> str: + """Remove comments from CDDL content.""" + # CDDL uses ; for comments to end of line + lines = content.split("\n") + cleaned = [] + for line in lines: + if ";" in line and not line.strip().startswith(";"): + line = line[: line.index(";")] + elif line.strip().startswith(";"): + continue + cleaned.append(line) + return "\n".join(cleaned) + + def _extract_definitions(self, content: str) -> None: + """Extract CDDL definitions (type definitions, commands, etc.).""" + # Match pattern: Name = Definition + # Handles multiline definitions properly + pattern = r"(\w+(?:\.\w+)*)\s*=\s*(.+?)(?=\n\w+(?:\.\w+)?\s*=|\Z)" + + for match in re.finditer(pattern, content, re.DOTALL): + name = match.group(1).strip() + definition = match.group(2).strip() + self.definitions[name] = definition + logger.debug(f"Extracted definition: {name}") + + def _extract_event_names(self) -> None: + """Extract event names from event union definitions. + + Event union definitions follow pattern: + module.ModuleEvent = ( + module.EventName1 // + module.EventName2 // + ... + ) + """ + # Look for definitions like "BrowsingContextEvent", "SessionEvent", etc. + event_union_pattern = re.compile(r"(\w+\.)?(\w+)Event") + + for def_name, def_content in self.definitions.items(): + # Check if this looks like an event union (name ends with "Event") and + # contains a module-qualified reference like "module.EventName". + # Handles both single-item (no //) and multi-item (// separated) unions. + if "Event" in def_name and re.search(r"\w+\.\w+", def_content): + # Extract event names from the union (works for single and multi-item) + event_refs = re.findall(r"(\w+\.\w+)", def_content) + for event_ref in event_refs: + self.event_names.add(event_ref) + logger.debug(f"Identified event: {event_ref} (from {def_name})") + + def _extract_types(self) -> None: + """Extract type definitions from parsed definitions.""" + # Type definitions follow pattern: module.TypeName = { field: type, ... } + # They have dots in the name and curly braces in the content + # But they DON'T have method: "..." pattern (which means it's not a command) + # Enums follow pattern: module.EnumName = "value1" / "value2" / ... + + for def_name, def_content in self.definitions.items(): + # Skip if not a namespaced name (e.g., skip "EmptyParams", "Extensible") + if "." not in def_name: + continue + + # Skip if it's a command (contains method: pattern) + if "method:" in def_content: + continue + + # Extract module.TypeName + if "." in def_name: + module_name, type_name = def_name.rsplit(".", 1) + + # Create module if not exists + if module_name not in self.modules: + self.modules[module_name] = CddlModule(name=module_name) + + # Check if this is an enum (string union with /) + if self._is_enum_definition(def_content): + # Extract enum values + values = self._extract_enum_values(def_content) + if values: + enum_def = CddlEnum( + module=module_name, + name=type_name, + values=values, + description=f"{type_name}", + ) + self.modules[module_name].enums.append(enum_def) + logger.debug( + f"Found enum: {def_name} with {len(values)} values" + ) + else: + # Extract fields from type definition + fields = self._extract_type_fields(def_content) + + if fields: # Only create type if it has fields + type_def = CddlTypeDefinition( + module=module_name, + name=type_name, + fields=fields, + description=f"{type_name}", + ) + self.modules[module_name].types.append(type_def) + logger.debug( + f"Found type: {def_name} with {len(fields)} fields" + ) + + def _is_enum_definition(self, definition: str) -> bool: + """Check if a definition is an enum (string union with /). + + Enums are defined as: "value1" / "value2" / "value3" + """ + # Clean whitespace + clean_def = definition.strip() + + # Must not have curly braces (that would be a type definition) + if "{" in clean_def or "}" in clean_def: + return False + + # Must contain the union operator / surrounded by quotes + # Pattern: "something" / "something_else" + return " / " in clean_def and '"' in clean_def + + def _extract_enum_values(self, enum_definition: str) -> List[str]: + """Extract individual values from an enum definition. + + Enums are defined as: "value1" / "value2" / "value3" + Can span multiple lines. + """ + values = [] + + # Clean the definition and extract quoted strings + # Split by / and extract quoted values + parts = enum_definition.split("/") + + for part in parts: + part = part.strip() + + # Extract quoted string - use search instead of match to find quotes anywhere + match = re.search(r'"([^"]*)"', part) + if match: + value = match.group(1) + values.append(value) + logger.debug(f"Extracted enum value: {value}") + + return values + + @staticmethod + def _normalize_cddl_type(field_type: str) -> str: + """Normalize a CDDL type expression to a simple Python-compatible form. + + Strips CDDL control operators (.ge, .le, .gt, .lt, .default, etc.) and + replaces interval/constraint expressions with their base types so that + the caller can safely check for nested struct syntax. + + Examples: + '(float .ge 0.0) .default 1.0' -> 'float' + '(float .ge 0.0) / null' -> 'float / null' + '(0.0...360.0) / null' -> 'float / null' + '-90.0..90.0' -> 'float' + 'float / null .default null' -> 'float / null' + """ + result = field_type + # Remove trailing .default annotations + result = re.sub(r"\s*\.default\s+\S+", "", result) + # Replace parenthesised constraint expressions: (baseType .operator ...) -> baseType + result = re.sub(r"\((\w+)\s+\.\w+[^)]*\)", r"\1", result) + # Replace parenthesised numeric interval types: (0.0...360.0) -> float + result = re.sub(r"\(-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?\)", "float", result) + # Replace bare numeric interval types: -90.0..90.0 -> float + result = re.sub(r"-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?", "float", result) + return result.strip() + + def _extract_type_fields(self, type_definition: str) -> Dict[str, str]: + """Extract fields from a type definition block.""" + fields = {} + + # Remove outer braces + clean_def = type_definition.strip() + if clean_def.startswith("{"): + clean_def = clean_def[1:] + if clean_def.endswith("}"): + clean_def = clean_def[:-1] + + # Parse each line for field: type patterns + for line in clean_def.split("\n"): + line = line.strip() + if not line or "Extensible" in line or line.startswith("//"): + continue + + # Match pattern: [?] fieldName: type + match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + if not match: + # Try without optional marker + match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + + if match: + field_name = match.group(1).strip() + field_type = match.group(2).strip() + normalized_type = self._normalize_cddl_type(field_type) + + # Skip lines that are part of nested definitions + if "{" not in normalized_type and "(" not in normalized_type: + fields[field_name] = normalized_type + logger.debug(f"Extracted field {field_name}: {normalized_type}") + + return fields + + def _extract_events(self) -> None: + """Extract event definitions from parsed definitions. + + Events are definitions that: + 1. Are listed in an event union (e.g., BrowsingContextEvent) + 2. Have method: "..." and params: ... fields + + Event pattern: module.EventName = (method: "module.eventName", params: module.ParamType) + """ + # Find definitions that are in the event_names set + event_pattern = re.compile( + r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" + ) + + for def_name, def_content in self.definitions.items(): + # Skip if not identified as an event + if def_name not in self.event_names: + continue + + # Extract method and params + match = event_pattern.search(def_content) + if match: + method = match.group(1) # e.g., "browsingContext.contextCreated" + params_type = match.group(2) # e.g., "browsingContext.Info" + + # Extract module name from method + if "." in method: + module_name, _ = method.split(".", 1) + + # Create module if not exists + if module_name not in self.modules: + self.modules[module_name] = CddlModule(name=module_name) + + # Extract event name from definition name (e.g., browsingContext.ContextCreated) + _, event_name = def_name.rsplit(".", 1) + + # Create event + event = CddlEvent( + module=module_name, + name=event_name, + method=method, + params_type=params_type, + description=f"Event: {method}", + ) + + self.modules[module_name].events.append(event) + logger.debug( + f"Found event: {def_name} (method={method}, params={params_type})" + ) + + def _extract_commands(self) -> None: + """Extract command definitions from parsed definitions.""" + # Find command definitions that follow pattern: module.Command = (method: "...", params: ...) + command_pattern = re.compile( + r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" + ) + + for def_name, def_content in self.definitions.items(): + # Skip definitions that are events (they share the same pattern) + if def_name in self.event_names: + continue + matches = list(command_pattern.finditer(def_content)) + if matches: + for match in matches: + method = match.group(1) # e.g., "session.new" + params_type = match.group(2) # e.g., "session.NewParameters" + + # Extract module name from method + if "." in method: + module_name, command_name = method.split(".", 1) + + # Create module if not exists + if module_name not in self.modules: + self.modules[module_name] = CddlModule(name=module_name) + + # Extract parameters + params = self._extract_parameters(params_type) + + # Create command + cmd = CddlCommand( + module=module_name, + name=command_name, + params=params, + description=f"Execute {method}", + ) + + self.modules[module_name].commands.append(cmd) + logger.debug( + f"Found command: {method} with params {params_type}" + ) + + def _extract_parameters( + self, params_type: str, _seen: Optional[Set[str]] = None + ) -> Dict[str, str]: + """Extract parameters from a parameter type definition. + + Handles both struct types ({...}) and top-level union types (TypeA / TypeB), + merging all fields from each alternative as optional parameters. + """ + params = {} + + if _seen is None: + _seen = set() + if params_type in _seen: + return params + _seen.add(params_type) + + if params_type not in self.definitions: + logger.debug(f"Parameter type not found: {params_type}") + return params + + definition = self.definitions[params_type] + + # Handle top-level type alias that is a union of other named types: + # e.g. session.UnsubscribeByAttributesRequest / session.UnsubscribeByIDRequest + # These definitions contain a single line with "/" separating type names + # (not the double-slash "//" used for command unions). + stripped = definition.strip() + if not stripped.startswith("{") and "/" in stripped and "//" not in stripped: + # Each token separated by "/" should be a named type reference + alternatives = [a.strip() for a in stripped.split("/") if a.strip()] + all_named = all(re.match(r"^[\w.]+$", a) for a in alternatives) + if all_named: + for alt_type in alternatives: + alt_params = self._extract_parameters(alt_type, _seen) + params.update(alt_params) + return params + + # Remove the outer curly braces and split by comma + # Then parse each line for key: type patterns + clean_def = stripped + if clean_def.startswith("{"): + clean_def = clean_def[1:] + if clean_def.endswith("}"): + clean_def = clean_def[:-1] + + # Split by newlines and process each line + for line in clean_def.split("\n"): + line = line.strip() + if not line or "Extensible" in line: + continue + + # Match pattern: [?] name: type + # Using a simple pattern that handles optional prefix + match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + if not match: + # Try without optional marker + match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + + if match: + param_name = match.group(1).strip() + param_type = match.group(2).strip() + normalized_type = self._normalize_cddl_type(param_type) + + # Skip lines that are part of nested definitions + if "{" not in normalized_type and "(" not in normalized_type: + params[param_name] = normalized_type + logger.debug( + f"Extracted param {param_name}: {normalized_type} from {params_type}" + ) + + return params + + +def module_name_to_class_name(module_name: str) -> str: + """Convert module name to class name (PascalCase). + + Handles both camelCase (browsingContext) and snake_case (browsing_context). + """ + if "_" in module_name: + # Snake_case: browsing_context -> BrowsingContext + return "".join(word.capitalize() for word in module_name.split("_")) + else: + # CamelCase: browsingContext -> BrowsingContext + return module_name[0].upper() + module_name[1:] if module_name else "" + + +def module_name_to_filename(module_name: str) -> str: + """Convert module name to Python filename (snake_case). + + Handles both camelCase (browsingContext) and snake_case (browsing_context). + Special cases: + - browsingContext -> browsing_context + - webExtension -> webextension + """ + # Handle explicit mappings for known camelCase names + camel_to_snake_map = { + "browsingContext": "browsing_context", + "webExtension": "webextension", + } + + if module_name in camel_to_snake_map: + return camel_to_snake_map[module_name] + + if "_" in module_name: + # Already snake_case + return module_name + else: + # Convert camelCase to snake_case for other cases + # This handles cases like "myModuleName" -> "my_module_name" + import re + + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", module_name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + +def generate_init_file(output_path: Path, modules: Dict[str, CddlModule]) -> None: + """Generate __init__.py file for the module.""" + init_path = output_path / "__init__.py" + + code = f"""{SHARED_HEADER} + +from __future__ import annotations + +""" + + for module_name in sorted(modules.keys()): + class_name = module_name_to_class_name(module_name) + filename = module_name_to_filename(module_name) + code += f"from .{filename} import {class_name}\n" + + code += f"\n__all__ = [\n" + for module_name in sorted(modules.keys()): + class_name = module_name_to_class_name(module_name) + code += f' "{class_name}",\n' + code += "]\n" + + with open(init_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {init_path}") + + +def generate_common_file(output_path: Path) -> None: + """Generate common.py file with shared utilities.""" + common_path = output_path / "common.py" + + code = ( + "# Licensed to the Software Freedom Conservancy (SFC) under one\n" + "# or more contributor license agreements. See the NOTICE file\n" + "# distributed with this work for additional information\n" + "# regarding copyright ownership. The SFC licenses this file\n" + "# to you under the Apache License, Version 2.0 (the\n" + '# "License"); you may not use this file except in compliance\n' + "# with the License. You may obtain a copy of the License at\n" + "#\n" + "# http://www.apache.org/licenses/LICENSE-2.0\n" + "#\n" + "# Unless required by applicable law or agreed to in writing,\n" + "# software distributed under the License is distributed on an\n" + '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' + "# KIND, either express or implied. See the License for the\n" + "# specific language governing permissions and limitations\n" + "# under the License.\n" + "\n" + '"""Common utilities for BiDi command construction."""\n' + "\n" + "from typing import Any, Dict, Generator\n" + "\n" + "\n" + "def command_builder(\n" + " method: str, params: Dict[str, Any]\n" + ") -> Generator[Dict[str, Any], Any, Any]:\n" + ' """Build a BiDi command generator.\n' + "\n" + " Args:\n" + ' method: The BiDi method name (e.g., "session.status", "browser.close")\n' + " params: The parameters for the command\n" + "\n" + " Yields:\n" + " A dictionary representing the BiDi command\n" + "\n" + " Returns:\n" + " The result from the BiDi command execution\n" + ' """\n' + ' result = yield {"method": method, "params": params}\n' + " return result\n" + ) + + with open(common_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {common_path}") + + +def generate_console_file(output_path: Path) -> None: + """Generate console.py file with Console enum helper.""" + console_path = output_path / "console.py" + + code = ( + "# Licensed to the Software Freedom Conservancy (SFC) under one\n" + "# or more contributor license agreements. See the NOTICE file\n" + "# distributed with this work for additional information\n" + "# regarding copyright ownership. The SFC licenses this file\n" + "# to you under the Apache License, Version 2.0 (the\n" + '# "License"); you may not use this file except in compliance\n' + "# with the License. You may obtain a copy of the License at\n" + "#\n" + "# http://www.apache.org/licenses/LICENSE-2.0\n" + "#\n" + "# Unless required by applicable law or agreed to in writing,\n" + "# software distributed under the License is distributed on an\n" + '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' + "# KIND, either express or implied. See the License for the\n" + "# specific language governing permissions and limitations\n" + "# under the License.\n" + "\n" + "from enum import Enum\n" + "\n" + "\n" + "class Console(Enum):\n" + ' ALL = "all"\n' + ' LOG = "log"\n' + ' ERROR = "error"\n' + ) + + with open(console_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {console_path}") + + +def generate_permissions_file(output_path: Path) -> None: + """Generate permissions.py file with permission-related classes.""" + permissions_path = output_path / "permissions.py" + + code = ( + "# Licensed to the Software Freedom Conservancy (SFC) under one\n" + "# or more contributor license agreements. See the NOTICE file\n" + "# distributed with this work for additional information\n" + "# regarding copyright ownership. The SFC licenses this file\n" + "# to you under the Apache License, Version 2.0 (the\n" + '# "License"); you may not use this file except in compliance\n' + "# with the License. You may obtain a copy of the License at\n" + "#\n" + "# http://www.apache.org/licenses/LICENSE-2.0\n" + "#\n" + "# Unless required by applicable law or agreed to in writing,\n" + "# software distributed under the License is distributed on an\n" + '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' + "# KIND, either express or implied. See the License for the\n" + "# specific language governing permissions and limitations\n" + "# under the License.\n" + "\n" + '"""WebDriver BiDi Permissions module."""\n' + "\n" + "from __future__ import annotations\n" + "\n" + "from enum import Enum\n" + "from typing import Any, Optional, Union\n" + "\n" + "from .common import command_builder\n" + "\n" + '_VALID_PERMISSION_STATES = {"granted", "denied", "prompt"}\n' + "\n" + "\n" + "class PermissionState(str, Enum):\n" + ' """Permission state enumeration."""\n' + "\n" + ' GRANTED = "granted"\n' + ' DENIED = "denied"\n' + ' PROMPT = "prompt"\n' + "\n" + "\n" + "class PermissionDescriptor:\n" + ' """Descriptor for a permission."""\n' + "\n" + " def __init__(self, name: str) -> None:\n" + ' """Initialize a PermissionDescriptor.\n' + "\n" + " Args:\n" + " name: The name of the permission (e.g., 'geolocation', 'microphone', 'camera')\n" + ' """\n' + " self.name = name\n" + "\n" + " def __repr__(self) -> str:\n" + " return f\"PermissionDescriptor('{self.name}')\"\n" + "\n" + "\n" + "class Permissions:\n" + ' """WebDriver BiDi Permissions module."""\n' + "\n" + " def __init__(self, websocket_connection: Any) -> None:\n" + ' """Initialize the Permissions module.\n' + "\n" + " Args:\n" + " websocket_connection: The WebSocket connection for sending BiDi commands\n" + ' """\n' + " self._conn = websocket_connection\n" + "\n" + " def set_permission(\n" + " self,\n" + " descriptor: Union[PermissionDescriptor, str],\n" + " state: Union[PermissionState, str],\n" + " origin: Optional[str] = None,\n" + " user_context: Optional[str] = None,\n" + " ) -> None:\n" + ' """Set a permission for a given origin.\n' + "\n" + " Args:\n" + " descriptor: The permission descriptor or permission name as a string\n" + " state: The desired permission state\n" + " origin: The origin for which to set the permission\n" + " user_context: Optional user context ID to scope the permission\n" + "\n" + " Raises:\n" + " ValueError: If the state is not a valid permission state\n" + ' """\n' + " state_value = state.value if isinstance(state, PermissionState) else state\n" + " if state_value not in _VALID_PERMISSION_STATES:\n" + " raise ValueError(\n" + ' f"Invalid permission state: {state_value!r}. "\n' + ' f"Must be one of {sorted(_VALID_PERMISSION_STATES)}"\n' + " )\n" + "\n" + " if isinstance(descriptor, str):\n" + ' descriptor_dict = {"name": descriptor}\n' + " else:\n" + ' descriptor_dict = {"name": descriptor.name}\n' + "\n" + " params: dict[str, Any] = {\n" + ' "descriptor": descriptor_dict,\n' + ' "state": state_value,\n' + " }\n" + " if origin is not None:\n" + ' params["origin"] = origin\n' + " if user_context is not None:\n" + ' params["userContext"] = user_context\n' + "\n" + ' cmd = command_builder("permissions.setPermission", params)\n' + " self._conn.execute(cmd)\n" + ) + + with open(permissions_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {permissions_path}") + + +def main( + cddl_file: str, + output_dir: str, + spec_version: str = "1.0", + enhancements_manifest: Optional[str] = None, +) -> None: + """Main entry point. + + Args: + cddl_file: Path to CDDL specification file + output_dir: Output directory for generated modules + spec_version: BiDi spec version + enhancements_manifest: Path to enhancement manifest Python file + """ + output_path = Path(output_dir).resolve() + output_path.mkdir(parents=True, exist_ok=True) + + logger.info(f"WebDriver BiDi Code Generator v{__version__}") + logger.info(f"Input CDDL: {cddl_file}") + logger.info(f"Output directory: {output_path}") + logger.info(f"Spec version: {spec_version}") + + # Load enhancement manifest + manifest = load_enhancements_manifest(enhancements_manifest) + if manifest: + logger.info(f"Loaded enhancement manifest from: {enhancements_manifest}") + + # Parse CDDL + parser = CddlParser(cddl_file) + modules = parser.parse() + + logger.info(f"Parsed {len(modules)} modules") + + # Clean up existing generated files + for file_path in output_path.glob("*.py"): + if file_path.name != "py.typed" and not file_path.name.startswith("_"): + file_path.unlink() + logger.debug(f"Removed: {file_path}") + + # Generate module files using snake_case filenames + for module_name, module in sorted(modules.items()): + filename = module_name_to_filename(module_name) + module_path = output_path / f"{filename}.py" + + # Get module-specific enhancements (merge with dataclass templates) + module_enhancements = manifest.get("enhancements", {}).get(module_name, {}) + + # Add dataclass methods and docstrings to the enhancement data for this module + full_module_enhancements = { + **module_enhancements, + "dataclass_methods": manifest.get("dataclass_methods", {}), + "method_docstrings": manifest.get("method_docstrings", {}), + } + + with open(module_path, "w", encoding="utf-8") as f: + f.write(module.generate_code(full_module_enhancements)) + logger.info(f"Generated: {module_path}") + + # Generate __init__.py + generate_init_file(output_path, modules) + + # Generate common.py + generate_common_file(output_path) + + # Generate permissions.py + generate_permissions_file(output_path) + + # Generate console.py + generate_console_file(output_path) + + # Create py.typed marker + py_typed_path = output_path / "py.typed" + py_typed_path.touch() + logger.info(f"Generated type marker: {py_typed_path}") + + logger.info("Code generation complete!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate Python WebDriver BiDi modules from CDDL specification" + ) + parser.add_argument( + "cddl_file", + help="Path to CDDL specification file", + ) + parser.add_argument( + "output_dir", + help="Output directory for generated Python modules", + ) + parser.add_argument( + "spec_version", + nargs="?", + default="1.0", + help="BiDi spec version (default: 1.0)", + ) + parser.add_argument( + "--enhancements-manifest", + default=None, + help="Path to enhancement manifest Python file (optional)", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Enable verbose logging", + ) + + args = parser.parse_args() + + if args.verbose: + logging.getLogger("generate_bidi").setLevel(logging.DEBUG) + + try: + main( + args.cddl_file, + args.output_dir, + args.spec_version, + args.enhancements_manifest, + ) + sys.exit(0) + except Exception as e: + logger.error(f"Generation failed: {e}", exc_info=True) + sys.exit(1) diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index adf0a17128af3..5dcce3c25ffeb 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -1351,18 +1351,24 @@ def to_bidi_dict(self) -> dict: params = {"extensionData": extension_data} cmd = command_builder("webExtension.install", params) return self._conn.execute(cmd)''', - ''' def uninstall(self, extension: Any | None = None): + ''' def uninstall(self, extension: str | dict): """Uninstall a web extension. Args: extension: Either the extension ID string returned by ``install``, or the full result dict returned by ``install`` (the ``"extension"`` value is extracted automatically). + + Raises: + ValueError: If extension is not provided or is None. """ if isinstance(extension, dict): extension = extension.get("extension") + + if extension is None: + raise ValueError("extension parameter is required") + params = {"extension": extension} - params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("webExtension.uninstall", params) return self._conn.execute(cmd)''', ], diff --git a/py/selenium/webdriver/common/bidi/__init__.py b/py/selenium/webdriver/common/bidi/__init__.py index bb129d5f6a195..7be7bd4f73856 100644 --- a/py/selenium/webdriver/common/bidi/__init__.py +++ b/py/selenium/webdriver/common/bidi/__init__.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index ff0c2d59b8cf2..7cf9678c9b007 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,10 +6,11 @@ # WebDriver BiDi module: browser from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - +from typing import Any, Dict, List, Optional, Union from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass def transform_download_params( @@ -77,9 +61,17 @@ def validate_download_behavior( raise ValueError("destination_folder should not be provided when allowed=False") +class ClientWindowNamedState: + """ClientWindowNamedState.""" + + FULLSCREEN = "fullscreen" + MAXIMIZED = "maximized" + MINIMIZED = "minimized" + + @dataclass class ClientWindowInfo: - """ClientWindowInfo type definition.""" + """ClientWindowInfo.""" active: bool | None = None client_window: Any | None = None @@ -121,14 +113,14 @@ def get_y(self): @dataclass class UserContextInfo: - """UserContextInfo type definition.""" + """UserContextInfo.""" user_context: Any | None = None @dataclass class CreateUserContextParameters: - """CreateUserContextParameters type definition.""" + """CreateUserContextParameters.""" accept_insecure_certs: bool | None = None proxy: Any | None = None @@ -137,35 +129,35 @@ class CreateUserContextParameters: @dataclass class GetClientWindowsResult: - """GetClientWindowsResult type definition.""" + """GetClientWindowsResult.""" - client_windows: list[Any | None] | None = field(default_factory=list) + client_windows: list[Any | None] | None = None @dataclass class GetUserContextsResult: - """GetUserContextsResult type definition.""" + """GetUserContextsResult.""" - user_contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = None @dataclass class RemoveUserContextParameters: - """RemoveUserContextParameters type definition.""" + """RemoveUserContextParameters.""" user_context: Any | None = None @dataclass class SetClientWindowStateParameters: - """SetClientWindowStateParameters type definition.""" + """SetClientWindowStateParameters.""" client_window: Any | None = None @dataclass class ClientWindowRectState: - """ClientWindowRectState type definition.""" + """ClientWindowRectState.""" state: str = field(default="normal", init=False) width: Any | None = None @@ -176,15 +168,15 @@ class ClientWindowRectState: @dataclass class SetDownloadBehaviorParameters: - """SetDownloadBehaviorParameters type definition.""" + """SetDownloadBehaviorParameters.""" download_behavior: Any | None = None - user_contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = None @dataclass class DownloadBehaviorAllowed: - """DownloadBehaviorAllowed type definition.""" + """DownloadBehaviorAllowed.""" type: str = field(default="allowed", init=False) destination_folder: str | None = None @@ -192,7 +184,7 @@ class DownloadBehaviorAllowed: @dataclass class DownloadBehaviorDenied: - """DownloadBehaviorDenied type definition.""" + """DownloadBehaviorDenied.""" type: str = field(default="denied", init=False) @@ -220,12 +212,7 @@ def close(self): result = self._conn.execute(cmd) return result - def create_user_context( - self, - accept_insecure_certs: bool | None = None, - proxy: Any | None = None, - unhandled_prompt_behavior: Any | None = None, - ): + def create_user_context(self, accept_insecure_certs: bool | None = None, proxy: Any | None = None, unhandled_prompt_behavior: Any | None = None): """Execute browser.createUserContext.""" if proxy and hasattr(proxy, 'to_bidi_dict'): proxy = proxy.to_bidi_dict() @@ -306,6 +293,22 @@ def set_client_window_state(self, client_window: Any | None = None): result = self._conn.execute(cmd) return result + def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): + """Execute browser.setDownloadBehavior.""" + validate_download_behavior(allowed=allowed, destination_folder=destination_folder, user_contexts=user_contexts) + + download_behavior = None + download_behavior = transform_download_params(allowed, destination_folder) + + params = { + "downloadBehavior": download_behavior, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.setDownloadBehavior", params) + result = self._conn.execute(cmd) + return result + def set_download_behavior( self, allowed: bool | None = None, diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 7a0f8faf8687e..35aea615d1780 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,15 +6,16 @@ # WebDriver BiDi module: browsingContext from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - +from dataclasses import dataclass from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class ReadinessState: """ReadinessState.""" @@ -65,7 +49,7 @@ class DownloadCompleteParams: @dataclass class Info: - """Info type definition.""" + """Info.""" children: Any | None = None client_window: Any | None = None @@ -78,7 +62,7 @@ class Info: @dataclass class AccessibilityLocator: - """AccessibilityLocator type definition.""" + """AccessibilityLocator.""" type: str = field(default="accessibility", init=False) name: str | None = None @@ -87,7 +71,7 @@ class AccessibilityLocator: @dataclass class CssLocator: - """CssLocator type definition.""" + """CssLocator.""" type: str = field(default="css", init=False) value: str | None = None @@ -95,7 +79,7 @@ class CssLocator: @dataclass class ContextLocator: - """ContextLocator type definition.""" + """ContextLocator.""" type: str = field(default="context", init=False) context: Any | None = None @@ -103,7 +87,7 @@ class ContextLocator: @dataclass class InnerTextLocator: - """InnerTextLocator type definition.""" + """InnerTextLocator.""" type: str = field(default="innerText", init=False) value: str | None = None @@ -114,7 +98,7 @@ class InnerTextLocator: @dataclass class XPathLocator: - """XPathLocator type definition.""" + """XPathLocator.""" type: str = field(default="xpath", init=False) value: str | None = None @@ -122,7 +106,7 @@ class XPathLocator: @dataclass class BaseNavigationInfo: - """BaseNavigationInfo type definition.""" + """BaseNavigationInfo.""" context: Any | None = None navigation: Any | None = None @@ -132,14 +116,14 @@ class BaseNavigationInfo: @dataclass class ActivateParameters: - """ActivateParameters type definition.""" + """ActivateParameters.""" context: Any | None = None @dataclass class CaptureScreenshotParameters: - """CaptureScreenshotParameters type definition.""" + """CaptureScreenshotParameters.""" context: Any | None = None format: Any | None = None @@ -148,7 +132,7 @@ class CaptureScreenshotParameters: @dataclass class ImageFormat: - """ImageFormat type definition.""" + """ImageFormat.""" type: str | None = None quality: Any | None = None @@ -156,7 +140,7 @@ class ImageFormat: @dataclass class ElementClipRectangle: - """ElementClipRectangle type definition.""" + """ElementClipRectangle.""" type: str = field(default="element", init=False) element: Any | None = None @@ -164,7 +148,7 @@ class ElementClipRectangle: @dataclass class BoxClipRectangle: - """BoxClipRectangle type definition.""" + """BoxClipRectangle.""" type: str = field(default="box", init=False) x: Any | None = None @@ -175,14 +159,14 @@ class BoxClipRectangle: @dataclass class CaptureScreenshotResult: - """CaptureScreenshotResult type definition.""" + """CaptureScreenshotResult.""" data: str | None = None @dataclass class CloseParameters: - """CloseParameters type definition.""" + """CloseParameters.""" context: Any | None = None prompt_unload: bool | None = None @@ -190,7 +174,7 @@ class CloseParameters: @dataclass class CreateParameters: - """CreateParameters type definition.""" + """CreateParameters.""" type: Any | None = None reference_context: Any | None = None @@ -200,14 +184,14 @@ class CreateParameters: @dataclass class CreateResult: - """CreateResult type definition.""" + """CreateResult.""" context: Any | None = None @dataclass class GetTreeParameters: - """GetTreeParameters type definition.""" + """GetTreeParameters.""" max_depth: Any | None = None root: Any | None = None @@ -215,14 +199,14 @@ class GetTreeParameters: @dataclass class GetTreeResult: - """GetTreeResult type definition.""" + """GetTreeResult.""" contexts: Any | None = None @dataclass class HandleUserPromptParameters: - """HandleUserPromptParameters type definition.""" + """HandleUserPromptParameters.""" context: Any | None = None accept: bool | None = None @@ -231,24 +215,24 @@ class HandleUserPromptParameters: @dataclass class LocateNodesParameters: - """LocateNodesParameters type definition.""" + """LocateNodesParameters.""" context: Any | None = None locator: Any | None = None serialization_options: Any | None = None - start_nodes: list[Any | None] | None = field(default_factory=list) + start_nodes: list[Any | None] | None = None @dataclass class LocateNodesResult: - """LocateNodesResult type definition.""" + """LocateNodesResult.""" - nodes: list[Any | None] | None = field(default_factory=list) + nodes: list[Any | None] | None = None @dataclass class NavigateParameters: - """NavigateParameters type definition.""" + """NavigateParameters.""" context: Any | None = None url: str | None = None @@ -257,7 +241,7 @@ class NavigateParameters: @dataclass class NavigateResult: - """NavigateResult type definition.""" + """NavigateResult.""" navigation: Any | None = None url: str | None = None @@ -265,7 +249,7 @@ class NavigateResult: @dataclass class PrintParameters: - """PrintParameters type definition.""" + """PrintParameters.""" context: Any | None = None background: bool | None = None @@ -277,7 +261,7 @@ class PrintParameters: @dataclass class PrintMarginParameters: - """PrintMarginParameters type definition.""" + """PrintMarginParameters.""" bottom: Any | None = None left: Any | None = None @@ -287,7 +271,7 @@ class PrintMarginParameters: @dataclass class PrintPageParameters: - """PrintPageParameters type definition.""" + """PrintPageParameters.""" height: Any | None = None width: Any | None = None @@ -295,14 +279,14 @@ class PrintPageParameters: @dataclass class PrintResult: - """PrintResult type definition.""" + """PrintResult.""" data: str | None = None @dataclass class ReloadParameters: - """ReloadParameters type definition.""" + """ReloadParameters.""" context: Any | None = None ignore_cache: bool | None = None @@ -311,17 +295,17 @@ class ReloadParameters: @dataclass class SetViewportParameters: - """SetViewportParameters type definition.""" + """SetViewportParameters.""" context: Any | None = None viewport: Any | None = None device_pixel_ratio: Any | None = None - user_contexts: list[Any | None] | None = field(default_factory=list) + user_contexts: list[Any | None] | None = None @dataclass class Viewport: - """Viewport type definition.""" + """Viewport.""" width: Any | None = None height: Any | None = None @@ -329,7 +313,7 @@ class Viewport: @dataclass class TraverseHistoryParameters: - """TraverseHistoryParameters type definition.""" + """TraverseHistoryParameters.""" context: Any | None = None delta: Any | None = None @@ -337,16 +321,30 @@ class TraverseHistoryParameters: @dataclass class HistoryUpdatedParameters: - """HistoryUpdatedParameters type definition.""" + """HistoryUpdatedParameters.""" context: Any | None = None timestamp: Any | None = None url: str | None = None +@dataclass +class DownloadWillBeginParams: + """DownloadWillBeginParams.""" + + suggested_filename: str | None = None + + +@dataclass +class DownloadCanceledParams: + """DownloadCanceledParams.""" + + status: str = field(default="canceled", init=False) + + @dataclass class UserPromptClosedParameters: - """UserPromptClosedParameters type definition.""" + """UserPromptClosedParameters.""" context: Any | None = None accepted: bool | None = None @@ -356,7 +354,7 @@ class UserPromptClosedParameters: @dataclass class UserPromptOpenedParameters: - """UserPromptOpenedParameters type definition.""" + """UserPromptOpenedParameters.""" context: Any | None = None handler: Any | None = None @@ -392,10 +390,10 @@ class DownloadParams: class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" - download_params: DownloadParams | None = None + download_params: "DownloadParams | None" = None @classmethod - def from_json(cls, params: dict) -> DownloadEndParams: + def from_json(cls, params: dict) -> "DownloadEndParams": """Deserialize from BiDi wire-level params dict.""" dp = DownloadParams( status=params.get("status"), @@ -416,6 +414,8 @@ def from_json(cls, params: dict) -> DownloadEndParams: "history_updated": "browsingContext.historyUpdated", "dom_content_loaded": "browsingContext.domContentLoaded", "load": "browsingContext.load", + "download_will_begin": "browsingContext.downloadWillBegin", + "download_end": "browsingContext.downloadEnd", "navigation_aborted": "browsingContext.navigationAborted", "navigation_committed": "browsingContext.navigationCommitted", "navigation_failed": "browsingContext.navigationFailed", @@ -630,13 +630,7 @@ def activate(self, context: Any | None = None): result = self._conn.execute(cmd) return result - def capture_screenshot( - self, - context: str | None = None, - format: Any | None = None, - clip: Any | None = None, - origin: str | None = None, - ): + def capture_screenshot(self, context: str | None = None, format: Any | None = None, clip: Any | None = None, origin: str | None = None): """Execute browsingContext.captureScreenshot.""" params = { "context": context, @@ -663,13 +657,7 @@ def close(self, context: Any | None = None, prompt_unload: bool | None = None): result = self._conn.execute(cmd) return result - def create( - self, - type: Any | None = None, - reference_context: Any | None = None, - background: bool | None = None, - user_context: Any | None = None, - ): + def create(self, type: Any | None = None, reference_context: Any | None = None, background: bool | None = None, user_context: Any | None = None): """Execute browsingContext.create.""" params = { "type": type, @@ -723,14 +711,7 @@ def handle_user_prompt(self, context: Any | None = None, accept: bool | None = N result = self._conn.execute(cmd) return result - def locate_nodes( - self, - context: str | None = None, - locator: Any | None = None, - serialization_options: Any | None = None, - start_nodes: Any | None = None, - max_node_count: int | None = None, - ): + def locate_nodes(self, context: str | None = None, locator: Any | None = None, serialization_options: Any | None = None, start_nodes: Any | None = None, max_node_count: int | None = None): """Execute browsingContext.locateNodes.""" params = { "context": context, @@ -759,15 +740,7 @@ def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any result = self._conn.execute(cmd) return result - def print( - self, - context: Any | None = None, - background: bool | None = None, - margin: Any | None = None, - page: Any | None = None, - scale: Any | None = None, - shrink_to_fit: bool | None = None, - ): + def print(self, context: Any | None = None, background: bool | None = None, margin: Any | None = None, page: Any | None = None, scale: Any | None = None, shrink_to_fit: bool | None = None): """Execute browsingContext.print.""" params = { "context": context, @@ -797,13 +770,7 @@ def reload(self, context: Any | None = None, ignore_cache: bool | None = None, w result = self._conn.execute(cmd) return result - def set_viewport( - self, - context: str | None = None, - viewport: Any | None = None, - user_contexts: Any | None = None, - device_pixel_ratio: Any | None = None, - ): + def set_viewport(self, context: str | None = None, viewport: Any | None = None, user_contexts: Any | None = None, device_pixel_ratio: Any | None = None): """Execute browsingContext.setViewport.""" params = { "context": context, @@ -901,81 +868,20 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() BrowsingContext.EVENT_CONFIGS = { - "context_created": ( - EventConfig("context_created", "browsingContext.contextCreated", - _globals.get("ContextCreated", dict)) - if _globals.get("ContextCreated") - else EventConfig("context_created", "browsingContext.contextCreated", dict) - ), - "context_destroyed": ( - EventConfig("context_destroyed", "browsingContext.contextDestroyed", - _globals.get("ContextDestroyed", dict)) - if _globals.get("ContextDestroyed") - else EventConfig("context_destroyed", "browsingContext.contextDestroyed", dict) - ), - "navigation_started": ( - EventConfig("navigation_started", "browsingContext.navigationStarted", - _globals.get("NavigationStarted", dict)) - if _globals.get("NavigationStarted") - else EventConfig("navigation_started", "browsingContext.navigationStarted", dict) - ), - "fragment_navigated": ( - EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", - _globals.get("FragmentNavigated", dict)) - if _globals.get("FragmentNavigated") - else EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", dict) - ), - "history_updated": ( - EventConfig("history_updated", "browsingContext.historyUpdated", - _globals.get("HistoryUpdated", dict)) - if _globals.get("HistoryUpdated") - else EventConfig("history_updated", "browsingContext.historyUpdated", dict) - ), - "dom_content_loaded": ( - EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", - _globals.get("DomContentLoaded", dict)) - if _globals.get("DomContentLoaded") - else EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", dict) - ), - "load": ( - EventConfig("load", "browsingContext.load", - _globals.get("Load", dict)) - if _globals.get("Load") - else EventConfig("load", "browsingContext.load", dict) - ), - "navigation_aborted": ( - EventConfig("navigation_aborted", "browsingContext.navigationAborted", - _globals.get("NavigationAborted", dict)) - if _globals.get("NavigationAborted") - else EventConfig("navigation_aborted", "browsingContext.navigationAborted", dict) - ), - "navigation_committed": ( - EventConfig("navigation_committed", "browsingContext.navigationCommitted", - _globals.get("NavigationCommitted", dict)) - if _globals.get("NavigationCommitted") - else EventConfig("navigation_committed", "browsingContext.navigationCommitted", dict) - ), - "navigation_failed": ( - EventConfig("navigation_failed", "browsingContext.navigationFailed", - _globals.get("NavigationFailed", dict)) - if _globals.get("NavigationFailed") - else EventConfig("navigation_failed", "browsingContext.navigationFailed", dict) - ), - "user_prompt_closed": ( - EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", - _globals.get("UserPromptClosed", dict)) - if _globals.get("UserPromptClosed") - else EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", dict) - ), - "user_prompt_opened": ( - EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", - _globals.get("UserPromptOpened", dict)) - if _globals.get("UserPromptOpened") - else EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", dict) - ), - "download_will_begin": EventConfig( - "download_will_begin", "browsingContext.downloadWillBegin", - _globals.get("DownloadWillBeginParams", dict), - ), + "context_created": (EventConfig("context_created", "browsingContext.contextCreated", _globals.get("ContextCreated", dict)) if _globals.get("ContextCreated") else EventConfig("context_created", "browsingContext.contextCreated", dict)), + "context_destroyed": (EventConfig("context_destroyed", "browsingContext.contextDestroyed", _globals.get("ContextDestroyed", dict)) if _globals.get("ContextDestroyed") else EventConfig("context_destroyed", "browsingContext.contextDestroyed", dict)), + "navigation_started": (EventConfig("navigation_started", "browsingContext.navigationStarted", _globals.get("NavigationStarted", dict)) if _globals.get("NavigationStarted") else EventConfig("navigation_started", "browsingContext.navigationStarted", dict)), + "fragment_navigated": (EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", _globals.get("FragmentNavigated", dict)) if _globals.get("FragmentNavigated") else EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", dict)), + "history_updated": (EventConfig("history_updated", "browsingContext.historyUpdated", _globals.get("HistoryUpdated", dict)) if _globals.get("HistoryUpdated") else EventConfig("history_updated", "browsingContext.historyUpdated", dict)), + "dom_content_loaded": (EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", _globals.get("DomContentLoaded", dict)) if _globals.get("DomContentLoaded") else EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", dict)), + "load": (EventConfig("load", "browsingContext.load", _globals.get("Load", dict)) if _globals.get("Load") else EventConfig("load", "browsingContext.load", dict)), + "download_will_begin": (EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBegin", dict)) if _globals.get("DownloadWillBegin") else EventConfig("download_will_begin", "browsingContext.downloadWillBegin", dict)), + "download_end": (EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEnd", dict)) if _globals.get("DownloadEnd") else EventConfig("download_end", "browsingContext.downloadEnd", dict)), + "navigation_aborted": (EventConfig("navigation_aborted", "browsingContext.navigationAborted", _globals.get("NavigationAborted", dict)) if _globals.get("NavigationAborted") else EventConfig("navigation_aborted", "browsingContext.navigationAborted", dict)), + "navigation_committed": (EventConfig("navigation_committed", "browsingContext.navigationCommitted", _globals.get("NavigationCommitted", dict)) if _globals.get("NavigationCommitted") else EventConfig("navigation_committed", "browsingContext.navigationCommitted", dict)), + "navigation_failed": (EventConfig("navigation_failed", "browsingContext.navigationFailed", _globals.get("NavigationFailed", dict)) if _globals.get("NavigationFailed") else EventConfig("navigation_failed", "browsingContext.navigationFailed", dict)), + "user_prompt_closed": (EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", _globals.get("UserPromptClosed", dict)) if _globals.get("UserPromptClosed") else EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", dict)), + "user_prompt_opened": (EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", _globals.get("UserPromptOpened", dict)) if _globals.get("UserPromptOpened") else EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", dict)), + "download_will_begin": EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBeginParams", dict)), "download_end": EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEndParams", dict)), } diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index dae051876833e..d90d8c770263a 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -17,15 +17,12 @@ """Common utilities for BiDi command construction.""" -from __future__ import annotations - -from collections.abc import Generator -from typing import Any +from typing import Any, Dict, Generator def command_builder( - method: str, params: dict[str, Any] | None = None -) -> Generator[dict[str, Any], Any, Any]: + method: str, params: Dict[str, Any] +) -> Generator[Dict[str, Any], Any, Any]: """Build a BiDi command generator. Args: @@ -38,7 +35,5 @@ def command_builder( Returns: The result from the BiDi command execution """ - if params is None: - params = {} result = yield {"method": method, "params": params} return result diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index c58f6d5f78d6c..a85eaad3e223a 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,10 +6,11 @@ # WebDriver BiDi module: emulation from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - +from typing import Any, Dict, List, Optional, Union from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass class ForcedColorsModeTheme: @@ -54,24 +38,24 @@ class ScreenOrientationType: @dataclass class SetForcedColorsModeThemeOverrideParameters: - """SetForcedColorsModeThemeOverrideParameters type definition.""" + """SetForcedColorsModeThemeOverrideParameters.""" theme: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetGeolocationOverrideParameters: - """SetGeolocationOverrideParameters type definition.""" + """SetGeolocationOverrideParameters.""" - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class GeolocationCoordinates: - """GeolocationCoordinates type definition.""" + """GeolocationCoordinates.""" latitude: Any | None = None longitude: Any | None = None @@ -84,39 +68,39 @@ class GeolocationCoordinates: @dataclass class GeolocationPositionError: - """GeolocationPositionError type definition.""" + """GeolocationPositionError.""" type: str = field(default="positionUnavailable", init=False) @dataclass class SetLocaleOverrideParameters: - """SetLocaleOverrideParameters type definition.""" + """SetLocaleOverrideParameters.""" locale: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass -class SetNetworkConditionsParameters: - """SetNetworkConditionsParameters type definition.""" +class setNetworkConditionsParameters: + """setNetworkConditionsParameters.""" network_conditions: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class NetworkConditionsOffline: - """NetworkConditionsOffline type definition.""" + """NetworkConditionsOffline.""" type: str = field(default="offline", init=False) @dataclass class ScreenArea: - """ScreenArea type definition.""" + """ScreenArea.""" width: Any | None = None height: Any | None = None @@ -124,16 +108,16 @@ class ScreenArea: @dataclass class SetScreenSettingsOverrideParameters: - """SetScreenSettingsOverrideParameters type definition.""" + """SetScreenSettingsOverrideParameters.""" screen_area: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class ScreenOrientation: - """ScreenOrientation type definition.""" + """ScreenOrientation.""" natural: Any | None = None type: Any | None = None @@ -141,64 +125,64 @@ class ScreenOrientation: @dataclass class SetScreenOrientationOverrideParameters: - """SetScreenOrientationOverrideParameters type definition.""" + """SetScreenOrientationOverrideParameters.""" screen_orientation: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetUserAgentOverrideParameters: - """SetUserAgentOverrideParameters type definition.""" + """SetUserAgentOverrideParameters.""" user_agent: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetViewportMetaOverrideParameters: - """SetViewportMetaOverrideParameters type definition.""" + """SetViewportMetaOverrideParameters.""" viewport_meta: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetScriptingEnabledParameters: - """SetScriptingEnabledParameters type definition.""" + """SetScriptingEnabledParameters.""" enabled: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetScrollbarTypeOverrideParameters: - """SetScrollbarTypeOverrideParameters type definition.""" + """SetScrollbarTypeOverrideParameters.""" scrollbar_type: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetTimezoneOverrideParameters: - """SetTimezoneOverrideParameters type definition.""" + """SetTimezoneOverrideParameters.""" timezone: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class SetTouchOverrideParameters: - """SetTouchOverrideParameters type definition.""" + """SetTouchOverrideParameters.""" - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None class Emulation: @@ -207,12 +191,7 @@ class Emulation: def __init__(self, conn) -> None: self._conn = conn - def set_forced_colors_mode_theme_override( - self, - theme: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def set_forced_colors_mode_theme_override(self, theme: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setForcedColorsModeThemeOverride.""" params = { "theme": theme, @@ -224,12 +203,18 @@ def set_forced_colors_mode_theme_override( result = self._conn.execute(cmd) return result - def set_locale_override( - self, - locale: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def set_geolocation_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setGeolocationOverride.""" + params = { + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setGeolocationOverride", params) + result = self._conn.execute(cmd) + return result + + def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setLocaleOverride.""" params = { "locale": locale, @@ -241,12 +226,19 @@ def set_locale_override( result = self._conn.execute(cmd) return result - def set_screen_settings_override( - self, - screen_area: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def set_network_conditions(self, network_conditions: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setNetworkConditions.""" + params = { + "networkConditions": network_conditions, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setNetworkConditions", params) + result = self._conn.execute(cmd) + return result + + def set_screen_settings_override(self, screen_area: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setScreenSettingsOverride.""" params = { "screenArea": screen_area, @@ -258,12 +250,31 @@ def set_screen_settings_override( result = self._conn.execute(cmd) return result - def set_viewport_meta_override( - self, - viewport_meta: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def set_screen_orientation_override(self, screen_orientation: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setScreenOrientationOverride.""" + params = { + "screenOrientation": screen_orientation, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setScreenOrientationOverride", params) + result = self._conn.execute(cmd) + return result + + def set_user_agent_override(self, user_agent: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setUserAgentOverride.""" + params = { + "userAgent": user_agent, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setUserAgentOverride", params) + result = self._conn.execute(cmd) + return result + + def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setViewportMetaOverride.""" params = { "viewportMeta": viewport_meta, @@ -275,12 +286,19 @@ def set_viewport_meta_override( result = self._conn.execute(cmd) return result - def set_scrollbar_type_override( - self, - scrollbar_type: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def set_scripting_enabled(self, enabled: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setScriptingEnabled.""" + params = { + "enabled": enabled, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setScriptingEnabled", params) + result = self._conn.execute(cmd) + return result + + def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setScrollbarTypeOverride.""" params = { "scrollbarType": scrollbar_type, @@ -292,7 +310,19 @@ def set_scrollbar_type_override( result = self._conn.execute(cmd) return result - def set_touch_override(self, contexts: list[Any] | None = None, user_contexts: list[Any] | None = None): + def set_timezone_override(self, timezone: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + """Execute emulation.setTimezoneOverride.""" + params = { + "timezone": timezone, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setTimezoneOverride", params) + result = self._conn.execute(cmd) + return result + + def set_touch_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setTouchOverride.""" params = { "contexts": contexts, diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index e9c3f8345f05d..5dbe71dbd3886 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,15 +6,16 @@ # WebDriver BiDi module: input from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - +from dataclasses import dataclass from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class PointerType: """PointerType.""" @@ -50,7 +34,7 @@ class Origin: @dataclass class ElementOrigin: - """ElementOrigin type definition.""" + """ElementOrigin.""" type: str = field(default="element", init=False) element: Any | None = None @@ -58,59 +42,59 @@ class ElementOrigin: @dataclass class PerformActionsParameters: - """PerformActionsParameters type definition.""" + """PerformActionsParameters.""" context: Any | None = None - actions: list[Any | None] | None = field(default_factory=list) + actions: list[Any | None] | None = None @dataclass class NoneSourceActions: - """NoneSourceActions type definition.""" + """NoneSourceActions.""" type: str = field(default="none", init=False) id: str | None = None - actions: list[Any | None] | None = field(default_factory=list) + actions: list[Any | None] | None = None @dataclass class KeySourceActions: - """KeySourceActions type definition.""" + """KeySourceActions.""" type: str = field(default="key", init=False) id: str | None = None - actions: list[Any | None] | None = field(default_factory=list) + actions: list[Any | None] | None = None @dataclass class PointerSourceActions: - """PointerSourceActions type definition.""" + """PointerSourceActions.""" type: str = field(default="pointer", init=False) id: str | None = None parameters: Any | None = None - actions: list[Any | None] | None = field(default_factory=list) + actions: list[Any | None] | None = None @dataclass class PointerParameters: - """PointerParameters type definition.""" + """PointerParameters.""" pointer_type: Any | None = None @dataclass class WheelSourceActions: - """WheelSourceActions type definition.""" + """WheelSourceActions.""" type: str = field(default="wheel", init=False) id: str | None = None - actions: list[Any | None] | None = field(default_factory=list) + actions: list[Any | None] | None = None @dataclass class PauseAction: - """PauseAction type definition.""" + """PauseAction.""" type: str = field(default="pause", init=False) duration: Any | None = None @@ -118,7 +102,7 @@ class PauseAction: @dataclass class KeyDownAction: - """KeyDownAction type definition.""" + """KeyDownAction.""" type: str = field(default="keyDown", init=False) value: str | None = None @@ -126,7 +110,7 @@ class KeyDownAction: @dataclass class KeyUpAction: - """KeyUpAction type definition.""" + """KeyUpAction.""" type: str = field(default="keyUp", init=False) value: str | None = None @@ -134,7 +118,7 @@ class KeyUpAction: @dataclass class PointerUpAction: - """PointerUpAction type definition.""" + """PointerUpAction.""" type: str = field(default="pointerUp", init=False) button: Any | None = None @@ -142,7 +126,7 @@ class PointerUpAction: @dataclass class WheelScrollAction: - """WheelScrollAction type definition.""" + """WheelScrollAction.""" type: str = field(default="scroll", init=False) x: Any | None = None @@ -155,7 +139,7 @@ class WheelScrollAction: @dataclass class PointerCommonProperties: - """PointerCommonProperties type definition.""" + """PointerCommonProperties.""" width: Any | None = None height: Any | None = None @@ -168,18 +152,18 @@ class PointerCommonProperties: @dataclass class ReleaseActionsParameters: - """ReleaseActionsParameters type definition.""" + """ReleaseActionsParameters.""" context: Any | None = None @dataclass class SetFilesParameters: - """SetFilesParameters type definition.""" + """SetFilesParameters.""" context: Any | None = None element: Any | None = None - files: list[Any | None] | None = field(default_factory=list) + files: list[Any | None] | None = None @dataclass @@ -191,7 +175,7 @@ class FileDialogInfo: multiple: bool | None = None @classmethod - def from_json(cls, params: dict) -> FileDialogInfo: + def from_json(cls, params: dict) -> "FileDialogInfo": """Deserialize event params into FileDialogInfo.""" return cls( context=params.get("context"), @@ -384,7 +368,7 @@ def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def perform_actions(self, context: Any | None = None, actions: list[Any] | None = None): + def perform_actions(self, context: Any | None = None, actions: List[Any] | None = None): """Execute input.performActions.""" params = { "context": context, @@ -405,7 +389,7 @@ def release_actions(self, context: Any | None = None): result = self._conn.execute(cmd) return result - def set_files(self, context: Any | None = None, element: Any | None = None, files: list[Any] | None = None): + def set_files(self, context: Any | None = None, element: Any | None = None, files: List[Any] | None = None): """Execute input.setFiles.""" params = { "context": context, @@ -470,10 +454,5 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Input.EVENT_CONFIGS = { - "file_dialog_opened": ( - EventConfig("file_dialog_opened", "input.fileDialogOpened", - _globals.get("FileDialogOpened", dict)) - if _globals.get("FileDialogOpened") - else EventConfig("file_dialog_opened", "input.fileDialogOpened", dict) - ), + "file_dialog_opened": (EventConfig("file_dialog_opened", "input.fileDialogOpened", _globals.get("FileDialogOpened", dict)) if _globals.get("FileDialogOpened") else EventConfig("file_dialog_opened", "input.fileDialogOpened", dict)), } diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 94f511d7185f8..7aa7fbf7a3171 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,11 +6,14 @@ # WebDriver BiDi module: log from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass import threading from collections.abc import Callable from dataclasses import dataclass -from typing import Any - from selenium.webdriver.common.bidi.session import Session @@ -44,7 +30,7 @@ class Level: @dataclass class BaseLogEntry: - """BaseLogEntry type definition.""" + """BaseLogEntry.""" level: Any | None = None source: Any | None = None @@ -55,7 +41,7 @@ class BaseLogEntry: @dataclass class GenericLogEntry: - """GenericLogEntry type definition.""" + """GenericLogEntry.""" type: str | None = None @@ -74,7 +60,7 @@ class ConsoleLogEntry: stack_trace: Any | None = None @classmethod - def from_json(cls, params: dict) -> ConsoleLogEntry: + def from_json(cls, params: dict) -> "ConsoleLogEntry": """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -99,7 +85,7 @@ class JavascriptLogEntry: stacktrace: Any | None = None @classmethod - def from_json(cls, params: dict) -> JavascriptLogEntry: + def from_json(cls, params: dict) -> "JavascriptLogEntry": """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -312,10 +298,5 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Log.EVENT_CONFIGS = { - "entry_added": ( - EventConfig("entry_added", "log.entryAdded", - _globals.get("EntryAdded", dict)) - if _globals.get("EntryAdded") - else EventConfig("entry_added", "log.entryAdded", dict) - ), + "entry_added": (EventConfig("entry_added", "log.entryAdded", _globals.get("EntryAdded", dict)) if _globals.get("EntryAdded") else EventConfig("entry_added", "log.entryAdded", dict)), } diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 9dc5fb94d8488..2290c9fec12d3 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,15 +6,16 @@ # WebDriver BiDi module: network from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - +from dataclasses import dataclass from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class SameSite: """SameSite.""" @@ -66,7 +50,7 @@ class ContinueWithAuthNoCredentials: @dataclass class AuthChallenge: - """AuthChallenge type definition.""" + """AuthChallenge.""" scheme: str | None = None realm: str | None = None @@ -74,7 +58,7 @@ class AuthChallenge: @dataclass class AuthCredentials: - """AuthCredentials type definition.""" + """AuthCredentials.""" type: str = field(default="password", init=False) username: str | None = None @@ -83,7 +67,7 @@ class AuthCredentials: @dataclass class BaseParameters: - """BaseParameters type definition.""" + """BaseParameters.""" context: Any | None = None is_blocked: bool | None = None @@ -91,12 +75,12 @@ class BaseParameters: redirect_count: Any | None = None request: Any | None = None timestamp: Any | None = None - intercepts: list[Any | None] | None = field(default_factory=list) + intercepts: list[Any | None] | None = None @dataclass class StringValue: - """StringValue type definition.""" + """StringValue.""" type: str = field(default="string", init=False) value: str | None = None @@ -104,7 +88,7 @@ class StringValue: @dataclass class Base64Value: - """Base64Value type definition.""" + """Base64Value.""" type: str = field(default="base64", init=False) value: str | None = None @@ -112,7 +96,7 @@ class Base64Value: @dataclass class Cookie: - """Cookie type definition.""" + """Cookie.""" name: str | None = None value: Any | None = None @@ -127,7 +111,7 @@ class Cookie: @dataclass class CookieHeader: - """CookieHeader type definition.""" + """CookieHeader.""" name: str | None = None value: Any | None = None @@ -135,7 +119,7 @@ class CookieHeader: @dataclass class FetchTimingInfo: - """FetchTimingInfo type definition.""" + """FetchTimingInfo.""" time_origin: Any | None = None request_time: Any | None = None @@ -154,7 +138,7 @@ class FetchTimingInfo: @dataclass class Header: - """Header type definition.""" + """Header.""" name: str | None = None value: Any | None = None @@ -162,7 +146,7 @@ class Header: @dataclass class Initiator: - """Initiator type definition.""" + """Initiator.""" column_number: Any | None = None line_number: Any | None = None @@ -173,32 +157,32 @@ class Initiator: @dataclass class ResponseContent: - """ResponseContent type definition.""" + """ResponseContent.""" size: Any | None = None @dataclass class ResponseData: - """ResponseData type definition.""" + """ResponseData.""" url: str | None = None protocol: str | None = None status: Any | None = None status_text: str | None = None from_cache: bool | None = None - headers: list[Any | None] | None = field(default_factory=list) + headers: list[Any | None] | None = None mime_type: str | None = None bytes_received: Any | None = None headers_size: Any | None = None body_size: Any | None = None content: Any | None = None - auth_challenges: list[Any | None] | None = field(default_factory=list) + auth_challenges: list[Any | None] | None = None @dataclass class SetCookieHeader: - """SetCookieHeader type definition.""" + """SetCookieHeader.""" name: str | None = None value: Any | None = None @@ -213,7 +197,7 @@ class SetCookieHeader: @dataclass class UrlPatternPattern: - """UrlPatternPattern type definition.""" + """UrlPatternPattern.""" type: str = field(default="pattern", init=False) protocol: str | None = None @@ -225,7 +209,7 @@ class UrlPatternPattern: @dataclass class UrlPatternString: - """UrlPatternString type definition.""" + """UrlPatternString.""" type: str = field(default="string", init=False) pattern: str | None = None @@ -233,68 +217,68 @@ class UrlPatternString: @dataclass class AddDataCollectorParameters: - """AddDataCollectorParameters type definition.""" + """AddDataCollectorParameters.""" - data_types: list[Any | None] | None = field(default_factory=list) + data_types: list[Any | None] | None = None max_encoded_data_size: Any | None = None collector_type: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class AddDataCollectorResult: - """AddDataCollectorResult type definition.""" + """AddDataCollectorResult.""" collector: Any | None = None @dataclass class AddInterceptParameters: - """AddInterceptParameters type definition.""" + """AddInterceptParameters.""" - phases: list[Any | None] | None = field(default_factory=list) - contexts: list[Any | None] | None = field(default_factory=list) - url_patterns: list[Any | None] | None = field(default_factory=list) + phases: list[Any | None] | None = None + contexts: list[Any | None] | None = None + url_patterns: list[Any | None] | None = None @dataclass class AddInterceptResult: - """AddInterceptResult type definition.""" + """AddInterceptResult.""" intercept: Any | None = None @dataclass class ContinueResponseParameters: - """ContinueResponseParameters type definition.""" + """ContinueResponseParameters.""" request: Any | None = None - cookies: list[Any | None] | None = field(default_factory=list) + cookies: list[Any | None] | None = None credentials: Any | None = None - headers: list[Any | None] | None = field(default_factory=list) + headers: list[Any | None] | None = None reason_phrase: str | None = None status_code: Any | None = None @dataclass class ContinueWithAuthParameters: - """ContinueWithAuthParameters type definition.""" + """ContinueWithAuthParameters.""" request: Any | None = None @dataclass class ContinueWithAuthCredentials: - """ContinueWithAuthCredentials type definition.""" + """ContinueWithAuthCredentials.""" action: str = field(default="provideCredentials", init=False) credentials: Any | None = None @dataclass -class DisownDataParameters: - """DisownDataParameters type definition.""" +class disownDataParameters: + """disownDataParameters.""" data_type: Any | None = None collector: Any | None = None @@ -303,14 +287,14 @@ class DisownDataParameters: @dataclass class FailRequestParameters: - """FailRequestParameters type definition.""" + """FailRequestParameters.""" request: Any | None = None @dataclass class GetDataParameters: - """GetDataParameters type definition.""" + """GetDataParameters.""" data_type: Any | None = None collector: Any | None = None @@ -320,85 +304,57 @@ class GetDataParameters: @dataclass class GetDataResult: - """GetDataResult type definition.""" + """GetDataResult.""" bytes: Any | None = None @dataclass class ProvideResponseParameters: - """ProvideResponseParameters type definition.""" + """ProvideResponseParameters.""" request: Any | None = None body: Any | None = None - cookies: list[Any | None] | None = field(default_factory=list) - headers: list[Any | None] | None = field(default_factory=list) + cookies: list[Any | None] | None = None + headers: list[Any | None] | None = None reason_phrase: str | None = None status_code: Any | None = None @dataclass class RemoveDataCollectorParameters: - """RemoveDataCollectorParameters type definition.""" + """RemoveDataCollectorParameters.""" collector: Any | None = None @dataclass class RemoveInterceptParameters: - """RemoveInterceptParameters type definition.""" + """RemoveInterceptParameters.""" intercept: Any | None = None @dataclass class SetCacheBehaviorParameters: - """SetCacheBehaviorParameters type definition.""" + """SetCacheBehaviorParameters.""" cache_behavior: Any | None = None - contexts: list[Any | None] | None = field(default_factory=list) + contexts: list[Any | None] | None = None @dataclass class SetExtraHeadersParameters: - """SetExtraHeadersParameters type definition.""" - - headers: list[Any | None] | None = field(default_factory=list) - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) - - -@dataclass -class AuthRequiredParameters: - """AuthRequiredParameters type definition.""" - - response: Any | None = None - - -@dataclass -class BeforeRequestSentParameters: - """BeforeRequestSentParameters type definition.""" - - initiator: Any | None = None - - -@dataclass -class FetchErrorParameters: - """FetchErrorParameters type definition.""" - - error_text: str | None = None + """SetExtraHeadersParameters.""" - -@dataclass -class ResponseCompletedParameters: - """ResponseCompletedParameters type definition.""" - - response: Any | None = None + headers: list[Any | None] | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class ResponseStartedParameters: - """ResponseStartedParameters type definition.""" + """ResponseStartedParameters.""" response: Any | None = None @@ -441,10 +397,6 @@ def continue_request(self, **kwargs): # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "auth_required": "network.authRequired", - "before_request_sent": "network.beforeRequestSent", - "fetch_error": "network.fetchError", - "response_completed": "network.responseCompleted", - "response_started": "network.responseStarted", "before_request": "network.beforeRequestSent", } @@ -611,14 +563,7 @@ def __init__(self, conn) -> None: self.intercepts = [] self._handler_intercepts: dict = {} - def add_data_collector( - self, - data_types: list[Any] | None = None, - max_encoded_data_size: Any | None = None, - collector_type: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_data_size: Any | None = None, collector_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute network.addDataCollector.""" params = { "dataTypes": data_types, @@ -632,12 +577,7 @@ def add_data_collector( result = self._conn.execute(cmd) return result - def add_intercept( - self, - phases: list[Any] | None = None, - contexts: list[Any] | None = None, - url_patterns: list[Any] | None = None, - ): + def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | None = None, url_patterns: List[Any] | None = None): """Execute network.addIntercept.""" params = { "phases": phases, @@ -649,15 +589,7 @@ def add_intercept( result = self._conn.execute(cmd) return result - def continue_request( - self, - request: Any | None = None, - body: Any | None = None, - cookies: list[Any] | None = None, - headers: list[Any] | None = None, - method: Any | None = None, - url: Any | None = None, - ): + def continue_request(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, method: Any | None = None, url: Any | None = None): """Execute network.continueRequest.""" params = { "request": request, @@ -672,15 +604,7 @@ def continue_request( result = self._conn.execute(cmd) return result - def continue_response( - self, - request: Any | None = None, - cookies: list[Any] | None = None, - credentials: Any | None = None, - headers: list[Any] | None = None, - reason_phrase: Any | None = None, - status_code: Any | None = None, - ): + def continue_response(self, request: Any | None = None, cookies: List[Any] | None = None, credentials: Any | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): """Execute network.continueResponse.""" params = { "request": request, @@ -727,13 +651,7 @@ def fail_request(self, request: Any | None = None): result = self._conn.execute(cmd) return result - def get_data( - self, - data_type: Any | None = None, - collector: Any | None = None, - disown: bool | None = None, - request: Any | None = None, - ): + def get_data(self, data_type: Any | None = None, collector: Any | None = None, disown: bool | None = None, request: Any | None = None): """Execute network.getData.""" params = { "dataType": data_type, @@ -746,15 +664,7 @@ def get_data( result = self._conn.execute(cmd) return result - def provide_response( - self, - request: Any | None = None, - body: Any | None = None, - cookies: list[Any] | None = None, - headers: list[Any] | None = None, - reason_phrase: Any | None = None, - status_code: Any | None = None, - ): + def provide_response(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): """Execute network.provideResponse.""" params = { "request": request, @@ -789,7 +699,7 @@ def remove_intercept(self, intercept: Any | None = None): result = self._conn.execute(cmd) return result - def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: list[Any] | None = None): + def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[Any] | None = None): """Execute network.setCacheBehavior.""" params = { "cacheBehavior": cache_behavior, @@ -800,12 +710,7 @@ def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: list[A result = self._conn.execute(cmd) return result - def set_extra_headers( - self, - headers: list[Any] | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def set_extra_headers(self, headers: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute network.setExtraHeaders.""" params = { "headers": headers, @@ -817,6 +722,52 @@ def set_extra_headers( result = self._conn.execute(cmd) return result + def before_request_sent(self, initiator: Any | None = None, method: Any | None = None, params: Any | None = None): + """Execute network.beforeRequestSent.""" + params = { + "initiator": initiator, + "method": method, + "params": params, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.beforeRequestSent", params) + result = self._conn.execute(cmd) + return result + + def fetch_error(self, error_text: Any | None = None, method: Any | None = None, params: Any | None = None): + """Execute network.fetchError.""" + params = { + "errorText": error_text, + "method": method, + "params": params, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.fetchError", params) + result = self._conn.execute(cmd) + return result + + def response_completed(self, response: Any | None = None, method: Any | None = None, params: Any | None = None): + """Execute network.responseCompleted.""" + params = { + "response": response, + "method": method, + "params": params, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.responseCompleted", params) + result = self._conn.execute(cmd) + return result + + def response_started(self, response: Any | None = None): + """Execute network.responseStarted.""" + params = { + "response": response, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.responseStarted", params) + result = self._conn.execute(cmd) + return result + def _add_intercept(self, phases=None, url_patterns=None): """Add a low-level network intercept. @@ -971,51 +922,10 @@ def clear_event_handlers(self) -> None: # Event: network.authRequired AuthRequired = globals().get('AuthRequiredParameters', dict) # Fallback to dict if type not defined -# Event: network.beforeRequestSent -BeforeRequestSent = globals().get('BeforeRequestSentParameters', dict) # Fallback to dict if type not defined - -# Event: network.fetchError -FetchError = globals().get('FetchErrorParameters', dict) # Fallback to dict if type not defined - -# Event: network.responseCompleted -ResponseCompleted = globals().get('ResponseCompletedParameters', dict) # Fallback to dict if type not defined - -# Event: network.responseStarted -ResponseStarted = globals().get('ResponseStartedParameters', dict) # Fallback to dict if type not defined - # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Network.EVENT_CONFIGS = { - "auth_required": ( - EventConfig("auth_required", "network.authRequired", - _globals.get("AuthRequired", dict)) - if _globals.get("AuthRequired") - else EventConfig("auth_required", "network.authRequired", dict) - ), - "before_request_sent": ( - EventConfig("before_request_sent", "network.beforeRequestSent", - _globals.get("BeforeRequestSent", dict)) - if _globals.get("BeforeRequestSent") - else EventConfig("before_request_sent", "network.beforeRequestSent", dict) - ), - "fetch_error": ( - EventConfig("fetch_error", "network.fetchError", - _globals.get("FetchError", dict)) - if _globals.get("FetchError") - else EventConfig("fetch_error", "network.fetchError", dict) - ), - "response_completed": ( - EventConfig("response_completed", "network.responseCompleted", - _globals.get("ResponseCompleted", dict)) - if _globals.get("ResponseCompleted") - else EventConfig("response_completed", "network.responseCompleted", dict) - ), - "response_started": ( - EventConfig("response_started", "network.responseStarted", - _globals.get("ResponseStarted", dict)) - if _globals.get("ResponseStarted") - else EventConfig("response_started", "network.responseStarted", dict) - ), + "auth_required": (EventConfig("auth_required", "network.authRequired", _globals.get("AuthRequired", dict)) if _globals.get("AuthRequired") else EventConfig("auth_required", "network.authRequired", dict)), "before_request": EventConfig("before_request", "network.beforeRequestSent", _globals.get("dict", dict)), } diff --git a/py/selenium/webdriver/common/bidi/permissions.py b/py/selenium/webdriver/common/bidi/permissions.py index 6dd138da17309..f00e765c62e3b 100644 --- a/py/selenium/webdriver/common/bidi/permissions.py +++ b/py/selenium/webdriver/common/bidi/permissions.py @@ -20,7 +20,7 @@ from __future__ import annotations from enum import Enum -from typing import Any +from typing import Any, Optional, Union from .common import command_builder @@ -63,10 +63,10 @@ def __init__(self, websocket_connection: Any) -> None: def set_permission( self, - descriptor: PermissionDescriptor | str, - state: PermissionState | str, - origin: str | None = None, - user_context: str | None = None, + descriptor: Union[PermissionDescriptor, str], + state: Union[PermissionState, str], + origin: Optional[str] = None, + user_context: Optional[str] = None, ) -> None: """Set a permission for a given origin. diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 0b2ec04101933..c7bfcb3774dff 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,15 +6,16 @@ # WebDriver BiDi module: script from __future__ import annotations +from typing import Any, Dict, List, Optional, Union +from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - +from dataclasses import dataclass from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class SpecialNumber: """SpecialNumber.""" @@ -64,7 +48,7 @@ class ResultOwnership: @dataclass class ChannelValue: - """ChannelValue type definition.""" + """ChannelValue.""" type: str = field(default="channel", init=False) value: Any | None = None @@ -72,7 +56,7 @@ class ChannelValue: @dataclass class ChannelProperties: - """ChannelProperties type definition.""" + """ChannelProperties.""" channel: Any | None = None serialization_options: Any | None = None @@ -81,7 +65,7 @@ class ChannelProperties: @dataclass class EvaluateResultSuccess: - """EvaluateResultSuccess type definition.""" + """EvaluateResultSuccess.""" type: str = field(default="success", init=False) result: Any | None = None @@ -90,7 +74,7 @@ class EvaluateResultSuccess: @dataclass class EvaluateResultException: - """EvaluateResultException type definition.""" + """EvaluateResultException.""" type: str = field(default="exception", init=False) exception_details: Any | None = None @@ -99,7 +83,7 @@ class EvaluateResultException: @dataclass class ExceptionDetails: - """ExceptionDetails type definition.""" + """ExceptionDetails.""" column_number: Any | None = None exception: Any | None = None @@ -110,7 +94,7 @@ class ExceptionDetails: @dataclass class ArrayLocalValue: - """ArrayLocalValue type definition.""" + """ArrayLocalValue.""" type: str = field(default="array", init=False) value: Any | None = None @@ -118,7 +102,7 @@ class ArrayLocalValue: @dataclass class DateLocalValue: - """DateLocalValue type definition.""" + """DateLocalValue.""" type: str = field(default="date", init=False) value: str | None = None @@ -126,7 +110,7 @@ class DateLocalValue: @dataclass class MapLocalValue: - """MapLocalValue type definition.""" + """MapLocalValue.""" type: str = field(default="map", init=False) value: Any | None = None @@ -134,7 +118,7 @@ class MapLocalValue: @dataclass class ObjectLocalValue: - """ObjectLocalValue type definition.""" + """ObjectLocalValue.""" type: str = field(default="object", init=False) value: Any | None = None @@ -142,7 +126,7 @@ class ObjectLocalValue: @dataclass class RegExpValue: - """RegExpValue type definition.""" + """RegExpValue.""" pattern: str | None = None flags: str | None = None @@ -150,7 +134,7 @@ class RegExpValue: @dataclass class RegExpLocalValue: - """RegExpLocalValue type definition.""" + """RegExpLocalValue.""" type: str = field(default="regexp", init=False) value: Any | None = None @@ -158,7 +142,7 @@ class RegExpLocalValue: @dataclass class SetLocalValue: - """SetLocalValue type definition.""" + """SetLocalValue.""" type: str = field(default="set", init=False) value: Any | None = None @@ -166,21 +150,21 @@ class SetLocalValue: @dataclass class UndefinedValue: - """UndefinedValue type definition.""" + """UndefinedValue.""" type: str = field(default="undefined", init=False) @dataclass class NullValue: - """NullValue type definition.""" + """NullValue.""" type: str = field(default="null", init=False) @dataclass class StringValue: - """StringValue type definition.""" + """StringValue.""" type: str = field(default="string", init=False) value: str | None = None @@ -188,7 +172,7 @@ class StringValue: @dataclass class NumberValue: - """NumberValue type definition.""" + """NumberValue.""" type: str = field(default="number", init=False) value: Any | None = None @@ -196,7 +180,7 @@ class NumberValue: @dataclass class BooleanValue: - """BooleanValue type definition.""" + """BooleanValue.""" type: str = field(default="boolean", init=False) value: bool | None = None @@ -204,7 +188,7 @@ class BooleanValue: @dataclass class BigIntValue: - """BigIntValue type definition.""" + """BigIntValue.""" type: str = field(default="bigint", init=False) value: str | None = None @@ -212,7 +196,7 @@ class BigIntValue: @dataclass class BaseRealmInfo: - """BaseRealmInfo type definition.""" + """BaseRealmInfo.""" realm: Any | None = None origin: str | None = None @@ -220,7 +204,7 @@ class BaseRealmInfo: @dataclass class WindowRealmInfo: - """WindowRealmInfo type definition.""" + """WindowRealmInfo.""" type: str = field(default="window", init=False) context: Any | None = None @@ -229,57 +213,57 @@ class WindowRealmInfo: @dataclass class DedicatedWorkerRealmInfo: - """DedicatedWorkerRealmInfo type definition.""" + """DedicatedWorkerRealmInfo.""" type: str = field(default="dedicated-worker", init=False) - owners: list[Any | None] | None = field(default_factory=list) + owners: list[Any | None] | None = None @dataclass class SharedWorkerRealmInfo: - """SharedWorkerRealmInfo type definition.""" + """SharedWorkerRealmInfo.""" type: str = field(default="shared-worker", init=False) @dataclass class ServiceWorkerRealmInfo: - """ServiceWorkerRealmInfo type definition.""" + """ServiceWorkerRealmInfo.""" type: str = field(default="service-worker", init=False) @dataclass class WorkerRealmInfo: - """WorkerRealmInfo type definition.""" + """WorkerRealmInfo.""" type: str = field(default="worker", init=False) @dataclass class PaintWorkletRealmInfo: - """PaintWorkletRealmInfo type definition.""" + """PaintWorkletRealmInfo.""" type: str = field(default="paint-worklet", init=False) @dataclass class AudioWorkletRealmInfo: - """AudioWorkletRealmInfo type definition.""" + """AudioWorkletRealmInfo.""" type: str = field(default="audio-worklet", init=False) @dataclass class WorkletRealmInfo: - """WorkletRealmInfo type definition.""" + """WorkletRealmInfo.""" type: str = field(default="worklet", init=False) @dataclass class SharedReference: - """SharedReference type definition.""" + """SharedReference.""" shared_id: Any | None = None handle: Any | None = None @@ -287,7 +271,7 @@ class SharedReference: @dataclass class RemoteObjectReference: - """RemoteObjectReference type definition.""" + """RemoteObjectReference.""" handle: Any | None = None shared_id: Any | None = None @@ -295,7 +279,7 @@ class RemoteObjectReference: @dataclass class SymbolRemoteValue: - """SymbolRemoteValue type definition.""" + """SymbolRemoteValue.""" type: str = field(default="symbol", init=False) handle: Any | None = None @@ -304,7 +288,7 @@ class SymbolRemoteValue: @dataclass class ArrayRemoteValue: - """ArrayRemoteValue type definition.""" + """ArrayRemoteValue.""" type: str = field(default="array", init=False) handle: Any | None = None @@ -314,7 +298,7 @@ class ArrayRemoteValue: @dataclass class ObjectRemoteValue: - """ObjectRemoteValue type definition.""" + """ObjectRemoteValue.""" type: str = field(default="object", init=False) handle: Any | None = None @@ -324,7 +308,7 @@ class ObjectRemoteValue: @dataclass class FunctionRemoteValue: - """FunctionRemoteValue type definition.""" + """FunctionRemoteValue.""" type: str = field(default="function", init=False) handle: Any | None = None @@ -333,7 +317,7 @@ class FunctionRemoteValue: @dataclass class RegExpRemoteValue: - """RegExpRemoteValue type definition.""" + """RegExpRemoteValue.""" handle: Any | None = None internal_id: Any | None = None @@ -341,7 +325,7 @@ class RegExpRemoteValue: @dataclass class DateRemoteValue: - """DateRemoteValue type definition.""" + """DateRemoteValue.""" handle: Any | None = None internal_id: Any | None = None @@ -349,7 +333,7 @@ class DateRemoteValue: @dataclass class MapRemoteValue: - """MapRemoteValue type definition.""" + """MapRemoteValue.""" type: str = field(default="map", init=False) handle: Any | None = None @@ -359,7 +343,7 @@ class MapRemoteValue: @dataclass class SetRemoteValue: - """SetRemoteValue type definition.""" + """SetRemoteValue.""" type: str = field(default="set", init=False) handle: Any | None = None @@ -369,7 +353,7 @@ class SetRemoteValue: @dataclass class WeakMapRemoteValue: - """WeakMapRemoteValue type definition.""" + """WeakMapRemoteValue.""" type: str = field(default="weakmap", init=False) handle: Any | None = None @@ -378,7 +362,7 @@ class WeakMapRemoteValue: @dataclass class WeakSetRemoteValue: - """WeakSetRemoteValue type definition.""" + """WeakSetRemoteValue.""" type: str = field(default="weakset", init=False) handle: Any | None = None @@ -387,7 +371,7 @@ class WeakSetRemoteValue: @dataclass class GeneratorRemoteValue: - """GeneratorRemoteValue type definition.""" + """GeneratorRemoteValue.""" type: str = field(default="generator", init=False) handle: Any | None = None @@ -396,7 +380,7 @@ class GeneratorRemoteValue: @dataclass class ErrorRemoteValue: - """ErrorRemoteValue type definition.""" + """ErrorRemoteValue.""" type: str = field(default="error", init=False) handle: Any | None = None @@ -405,7 +389,7 @@ class ErrorRemoteValue: @dataclass class ProxyRemoteValue: - """ProxyRemoteValue type definition.""" + """ProxyRemoteValue.""" type: str = field(default="proxy", init=False) handle: Any | None = None @@ -414,7 +398,7 @@ class ProxyRemoteValue: @dataclass class PromiseRemoteValue: - """PromiseRemoteValue type definition.""" + """PromiseRemoteValue.""" type: str = field(default="promise", init=False) handle: Any | None = None @@ -423,7 +407,7 @@ class PromiseRemoteValue: @dataclass class TypedArrayRemoteValue: - """TypedArrayRemoteValue type definition.""" + """TypedArrayRemoteValue.""" type: str = field(default="typedarray", init=False) handle: Any | None = None @@ -432,7 +416,7 @@ class TypedArrayRemoteValue: @dataclass class ArrayBufferRemoteValue: - """ArrayBufferRemoteValue type definition.""" + """ArrayBufferRemoteValue.""" type: str = field(default="arraybuffer", init=False) handle: Any | None = None @@ -441,7 +425,7 @@ class ArrayBufferRemoteValue: @dataclass class NodeListRemoteValue: - """NodeListRemoteValue type definition.""" + """NodeListRemoteValue.""" type: str = field(default="nodelist", init=False) handle: Any | None = None @@ -451,7 +435,7 @@ class NodeListRemoteValue: @dataclass class HTMLCollectionRemoteValue: - """HTMLCollectionRemoteValue type definition.""" + """HTMLCollectionRemoteValue.""" type: str = field(default="htmlcollection", init=False) handle: Any | None = None @@ -461,7 +445,7 @@ class HTMLCollectionRemoteValue: @dataclass class NodeRemoteValue: - """NodeRemoteValue type definition.""" + """NodeRemoteValue.""" type: str = field(default="node", init=False) shared_id: Any | None = None @@ -472,11 +456,11 @@ class NodeRemoteValue: @dataclass class NodeProperties: - """NodeProperties type definition.""" + """NodeProperties.""" node_type: Any | None = None child_node_count: Any | None = None - children: list[Any | None] | None = field(default_factory=list) + children: list[Any | None] | None = None local_name: str | None = None mode: Any | None = None namespace_uri: str | None = None @@ -486,7 +470,7 @@ class NodeProperties: @dataclass class WindowProxyRemoteValue: - """WindowProxyRemoteValue type definition.""" + """WindowProxyRemoteValue.""" type: str = field(default="window", init=False) value: Any | None = None @@ -496,14 +480,14 @@ class WindowProxyRemoteValue: @dataclass class WindowProxyProperties: - """WindowProxyProperties type definition.""" + """WindowProxyProperties.""" context: Any | None = None @dataclass class StackFrame: - """StackFrame type definition.""" + """StackFrame.""" column_number: Any | None = None function_name: str | None = None @@ -513,14 +497,14 @@ class StackFrame: @dataclass class StackTrace: - """StackTrace type definition.""" + """StackTrace.""" - call_frames: list[Any | None] | None = field(default_factory=list) + call_frames: list[Any | None] | None = None @dataclass class Source: - """Source type definition.""" + """Source.""" realm: Any | None = None context: Any | None = None @@ -528,14 +512,14 @@ class Source: @dataclass class RealmTarget: - """RealmTarget type definition.""" + """RealmTarget.""" realm: Any | None = None @dataclass class ContextTarget: - """ContextTarget type definition.""" + """ContextTarget.""" context: Any | None = None sandbox: str | None = None @@ -543,38 +527,38 @@ class ContextTarget: @dataclass class AddPreloadScriptParameters: - """AddPreloadScriptParameters type definition.""" + """AddPreloadScriptParameters.""" function_declaration: str | None = None - arguments: list[Any | None] | None = field(default_factory=list) - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + arguments: list[Any | None] | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None sandbox: str | None = None @dataclass class AddPreloadScriptResult: - """AddPreloadScriptResult type definition.""" + """AddPreloadScriptResult.""" script: Any | None = None @dataclass class DisownParameters: - """DisownParameters type definition.""" + """DisownParameters.""" - handles: list[Any | None] | None = field(default_factory=list) + handles: list[Any | None] | None = None target: Any | None = None @dataclass class CallFunctionParameters: - """CallFunctionParameters type definition.""" + """CallFunctionParameters.""" function_declaration: str | None = None await_promise: bool | None = None target: Any | None = None - arguments: list[Any | None] | None = field(default_factory=list) + arguments: list[Any | None] | None = None result_ownership: Any | None = None serialization_options: Any | None = None this: Any | None = None @@ -583,7 +567,7 @@ class CallFunctionParameters: @dataclass class EvaluateParameters: - """EvaluateParameters type definition.""" + """EvaluateParameters.""" expression: str | None = None target: Any | None = None @@ -595,7 +579,7 @@ class EvaluateParameters: @dataclass class GetRealmsParameters: - """GetRealmsParameters type definition.""" + """GetRealmsParameters.""" context: Any | None = None type: Any | None = None @@ -603,21 +587,21 @@ class GetRealmsParameters: @dataclass class GetRealmsResult: - """GetRealmsResult type definition.""" + """GetRealmsResult.""" - realms: list[Any | None] | None = field(default_factory=list) + realms: list[Any | None] | None = None @dataclass class RemovePreloadScriptParameters: - """RemovePreloadScriptParameters type definition.""" + """RemovePreloadScriptParameters.""" script: Any | None = None @dataclass class MessageParameters: - """MessageParameters type definition.""" + """MessageParameters.""" channel: Any | None = None data: Any | None = None @@ -626,14 +610,13 @@ class MessageParameters: @dataclass class RealmDestroyedParameters: - """RealmDestroyedParameters type definition.""" + """RealmDestroyedParameters.""" realm: Any | None = None # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { - "message": "script.message", "realm_created": "script.realmCreated", "realm_destroyed": "script.realmDestroyed", } @@ -800,14 +783,7 @@ def __init__(self, conn, driver=None) -> None: self._driver = driver self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def add_preload_script( - self, - function_declaration: Any | None = None, - arguments: list[Any] | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - sandbox: Any | None = None, - ): + def add_preload_script(self, function_declaration: Any | None = None, arguments: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None, sandbox: Any | None = None): """Execute script.addPreloadScript.""" params = { "functionDeclaration": function_declaration, @@ -821,7 +797,7 @@ def add_preload_script( result = self._conn.execute(cmd) return result - def disown(self, handles: list[Any] | None = None, target: Any | None = None): + def disown(self, handles: List[Any] | None = None, target: Any | None = None): """Execute script.disown.""" params = { "handles": handles, @@ -832,17 +808,7 @@ def disown(self, handles: list[Any] | None = None, target: Any | None = None): result = self._conn.execute(cmd) return result - def call_function( - self, - function_declaration: Any | None = None, - await_promise: bool | None = None, - target: Any | None = None, - arguments: list[Any] | None = None, - result_ownership: Any | None = None, - serialization_options: Any | None = None, - this: Any | None = None, - user_activation: bool | None = None, - ): + def call_function(self, function_declaration: Any | None = None, await_promise: bool | None = None, target: Any | None = None, arguments: List[Any] | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, this: Any | None = None, user_activation: bool | None = None): """Execute script.callFunction.""" params = { "functionDeclaration": function_declaration, @@ -859,15 +825,7 @@ def call_function( result = self._conn.execute(cmd) return result - def evaluate( - self, - expression: Any | None = None, - target: Any | None = None, - await_promise: bool | None = None, - result_ownership: Any | None = None, - serialization_options: Any | None = None, - user_activation: bool | None = None, - ): + def evaluate(self, expression: Any | None = None, target: Any | None = None, await_promise: bool | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, user_activation: bool | None = None): """Execute script.evaluate.""" params = { "expression": expression, @@ -903,6 +861,18 @@ def remove_preload_script(self, script: Any | None = None): result = self._conn.execute(cmd) return result + def message(self, channel: Any | None = None, data: Any | None = None, source: Any | None = None): + """Execute script.message.""" + params = { + "channel": channel, + "data": data, + "source": source, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.message", params) + result = self._conn.execute(cmd) + return result + def execute(self, function_declaration: str, *args, context_id: str | None = None) -> Any: """Execute a function declaration in the browser context. @@ -919,9 +889,8 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import datetime as _datetime import math as _math - + import datetime as _datetime from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -1162,9 +1131,8 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - - from selenium.webdriver.common.bidi import log as _log_mod from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod bidi_event = "log.entryAdded" @@ -1304,9 +1272,6 @@ def clear_event_handlers(self) -> None: return self._event_manager.clear_event_handlers() # Event Info Type Aliases -# Event: script.message -Message = globals().get('MessageParameters', dict) # Fallback to dict if type not defined - # Event: script.realmCreated RealmCreated = globals().get('RealmInfo', dict) # Fallback to dict if type not defined @@ -1317,22 +1282,6 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Script.EVENT_CONFIGS = { - "message": ( - EventConfig("message", "script.message", - _globals.get("Message", dict)) - if _globals.get("Message") - else EventConfig("message", "script.message", dict) - ), - "realm_created": ( - EventConfig("realm_created", "script.realmCreated", - _globals.get("RealmCreated", dict)) - if _globals.get("RealmCreated") - else EventConfig("realm_created", "script.realmCreated", dict) - ), - "realm_destroyed": ( - EventConfig("realm_destroyed", "script.realmDestroyed", - _globals.get("RealmDestroyed", dict)) - if _globals.get("RealmDestroyed") - else EventConfig("realm_destroyed", "script.realmDestroyed", dict) - ), + "realm_created": (EventConfig("realm_created", "script.realmCreated", _globals.get("RealmCreated", dict)) if _globals.get("RealmCreated") else EventConfig("realm_created", "script.realmCreated", dict)), + "realm_destroyed": (EventConfig("realm_destroyed", "script.realmDestroyed", _globals.get("RealmDestroyed", dict)) if _globals.get("RealmDestroyed") else EventConfig("realm_destroyed", "script.realmDestroyed", dict)), } diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index 771a5327151bf..9b1daaae557fa 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,10 +6,11 @@ # WebDriver BiDi module: session from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - +from typing import Any, Dict, List, Optional, Union from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass class UserPromptHandlerType: @@ -39,15 +23,15 @@ class UserPromptHandlerType: @dataclass class CapabilitiesRequest: - """CapabilitiesRequest type definition.""" + """CapabilitiesRequest.""" always_match: Any | None = None - first_match: list[Any | None] | None = field(default_factory=list) + first_match: list[Any | None] | None = None @dataclass class CapabilityRequest: - """CapabilityRequest type definition.""" + """CapabilityRequest.""" accept_insecure_certs: bool | None = None browser_name: str | None = None @@ -59,31 +43,31 @@ class CapabilityRequest: @dataclass class AutodetectProxyConfiguration: - """AutodetectProxyConfiguration type definition.""" + """AutodetectProxyConfiguration.""" proxy_type: str = field(default="autodetect", init=False) @dataclass class DirectProxyConfiguration: - """DirectProxyConfiguration type definition.""" + """DirectProxyConfiguration.""" proxy_type: str = field(default="direct", init=False) @dataclass class ManualProxyConfiguration: - """ManualProxyConfiguration type definition.""" + """ManualProxyConfiguration.""" proxy_type: str = field(default="manual", init=False) http_proxy: str | None = None ssl_proxy: str | None = None - no_proxy: list[Any | None] | None = field(default_factory=list) + no_proxy: list[Any | None] | None = None @dataclass class SocksProxyConfiguration: - """SocksProxyConfiguration type definition.""" + """SocksProxyConfiguration.""" socks_proxy: str | None = None socks_version: Any | None = None @@ -91,7 +75,7 @@ class SocksProxyConfiguration: @dataclass class PacProxyConfiguration: - """PacProxyConfiguration type definition.""" + """PacProxyConfiguration.""" proxy_type: str = field(default="pac", init=False) proxy_autoconfig_url: str | None = None @@ -99,37 +83,37 @@ class PacProxyConfiguration: @dataclass class SystemProxyConfiguration: - """SystemProxyConfiguration type definition.""" + """SystemProxyConfiguration.""" proxy_type: str = field(default="system", init=False) @dataclass class SubscribeParameters: - """SubscribeParameters type definition.""" + """SubscribeParameters.""" - events: list[str | None] | None = field(default_factory=list) - contexts: list[Any | None] | None = field(default_factory=list) - user_contexts: list[Any | None] | None = field(default_factory=list) + events: list[str | None] | None = None + contexts: list[Any | None] | None = None + user_contexts: list[Any | None] | None = None @dataclass class UnsubscribeByIDRequest: - """UnsubscribeByIDRequest type definition.""" + """UnsubscribeByIDRequest.""" - subscriptions: list[Any | None] | None = field(default_factory=list) + subscriptions: list[Any | None] | None = None @dataclass class UnsubscribeByAttributesRequest: - """UnsubscribeByAttributesRequest type definition.""" + """UnsubscribeByAttributesRequest.""" - events: list[str | None] | None = field(default_factory=list) + events: list[str | None] | None = None @dataclass class StatusResult: - """StatusResult type definition.""" + """StatusResult.""" ready: bool | None = None message: str | None = None @@ -137,14 +121,14 @@ class StatusResult: @dataclass class NewParameters: - """NewParameters type definition.""" + """NewParameters.""" capabilities: Any | None = None @dataclass class NewResult: - """NewResult type definition.""" + """NewResult.""" session_id: str | None = None accept_insecure_certs: bool | None = None @@ -160,7 +144,7 @@ class NewResult: @dataclass class SubscribeResult: - """SubscribeResult type definition.""" + """SubscribeResult.""" subscription: Any | None = None @@ -227,12 +211,7 @@ def end(self): result = self._conn.execute(cmd) return result - def subscribe( - self, - events: list[Any] | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): + def subscribe(self, events: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute session.subscribe.""" params = { "events": events, @@ -244,7 +223,7 @@ def subscribe( result = self._conn.execute(cmd) return result - def unsubscribe(self, events: list[Any] | None = None, subscriptions: list[Any] | None = None): + def unsubscribe(self, events: List[Any] | None = None, subscriptions: List[Any] | None = None): """Execute session.unsubscribe.""" params = { "events": events, diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 7623381706040..7e4c9c6dee459 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,15 +6,16 @@ # WebDriver BiDi module: storage from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - +from typing import Any, Dict, List, Optional, Union from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass @dataclass class PartitionKey: - """PartitionKey type definition.""" + """PartitionKey.""" user_context: str | None = None source_origin: str | None = None @@ -39,7 +23,7 @@ class PartitionKey: @dataclass class GetCookiesParameters: - """GetCookiesParameters type definition.""" + """GetCookiesParameters.""" filter: Any | None = None partition: Any | None = None @@ -47,15 +31,15 @@ class GetCookiesParameters: @dataclass class GetCookiesResult: - """GetCookiesResult type definition.""" + """GetCookiesResult.""" - cookies: list[Any | None] | None = field(default_factory=list) + cookies: list[Any | None] | None = None partition_key: Any | None = None @dataclass class SetCookieParameters: - """SetCookieParameters type definition.""" + """SetCookieParameters.""" cookie: Any | None = None partition: Any | None = None @@ -63,14 +47,14 @@ class SetCookieParameters: @dataclass class SetCookieResult: - """SetCookieResult type definition.""" + """SetCookieResult.""" partition_key: Any | None = None @dataclass class DeleteCookiesParameters: - """DeleteCookiesParameters type definition.""" + """DeleteCookiesParameters.""" filter: Any | None = None partition: Any | None = None @@ -78,7 +62,7 @@ class DeleteCookiesParameters: @dataclass class DeleteCookiesResult: - """DeleteCookiesResult type definition.""" + """DeleteCookiesResult.""" partition_key: Any | None = None @@ -123,7 +107,7 @@ class StorageCookie: expiry: Any | None = None @classmethod - def from_bidi_dict(cls, raw: dict) -> StorageCookie: + def from_bidi_dict(cls, raw: dict) -> "StorageCookie": """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): @@ -251,6 +235,39 @@ class Storage: def __init__(self, conn) -> None: self._conn = conn + def get_cookies(self, filter: Any | None = None, partition: Any | None = None): + """Execute storage.getCookies.""" + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.getCookies", params) + result = self._conn.execute(cmd) + return result + + def set_cookie(self, cookie: Any | None = None, partition: Any | None = None): + """Execute storage.setCookie.""" + params = { + "cookie": cookie, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.setCookie", params) + result = self._conn.execute(cmd) + return result + + def delete_cookies(self, filter: Any | None = None, partition: Any | None = None): + """Execute storage.deleteCookies.""" + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.deleteCookies", params) + result = self._conn.execute(cmd) + return result + def get_cookies(self, filter=None, partition=None): """Execute storage.getCookies and return a GetCookiesResult.""" if filter and hasattr(filter, "to_bidi_dict"): diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 99250afca4c68..98d852512f591 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -1,20 +1,3 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# # DO NOT EDIT THIS FILE! # # This file is generated from the WebDriver BiDi specification. If you need to make @@ -23,22 +6,23 @@ # WebDriver BiDi module: webExtension from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - +from typing import Any, Dict, List, Optional, Union from .common import command_builder +from dataclasses import field +from typing import Generator +from dataclasses import dataclass @dataclass class InstallParameters: - """InstallParameters type definition.""" + """InstallParameters.""" extension_data: Any | None = None @dataclass class ExtensionPath: - """ExtensionPath type definition.""" + """ExtensionPath.""" type: str = field(default="path", init=False) path: str | None = None @@ -46,7 +30,7 @@ class ExtensionPath: @dataclass class ExtensionArchivePath: - """ExtensionArchivePath type definition.""" + """ExtensionArchivePath.""" type: str = field(default="archivePath", init=False) path: str | None = None @@ -54,7 +38,7 @@ class ExtensionArchivePath: @dataclass class ExtensionBase64Encoded: - """ExtensionBase64Encoded type definition.""" + """ExtensionBase64Encoded.""" type: str = field(default="base64", init=False) value: str | None = None @@ -62,14 +46,14 @@ class ExtensionBase64Encoded: @dataclass class InstallResult: - """InstallResult type definition.""" + """InstallResult.""" extension: Any | None = None @dataclass class UninstallParameters: - """UninstallParameters type definition.""" + """UninstallParameters.""" extension: Any | None = None @@ -104,9 +88,13 @@ def install( ValueError: If more than one, or none, of the arguments is provided. """ provided = [ - k for k, v in { - "path": path, "archive_path": archive_path, "base64_value": base64_value, - }.items() if v is not None + k + for k, v in { + "path": path, + "archive_path": archive_path, + "base64_value": base64_value, + }.items() + if v is not None ] if len(provided) != 1: raise ValueError( @@ -121,17 +109,24 @@ def install( params = {"extensionData": extension_data} cmd = command_builder("webExtension.install", params) return self._conn.execute(cmd) - def uninstall(self, extension: Any | None = None): + + def uninstall(self, extension: str | dict): """Uninstall a web extension. Args: extension: Either the extension ID string returned by ``install``, or the full result dict returned by ``install`` (the ``"extension"`` value is extracted automatically). + + Raises: + ValueError: If extension is not provided or is None. """ if isinstance(extension, dict): extension = extension.get("extension") + + if extension is None: + raise ValueError("extension parameter is required") + params = {"extension": extension} - params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("webExtension.uninstall", params) return self._conn.execute(cmd) From 2b10803a5310685c8a0b9e7c8eeca0855de901d8 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 12:57:06 +0000 Subject: [PATCH 10/37] Fix webextension and log from comments --- py/generate_bidi.py | 21 ++++++++++++++++++- py/private/bidi_enhancements_manifest.py | 7 +++++++ py/selenium/webdriver/common/bidi/log.py | 4 +++- .../webdriver/common/bidi/webextension.py | 11 +++------- 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index d14e2575c8bfd..53eb3a9e52fcc 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -672,6 +672,11 @@ def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: code += extra_cls code += "\n\n" + # Emit extra type aliases from enhancement manifest (e.g., union types for events) + for extra_alias in enhancements.get("extra_type_aliases", []): + code += extra_alias + code += "\n\n" + # NOTE: Don't generate event type aliases here - they reference types that may not be defined yet # They will be generated after the class definition instead @@ -976,8 +981,22 @@ def clear_event_handlers(self) -> None: # This ensures all types are available when we create the aliases if self.events: code += "\n# Event Info Type Aliases\n" + # Check for explicit event_type_aliases in the enhancement manifest + event_type_aliases = enhancements.get("event_type_aliases", {}) for event_def in self.events: - code += event_def.to_python_dataclass() + # Convert method name to user-friendly event name + method_parts = event_def.method.split(".") + if len(method_parts) == 2: + event_name = self._convert_method_to_event_name(method_parts[1]) + # Check if there's an explicit alias defined in the enhancement manifest + if event_name in event_type_aliases: + # Use the alias directly + type_name = event_type_aliases[event_name] + code += f"# Event: {event_def.method}\n" + code += f"{event_def.name} = {type_name}\n" + else: + # Fall back to the original behavior + code += event_def.to_python_dataclass() code += "\n" # Now populate EVENT_CONFIGS after the aliases are defined diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 5dcce3c25ffeb..f06a4119625e6 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -284,6 +284,13 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": stacktrace=params.get("stackTrace"), )''', ], + # Define Entry union type for log.entryAdded event deserialization + "extra_type_aliases": [ + "Entry = GenericLogEntry | ConsoleLogEntry | JavascriptLogEntry", + ], + "event_type_aliases": { + "entry_added": "Entry", + }, }, "emulation": { "extra_methods": [ diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 7aa7fbf7a3171..c58018e8a947a 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -96,6 +96,8 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": stacktrace=params.get("stackTrace"), ) +Entry = GenericLogEntry | ConsoleLogEntry | JavascriptLogEntry + # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "entry_added": "log.entryAdded", @@ -292,7 +294,7 @@ def clear_event_handlers(self) -> None: # Event Info Type Aliases # Event: log.entryAdded -EntryAdded = globals().get('Entry', dict) # Fallback to dict if type not defined +EntryAdded = Entry # Populate EVENT_CONFIGS with event configuration mappings diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 98d852512f591..e007f8e4792a6 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -88,13 +88,9 @@ def install( ValueError: If more than one, or none, of the arguments is provided. """ provided = [ - k - for k, v in { - "path": path, - "archive_path": archive_path, - "base64_value": base64_value, - }.items() - if v is not None + k for k, v in { + "path": path, "archive_path": archive_path, "base64_value": base64_value, + }.items() if v is not None ] if len(provided) != 1: raise ValueError( @@ -109,7 +105,6 @@ def install( params = {"extensionData": extension_data} cmd = command_builder("webExtension.install", params) return self._conn.execute(cmd) - def uninstall(self, extension: str | dict): """Uninstall a web extension. From 0c44a0917c78251f8f7e50e42862ef3fd21cdeef Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 13:06:48 +0000 Subject: [PATCH 11/37] Correct usage of dafault_factory --- py/generate_bidi.py | 13 +++-- py/selenium/webdriver/common/bidi/browser.py | 6 +-- .../webdriver/common/bidi/browsing_context.py | 6 +-- .../webdriver/common/bidi/emulation.py | 48 +++++++++---------- py/selenium/webdriver/common/bidi/input.py | 12 ++--- py/selenium/webdriver/common/bidi/network.py | 34 ++++++------- py/selenium/webdriver/common/bidi/script.py | 18 +++---- py/selenium/webdriver/common/bidi/session.py | 14 +++--- py/selenium/webdriver/common/bidi/storage.py | 2 +- 9 files changed, 80 insertions(+), 73 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 53eb3a9e52fcc..f4915aa1ad123 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -385,9 +385,16 @@ def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> if literal_match: literal_value = literal_match.group(1) code += f' {snake_name}: str = field(default="{literal_value}", init=False)\n' - # Check if this field is a list type - elif "List[" in python_type: - code += f" {snake_name}: {python_type} = field(default_factory=list)\n" + # Check if this field is a list type (using lowercase 'list[' from Python 3.10+ syntax) + elif python_type.startswith("list["): + # Remove the trailing ' | None' from list types since default_factory=list ensures non-None + type_annotation = python_type.replace(" | None", "") + code += f" {snake_name}: {type_annotation} = field(default_factory=list)\n" + # Check if this field is a dict type (using lowercase 'dict[' from Python 3.10+ syntax) + elif python_type.startswith("dict["): + # Remove the trailing ' | None' from dict types since default_factory=dict ensures non-None + type_annotation = python_type.replace(" | None", "") + code += f" {snake_name}: {type_annotation} = field(default_factory=dict)\n" else: code += f" {snake_name}: {python_type} = None\n" diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 7cf9678c9b007..0618beb14ddef 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -131,14 +131,14 @@ class CreateUserContextParameters: class GetClientWindowsResult: """GetClientWindowsResult.""" - client_windows: list[Any | None] | None = None + client_windows: list[Any] = field(default_factory=list) @dataclass class GetUserContextsResult: """GetUserContextsResult.""" - user_contexts: list[Any | None] | None = None + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -171,7 +171,7 @@ class SetDownloadBehaviorParameters: """SetDownloadBehaviorParameters.""" download_behavior: Any | None = None - user_contexts: list[Any | None] | None = None + user_contexts: list[Any] = field(default_factory=list) @dataclass diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 35aea615d1780..d17829709c0c3 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -220,14 +220,14 @@ class LocateNodesParameters: context: Any | None = None locator: Any | None = None serialization_options: Any | None = None - start_nodes: list[Any | None] | None = None + start_nodes: list[Any] = field(default_factory=list) @dataclass class LocateNodesResult: """LocateNodesResult.""" - nodes: list[Any | None] | None = None + nodes: list[Any] = field(default_factory=list) @dataclass @@ -300,7 +300,7 @@ class SetViewportParameters: context: Any | None = None viewport: Any | None = None device_pixel_ratio: Any | None = None - user_contexts: list[Any | None] | None = None + user_contexts: list[Any] = field(default_factory=list) @dataclass diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index a85eaad3e223a..7edb7a9dacd06 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -41,16 +41,16 @@ class SetForcedColorsModeThemeOverrideParameters: """SetForcedColorsModeThemeOverrideParameters.""" theme: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass class SetGeolocationOverrideParameters: """SetGeolocationOverrideParameters.""" - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -78,8 +78,8 @@ class SetLocaleOverrideParameters: """SetLocaleOverrideParameters.""" locale: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -87,8 +87,8 @@ class setNetworkConditionsParameters: """setNetworkConditionsParameters.""" network_conditions: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -111,8 +111,8 @@ class SetScreenSettingsOverrideParameters: """SetScreenSettingsOverrideParameters.""" screen_area: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -128,8 +128,8 @@ class SetScreenOrientationOverrideParameters: """SetScreenOrientationOverrideParameters.""" screen_orientation: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -137,8 +137,8 @@ class SetUserAgentOverrideParameters: """SetUserAgentOverrideParameters.""" user_agent: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -146,8 +146,8 @@ class SetViewportMetaOverrideParameters: """SetViewportMetaOverrideParameters.""" viewport_meta: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -155,8 +155,8 @@ class SetScriptingEnabledParameters: """SetScriptingEnabledParameters.""" enabled: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -164,8 +164,8 @@ class SetScrollbarTypeOverrideParameters: """SetScrollbarTypeOverrideParameters.""" scrollbar_type: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -173,16 +173,16 @@ class SetTimezoneOverrideParameters: """SetTimezoneOverrideParameters.""" timezone: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass class SetTouchOverrideParameters: """SetTouchOverrideParameters.""" - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) class Emulation: diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 5dbe71dbd3886..a294bde307b89 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -45,7 +45,7 @@ class PerformActionsParameters: """PerformActionsParameters.""" context: Any | None = None - actions: list[Any | None] | None = None + actions: list[Any] = field(default_factory=list) @dataclass @@ -54,7 +54,7 @@ class NoneSourceActions: type: str = field(default="none", init=False) id: str | None = None - actions: list[Any | None] | None = None + actions: list[Any] = field(default_factory=list) @dataclass @@ -63,7 +63,7 @@ class KeySourceActions: type: str = field(default="key", init=False) id: str | None = None - actions: list[Any | None] | None = None + actions: list[Any] = field(default_factory=list) @dataclass @@ -73,7 +73,7 @@ class PointerSourceActions: type: str = field(default="pointer", init=False) id: str | None = None parameters: Any | None = None - actions: list[Any | None] | None = None + actions: list[Any] = field(default_factory=list) @dataclass @@ -89,7 +89,7 @@ class WheelSourceActions: type: str = field(default="wheel", init=False) id: str | None = None - actions: list[Any | None] | None = None + actions: list[Any] = field(default_factory=list) @dataclass @@ -163,7 +163,7 @@ class SetFilesParameters: context: Any | None = None element: Any | None = None - files: list[Any | None] | None = None + files: list[Any] = field(default_factory=list) @dataclass diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 2290c9fec12d3..af079f421546c 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -75,7 +75,7 @@ class BaseParameters: redirect_count: Any | None = None request: Any | None = None timestamp: Any | None = None - intercepts: list[Any | None] | None = None + intercepts: list[Any] = field(default_factory=list) @dataclass @@ -171,13 +171,13 @@ class ResponseData: status: Any | None = None status_text: str | None = None from_cache: bool | None = None - headers: list[Any | None] | None = None + headers: list[Any] = field(default_factory=list) mime_type: str | None = None bytes_received: Any | None = None headers_size: Any | None = None body_size: Any | None = None content: Any | None = None - auth_challenges: list[Any | None] | None = None + auth_challenges: list[Any] = field(default_factory=list) @dataclass @@ -219,11 +219,11 @@ class UrlPatternString: class AddDataCollectorParameters: """AddDataCollectorParameters.""" - data_types: list[Any | None] | None = None + data_types: list[Any] = field(default_factory=list) max_encoded_data_size: Any | None = None collector_type: Any | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass @@ -237,9 +237,9 @@ class AddDataCollectorResult: class AddInterceptParameters: """AddInterceptParameters.""" - phases: list[Any | None] | None = None - contexts: list[Any | None] | None = None - url_patterns: list[Any | None] | None = None + phases: list[Any] = field(default_factory=list) + contexts: list[Any] = field(default_factory=list) + url_patterns: list[Any] = field(default_factory=list) @dataclass @@ -254,9 +254,9 @@ class ContinueResponseParameters: """ContinueResponseParameters.""" request: Any | None = None - cookies: list[Any | None] | None = None + cookies: list[Any] = field(default_factory=list) credentials: Any | None = None - headers: list[Any | None] | None = None + headers: list[Any] = field(default_factory=list) reason_phrase: str | None = None status_code: Any | None = None @@ -315,8 +315,8 @@ class ProvideResponseParameters: request: Any | None = None body: Any | None = None - cookies: list[Any | None] | None = None - headers: list[Any | None] | None = None + cookies: list[Any] = field(default_factory=list) + headers: list[Any] = field(default_factory=list) reason_phrase: str | None = None status_code: Any | None = None @@ -340,16 +340,16 @@ class SetCacheBehaviorParameters: """SetCacheBehaviorParameters.""" cache_behavior: Any | None = None - contexts: list[Any | None] | None = None + contexts: list[Any] = field(default_factory=list) @dataclass class SetExtraHeadersParameters: """SetExtraHeadersParameters.""" - headers: list[Any | None] | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + headers: list[Any] = field(default_factory=list) + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index c7bfcb3774dff..492d1fe431680 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -216,7 +216,7 @@ class DedicatedWorkerRealmInfo: """DedicatedWorkerRealmInfo.""" type: str = field(default="dedicated-worker", init=False) - owners: list[Any | None] | None = None + owners: list[Any] = field(default_factory=list) @dataclass @@ -460,7 +460,7 @@ class NodeProperties: node_type: Any | None = None child_node_count: Any | None = None - children: list[Any | None] | None = None + children: list[Any] = field(default_factory=list) local_name: str | None = None mode: Any | None = None namespace_uri: str | None = None @@ -499,7 +499,7 @@ class StackFrame: class StackTrace: """StackTrace.""" - call_frames: list[Any | None] | None = None + call_frames: list[Any] = field(default_factory=list) @dataclass @@ -530,9 +530,9 @@ class AddPreloadScriptParameters: """AddPreloadScriptParameters.""" function_declaration: str | None = None - arguments: list[Any | None] | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + arguments: list[Any] = field(default_factory=list) + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) sandbox: str | None = None @@ -547,7 +547,7 @@ class AddPreloadScriptResult: class DisownParameters: """DisownParameters.""" - handles: list[Any | None] | None = None + handles: list[Any] = field(default_factory=list) target: Any | None = None @@ -558,7 +558,7 @@ class CallFunctionParameters: function_declaration: str | None = None await_promise: bool | None = None target: Any | None = None - arguments: list[Any | None] | None = None + arguments: list[Any] = field(default_factory=list) result_ownership: Any | None = None serialization_options: Any | None = None this: Any | None = None @@ -589,7 +589,7 @@ class GetRealmsParameters: class GetRealmsResult: """GetRealmsResult.""" - realms: list[Any | None] | None = None + realms: list[Any] = field(default_factory=list) @dataclass diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index 9b1daaae557fa..f1430cb6e59d3 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -26,7 +26,7 @@ class CapabilitiesRequest: """CapabilitiesRequest.""" always_match: Any | None = None - first_match: list[Any | None] | None = None + first_match: list[Any] = field(default_factory=list) @dataclass @@ -62,7 +62,7 @@ class ManualProxyConfiguration: proxy_type: str = field(default="manual", init=False) http_proxy: str | None = None ssl_proxy: str | None = None - no_proxy: list[Any | None] | None = None + no_proxy: list[Any] = field(default_factory=list) @dataclass @@ -92,23 +92,23 @@ class SystemProxyConfiguration: class SubscribeParameters: """SubscribeParameters.""" - events: list[str | None] | None = None - contexts: list[Any | None] | None = None - user_contexts: list[Any | None] | None = None + events: list[str] = field(default_factory=list) + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) @dataclass class UnsubscribeByIDRequest: """UnsubscribeByIDRequest.""" - subscriptions: list[Any | None] | None = None + subscriptions: list[Any] = field(default_factory=list) @dataclass class UnsubscribeByAttributesRequest: """UnsubscribeByAttributesRequest.""" - events: list[str | None] | None = None + events: list[str] = field(default_factory=list) @dataclass diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 7e4c9c6dee459..833e9cdc74f2a 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -33,7 +33,7 @@ class GetCookiesParameters: class GetCookiesResult: """GetCookiesResult.""" - cookies: list[Any | None] | None = None + cookies: list[Any] = field(default_factory=list) partition_key: Any | None = None From 648409e5993f65aaa3e2292dbac4a88e8f5f2613 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 13:16:41 +0000 Subject: [PATCH 12/37] fixing generating extra pass --- py/generate_bidi.py | 2 +- py/selenium/webdriver/common/bidi/log.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index f4915aa1ad123..a53ea96db7481 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -946,7 +946,7 @@ def clear_event_handlers(self) -> None: method_enhancements = enhancements.get(method_name_snake, {}) code += command.to_python_method(method_enhancements) code += "\n" - else: + elif not self.events and not enhancements.get("extra_methods", []): code += " pass\n" # Emit extra methods from enhancement manifest diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index c58018e8a947a..1f16849b8e03d 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -264,7 +264,6 @@ def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - pass def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: """Add an event handler. From 373a7182e5f3aaf51c772e08ced70c7f93fd0826 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 13:21:24 +0000 Subject: [PATCH 13/37] fix window tests --- py/private/bidi_enhancements_manifest.py | 37 ++++++++++++++++++++ py/selenium/webdriver/common/bidi/browser.py | 37 ++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index f06a4119625e6..e33a11d5f2b79 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -130,6 +130,43 @@ if user_contexts is not None: params["userContexts"] = user_contexts cmd = command_builder("browser.setDownloadBehavior", params) + return self._conn.execute(cmd)''', + ''' def set_client_window_state( + self, + client_window: Any | None = None, + state: Any | None = None, + ): + """Set the client window state. + + Args: + client_window: The client window ID to apply the state to. + state: The window state to set. Can be one of: + - A string: "fullscreen", "maximized", "minimized", "normal" + - A ClientWindowRectState object with width, height, x, y + - A dict representing the state + + Raises: + ValueError: If client_window is not provided or state is invalid. + """ + if client_window is None: + raise ValueError("client_window is required") + if state is None: + raise ValueError("state is required") + + # Serialize ClientWindowRectState if needed + state_param = state + if hasattr(state, '__dataclass_fields__'): + # It's a dataclass, convert to dict + state_param = { + k: v for k, v in state.__dict__.items() + if v is not None + } + + params = { + "clientWindow": client_window, + "state": state_param, + } + cmd = command_builder("browser.setClientWindowState", params) return self._conn.execute(cmd)''', ], }, diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 0618beb14ddef..7c1958fd435f0 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -341,3 +341,40 @@ def set_download_behavior( params["userContexts"] = user_contexts cmd = command_builder("browser.setDownloadBehavior", params) return self._conn.execute(cmd) + def set_client_window_state( + self, + client_window: Any | None = None, + state: Any | None = None, + ): + """Set the client window state. + + Args: + client_window: The client window ID to apply the state to. + state: The window state to set. Can be one of: + - A string: "fullscreen", "maximized", "minimized", "normal" + - A ClientWindowRectState object with width, height, x, y + - A dict representing the state + + Raises: + ValueError: If client_window is not provided or state is invalid. + """ + if client_window is None: + raise ValueError("client_window is required") + if state is None: + raise ValueError("state is required") + + # Serialize ClientWindowRectState if needed + state_param = state + if hasattr(state, '__dataclass_fields__'): + # It's a dataclass, convert to dict + state_param = { + k: v for k, v in state.__dict__.items() + if v is not None + } + + params = { + "clientWindow": client_window, + "state": state_param, + } + cmd = command_builder("browser.setClientWindowState", params) + return self._conn.execute(cmd) From d2fa505136fce5caeefde7982f6ddc440f707786 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Sat, 7 Mar 2026 13:21:35 +0000 Subject: [PATCH 14/37] fix window tests --- py/private/bidi_enhancements_manifest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index e33a11d5f2b79..2b93f36f1a5dc 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -158,7 +158,7 @@ if hasattr(state, '__dataclass_fields__'): # It's a dataclass, convert to dict state_param = { - k: v for k, v in state.__dict__.items() + k: v for k, v in state.__dict__.items() if v is not None } From d25a1362d79f0693228f11b983c4322b450878a5 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 11 Mar 2026 11:27:05 +0000 Subject: [PATCH 15/37] correct checks on method arguments --- py/generate_bidi.py | 151 +++++++++++------- py/selenium/webdriver/common/bidi/browser.py | 9 +- .../webdriver/common/bidi/browsing_context.py | 36 +++++ .../webdriver/common/bidi/emulation.py | 30 ++++ py/selenium/webdriver/common/bidi/input.py | 15 ++ py/selenium/webdriver/common/bidi/network.py | 69 ++++++++ py/selenium/webdriver/common/bidi/script.py | 32 ++++ py/selenium/webdriver/common/bidi/session.py | 6 + py/selenium/webdriver/common/bidi/storage.py | 3 + 9 files changed, 292 insertions(+), 59 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index a53ea96db7481..78a7603b929c0 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -18,12 +18,11 @@ import logging import re import sys -from collections import defaultdict from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from textwrap import dedent, indent as tw_indent -from typing import Any, Dict, List, Optional, Set, Tuple +from textwrap import indent as tw_indent +from typing import Any __version__ = "1.0.0" @@ -53,7 +52,7 @@ def indent(s: str, n: int) -> str: return tw_indent(s, n * " ") -def load_enhancements_manifest(manifest_path: Optional[str]) -> Dict[str, Any]: +def load_enhancements_manifest(manifest_path: str | None) -> dict[str, Any]: """Load enhancement manifest from a Python file. Args: @@ -139,11 +138,12 @@ class CddlCommand: module: str name: str - params: Dict[str, str] = field(default_factory=dict) - result: Optional[str] = None + params: dict[str, str] = field(default_factory=dict) + required_params: set[str] = field(default_factory=set) + result: str | None = None description: str = "" - def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: """Generate Python method code for this command. Args: @@ -178,7 +178,17 @@ def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str body = f" def {method_name}({param_list}):\n" body += f' """{self.description or "Execute " + self.module + "." + self.name}."""\n' - # Add validation if specified + # Add automatic validation for required parameters + # (This is used unless there's no required_params, in which case all params are optional) + if self.required_params: + for param_name, snake_param in param_names: + if param_name in self.required_params: + method_snake = self._camel_to_snake(self.name) + body += f" if {snake_param} is None:\n" + body += f' raise TypeError("{method_snake}() missing required argument: {snake_param!r}")\n' + body += "\n" + + # Add validation if specified in enhancements (for additional business logic validation) if "validate" in enhancements: validate_func = enhancements["validate"] # Build parameter list for validation function @@ -264,45 +274,45 @@ def to_python_method(self, enhancements: Optional[Dict[str, Any]] = None) -> str # Extract property from list items body += f' if result and "{extract_field}" in result:\n' body += f' items = result.get("{extract_field}", [])\n' - body += f" return [\n" + body += " return [\n" body += f' item.get("{extract_property}")\n' - body += f" for item in items\n" - body += f" if isinstance(item, dict)\n" - body += f" ]\n" - body += f" return []\n" + body += " for item in items\n" + body += " if isinstance(item, dict)\n" + body += " ]\n" + body += " return []\n" elif extract_field in deserialize_rules: # Extract field and deserialize to typed objects type_name = deserialize_rules[extract_field] body += f' if result and "{extract_field}" in result:\n' body += f' items = result.get("{extract_field}", [])\n' - body += f" return [\n" + body += " return [\n" body += f" {type_name}(\n" body += self._generate_field_args(extract_field, type_name) - body += f" )\n" - body += f" for item in items\n" - body += f" if isinstance(item, dict)\n" - body += f" ]\n" - body += f" return []\n" + body += " )\n" + body += " for item in items\n" + body += " if isinstance(item, dict)\n" + body += " ]\n" + body += " return []\n" else: # Simple field extraction (return the value directly, not wrapped in result dict) body += f' if result and "{extract_field}" in result:\n' body += f' extracted = result.get("{extract_field}")\n' - body += f" return extracted\n" - body += f" return result\n" + body += " return extracted\n" + body += " return result\n" elif "deserialize" in enhancements: # Deserialize response to typed objects (legacy, without extract_field) deserialize_rules = enhancements["deserialize"] for response_field, type_name in deserialize_rules.items(): body += f' if result and "{response_field}" in result:\n' body += f' items = result.get("{response_field}", [])\n' - body += f" return [\n" + body += " return [\n" body += f" {type_name}(\n" body += self._generate_field_args(response_field, type_name) - body += f" )\n" - body += f" for item in items\n" - body += f" if isinstance(item, dict)\n" - body += f" ]\n" - body += f" return []\n" + body += " )\n" + body += " for item in items\n" + body += " if isinstance(item, dict)\n" + body += " ]\n" + body += " return []\n" else: # No special response handling, just return the result body += " return result\n" @@ -351,10 +361,10 @@ class CddlTypeDefinition: module: str name: str - fields: Dict[str, str] = field(default_factory=dict) + fields: dict[str, str] = field(default_factory=dict) description: str = "" - def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str: """Generate Python dataclass code for this type. Args: @@ -366,7 +376,7 @@ def to_python_dataclass(self, enhancements: Optional[Dict[str, Any]] = None) -> # Generate class name from type name (keep it as-is, don't split on underscores) class_name = self.name - code = f"@dataclass\n" + code = "@dataclass\n" code += f"class {class_name}:\n" code += f' """{self.description or self.name}."""\n\n' @@ -460,7 +470,7 @@ class CddlEnum: module: str name: str - values: List[str] = field(default_factory=list) + values: list[str] = field(default_factory=list) description: str = "" def to_python_class(self) -> str: @@ -537,10 +547,10 @@ class CddlModule: """Represents a CDDL module (e.g., script, network, browsing_context).""" name: str - commands: List[CddlCommand] = field(default_factory=list) - types: List[CddlTypeDefinition] = field(default_factory=list) - enums: List[CddlEnum] = field(default_factory=list) - events: List[CddlEvent] = field(default_factory=list) + commands: list[CddlCommand] = field(default_factory=list) + types: list[CddlTypeDefinition] = field(default_factory=list) + enums: list[CddlEnum] = field(default_factory=list) + events: list[CddlEvent] = field(default_factory=list) @staticmethod def _convert_method_to_event_name(method_suffix: str) -> str: @@ -555,7 +565,7 @@ def _convert_method_to_event_name(method_suffix: str) -> str: s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", method_suffix) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() - def generate_code(self, enhancements: Optional[Dict[str, Any]] = None) -> str: + def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """Generate Python code for this module. Args: @@ -1007,9 +1017,9 @@ def clear_event_handlers(self) -> None: code += "\n" # Now populate EVENT_CONFIGS after the aliases are defined - code += f"\n# Populate EVENT_CONFIGS with event configuration mappings\n" + code += "\n# Populate EVENT_CONFIGS with event configuration mappings\n" # Use globals() to look up types dynamically to handle missing types gracefully - code += f"_globals = globals()\n" + code += "_globals = globals()\n" code += f"{class_name}.EVENT_CONFIGS = {{\n" for event_def in self.events: # Convert method name to user-friendly event name @@ -1037,9 +1047,9 @@ def __init__(self, cddl_path: str): """Initialize parser with CDDL file path.""" self.cddl_path = Path(cddl_path) self.content = "" - self.modules: Dict[str, CddlModule] = {} - self.definitions: Dict[str, str] = {} - self.event_names: Set[str] = set() # Names of definitions that are events + self.modules: dict[str, CddlModule] = {} + self.definitions: dict[str, str] = {} + self.event_names: set[str] = set() # Names of definitions that are events self._read_file() def _read_file(self) -> None: @@ -1047,12 +1057,12 @@ def _read_file(self) -> None: if not self.cddl_path.exists(): raise FileNotFoundError(f"CDDL file not found: {self.cddl_path}") - with open(self.cddl_path, "r", encoding="utf-8") as f: + with open(self.cddl_path, encoding="utf-8") as f: self.content = f.read() logger.info(f"Loaded CDDL file: {self.cddl_path}") - def parse(self) -> Dict[str, CddlModule]: + def parse(self) -> dict[str, CddlModule]: """Parse CDDL content and return modules.""" # Remove comments content = self._remove_comments(self.content) @@ -1201,7 +1211,7 @@ def _is_enum_definition(self, definition: str) -> bool: # Pattern: "something" / "something_else" return " / " in clean_def and '"' in clean_def - def _extract_enum_values(self, enum_definition: str) -> List[str]: + def _extract_enum_values(self, enum_definition: str) -> list[str]: """Extract individual values from an enum definition. Enums are defined as: "value1" / "value2" / "value3" @@ -1251,7 +1261,7 @@ def _normalize_cddl_type(field_type: str) -> str: result = re.sub(r"-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?", "float", result) return result.strip() - def _extract_type_fields(self, type_definition: str) -> Dict[str, str]: + def _extract_type_fields(self, type_definition: str) -> dict[str, str]: """Extract fields from a type definition block.""" fields = {} @@ -1361,14 +1371,17 @@ def _extract_commands(self) -> None: if module_name not in self.modules: self.modules[module_name] = CddlModule(name=module_name) - # Extract parameters - params = self._extract_parameters(params_type) + # Extract parameters and required parameters + params, required_params = self._extract_parameters_and_required( + params_type + ) # Create command cmd = CddlCommand( module=module_name, name=command_name, params=params, + required_params=required_params, description=f"Execute {method}", ) @@ -1378,24 +1391,36 @@ def _extract_commands(self) -> None: ) def _extract_parameters( - self, params_type: str, _seen: Optional[Set[str]] = None - ) -> Dict[str, str]: + self, params_type: str, _seen: set[str] | None = None + ) -> dict[str, str]: """Extract parameters from a parameter type definition. Handles both struct types ({...}) and top-level union types (TypeA / TypeB), merging all fields from each alternative as optional parameters. """ + params, _ = self._extract_parameters_and_required(params_type, _seen) + return params + + def _extract_parameters_and_required( + self, params_type: str, _seen: set[str] | None = None + ) -> tuple[dict[str, str], set[str]]: + """Extract parameters and track which are required from a parameter type definition. + + Returns: + Tuple of (params dict, required_params set) + """ params = {} + required = set() if _seen is None: _seen = set() if params_type in _seen: - return params + return params, required _seen.add(params_type) if params_type not in self.definitions: logger.debug(f"Parameter type not found: {params_type}") - return params + return params, required definition = self.definitions[params_type] @@ -1409,10 +1434,15 @@ def _extract_parameters( alternatives = [a.strip() for a in stripped.split("/") if a.strip()] all_named = all(re.match(r"^[\w.]+$", a) for a in alternatives) if all_named: + # For union types, collect parameters from all alternatives + # but treat them as optional since the caller only needs to pass one alternative for alt_type in alternatives: - alt_params = self._extract_parameters(alt_type, _seen) + alt_params, _ = self._extract_parameters_and_required( + alt_type, _seen + ) params.update(alt_params) - return params + # Note: We intentionally DON'T add to required, since these are union alternatives + return params, required # Remove the outer curly braces and split by comma # Then parse each line for key: type patterns @@ -1429,6 +1459,9 @@ def _extract_parameters( continue # Match pattern: [?] name: type + # Check if parameter has optional marker (?) + is_optional = line.startswith("?") + # Using a simple pattern that handles optional prefix match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) if not match: @@ -1443,11 +1476,13 @@ def _extract_parameters( # Skip lines that are part of nested definitions if "{" not in normalized_type and "(" not in normalized_type: params[param_name] = normalized_type + if not is_optional: + required.add(param_name) logger.debug( - f"Extracted param {param_name}: {normalized_type} from {params_type}" + f"Extracted param {param_name}: {normalized_type} (required={not is_optional}) from {params_type}" ) - return params + return params, required def module_name_to_class_name(module_name: str) -> str: @@ -1492,7 +1527,7 @@ def module_name_to_filename(module_name: str) -> str: return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() -def generate_init_file(output_path: Path, modules: Dict[str, CddlModule]) -> None: +def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> None: """Generate __init__.py file for the module.""" init_path = output_path / "__init__.py" @@ -1507,7 +1542,7 @@ def generate_init_file(output_path: Path, modules: Dict[str, CddlModule]) -> Non filename = module_name_to_filename(module_name) code += f"from .{filename} import {class_name}\n" - code += f"\n__all__ = [\n" + code += "\n__all__ = [\n" for module_name in sorted(modules.keys()): class_name = module_name_to_class_name(module_name) code += f' "{class_name}",\n' @@ -1729,7 +1764,7 @@ def main( cddl_file: str, output_dir: str, spec_version: str = "1.0", - enhancements_manifest: Optional[str] = None, + enhancements_manifest: str | None = None, ) -> None: """Main entry point. diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 7c1958fd435f0..c4017265ac757 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -275,6 +275,9 @@ def get_user_contexts(self): def remove_user_context(self, user_context: Any | None = None): """Execute browser.removeUserContext.""" + if user_context is None: + raise TypeError("remove_user_context() missing required argument: 'user_context'") + params = { "userContext": user_context, } @@ -285,6 +288,9 @@ def remove_user_context(self, user_context: Any | None = None): def set_client_window_state(self, client_window: Any | None = None): """Execute browser.setClientWindowState.""" + if client_window is None: + raise TypeError("set_client_window_state() missing required argument: 'client_window'") + params = { "clientWindow": client_window, } @@ -295,6 +301,7 @@ def set_client_window_state(self, client_window: Any | None = None): def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): """Execute browser.setDownloadBehavior.""" + validate_download_behavior(allowed=allowed, destination_folder=destination_folder, user_contexts=user_contexts) download_behavior = None @@ -368,7 +375,7 @@ def set_client_window_state( if hasattr(state, '__dataclass_fields__'): # It's a dataclass, convert to dict state_param = { - k: v for k, v in state.__dict__.items() + k: v for k, v in state.__dict__.items() if v is not None } diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index d17829709c0c3..775bcdb8f9dbb 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -622,6 +622,9 @@ def __init__(self, conn) -> None: def activate(self, context: Any | None = None): """Execute browsingContext.activate.""" + if context is None: + raise TypeError("activate() missing required argument: 'context'") + params = { "context": context, } @@ -632,6 +635,9 @@ def activate(self, context: Any | None = None): def capture_screenshot(self, context: str | None = None, format: Any | None = None, clip: Any | None = None, origin: str | None = None): """Execute browsingContext.captureScreenshot.""" + if context is None: + raise TypeError("capture_screenshot() missing required argument: 'context'") + params = { "context": context, "format": format, @@ -648,6 +654,9 @@ def capture_screenshot(self, context: str | None = None, format: Any | None = No def close(self, context: Any | None = None, prompt_unload: bool | None = None): """Execute browsingContext.close.""" + if context is None: + raise TypeError("close() missing required argument: 'context'") + params = { "context": context, "promptUnload": prompt_unload, @@ -659,6 +668,9 @@ def close(self, context: Any | None = None, prompt_unload: bool | None = None): def create(self, type: Any | None = None, reference_context: Any | None = None, background: bool | None = None, user_context: Any | None = None): """Execute browsingContext.create.""" + if type is None: + raise TypeError("create() missing required argument: 'type'") + params = { "type": type, "referenceContext": reference_context, @@ -701,6 +713,9 @@ def get_tree(self, max_depth: Any | None = None, root: Any | None = None): def handle_user_prompt(self, context: Any | None = None, accept: bool | None = None, user_text: Any | None = None): """Execute browsingContext.handleUserPrompt.""" + if context is None: + raise TypeError("handle_user_prompt() missing required argument: 'context'") + params = { "context": context, "accept": accept, @@ -713,6 +728,11 @@ def handle_user_prompt(self, context: Any | None = None, accept: bool | None = N def locate_nodes(self, context: str | None = None, locator: Any | None = None, serialization_options: Any | None = None, start_nodes: Any | None = None, max_node_count: int | None = None): """Execute browsingContext.locateNodes.""" + if context is None: + raise TypeError("locate_nodes() missing required argument: 'context'") + if locator is None: + raise TypeError("locate_nodes() missing required argument: 'locator'") + params = { "context": context, "locator": locator, @@ -730,6 +750,11 @@ def locate_nodes(self, context: str | None = None, locator: Any | None = None, s def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any | None = None): """Execute browsingContext.navigate.""" + if context is None: + raise TypeError("navigate() missing required argument: 'context'") + if url is None: + raise TypeError("navigate() missing required argument: 'url'") + params = { "context": context, "url": url, @@ -742,6 +767,9 @@ def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any def print(self, context: Any | None = None, background: bool | None = None, margin: Any | None = None, page: Any | None = None, scale: Any | None = None, shrink_to_fit: bool | None = None): """Execute browsingContext.print.""" + if context is None: + raise TypeError("print() missing required argument: 'context'") + params = { "context": context, "background": background, @@ -760,6 +788,9 @@ def print(self, context: Any | None = None, background: bool | None = None, marg def reload(self, context: Any | None = None, ignore_cache: bool | None = None, wait: Any | None = None): """Execute browsingContext.reload.""" + if context is None: + raise TypeError("reload() missing required argument: 'context'") + params = { "context": context, "ignoreCache": ignore_cache, @@ -785,6 +816,11 @@ def set_viewport(self, context: str | None = None, viewport: Any | None = None, def traverse_history(self, context: Any | None = None, delta: Any | None = None): """Execute browsingContext.traverseHistory.""" + if context is None: + raise TypeError("traverse_history() missing required argument: 'context'") + if delta is None: + raise TypeError("traverse_history() missing required argument: 'delta'") + params = { "context": context, "delta": delta, diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 7edb7a9dacd06..8428c233682b8 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -193,6 +193,9 @@ def __init__(self, conn) -> None: def set_forced_colors_mode_theme_override(self, theme: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setForcedColorsModeThemeOverride.""" + if theme is None: + raise TypeError("set_forced_colors_mode_theme_override() missing required argument: 'theme'") + params = { "theme": theme, "contexts": contexts, @@ -216,6 +219,9 @@ def set_geolocation_override(self, contexts: List[Any] | None = None, user_conte def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setLocaleOverride.""" + if locale is None: + raise TypeError("set_locale_override() missing required argument: 'locale'") + params = { "locale": locale, "contexts": contexts, @@ -228,6 +234,9 @@ def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | N def set_network_conditions(self, network_conditions: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setNetworkConditions.""" + if network_conditions is None: + raise TypeError("set_network_conditions() missing required argument: 'network_conditions'") + params = { "networkConditions": network_conditions, "contexts": contexts, @@ -240,6 +249,9 @@ def set_network_conditions(self, network_conditions: Any | None = None, contexts def set_screen_settings_override(self, screen_area: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setScreenSettingsOverride.""" + if screen_area is None: + raise TypeError("set_screen_settings_override() missing required argument: 'screen_area'") + params = { "screenArea": screen_area, "contexts": contexts, @@ -252,6 +264,9 @@ def set_screen_settings_override(self, screen_area: Any | None = None, contexts: def set_screen_orientation_override(self, screen_orientation: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setScreenOrientationOverride.""" + if screen_orientation is None: + raise TypeError("set_screen_orientation_override() missing required argument: 'screen_orientation'") + params = { "screenOrientation": screen_orientation, "contexts": contexts, @@ -264,6 +279,9 @@ def set_screen_orientation_override(self, screen_orientation: Any | None = None, def set_user_agent_override(self, user_agent: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setUserAgentOverride.""" + if user_agent is None: + raise TypeError("set_user_agent_override() missing required argument: 'user_agent'") + params = { "userAgent": user_agent, "contexts": contexts, @@ -276,6 +294,9 @@ def set_user_agent_override(self, user_agent: Any | None = None, contexts: List[ def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setViewportMetaOverride.""" + if viewport_meta is None: + raise TypeError("set_viewport_meta_override() missing required argument: 'viewport_meta'") + params = { "viewportMeta": viewport_meta, "contexts": contexts, @@ -288,6 +309,9 @@ def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: def set_scripting_enabled(self, enabled: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setScriptingEnabled.""" + if enabled is None: + raise TypeError("set_scripting_enabled() missing required argument: 'enabled'") + params = { "enabled": enabled, "contexts": contexts, @@ -300,6 +324,9 @@ def set_scripting_enabled(self, enabled: Any | None = None, contexts: List[Any] def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setScrollbarTypeOverride.""" + if scrollbar_type is None: + raise TypeError("set_scrollbar_type_override() missing required argument: 'scrollbar_type'") + params = { "scrollbarType": scrollbar_type, "contexts": contexts, @@ -312,6 +339,9 @@ def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, context def set_timezone_override(self, timezone: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setTimezoneOverride.""" + if timezone is None: + raise TypeError("set_timezone_override() missing required argument: 'timezone'") + params = { "timezone": timezone, "contexts": contexts, diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index a294bde307b89..2a19d8072781a 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -370,6 +370,11 @@ def __init__(self, conn) -> None: def perform_actions(self, context: Any | None = None, actions: List[Any] | None = None): """Execute input.performActions.""" + if context is None: + raise TypeError("perform_actions() missing required argument: 'context'") + if actions is None: + raise TypeError("perform_actions() missing required argument: 'actions'") + params = { "context": context, "actions": actions, @@ -381,6 +386,9 @@ def perform_actions(self, context: Any | None = None, actions: List[Any] | None def release_actions(self, context: Any | None = None): """Execute input.releaseActions.""" + if context is None: + raise TypeError("release_actions() missing required argument: 'context'") + params = { "context": context, } @@ -391,6 +399,13 @@ def release_actions(self, context: Any | None = None): def set_files(self, context: Any | None = None, element: Any | None = None, files: List[Any] | None = None): """Execute input.setFiles.""" + if context is None: + raise TypeError("set_files() missing required argument: 'context'") + if element is None: + raise TypeError("set_files() missing required argument: 'element'") + if files is None: + raise TypeError("set_files() missing required argument: 'files'") + params = { "context": context, "element": element, diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index af079f421546c..1f6b0471f2414 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -565,6 +565,11 @@ def __init__(self, conn) -> None: def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_data_size: Any | None = None, collector_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute network.addDataCollector.""" + if data_types is None: + raise TypeError("add_data_collector() missing required argument: 'data_types'") + if max_encoded_data_size is None: + raise TypeError("add_data_collector() missing required argument: 'max_encoded_data_size'") + params = { "dataTypes": data_types, "maxEncodedDataSize": max_encoded_data_size, @@ -579,6 +584,9 @@ def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_da def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | None = None, url_patterns: List[Any] | None = None): """Execute network.addIntercept.""" + if phases is None: + raise TypeError("add_intercept() missing required argument: 'phases'") + params = { "phases": phases, "contexts": contexts, @@ -591,6 +599,9 @@ def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | N def continue_request(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, method: Any | None = None, url: Any | None = None): """Execute network.continueRequest.""" + if request is None: + raise TypeError("continue_request() missing required argument: 'request'") + params = { "request": request, "body": body, @@ -606,6 +617,9 @@ def continue_request(self, request: Any | None = None, body: Any | None = None, def continue_response(self, request: Any | None = None, cookies: List[Any] | None = None, credentials: Any | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): """Execute network.continueResponse.""" + if request is None: + raise TypeError("continue_response() missing required argument: 'request'") + params = { "request": request, "cookies": cookies, @@ -621,6 +635,9 @@ def continue_response(self, request: Any | None = None, cookies: List[Any] | Non def continue_with_auth(self, request: Any | None = None): """Execute network.continueWithAuth.""" + if request is None: + raise TypeError("continue_with_auth() missing required argument: 'request'") + params = { "request": request, } @@ -631,6 +648,13 @@ def continue_with_auth(self, request: Any | None = None): def disown_data(self, data_type: Any | None = None, collector: Any | None = None, request: Any | None = None): """Execute network.disownData.""" + if data_type is None: + raise TypeError("disown_data() missing required argument: 'data_type'") + if collector is None: + raise TypeError("disown_data() missing required argument: 'collector'") + if request is None: + raise TypeError("disown_data() missing required argument: 'request'") + params = { "dataType": data_type, "collector": collector, @@ -643,6 +667,9 @@ def disown_data(self, data_type: Any | None = None, collector: Any | None = None def fail_request(self, request: Any | None = None): """Execute network.failRequest.""" + if request is None: + raise TypeError("fail_request() missing required argument: 'request'") + params = { "request": request, } @@ -653,6 +680,11 @@ def fail_request(self, request: Any | None = None): def get_data(self, data_type: Any | None = None, collector: Any | None = None, disown: bool | None = None, request: Any | None = None): """Execute network.getData.""" + if data_type is None: + raise TypeError("get_data() missing required argument: 'data_type'") + if request is None: + raise TypeError("get_data() missing required argument: 'request'") + params = { "dataType": data_type, "collector": collector, @@ -666,6 +698,9 @@ def get_data(self, data_type: Any | None = None, collector: Any | None = None, d def provide_response(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): """Execute network.provideResponse.""" + if request is None: + raise TypeError("provide_response() missing required argument: 'request'") + params = { "request": request, "body": body, @@ -681,6 +716,9 @@ def provide_response(self, request: Any | None = None, body: Any | None = None, def remove_data_collector(self, collector: Any | None = None): """Execute network.removeDataCollector.""" + if collector is None: + raise TypeError("remove_data_collector() missing required argument: 'collector'") + params = { "collector": collector, } @@ -691,6 +729,9 @@ def remove_data_collector(self, collector: Any | None = None): def remove_intercept(self, intercept: Any | None = None): """Execute network.removeIntercept.""" + if intercept is None: + raise TypeError("remove_intercept() missing required argument: 'intercept'") + params = { "intercept": intercept, } @@ -701,6 +742,9 @@ def remove_intercept(self, intercept: Any | None = None): def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[Any] | None = None): """Execute network.setCacheBehavior.""" + if cache_behavior is None: + raise TypeError("set_cache_behavior() missing required argument: 'cache_behavior'") + params = { "cacheBehavior": cache_behavior, "contexts": contexts, @@ -712,6 +756,9 @@ def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[A def set_extra_headers(self, headers: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute network.setExtraHeaders.""" + if headers is None: + raise TypeError("set_extra_headers() missing required argument: 'headers'") + params = { "headers": headers, "contexts": contexts, @@ -724,6 +771,11 @@ def set_extra_headers(self, headers: List[Any] | None = None, contexts: List[Any def before_request_sent(self, initiator: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.beforeRequestSent.""" + if method is None: + raise TypeError("before_request_sent() missing required argument: 'method'") + if params is None: + raise TypeError("before_request_sent() missing required argument: 'params'") + params = { "initiator": initiator, "method": method, @@ -736,6 +788,13 @@ def before_request_sent(self, initiator: Any | None = None, method: Any | None = def fetch_error(self, error_text: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.fetchError.""" + if error_text is None: + raise TypeError("fetch_error() missing required argument: 'error_text'") + if method is None: + raise TypeError("fetch_error() missing required argument: 'method'") + if params is None: + raise TypeError("fetch_error() missing required argument: 'params'") + params = { "errorText": error_text, "method": method, @@ -748,6 +807,13 @@ def fetch_error(self, error_text: Any | None = None, method: Any | None = None, def response_completed(self, response: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.responseCompleted.""" + if response is None: + raise TypeError("response_completed() missing required argument: 'response'") + if method is None: + raise TypeError("response_completed() missing required argument: 'method'") + if params is None: + raise TypeError("response_completed() missing required argument: 'params'") + params = { "response": response, "method": method, @@ -760,6 +826,9 @@ def response_completed(self, response: Any | None = None, method: Any | None = N def response_started(self, response: Any | None = None): """Execute network.responseStarted.""" + if response is None: + raise TypeError("response_started() missing required argument: 'response'") + params = { "response": response, } diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 492d1fe431680..0f59c400a38c2 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -785,6 +785,9 @@ def __init__(self, conn, driver=None) -> None: def add_preload_script(self, function_declaration: Any | None = None, arguments: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None, sandbox: Any | None = None): """Execute script.addPreloadScript.""" + if function_declaration is None: + raise TypeError("add_preload_script() missing required argument: 'function_declaration'") + params = { "functionDeclaration": function_declaration, "arguments": arguments, @@ -799,6 +802,11 @@ def add_preload_script(self, function_declaration: Any | None = None, arguments: def disown(self, handles: List[Any] | None = None, target: Any | None = None): """Execute script.disown.""" + if handles is None: + raise TypeError("disown() missing required argument: 'handles'") + if target is None: + raise TypeError("disown() missing required argument: 'target'") + params = { "handles": handles, "target": target, @@ -810,6 +818,13 @@ def disown(self, handles: List[Any] | None = None, target: Any | None = None): def call_function(self, function_declaration: Any | None = None, await_promise: bool | None = None, target: Any | None = None, arguments: List[Any] | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, this: Any | None = None, user_activation: bool | None = None): """Execute script.callFunction.""" + if function_declaration is None: + raise TypeError("call_function() missing required argument: 'function_declaration'") + if await_promise is None: + raise TypeError("call_function() missing required argument: 'await_promise'") + if target is None: + raise TypeError("call_function() missing required argument: 'target'") + params = { "functionDeclaration": function_declaration, "awaitPromise": await_promise, @@ -827,6 +842,13 @@ def call_function(self, function_declaration: Any | None = None, await_promise: def evaluate(self, expression: Any | None = None, target: Any | None = None, await_promise: bool | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, user_activation: bool | None = None): """Execute script.evaluate.""" + if expression is None: + raise TypeError("evaluate() missing required argument: 'expression'") + if target is None: + raise TypeError("evaluate() missing required argument: 'target'") + if await_promise is None: + raise TypeError("evaluate() missing required argument: 'await_promise'") + params = { "expression": expression, "target": target, @@ -853,6 +875,9 @@ def get_realms(self, context: Any | None = None, type: Any | None = None): def remove_preload_script(self, script: Any | None = None): """Execute script.removePreloadScript.""" + if script is None: + raise TypeError("remove_preload_script() missing required argument: 'script'") + params = { "script": script, } @@ -863,6 +888,13 @@ def remove_preload_script(self, script: Any | None = None): def message(self, channel: Any | None = None, data: Any | None = None, source: Any | None = None): """Execute script.message.""" + if channel is None: + raise TypeError("message() missing required argument: 'channel'") + if data is None: + raise TypeError("message() missing required argument: 'data'") + if source is None: + raise TypeError("message() missing required argument: 'source'") + params = { "channel": channel, "data": data, diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index f1430cb6e59d3..374375a62f2ec 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -194,6 +194,9 @@ def status(self): def new(self, capabilities: Any | None = None): """Execute session.new.""" + if capabilities is None: + raise TypeError("new() missing required argument: 'capabilities'") + params = { "capabilities": capabilities, } @@ -213,6 +216,9 @@ def end(self): def subscribe(self, events: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute session.subscribe.""" + if events is None: + raise TypeError("subscribe() missing required argument: 'events'") + params = { "events": events, "contexts": contexts, diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 833e9cdc74f2a..8742dc61ebccf 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -248,6 +248,9 @@ def get_cookies(self, filter: Any | None = None, partition: Any | None = None): def set_cookie(self, cookie: Any | None = None, partition: Any | None = None): """Execute storage.setCookie.""" + if cookie is None: + raise TypeError("set_cookie() missing required argument: 'cookie'") + params = { "cookie": cookie, "partition": partition, From 5ec1439ab8558c2d8493558cacd47610cafbd010 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 11 Mar 2026 12:32:27 +0000 Subject: [PATCH 16/37] improve generation so we don't need to run ruffs over it --- py/generate_bidi.py | 81 +++++++-- py/private/bidi_enhancements_manifest.py | 14 +- py/selenium/webdriver/common/bidi/browser.py | 47 +---- .../webdriver/common/bidi/browsing_context.py | 167 ++++++++++++------ .../webdriver/common/bidi/emulation.py | 131 ++++---------- py/selenium/webdriver/common/bidi/input.py | 18 +- py/selenium/webdriver/common/bidi/log.py | 6 +- py/selenium/webdriver/common/bidi/network.py | 119 +++++++++---- py/selenium/webdriver/common/bidi/script.py | 69 ++++++-- py/selenium/webdriver/common/bidi/session.py | 11 +- py/selenium/webdriver/common/bidi/storage.py | 36 ---- 11 files changed, 384 insertions(+), 315 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 78a7603b929c0..affd0a63a750c 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -170,22 +170,34 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: param_strs.append(f"{snake_param}: {python_type} | None = None") if param_strs: - param_list = "self, " + ", ".join(param_strs) + # Check if full signature would exceed line length limit (120 chars) + single_line_signature = f" def {method_name}(self, {', '.join(param_strs)}):" + if len(single_line_signature) > 120: + # Format parameters on multiple lines + body = f" def {method_name}(\n" + body += " self,\n" + for i, param_str in enumerate(param_strs): + if i < len(param_strs) - 1: + body += f" {param_str},\n" + else: + body += f" {param_str},\n" + body += " ):\n" + else: + param_list = "self, " + ", ".join(param_strs) + body = f" def {method_name}({param_list}):\n" else: - param_list = "self" - - # Build method body - body = f" def {method_name}({param_list}):\n" + body = f" def {method_name}(self):\n" body += f' """{self.description or "Execute " + self.module + "." + self.name}."""\n' # Add automatic validation for required parameters # (This is used unless there's no required_params, in which case all params are optional) if self.required_params: + method_snake = self._camel_to_snake(self.name) for param_name, snake_param in param_names: if param_name in self.required_params: - method_snake = self._camel_to_snake(self.name) body += f" if {snake_param} is None:\n" - body += f' raise TypeError("{method_snake}() missing required argument: {snake_param!r}")\n' + msg = f"{method_snake}() missing required argument:" + body += f' raise TypeError("{msg} {{{{snake_param!r}}}}")\n' body += "\n" # Add validation if specified in enhancements (for additional business logic validation) @@ -247,7 +259,6 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: if result_param == "download_behavior": body += ' "downloadBehavior": download_behavior,\n' # Add remaining parameters that weren't part of the transform - override_params = enhancements.get("params_override", {}) for cddl_param_name in self.params: if cddl_param_name not in ["downloadBehavior"]: snake_name = self._camel_to_snake(cddl_param_name) @@ -667,8 +678,20 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ - # Generate enums first + # Generate enums first (excluding those in exclude_types) + exclude_types = set(enhancements.get("exclude_types", [])) + + # Also exclude any types that have extra_dataclasses overrides + # Extract class names from extra_dataclasses strings + for extra_cls in enhancements.get("extra_dataclasses", []): + # Match "class ClassName:" pattern + match = re.search(r"class\s+(\w+)\s*:", extra_cls) + if match: + exclude_types.add(match.group(1)) + for enum_def in self.enums: + if enum_def.name in exclude_types: + continue code += enum_def.to_python_class() code += "\n\n" @@ -677,7 +700,6 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code += f"{alias} = {target}\n\n" # Generate type dataclasses, skipping any overridden by extra_dataclasses - exclude_types = set(enhancements.get("exclude_types", [])) for type_def in self.types: if type_def.name in exclude_types: continue @@ -946,6 +968,16 @@ def clear_event_handlers(self) -> None: # Generate command methods exclude_methods = enhancements.get("exclude_methods", []) + + # Automatically exclude methods that are defined in extra_methods + # to prevent generating duplicates + if "extra_methods" in enhancements: + for extra_method in enhancements["extra_methods"]: + # Extract method name from "def method_name(" + match = re.search(r"def\s+(\w+)\s*\(", extra_method) + if match: + exclude_methods = list(exclude_methods) + [match.group(1)] + if self.commands: for command in self.commands: # Get method-specific enhancements @@ -1026,9 +1058,26 @@ def clear_event_handlers(self) -> None: method_parts = event_def.method.split(".") if len(method_parts) == 2: event_name = self._convert_method_to_event_name(method_parts[1]) - # The event class is the event name (e.g., ContextCreated) - # Try to get it from globals, default to dict if not found - code += f' "{event_name}": (EventConfig("{event_name}", "{event_def.method}", _globals.get("{event_def.name}", dict)) if _globals.get("{event_def.name}") else EventConfig("{event_name}", "{event_def.method}", dict)),\n' + # Try to get event class from globals, default to dict if not found + getter = f'_globals.get("{event_def.name}", dict)' + condition = f'_globals.get("{event_def.name}")' + event_class = f'{getter} if {condition} else dict' + + # Build the entry line and check if it exceeds 120 chars + single_line = ( + f' "{event_name}": ' + f'EventConfig("{event_name}", "{event_def.method}", {event_class}),' + ) + + if len(single_line) > 120: + # Break into multiple lines + code += f' "{event_name}": EventConfig(\n' + code += f' "{event_name}",\n' + code += f' "{event_def.method}",\n' + code += f' {event_class},\n' + code += ' ),\n' + else: + code += single_line + '\n' # Extra events not in the CDDL spec for extra_evt in enhancements.get("extra_events", []): ek = extra_evt["event_key"] @@ -1126,9 +1175,6 @@ def _extract_event_names(self) -> None: ... ) """ - # Look for definitions like "BrowsingContextEvent", "SessionEvent", etc. - event_union_pattern = re.compile(r"(\w+\.)?(\w+)Event") - for def_name, def_content in self.definitions.items(): # Check if this looks like an event union (name ends with "Event") and # contains a module-qualified reference like "module.EventName". @@ -1479,7 +1525,8 @@ def _extract_parameters_and_required( if not is_optional: required.add(param_name) logger.debug( - f"Extracted param {param_name}: {normalized_type} (required={not is_optional}) from {params_type}" + f"Extracted param {param_name}: {normalized_type} " + f"(required={not is_optional}) from {params_type}" ) return params, required diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 2b93f36f1a5dc..40647157f8535 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -252,19 +252,7 @@ def from_json(cls, params: dict) -> "DownloadEndParams": ) return cls(download_params=dp)''', ], - # Non-CDDL download events (Chromium-specific, not in the BiDi spec) - "extra_events": [ - { - "event_key": "download_will_begin", - "bidi_event": "browsingContext.downloadWillBegin", - "event_class": "DownloadWillBeginParams", - }, - { - "event_key": "download_end", - "bidi_event": "browsingContext.downloadEnd", - "event_class": "DownloadEndParams", - }, - ], + # Download events are now in the CDDL spec, so no extra_events needed }, "log": { # Make LogLevel an alias for Level so existing code using LogLevel works diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index c4017265ac757..77ae8f0696281 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -61,14 +61,6 @@ def validate_download_behavior( raise ValueError("destination_folder should not be provided when allowed=False") -class ClientWindowNamedState: - """ClientWindowNamedState.""" - - FULLSCREEN = "fullscreen" - MAXIMIZED = "maximized" - MINIMIZED = "minimized" - - @dataclass class ClientWindowInfo: """ClientWindowInfo.""" @@ -212,7 +204,12 @@ def close(self): result = self._conn.execute(cmd) return result - def create_user_context(self, accept_insecure_certs: bool | None = None, proxy: Any | None = None, unhandled_prompt_behavior: Any | None = None): + def create_user_context( + self, + accept_insecure_certs: bool | None = None, + proxy: Any | None = None, + unhandled_prompt_behavior: Any | None = None, + ): """Execute browser.createUserContext.""" if proxy and hasattr(proxy, 'to_bidi_dict'): proxy = proxy.to_bidi_dict() @@ -276,7 +273,7 @@ def get_user_contexts(self): def remove_user_context(self, user_context: Any | None = None): """Execute browser.removeUserContext.""" if user_context is None: - raise TypeError("remove_user_context() missing required argument: 'user_context'") + raise TypeError("remove_user_context() missing required argument: {{snake_param!r}}") params = { "userContext": user_context, @@ -286,36 +283,6 @@ def remove_user_context(self, user_context: Any | None = None): result = self._conn.execute(cmd) return result - def set_client_window_state(self, client_window: Any | None = None): - """Execute browser.setClientWindowState.""" - if client_window is None: - raise TypeError("set_client_window_state() missing required argument: 'client_window'") - - params = { - "clientWindow": client_window, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browser.setClientWindowState", params) - result = self._conn.execute(cmd) - return result - - def set_download_behavior(self, allowed: bool | None = None, destination_folder: str | None = None, user_contexts: List[Any] | None = None): - """Execute browser.setDownloadBehavior.""" - - validate_download_behavior(allowed=allowed, destination_folder=destination_folder, user_contexts=user_contexts) - - download_behavior = None - download_behavior = transform_download_params(allowed, destination_folder) - - params = { - "downloadBehavior": download_behavior, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browser.setDownloadBehavior", params) - result = self._conn.execute(cmd) - return result - def set_download_behavior( self, allowed: bool | None = None, diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 775bcdb8f9dbb..3f877b06b00ab 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -328,20 +328,6 @@ class HistoryUpdatedParameters: url: str | None = None -@dataclass -class DownloadWillBeginParams: - """DownloadWillBeginParams.""" - - suggested_filename: str | None = None - - -@dataclass -class DownloadCanceledParams: - """DownloadCanceledParams.""" - - status: str = field(default="canceled", init=False) - - @dataclass class UserPromptClosedParameters: """UserPromptClosedParameters.""" @@ -421,8 +407,6 @@ def from_json(cls, params: dict) -> "DownloadEndParams": "navigation_failed": "browsingContext.navigationFailed", "user_prompt_closed": "browsingContext.userPromptClosed", "user_prompt_opened": "browsingContext.userPromptOpened", - "download_will_begin": "browsingContext.downloadWillBegin", - "download_end": "browsingContext.downloadEnd", } def _deserialize_info_list(items: list) -> list | None: @@ -623,7 +607,7 @@ def __init__(self, conn) -> None: def activate(self, context: Any | None = None): """Execute browsingContext.activate.""" if context is None: - raise TypeError("activate() missing required argument: 'context'") + raise TypeError("activate() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -633,10 +617,16 @@ def activate(self, context: Any | None = None): result = self._conn.execute(cmd) return result - def capture_screenshot(self, context: str | None = None, format: Any | None = None, clip: Any | None = None, origin: str | None = None): + def capture_screenshot( + self, + context: str | None = None, + format: Any | None = None, + clip: Any | None = None, + origin: str | None = None, + ): """Execute browsingContext.captureScreenshot.""" if context is None: - raise TypeError("capture_screenshot() missing required argument: 'context'") + raise TypeError("capture_screenshot() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -655,7 +645,7 @@ def capture_screenshot(self, context: str | None = None, format: Any | None = No def close(self, context: Any | None = None, prompt_unload: bool | None = None): """Execute browsingContext.close.""" if context is None: - raise TypeError("close() missing required argument: 'context'") + raise TypeError("close() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -666,10 +656,16 @@ def close(self, context: Any | None = None, prompt_unload: bool | None = None): result = self._conn.execute(cmd) return result - def create(self, type: Any | None = None, reference_context: Any | None = None, background: bool | None = None, user_context: Any | None = None): + def create( + self, + type: Any | None = None, + reference_context: Any | None = None, + background: bool | None = None, + user_context: Any | None = None, + ): """Execute browsingContext.create.""" if type is None: - raise TypeError("create() missing required argument: 'type'") + raise TypeError("create() missing required argument: {{snake_param!r}}") params = { "type": type, @@ -714,7 +710,7 @@ def get_tree(self, max_depth: Any | None = None, root: Any | None = None): def handle_user_prompt(self, context: Any | None = None, accept: bool | None = None, user_text: Any | None = None): """Execute browsingContext.handleUserPrompt.""" if context is None: - raise TypeError("handle_user_prompt() missing required argument: 'context'") + raise TypeError("handle_user_prompt() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -726,12 +722,19 @@ def handle_user_prompt(self, context: Any | None = None, accept: bool | None = N result = self._conn.execute(cmd) return result - def locate_nodes(self, context: str | None = None, locator: Any | None = None, serialization_options: Any | None = None, start_nodes: Any | None = None, max_node_count: int | None = None): + def locate_nodes( + self, + context: str | None = None, + locator: Any | None = None, + serialization_options: Any | None = None, + start_nodes: Any | None = None, + max_node_count: int | None = None, + ): """Execute browsingContext.locateNodes.""" if context is None: - raise TypeError("locate_nodes() missing required argument: 'context'") + raise TypeError("locate_nodes() missing required argument: {{snake_param!r}}") if locator is None: - raise TypeError("locate_nodes() missing required argument: 'locator'") + raise TypeError("locate_nodes() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -751,9 +754,9 @@ def locate_nodes(self, context: str | None = None, locator: Any | None = None, s def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any | None = None): """Execute browsingContext.navigate.""" if context is None: - raise TypeError("navigate() missing required argument: 'context'") + raise TypeError("navigate() missing required argument: {{snake_param!r}}") if url is None: - raise TypeError("navigate() missing required argument: 'url'") + raise TypeError("navigate() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -765,10 +768,18 @@ def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any result = self._conn.execute(cmd) return result - def print(self, context: Any | None = None, background: bool | None = None, margin: Any | None = None, page: Any | None = None, scale: Any | None = None, shrink_to_fit: bool | None = None): + def print( + self, + context: Any | None = None, + background: bool | None = None, + margin: Any | None = None, + page: Any | None = None, + scale: Any | None = None, + shrink_to_fit: bool | None = None, + ): """Execute browsingContext.print.""" if context is None: - raise TypeError("print() missing required argument: 'context'") + raise TypeError("print() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -789,7 +800,7 @@ def print(self, context: Any | None = None, background: bool | None = None, marg def reload(self, context: Any | None = None, ignore_cache: bool | None = None, wait: Any | None = None): """Execute browsingContext.reload.""" if context is None: - raise TypeError("reload() missing required argument: 'context'") + raise TypeError("reload() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -801,7 +812,13 @@ def reload(self, context: Any | None = None, ignore_cache: bool | None = None, w result = self._conn.execute(cmd) return result - def set_viewport(self, context: str | None = None, viewport: Any | None = None, user_contexts: Any | None = None, device_pixel_ratio: Any | None = None): + def set_viewport( + self, + context: str | None = None, + viewport: Any | None = None, + user_contexts: Any | None = None, + device_pixel_ratio: Any | None = None, + ): """Execute browsingContext.setViewport.""" params = { "context": context, @@ -817,9 +834,9 @@ def set_viewport(self, context: str | None = None, viewport: Any | None = None, def traverse_history(self, context: Any | None = None, delta: Any | None = None): """Execute browsingContext.traverseHistory.""" if context is None: - raise TypeError("traverse_history() missing required argument: 'context'") + raise TypeError("traverse_history() missing required argument: {{snake_param!r}}") if delta is None: - raise TypeError("traverse_history() missing required argument: 'delta'") + raise TypeError("traverse_history() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -904,20 +921,70 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() BrowsingContext.EVENT_CONFIGS = { - "context_created": (EventConfig("context_created", "browsingContext.contextCreated", _globals.get("ContextCreated", dict)) if _globals.get("ContextCreated") else EventConfig("context_created", "browsingContext.contextCreated", dict)), - "context_destroyed": (EventConfig("context_destroyed", "browsingContext.contextDestroyed", _globals.get("ContextDestroyed", dict)) if _globals.get("ContextDestroyed") else EventConfig("context_destroyed", "browsingContext.contextDestroyed", dict)), - "navigation_started": (EventConfig("navigation_started", "browsingContext.navigationStarted", _globals.get("NavigationStarted", dict)) if _globals.get("NavigationStarted") else EventConfig("navigation_started", "browsingContext.navigationStarted", dict)), - "fragment_navigated": (EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", _globals.get("FragmentNavigated", dict)) if _globals.get("FragmentNavigated") else EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", dict)), - "history_updated": (EventConfig("history_updated", "browsingContext.historyUpdated", _globals.get("HistoryUpdated", dict)) if _globals.get("HistoryUpdated") else EventConfig("history_updated", "browsingContext.historyUpdated", dict)), - "dom_content_loaded": (EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", _globals.get("DomContentLoaded", dict)) if _globals.get("DomContentLoaded") else EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", dict)), - "load": (EventConfig("load", "browsingContext.load", _globals.get("Load", dict)) if _globals.get("Load") else EventConfig("load", "browsingContext.load", dict)), - "download_will_begin": (EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBegin", dict)) if _globals.get("DownloadWillBegin") else EventConfig("download_will_begin", "browsingContext.downloadWillBegin", dict)), - "download_end": (EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEnd", dict)) if _globals.get("DownloadEnd") else EventConfig("download_end", "browsingContext.downloadEnd", dict)), - "navigation_aborted": (EventConfig("navigation_aborted", "browsingContext.navigationAborted", _globals.get("NavigationAborted", dict)) if _globals.get("NavigationAborted") else EventConfig("navigation_aborted", "browsingContext.navigationAborted", dict)), - "navigation_committed": (EventConfig("navigation_committed", "browsingContext.navigationCommitted", _globals.get("NavigationCommitted", dict)) if _globals.get("NavigationCommitted") else EventConfig("navigation_committed", "browsingContext.navigationCommitted", dict)), - "navigation_failed": (EventConfig("navigation_failed", "browsingContext.navigationFailed", _globals.get("NavigationFailed", dict)) if _globals.get("NavigationFailed") else EventConfig("navigation_failed", "browsingContext.navigationFailed", dict)), - "user_prompt_closed": (EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", _globals.get("UserPromptClosed", dict)) if _globals.get("UserPromptClosed") else EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", dict)), - "user_prompt_opened": (EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", _globals.get("UserPromptOpened", dict)) if _globals.get("UserPromptOpened") else EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", dict)), - "download_will_begin": EventConfig("download_will_begin", "browsingContext.downloadWillBegin", _globals.get("DownloadWillBeginParams", dict)), - "download_end": EventConfig("download_end", "browsingContext.downloadEnd", _globals.get("DownloadEndParams", dict)), + "context_created": EventConfig( + "context_created", + "browsingContext.contextCreated", + _globals.get("ContextCreated", dict) if _globals.get("ContextCreated") else dict, + ), + "context_destroyed": EventConfig( + "context_destroyed", + "browsingContext.contextDestroyed", + _globals.get("ContextDestroyed", dict) if _globals.get("ContextDestroyed") else dict, + ), + "navigation_started": EventConfig( + "navigation_started", + "browsingContext.navigationStarted", + _globals.get("NavigationStarted", dict) if _globals.get("NavigationStarted") else dict, + ), + "fragment_navigated": EventConfig( + "fragment_navigated", + "browsingContext.fragmentNavigated", + _globals.get("FragmentNavigated", dict) if _globals.get("FragmentNavigated") else dict, + ), + "history_updated": EventConfig( + "history_updated", + "browsingContext.historyUpdated", + _globals.get("HistoryUpdated", dict) if _globals.get("HistoryUpdated") else dict, + ), + "dom_content_loaded": EventConfig( + "dom_content_loaded", + "browsingContext.domContentLoaded", + _globals.get("DomContentLoaded", dict) if _globals.get("DomContentLoaded") else dict, + ), + "load": EventConfig("load", "browsingContext.load", _globals.get("Load", dict) if _globals.get("Load") else dict), + "download_will_begin": EventConfig( + "download_will_begin", + "browsingContext.downloadWillBegin", + _globals.get("DownloadWillBegin", dict) if _globals.get("DownloadWillBegin") else dict, + ), + "download_end": EventConfig( + "download_end", + "browsingContext.downloadEnd", + _globals.get("DownloadEnd", dict) if _globals.get("DownloadEnd") else dict, + ), + "navigation_aborted": EventConfig( + "navigation_aborted", + "browsingContext.navigationAborted", + _globals.get("NavigationAborted", dict) if _globals.get("NavigationAborted") else dict, + ), + "navigation_committed": EventConfig( + "navigation_committed", + "browsingContext.navigationCommitted", + _globals.get("NavigationCommitted", dict) if _globals.get("NavigationCommitted") else dict, + ), + "navigation_failed": EventConfig( + "navigation_failed", + "browsingContext.navigationFailed", + _globals.get("NavigationFailed", dict) if _globals.get("NavigationFailed") else dict, + ), + "user_prompt_closed": EventConfig( + "user_prompt_closed", + "browsingContext.userPromptClosed", + _globals.get("UserPromptClosed", dict) if _globals.get("UserPromptClosed") else dict, + ), + "user_prompt_opened": EventConfig( + "user_prompt_opened", + "browsingContext.userPromptOpened", + _globals.get("UserPromptOpened", dict) if _globals.get("UserPromptOpened") else dict, + ), } diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 8428c233682b8..d482fecc755cb 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -191,10 +191,15 @@ class Emulation: def __init__(self, conn) -> None: self._conn = conn - def set_forced_colors_mode_theme_override(self, theme: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_forced_colors_mode_theme_override( + self, + theme: Any | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute emulation.setForcedColorsModeThemeOverride.""" if theme is None: - raise TypeError("set_forced_colors_mode_theme_override() missing required argument: 'theme'") + raise TypeError("set_forced_colors_mode_theme_override() missing required argument: {{snake_param!r}}") params = { "theme": theme, @@ -206,21 +211,15 @@ def set_forced_colors_mode_theme_override(self, theme: Any | None = None, contex result = self._conn.execute(cmd) return result - def set_geolocation_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setGeolocationOverride.""" - params = { - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setGeolocationOverride", params) - result = self._conn.execute(cmd) - return result - - def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_locale_override( + self, + locale: Any | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute emulation.setLocaleOverride.""" if locale is None: - raise TypeError("set_locale_override() missing required argument: 'locale'") + raise TypeError("set_locale_override() missing required argument: {{snake_param!r}}") params = { "locale": locale, @@ -232,25 +231,15 @@ def set_locale_override(self, locale: Any | None = None, contexts: List[Any] | N result = self._conn.execute(cmd) return result - def set_network_conditions(self, network_conditions: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setNetworkConditions.""" - if network_conditions is None: - raise TypeError("set_network_conditions() missing required argument: 'network_conditions'") - - params = { - "networkConditions": network_conditions, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setNetworkConditions", params) - result = self._conn.execute(cmd) - return result - - def set_screen_settings_override(self, screen_area: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_screen_settings_override( + self, + screen_area: Any | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute emulation.setScreenSettingsOverride.""" if screen_area is None: - raise TypeError("set_screen_settings_override() missing required argument: 'screen_area'") + raise TypeError("set_screen_settings_override() missing required argument: {{snake_param!r}}") params = { "screenArea": screen_area, @@ -262,40 +251,15 @@ def set_screen_settings_override(self, screen_area: Any | None = None, contexts: result = self._conn.execute(cmd) return result - def set_screen_orientation_override(self, screen_orientation: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setScreenOrientationOverride.""" - if screen_orientation is None: - raise TypeError("set_screen_orientation_override() missing required argument: 'screen_orientation'") - - params = { - "screenOrientation": screen_orientation, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setScreenOrientationOverride", params) - result = self._conn.execute(cmd) - return result - - def set_user_agent_override(self, user_agent: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setUserAgentOverride.""" - if user_agent is None: - raise TypeError("set_user_agent_override() missing required argument: 'user_agent'") - - params = { - "userAgent": user_agent, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setUserAgentOverride", params) - result = self._conn.execute(cmd) - return result - - def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_viewport_meta_override( + self, + viewport_meta: Any | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute emulation.setViewportMetaOverride.""" if viewport_meta is None: - raise TypeError("set_viewport_meta_override() missing required argument: 'viewport_meta'") + raise TypeError("set_viewport_meta_override() missing required argument: {{snake_param!r}}") params = { "viewportMeta": viewport_meta, @@ -307,25 +271,15 @@ def set_viewport_meta_override(self, viewport_meta: Any | None = None, contexts: result = self._conn.execute(cmd) return result - def set_scripting_enabled(self, enabled: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setScriptingEnabled.""" - if enabled is None: - raise TypeError("set_scripting_enabled() missing required argument: 'enabled'") - - params = { - "enabled": enabled, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setScriptingEnabled", params) - result = self._conn.execute(cmd) - return result - - def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_scrollbar_type_override( + self, + scrollbar_type: Any | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute emulation.setScrollbarTypeOverride.""" if scrollbar_type is None: - raise TypeError("set_scrollbar_type_override() missing required argument: 'scrollbar_type'") + raise TypeError("set_scrollbar_type_override() missing required argument: {{snake_param!r}}") params = { "scrollbarType": scrollbar_type, @@ -337,21 +291,6 @@ def set_scrollbar_type_override(self, scrollbar_type: Any | None = None, context result = self._conn.execute(cmd) return result - def set_timezone_override(self, timezone: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): - """Execute emulation.setTimezoneOverride.""" - if timezone is None: - raise TypeError("set_timezone_override() missing required argument: 'timezone'") - - params = { - "timezone": timezone, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setTimezoneOverride", params) - result = self._conn.execute(cmd) - return result - def set_touch_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): """Execute emulation.setTouchOverride.""" params = { diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 2a19d8072781a..0990dacc39363 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -371,9 +371,9 @@ def __init__(self, conn) -> None: def perform_actions(self, context: Any | None = None, actions: List[Any] | None = None): """Execute input.performActions.""" if context is None: - raise TypeError("perform_actions() missing required argument: 'context'") + raise TypeError("perform_actions() missing required argument: {{snake_param!r}}") if actions is None: - raise TypeError("perform_actions() missing required argument: 'actions'") + raise TypeError("perform_actions() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -387,7 +387,7 @@ def perform_actions(self, context: Any | None = None, actions: List[Any] | None def release_actions(self, context: Any | None = None): """Execute input.releaseActions.""" if context is None: - raise TypeError("release_actions() missing required argument: 'context'") + raise TypeError("release_actions() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -400,11 +400,11 @@ def release_actions(self, context: Any | None = None): def set_files(self, context: Any | None = None, element: Any | None = None, files: List[Any] | None = None): """Execute input.setFiles.""" if context is None: - raise TypeError("set_files() missing required argument: 'context'") + raise TypeError("set_files() missing required argument: {{snake_param!r}}") if element is None: - raise TypeError("set_files() missing required argument: 'element'") + raise TypeError("set_files() missing required argument: {{snake_param!r}}") if files is None: - raise TypeError("set_files() missing required argument: 'files'") + raise TypeError("set_files() missing required argument: {{snake_param!r}}") params = { "context": context, @@ -469,5 +469,9 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Input.EVENT_CONFIGS = { - "file_dialog_opened": (EventConfig("file_dialog_opened", "input.fileDialogOpened", _globals.get("FileDialogOpened", dict)) if _globals.get("FileDialogOpened") else EventConfig("file_dialog_opened", "input.fileDialogOpened", dict)), + "file_dialog_opened": EventConfig( + "file_dialog_opened", + "input.fileDialogOpened", + _globals.get("FileDialogOpened", dict) if _globals.get("FileDialogOpened") else dict, + ), } diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 1f16849b8e03d..07121242348ea 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -299,5 +299,9 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Log.EVENT_CONFIGS = { - "entry_added": (EventConfig("entry_added", "log.entryAdded", _globals.get("EntryAdded", dict)) if _globals.get("EntryAdded") else EventConfig("entry_added", "log.entryAdded", dict)), + "entry_added": EventConfig( + "entry_added", + "log.entryAdded", + _globals.get("EntryAdded", dict) if _globals.get("EntryAdded") else dict, + ), } diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 1f6b0471f2414..d7baeb07040ce 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -563,12 +563,19 @@ def __init__(self, conn) -> None: self.intercepts = [] self._handler_intercepts: dict = {} - def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_data_size: Any | None = None, collector_type: Any | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def add_data_collector( + self, + data_types: List[Any] | None = None, + max_encoded_data_size: Any | None = None, + collector_type: Any | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute network.addDataCollector.""" if data_types is None: - raise TypeError("add_data_collector() missing required argument: 'data_types'") + raise TypeError("add_data_collector() missing required argument: {{snake_param!r}}") if max_encoded_data_size is None: - raise TypeError("add_data_collector() missing required argument: 'max_encoded_data_size'") + raise TypeError("add_data_collector() missing required argument: {{snake_param!r}}") params = { "dataTypes": data_types, @@ -582,10 +589,15 @@ def add_data_collector(self, data_types: List[Any] | None = None, max_encoded_da result = self._conn.execute(cmd) return result - def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | None = None, url_patterns: List[Any] | None = None): + def add_intercept( + self, + phases: List[Any] | None = None, + contexts: List[Any] | None = None, + url_patterns: List[Any] | None = None, + ): """Execute network.addIntercept.""" if phases is None: - raise TypeError("add_intercept() missing required argument: 'phases'") + raise TypeError("add_intercept() missing required argument: {{snake_param!r}}") params = { "phases": phases, @@ -597,10 +609,18 @@ def add_intercept(self, phases: List[Any] | None = None, contexts: List[Any] | N result = self._conn.execute(cmd) return result - def continue_request(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, method: Any | None = None, url: Any | None = None): + def continue_request( + self, + request: Any | None = None, + body: Any | None = None, + cookies: List[Any] | None = None, + headers: List[Any] | None = None, + method: Any | None = None, + url: Any | None = None, + ): """Execute network.continueRequest.""" if request is None: - raise TypeError("continue_request() missing required argument: 'request'") + raise TypeError("continue_request() missing required argument: {{snake_param!r}}") params = { "request": request, @@ -615,10 +635,18 @@ def continue_request(self, request: Any | None = None, body: Any | None = None, result = self._conn.execute(cmd) return result - def continue_response(self, request: Any | None = None, cookies: List[Any] | None = None, credentials: Any | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): + def continue_response( + self, + request: Any | None = None, + cookies: List[Any] | None = None, + credentials: Any | None = None, + headers: List[Any] | None = None, + reason_phrase: Any | None = None, + status_code: Any | None = None, + ): """Execute network.continueResponse.""" if request is None: - raise TypeError("continue_response() missing required argument: 'request'") + raise TypeError("continue_response() missing required argument: {{snake_param!r}}") params = { "request": request, @@ -636,7 +664,7 @@ def continue_response(self, request: Any | None = None, cookies: List[Any] | Non def continue_with_auth(self, request: Any | None = None): """Execute network.continueWithAuth.""" if request is None: - raise TypeError("continue_with_auth() missing required argument: 'request'") + raise TypeError("continue_with_auth() missing required argument: {{snake_param!r}}") params = { "request": request, @@ -649,11 +677,11 @@ def continue_with_auth(self, request: Any | None = None): def disown_data(self, data_type: Any | None = None, collector: Any | None = None, request: Any | None = None): """Execute network.disownData.""" if data_type is None: - raise TypeError("disown_data() missing required argument: 'data_type'") + raise TypeError("disown_data() missing required argument: {{snake_param!r}}") if collector is None: - raise TypeError("disown_data() missing required argument: 'collector'") + raise TypeError("disown_data() missing required argument: {{snake_param!r}}") if request is None: - raise TypeError("disown_data() missing required argument: 'request'") + raise TypeError("disown_data() missing required argument: {{snake_param!r}}") params = { "dataType": data_type, @@ -668,7 +696,7 @@ def disown_data(self, data_type: Any | None = None, collector: Any | None = None def fail_request(self, request: Any | None = None): """Execute network.failRequest.""" if request is None: - raise TypeError("fail_request() missing required argument: 'request'") + raise TypeError("fail_request() missing required argument: {{snake_param!r}}") params = { "request": request, @@ -678,12 +706,18 @@ def fail_request(self, request: Any | None = None): result = self._conn.execute(cmd) return result - def get_data(self, data_type: Any | None = None, collector: Any | None = None, disown: bool | None = None, request: Any | None = None): + def get_data( + self, + data_type: Any | None = None, + collector: Any | None = None, + disown: bool | None = None, + request: Any | None = None, + ): """Execute network.getData.""" if data_type is None: - raise TypeError("get_data() missing required argument: 'data_type'") + raise TypeError("get_data() missing required argument: {{snake_param!r}}") if request is None: - raise TypeError("get_data() missing required argument: 'request'") + raise TypeError("get_data() missing required argument: {{snake_param!r}}") params = { "dataType": data_type, @@ -696,10 +730,18 @@ def get_data(self, data_type: Any | None = None, collector: Any | None = None, d result = self._conn.execute(cmd) return result - def provide_response(self, request: Any | None = None, body: Any | None = None, cookies: List[Any] | None = None, headers: List[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None): + def provide_response( + self, + request: Any | None = None, + body: Any | None = None, + cookies: List[Any] | None = None, + headers: List[Any] | None = None, + reason_phrase: Any | None = None, + status_code: Any | None = None, + ): """Execute network.provideResponse.""" if request is None: - raise TypeError("provide_response() missing required argument: 'request'") + raise TypeError("provide_response() missing required argument: {{snake_param!r}}") params = { "request": request, @@ -717,7 +759,7 @@ def provide_response(self, request: Any | None = None, body: Any | None = None, def remove_data_collector(self, collector: Any | None = None): """Execute network.removeDataCollector.""" if collector is None: - raise TypeError("remove_data_collector() missing required argument: 'collector'") + raise TypeError("remove_data_collector() missing required argument: {{snake_param!r}}") params = { "collector": collector, @@ -730,7 +772,7 @@ def remove_data_collector(self, collector: Any | None = None): def remove_intercept(self, intercept: Any | None = None): """Execute network.removeIntercept.""" if intercept is None: - raise TypeError("remove_intercept() missing required argument: 'intercept'") + raise TypeError("remove_intercept() missing required argument: {{snake_param!r}}") params = { "intercept": intercept, @@ -743,7 +785,7 @@ def remove_intercept(self, intercept: Any | None = None): def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[Any] | None = None): """Execute network.setCacheBehavior.""" if cache_behavior is None: - raise TypeError("set_cache_behavior() missing required argument: 'cache_behavior'") + raise TypeError("set_cache_behavior() missing required argument: {{snake_param!r}}") params = { "cacheBehavior": cache_behavior, @@ -754,10 +796,15 @@ def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[A result = self._conn.execute(cmd) return result - def set_extra_headers(self, headers: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_extra_headers( + self, + headers: List[Any] | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute network.setExtraHeaders.""" if headers is None: - raise TypeError("set_extra_headers() missing required argument: 'headers'") + raise TypeError("set_extra_headers() missing required argument: {{snake_param!r}}") params = { "headers": headers, @@ -772,9 +819,9 @@ def set_extra_headers(self, headers: List[Any] | None = None, contexts: List[Any def before_request_sent(self, initiator: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.beforeRequestSent.""" if method is None: - raise TypeError("before_request_sent() missing required argument: 'method'") + raise TypeError("before_request_sent() missing required argument: {{snake_param!r}}") if params is None: - raise TypeError("before_request_sent() missing required argument: 'params'") + raise TypeError("before_request_sent() missing required argument: {{snake_param!r}}") params = { "initiator": initiator, @@ -789,11 +836,11 @@ def before_request_sent(self, initiator: Any | None = None, method: Any | None = def fetch_error(self, error_text: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.fetchError.""" if error_text is None: - raise TypeError("fetch_error() missing required argument: 'error_text'") + raise TypeError("fetch_error() missing required argument: {{snake_param!r}}") if method is None: - raise TypeError("fetch_error() missing required argument: 'method'") + raise TypeError("fetch_error() missing required argument: {{snake_param!r}}") if params is None: - raise TypeError("fetch_error() missing required argument: 'params'") + raise TypeError("fetch_error() missing required argument: {{snake_param!r}}") params = { "errorText": error_text, @@ -808,11 +855,11 @@ def fetch_error(self, error_text: Any | None = None, method: Any | None = None, def response_completed(self, response: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.responseCompleted.""" if response is None: - raise TypeError("response_completed() missing required argument: 'response'") + raise TypeError("response_completed() missing required argument: {{snake_param!r}}") if method is None: - raise TypeError("response_completed() missing required argument: 'method'") + raise TypeError("response_completed() missing required argument: {{snake_param!r}}") if params is None: - raise TypeError("response_completed() missing required argument: 'params'") + raise TypeError("response_completed() missing required argument: {{snake_param!r}}") params = { "response": response, @@ -827,7 +874,7 @@ def response_completed(self, response: Any | None = None, method: Any | None = N def response_started(self, response: Any | None = None): """Execute network.responseStarted.""" if response is None: - raise TypeError("response_started() missing required argument: 'response'") + raise TypeError("response_started() missing required argument: {{snake_param!r}}") params = { "response": response, @@ -995,6 +1042,10 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Network.EVENT_CONFIGS = { - "auth_required": (EventConfig("auth_required", "network.authRequired", _globals.get("AuthRequired", dict)) if _globals.get("AuthRequired") else EventConfig("auth_required", "network.authRequired", dict)), + "auth_required": EventConfig( + "auth_required", + "network.authRequired", + _globals.get("AuthRequired", dict) if _globals.get("AuthRequired") else dict, + ), "before_request": EventConfig("before_request", "network.beforeRequestSent", _globals.get("dict", dict)), } diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 0f59c400a38c2..8e832f4a9cae9 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -783,10 +783,17 @@ def __init__(self, conn, driver=None) -> None: self._driver = driver self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def add_preload_script(self, function_declaration: Any | None = None, arguments: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None, sandbox: Any | None = None): + def add_preload_script( + self, + function_declaration: Any | None = None, + arguments: List[Any] | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + sandbox: Any | None = None, + ): """Execute script.addPreloadScript.""" if function_declaration is None: - raise TypeError("add_preload_script() missing required argument: 'function_declaration'") + raise TypeError("add_preload_script() missing required argument: {{snake_param!r}}") params = { "functionDeclaration": function_declaration, @@ -803,9 +810,9 @@ def add_preload_script(self, function_declaration: Any | None = None, arguments: def disown(self, handles: List[Any] | None = None, target: Any | None = None): """Execute script.disown.""" if handles is None: - raise TypeError("disown() missing required argument: 'handles'") + raise TypeError("disown() missing required argument: {{snake_param!r}}") if target is None: - raise TypeError("disown() missing required argument: 'target'") + raise TypeError("disown() missing required argument: {{snake_param!r}}") params = { "handles": handles, @@ -816,14 +823,24 @@ def disown(self, handles: List[Any] | None = None, target: Any | None = None): result = self._conn.execute(cmd) return result - def call_function(self, function_declaration: Any | None = None, await_promise: bool | None = None, target: Any | None = None, arguments: List[Any] | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, this: Any | None = None, user_activation: bool | None = None): + def call_function( + self, + function_declaration: Any | None = None, + await_promise: bool | None = None, + target: Any | None = None, + arguments: List[Any] | None = None, + result_ownership: Any | None = None, + serialization_options: Any | None = None, + this: Any | None = None, + user_activation: bool | None = None, + ): """Execute script.callFunction.""" if function_declaration is None: - raise TypeError("call_function() missing required argument: 'function_declaration'") + raise TypeError("call_function() missing required argument: {{snake_param!r}}") if await_promise is None: - raise TypeError("call_function() missing required argument: 'await_promise'") + raise TypeError("call_function() missing required argument: {{snake_param!r}}") if target is None: - raise TypeError("call_function() missing required argument: 'target'") + raise TypeError("call_function() missing required argument: {{snake_param!r}}") params = { "functionDeclaration": function_declaration, @@ -840,14 +857,22 @@ def call_function(self, function_declaration: Any | None = None, await_promise: result = self._conn.execute(cmd) return result - def evaluate(self, expression: Any | None = None, target: Any | None = None, await_promise: bool | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, user_activation: bool | None = None): + def evaluate( + self, + expression: Any | None = None, + target: Any | None = None, + await_promise: bool | None = None, + result_ownership: Any | None = None, + serialization_options: Any | None = None, + user_activation: bool | None = None, + ): """Execute script.evaluate.""" if expression is None: - raise TypeError("evaluate() missing required argument: 'expression'") + raise TypeError("evaluate() missing required argument: {{snake_param!r}}") if target is None: - raise TypeError("evaluate() missing required argument: 'target'") + raise TypeError("evaluate() missing required argument: {{snake_param!r}}") if await_promise is None: - raise TypeError("evaluate() missing required argument: 'await_promise'") + raise TypeError("evaluate() missing required argument: {{snake_param!r}}") params = { "expression": expression, @@ -876,7 +901,7 @@ def get_realms(self, context: Any | None = None, type: Any | None = None): def remove_preload_script(self, script: Any | None = None): """Execute script.removePreloadScript.""" if script is None: - raise TypeError("remove_preload_script() missing required argument: 'script'") + raise TypeError("remove_preload_script() missing required argument: {{snake_param!r}}") params = { "script": script, @@ -889,11 +914,11 @@ def remove_preload_script(self, script: Any | None = None): def message(self, channel: Any | None = None, data: Any | None = None, source: Any | None = None): """Execute script.message.""" if channel is None: - raise TypeError("message() missing required argument: 'channel'") + raise TypeError("message() missing required argument: {{snake_param!r}}") if data is None: - raise TypeError("message() missing required argument: 'data'") + raise TypeError("message() missing required argument: {{snake_param!r}}") if source is None: - raise TypeError("message() missing required argument: 'source'") + raise TypeError("message() missing required argument: {{snake_param!r}}") params = { "channel": channel, @@ -1314,6 +1339,14 @@ def clear_event_handlers(self) -> None: # Populate EVENT_CONFIGS with event configuration mappings _globals = globals() Script.EVENT_CONFIGS = { - "realm_created": (EventConfig("realm_created", "script.realmCreated", _globals.get("RealmCreated", dict)) if _globals.get("RealmCreated") else EventConfig("realm_created", "script.realmCreated", dict)), - "realm_destroyed": (EventConfig("realm_destroyed", "script.realmDestroyed", _globals.get("RealmDestroyed", dict)) if _globals.get("RealmDestroyed") else EventConfig("realm_destroyed", "script.realmDestroyed", dict)), + "realm_created": EventConfig( + "realm_created", + "script.realmCreated", + _globals.get("RealmCreated", dict) if _globals.get("RealmCreated") else dict, + ), + "realm_destroyed": EventConfig( + "realm_destroyed", + "script.realmDestroyed", + _globals.get("RealmDestroyed", dict) if _globals.get("RealmDestroyed") else dict, + ), } diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index 374375a62f2ec..c7dd45ec824b8 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -195,7 +195,7 @@ def status(self): def new(self, capabilities: Any | None = None): """Execute session.new.""" if capabilities is None: - raise TypeError("new() missing required argument: 'capabilities'") + raise TypeError("new() missing required argument: {{snake_param!r}}") params = { "capabilities": capabilities, @@ -214,10 +214,15 @@ def end(self): result = self._conn.execute(cmd) return result - def subscribe(self, events: List[Any] | None = None, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def subscribe( + self, + events: List[Any] | None = None, + contexts: List[Any] | None = None, + user_contexts: List[Any] | None = None, + ): """Execute session.subscribe.""" if events is None: - raise TypeError("subscribe() missing required argument: 'events'") + raise TypeError("subscribe() missing required argument: {{snake_param!r}}") params = { "events": events, diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 8742dc61ebccf..267569f782289 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -235,42 +235,6 @@ class Storage: def __init__(self, conn) -> None: self._conn = conn - def get_cookies(self, filter: Any | None = None, partition: Any | None = None): - """Execute storage.getCookies.""" - params = { - "filter": filter, - "partition": partition, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("storage.getCookies", params) - result = self._conn.execute(cmd) - return result - - def set_cookie(self, cookie: Any | None = None, partition: Any | None = None): - """Execute storage.setCookie.""" - if cookie is None: - raise TypeError("set_cookie() missing required argument: 'cookie'") - - params = { - "cookie": cookie, - "partition": partition, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("storage.setCookie", params) - result = self._conn.execute(cmd) - return result - - def delete_cookies(self, filter: Any | None = None, partition: Any | None = None): - """Execute storage.deleteCookies.""" - params = { - "filter": filter, - "partition": partition, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("storage.deleteCookies", params) - result = self._conn.execute(cmd) - return result - def get_cookies(self, filter=None, partition=None): """Execute storage.getCookies and return a GetCookiesResult.""" if filter and hasattr(filter, "to_bidi_dict"): From 1b286b8f690dbc41dab16f3e40691a14a2f0cc29 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 11 Mar 2026 12:44:03 +0000 Subject: [PATCH 17/37] make sure not to generate F401 ruff errors --- py/generate_bidi.py | 55 +++++++++++-------- py/selenium/webdriver/common/bidi/browser.py | 7 +-- .../webdriver/common/bidi/browsing_context.py | 15 +++-- py/selenium/webdriver/common/bidi/common.py | 7 ++- .../webdriver/common/bidi/emulation.py | 29 +++++----- py/selenium/webdriver/common/bidi/input.py | 17 +++--- py/selenium/webdriver/common/bidi/log.py | 11 ++-- py/selenium/webdriver/common/bidi/network.py | 43 +++++++-------- .../webdriver/common/bidi/permissions.py | 10 ++-- py/selenium/webdriver/common/bidi/script.py | 27 ++++----- py/selenium/webdriver/common/bidi/session.py | 15 +++-- py/selenium/webdriver/common/bidi/storage.py | 9 ++- .../webdriver/common/bidi/webextension.py | 7 +-- 13 files changed, 126 insertions(+), 126 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index affd0a63a750c..8372d25743c08 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -42,7 +42,7 @@ # WebDriver BiDi module: {{}} from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from typing import Any from .common import command_builder """ @@ -123,10 +123,10 @@ def get_annotation(cls, cddl_type: str) -> str: if cddl_type.startswith("["): # Array inner = cddl_type.strip("[]+ ") inner_type = cls.get_annotation(inner) - return f"List[{inner_type}]" + return f"list[{inner_type}]" if cddl_type.startswith("{"): # Map/Dict - return "Dict[str, Any]" + return "dict[str, Any]" # Default to Any for unknown types return "Any" @@ -171,7 +171,9 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: if param_strs: # Check if full signature would exceed line length limit (120 chars) - single_line_signature = f" def {method_name}(self, {', '.join(param_strs)}):" + single_line_signature = ( + f" def {method_name}(self, {', '.join(param_strs)}):" + ) if len(single_line_signature) > 120: # Format parameters on multiple lines body = f" def {method_name}(\n" @@ -197,7 +199,9 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: if param_name in self.required_params: body += f" if {snake_param} is None:\n" msg = f"{method_snake}() missing required argument:" - body += f' raise TypeError("{msg} {{{{snake_param!r}}}}")\n' + body += ( + f' raise TypeError("{msg} {{{{snake_param!r}}}}")\n' + ) body += "\n" # Add validation if specified in enhancements (for additional business logic validation) @@ -585,18 +589,23 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: enhancements = enhancements or {} code = MODULE_HEADER.format(self.name) + # Collect needed imports to avoid duplicates + needs_dataclass = self.commands or self.types or self.events + needs_field = self.types + needs_threading = self.events + needs_callable = self.events + needs_session = self.events + # Add imports if needed - if self.types: - code += "from dataclasses import field\n" - if self.commands or self.types: - code += "from typing import Generator\n" + if needs_dataclass: code += "from dataclasses import dataclass\n" - - # Add imports for event handling if needed - if self.events: + if needs_field: + code += "from dataclasses import field\n" + if needs_threading: code += "import threading\n" + if needs_callable: code += "from collections.abc import Callable\n" - code += "from dataclasses import dataclass\n" + if needs_session: code += "from selenium.webdriver.common.bidi.session import Session\n" code += "\n\n" @@ -680,7 +689,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: # Generate enums first (excluding those in exclude_types) exclude_types = set(enhancements.get("exclude_types", [])) - + # Also exclude any types that have extra_dataclasses overrides # Extract class names from extra_dataclasses strings for extra_cls in enhancements.get("extra_dataclasses", []): @@ -688,7 +697,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: match = re.search(r"class\s+(\w+)\s*:", extra_cls) if match: exclude_types.add(match.group(1)) - + for enum_def in self.enums: if enum_def.name in exclude_types: continue @@ -968,7 +977,7 @@ def clear_event_handlers(self) -> None: # Generate command methods exclude_methods = enhancements.get("exclude_methods", []) - + # Automatically exclude methods that are defined in extra_methods # to prevent generating duplicates if "extra_methods" in enhancements: @@ -977,7 +986,7 @@ def clear_event_handlers(self) -> None: match = re.search(r"def\s+(\w+)\s*\(", extra_method) if match: exclude_methods = list(exclude_methods) + [match.group(1)] - + if self.commands: for command in self.commands: # Get method-specific enhancements @@ -1061,23 +1070,23 @@ def clear_event_handlers(self) -> None: # Try to get event class from globals, default to dict if not found getter = f'_globals.get("{event_def.name}", dict)' condition = f'_globals.get("{event_def.name}")' - event_class = f'{getter} if {condition} else dict' - + event_class = f"{getter} if {condition} else dict" + # Build the entry line and check if it exceeds 120 chars single_line = ( f' "{event_name}": ' f'EventConfig("{event_name}", "{event_def.method}", {event_class}),' ) - + if len(single_line) > 120: # Break into multiple lines code += f' "{event_name}": EventConfig(\n' code += f' "{event_name}",\n' code += f' "{event_def.method}",\n' - code += f' {event_class},\n' - code += ' ),\n' + code += f" {event_class},\n" + code += " ),\n" else: - code += single_line + '\n' + code += single_line + "\n" # Extra events not in the CDDL spec for extra_evt in enhancements.get("extra_events", []): ek = extra_evt["event_key"] diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 77ae8f0696281..a8fb60c98178d 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: browser from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass def transform_download_params( diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 3f877b06b00ab..777005e0ce4e5 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: browsingContext from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class ReadinessState: """ReadinessState.""" @@ -376,10 +375,10 @@ class DownloadParams: class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" - download_params: "DownloadParams | None" = None + download_params: DownloadParams | None = None @classmethod - def from_json(cls, params: dict) -> "DownloadEndParams": + def from_json(cls, params: dict) -> DownloadEndParams: """Deserialize from BiDi wire-level params dict.""" dp = DownloadParams( status=params.get("status"), diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index d90d8c770263a..d7cb436a08471 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -17,12 +17,13 @@ """Common utilities for BiDi command construction.""" -from typing import Any, Dict, Generator +from collections.abc import Generator +from typing import Any def command_builder( - method: str, params: Dict[str, Any] -) -> Generator[Dict[str, Any], Any, Any]: + method: str, params: dict[str, Any] +) -> Generator[dict[str, Any], Any, Any]: """Build a BiDi command generator. Args: diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index d482fecc755cb..0356372c48f03 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: emulation from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass class ForcedColorsModeTheme: @@ -194,8 +193,8 @@ def __init__(self, conn) -> None: def set_forced_colors_mode_theme_override( self, theme: Any | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setForcedColorsModeThemeOverride.""" if theme is None: @@ -214,8 +213,8 @@ def set_forced_colors_mode_theme_override( def set_locale_override( self, locale: Any | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setLocaleOverride.""" if locale is None: @@ -234,8 +233,8 @@ def set_locale_override( def set_screen_settings_override( self, screen_area: Any | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setScreenSettingsOverride.""" if screen_area is None: @@ -254,8 +253,8 @@ def set_screen_settings_override( def set_viewport_meta_override( self, viewport_meta: Any | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setViewportMetaOverride.""" if viewport_meta is None: @@ -274,8 +273,8 @@ def set_viewport_meta_override( def set_scrollbar_type_override( self, scrollbar_type: Any | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute emulation.setScrollbarTypeOverride.""" if scrollbar_type is None: @@ -291,7 +290,7 @@ def set_scrollbar_type_override( result = self._conn.execute(cmd) return result - def set_touch_override(self, contexts: List[Any] | None = None, user_contexts: List[Any] | None = None): + def set_touch_override(self, contexts: list[Any] | None = None, user_contexts: list[Any] | None = None): """Execute emulation.setTouchOverride.""" params = { "contexts": contexts, diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 0990dacc39363..7e76cb831543f 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: input from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class PointerType: """PointerType.""" @@ -175,7 +174,7 @@ class FileDialogInfo: multiple: bool | None = None @classmethod - def from_json(cls, params: dict) -> "FileDialogInfo": + def from_json(cls, params: dict) -> FileDialogInfo: """Deserialize event params into FileDialogInfo.""" return cls( context=params.get("context"), @@ -368,7 +367,7 @@ def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def perform_actions(self, context: Any | None = None, actions: List[Any] | None = None): + def perform_actions(self, context: Any | None = None, actions: list[Any] | None = None): """Execute input.performActions.""" if context is None: raise TypeError("perform_actions() missing required argument: {{snake_param!r}}") @@ -397,7 +396,7 @@ def release_actions(self, context: Any | None = None): result = self._conn.execute(cmd) return result - def set_files(self, context: Any | None = None, element: Any | None = None, files: List[Any] | None = None): + def set_files(self, context: Any | None = None, element: Any | None = None, files: list[Any] | None = None): """Execute input.setFiles.""" if context is None: raise TypeError("set_files() missing required argument: {{snake_param!r}}") diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 07121242348ea..fd712b7c9a8ab 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -6,14 +6,11 @@ # WebDriver BiDi module: log from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable from dataclasses import dataclass +from typing import Any + from selenium.webdriver.common.bidi.session import Session @@ -60,7 +57,7 @@ class ConsoleLogEntry: stack_trace: Any | None = None @classmethod - def from_json(cls, params: dict) -> "ConsoleLogEntry": + def from_json(cls, params: dict) -> ConsoleLogEntry: """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -85,7 +82,7 @@ class JavascriptLogEntry: stacktrace: Any | None = None @classmethod - def from_json(cls, params: dict) -> "JavascriptLogEntry": + def from_json(cls, params: dict) -> JavascriptLogEntry: """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index d7baeb07040ce..74951031c597f 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: network from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class SameSite: """SameSite.""" @@ -565,11 +564,11 @@ def __init__(self, conn) -> None: def add_data_collector( self, - data_types: List[Any] | None = None, + data_types: list[Any] | None = None, max_encoded_data_size: Any | None = None, collector_type: Any | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute network.addDataCollector.""" if data_types is None: @@ -591,9 +590,9 @@ def add_data_collector( def add_intercept( self, - phases: List[Any] | None = None, - contexts: List[Any] | None = None, - url_patterns: List[Any] | None = None, + phases: list[Any] | None = None, + contexts: list[Any] | None = None, + url_patterns: list[Any] | None = None, ): """Execute network.addIntercept.""" if phases is None: @@ -613,8 +612,8 @@ def continue_request( self, request: Any | None = None, body: Any | None = None, - cookies: List[Any] | None = None, - headers: List[Any] | None = None, + cookies: list[Any] | None = None, + headers: list[Any] | None = None, method: Any | None = None, url: Any | None = None, ): @@ -638,9 +637,9 @@ def continue_request( def continue_response( self, request: Any | None = None, - cookies: List[Any] | None = None, + cookies: list[Any] | None = None, credentials: Any | None = None, - headers: List[Any] | None = None, + headers: list[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None, ): @@ -734,8 +733,8 @@ def provide_response( self, request: Any | None = None, body: Any | None = None, - cookies: List[Any] | None = None, - headers: List[Any] | None = None, + cookies: list[Any] | None = None, + headers: list[Any] | None = None, reason_phrase: Any | None = None, status_code: Any | None = None, ): @@ -782,7 +781,7 @@ def remove_intercept(self, intercept: Any | None = None): result = self._conn.execute(cmd) return result - def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[Any] | None = None): + def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: list[Any] | None = None): """Execute network.setCacheBehavior.""" if cache_behavior is None: raise TypeError("set_cache_behavior() missing required argument: {{snake_param!r}}") @@ -798,9 +797,9 @@ def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: List[A def set_extra_headers( self, - headers: List[Any] | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + headers: list[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute network.setExtraHeaders.""" if headers is None: diff --git a/py/selenium/webdriver/common/bidi/permissions.py b/py/selenium/webdriver/common/bidi/permissions.py index f00e765c62e3b..6dd138da17309 100644 --- a/py/selenium/webdriver/common/bidi/permissions.py +++ b/py/selenium/webdriver/common/bidi/permissions.py @@ -20,7 +20,7 @@ from __future__ import annotations from enum import Enum -from typing import Any, Optional, Union +from typing import Any from .common import command_builder @@ -63,10 +63,10 @@ def __init__(self, websocket_connection: Any) -> None: def set_permission( self, - descriptor: Union[PermissionDescriptor, str], - state: Union[PermissionState, str], - origin: Optional[str] = None, - user_context: Optional[str] = None, + descriptor: PermissionDescriptor | str, + state: PermissionState | str, + origin: str | None = None, + user_context: str | None = None, ) -> None: """Set a permission for a given origin. diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 8e832f4a9cae9..6c2e4298a2dce 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -6,16 +6,15 @@ # WebDriver BiDi module: script from __future__ import annotations -from typing import Any, Dict, List, Optional, Union -from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class SpecialNumber: """SpecialNumber.""" @@ -786,9 +785,9 @@ def __init__(self, conn, driver=None) -> None: def add_preload_script( self, function_declaration: Any | None = None, - arguments: List[Any] | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + arguments: list[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, sandbox: Any | None = None, ): """Execute script.addPreloadScript.""" @@ -807,7 +806,7 @@ def add_preload_script( result = self._conn.execute(cmd) return result - def disown(self, handles: List[Any] | None = None, target: Any | None = None): + def disown(self, handles: list[Any] | None = None, target: Any | None = None): """Execute script.disown.""" if handles is None: raise TypeError("disown() missing required argument: {{snake_param!r}}") @@ -828,7 +827,7 @@ def call_function( function_declaration: Any | None = None, await_promise: bool | None = None, target: Any | None = None, - arguments: List[Any] | None = None, + arguments: list[Any] | None = None, result_ownership: Any | None = None, serialization_options: Any | None = None, this: Any | None = None, @@ -946,8 +945,9 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import math as _math import datetime as _datetime + import math as _math + from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -1188,8 +1188,9 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod + from selenium.webdriver.common.bidi.session import Session as _Session bidi_event = "log.entryAdded" diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index c7dd45ec824b8..fcb42a4ad86fc 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: session from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass class UserPromptHandlerType: @@ -216,9 +215,9 @@ def end(self): def subscribe( self, - events: List[Any] | None = None, - contexts: List[Any] | None = None, - user_contexts: List[Any] | None = None, + events: list[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, ): """Execute session.subscribe.""" if events is None: @@ -234,7 +233,7 @@ def subscribe( result = self._conn.execute(cmd) return result - def unsubscribe(self, events: List[Any] | None = None, subscriptions: List[Any] | None = None): + def unsubscribe(self, events: list[Any] | None = None, subscriptions: list[Any] | None = None): """Execute session.unsubscribe.""" params = { "events": events, diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 267569f782289..089cee2c4fbdf 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: storage from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass @dataclass @@ -107,7 +106,7 @@ class StorageCookie: expiry: Any | None = None @classmethod - def from_bidi_dict(cls, raw: dict) -> "StorageCookie": + def from_bidi_dict(cls, raw: dict) -> StorageCookie: """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index e007f8e4792a6..b1bc09452bc63 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: webExtension from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any + from .common import command_builder -from dataclasses import field -from typing import Generator -from dataclasses import dataclass @dataclass From d9a0593e3ad6368ad212dd63a863759af46b782b Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 11 Mar 2026 19:36:16 +0000 Subject: [PATCH 18/37] ruffs and mypy fixes --- py/generate_bidi.py | 51 +++++++++++++------ py/private/bidi_enhancements_manifest.py | 33 ++++++------ .../webdriver/common/bidi/browsing_context.py | 2 +- py/selenium/webdriver/common/bidi/common.py | 5 +- .../webdriver/common/bidi/emulation.py | 12 ++--- py/selenium/webdriver/common/bidi/input.py | 2 +- py/selenium/webdriver/common/bidi/log.py | 2 +- py/selenium/webdriver/common/bidi/network.py | 8 +-- py/selenium/webdriver/common/bidi/script.py | 2 +- py/selenium/webdriver/common/bidi/storage.py | 4 +- .../webdriver/common/bidi/webextension.py | 11 ++-- 11 files changed, 80 insertions(+), 52 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 8372d25743c08..ce29235456e48 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3 +#!/usr/bin/env python3.10 """ Generate Python WebDriver BiDi command modules from CDDL specification. @@ -43,7 +43,6 @@ from __future__ import annotations from typing import Any -from .common import command_builder """ @@ -590,17 +589,17 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code = MODULE_HEADER.format(self.name) # Collect needed imports to avoid duplicates + needs_command_builder = bool(self.commands) needs_dataclass = self.commands or self.types or self.events - needs_field = self.types needs_threading = self.events needs_callable = self.events needs_session = self.events - # Add imports if needed + # Add imports (field import will be added conditionally after code generation) + if needs_command_builder: + code += "from .common import command_builder\n" if needs_dataclass: code += "from dataclasses import dataclass\n" - if needs_field: - code += "from dataclasses import field\n" if needs_threading: code += "import threading\n" if needs_callable: @@ -954,7 +953,7 @@ def clear_event_handlers(self) -> None: # Add EVENT_CONFIGS dict if there are events if self.events: code += ( - " EVENT_CONFIGS = {}\n" # Will be populated after types are defined + " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined ) if self.name == "script": @@ -1095,6 +1094,26 @@ def clear_event_handlers(self) -> None: code += f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),\n' code += "}\n" + # Check if field() is actually used in the generated code + # If so, add the field import after the dataclass import + if "field(" in code: + # Find where to insert the field import + # It should go after "from dataclasses import dataclass" line + dataclass_import_pattern = r"from dataclasses import dataclass\n" + if re.search(dataclass_import_pattern, code): + code = re.sub( + dataclass_import_pattern, + "from dataclasses import dataclass\nfrom dataclasses import field\n", + code, + count=1 + ) + elif "from dataclasses import" not in code: + # If there's no dataclasses import yet, add field import after typing + code = code.replace( + "from typing import Any\n", + "from typing import Any\nfrom dataclasses import field\n" + ) + return code @@ -1634,12 +1653,14 @@ def generate_common_file(output_path: Path) -> None: "\n" '"""Common utilities for BiDi command construction."""\n' "\n" - "from typing import Any, Dict, Generator\n" + "from __future__ import annotations\n" + "\n" + "from typing import Any\n" "\n" "\n" "def command_builder(\n" - " method: str, params: Dict[str, Any]\n" - ") -> Generator[Dict[str, Any], Any, Any]:\n" + " method: str, params: dict[str, Any]\n" + ") -> dict[str, Any]:\n" ' """Build a BiDi command generator.\n' "\n" " Args:\n" @@ -1726,7 +1747,7 @@ def generate_permissions_file(output_path: Path) -> None: "from __future__ import annotations\n" "\n" "from enum import Enum\n" - "from typing import Any, Optional, Union\n" + "from typing import Any\n" "\n" "from .common import command_builder\n" "\n" @@ -1769,10 +1790,10 @@ def generate_permissions_file(output_path: Path) -> None: "\n" " def set_permission(\n" " self,\n" - " descriptor: Union[PermissionDescriptor, str],\n" - " state: Union[PermissionState, str],\n" - " origin: Optional[str] = None,\n" - " user_context: Optional[str] = None,\n" + " descriptor: PermissionDescriptor | str,\n" + " state: PermissionState | str,\n" + " origin: str | None = None,\n" + " user_context: str | None = None,\n" " ) -> None:\n" ' """Set a permission for a given origin.\n' "\n" diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 40647157f8535..647dd7bcfd892 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -338,7 +338,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {} + params: dict[str, Any] = {} if coordinates is not None: if isinstance(coordinates, dict): coords_dict = coordinates @@ -390,7 +390,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {"timezone": timezone} + params: dict[str, Any] = {"timezone": timezone} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -414,7 +414,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {"enabled": enabled} + params: dict[str, Any] = {"enabled": enabled} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -437,7 +437,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {"userAgent": user_agent} + params: dict[str, Any] = {"userAgent": user_agent} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -473,7 +473,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": "natural": natural.lower() if isinstance(natural, str) else natural, "type": orientation_type.lower() if isinstance(orientation_type, str) else orientation_type, } - params = {"screenOrientation": so_value} + params: dict[str, Any] = {"screenOrientation": so_value} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -506,7 +506,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": nc_value = {"type": "offline"} if offline else None else: nc_value = network_conditions - params = {"networkConditions": nc_value} + params: dict[str, Any] = {"networkConditions": nc_value} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -893,8 +893,8 @@ def from_json(self2, p): "network": { # Initialize intercepts tracking list and per-handler intercept map "extra_init_code": [ - "self.intercepts = []", - "self._handler_intercepts: dict = {}", + "self.intercepts: list[Any] = []", + "self._handler_intercepts: dict[str, Any] = {}", ], # Request class wraps a beforeRequestSent event params and provides actions "extra_dataclasses": [ @@ -908,7 +908,7 @@ def from_json(self2, p): TYPE_STRING = "string" TYPE_BASE64 = "base64" - def __init__(self, type: str, value: str) -> None: + def __init__(self, type: Any | None, value: Any | None) -> None: self.type = type self.value = value @@ -1089,7 +1089,7 @@ def _auth_callback(params): TYPE_STRING = "string" TYPE_BASE64 = "base64" - def __init__(self, type: str, value: str) -> None: + def __init__(self, type: Any | None, value: Any | None) -> None: self.type = type self.value = value @@ -1122,7 +1122,7 @@ def from_bidi_dict(cls, raw: dict) -> "StorageCookie": """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): - value = BytesValue(value_raw.get("type"), value_raw.get("value")) + value: Any = BytesValue(value_raw.get("type"), value_raw.get("value")) else: value = value_raw return cls( @@ -1379,6 +1379,7 @@ def to_bidi_dict(self) -> dict: elif archive_path is not None: extension_data = {"type": "archivePath", "path": archive_path} else: + assert base64_value is not None extension_data = {"type": "base64", "value": base64_value} params = {"extensionData": extension_data} cmd = command_builder("webExtension.install", params) @@ -1395,12 +1396,14 @@ def to_bidi_dict(self) -> dict: ValueError: If extension is not provided or is None. """ if isinstance(extension, dict): - extension = extension.get("extension") + extension_id: Any = extension.get("extension") + else: + extension_id = extension - if extension is None: + if extension_id is None: raise ValueError("extension parameter is required") - - params = {"extension": extension} + + params = {"extension": extension_id} cmd = command_builder("webExtension.uninstall", params) return self._conn.execute(cmd)''', ], diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 777005e0ce4e5..5b1a67ce93f11 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -598,7 +598,7 @@ def clear_event_handlers(self) -> None: class BrowsingContext: """WebDriver BiDi browsingContext module.""" - EVENT_CONFIGS = {} + EVENT_CONFIGS: dict[str, EventConfig] = {} def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index d7cb436a08471..168f748d5501b 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -17,13 +17,14 @@ """Common utilities for BiDi command construction.""" -from collections.abc import Generator +from __future__ import annotations + from typing import Any def command_builder( method: str, params: dict[str, Any] -) -> Generator[dict[str, Any], Any, Any]: +) -> dict[str, Any]: """Build a BiDi command generator. Args: diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 0356372c48f03..3dcf8e58881e4 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -320,7 +320,7 @@ def set_geolocation_override( contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {} + params: dict[str, Any] = {} if coordinates is not None: if isinstance(coordinates, dict): coords_dict = coordinates @@ -372,7 +372,7 @@ def set_timezone_override( contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {"timezone": timezone} + params: dict[str, Any] = {"timezone": timezone} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -396,7 +396,7 @@ def set_scripting_enabled( contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {"enabled": enabled} + params: dict[str, Any] = {"enabled": enabled} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -419,7 +419,7 @@ def set_user_agent_override( contexts: List of browsing context IDs to target. user_contexts: List of user context IDs to target. """ - params = {"userAgent": user_agent} + params: dict[str, Any] = {"userAgent": user_agent} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -455,7 +455,7 @@ def set_screen_orientation_override( "natural": natural.lower() if isinstance(natural, str) else natural, "type": orientation_type.lower() if isinstance(orientation_type, str) else orientation_type, } - params = {"screenOrientation": so_value} + params: dict[str, Any] = {"screenOrientation": so_value} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -488,7 +488,7 @@ def set_network_conditions( nc_value = {"type": "offline"} if offline else None else: nc_value = network_conditions - params = {"networkConditions": nc_value} + params: dict[str, Any] = {"networkConditions": nc_value} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 7e76cb831543f..1d4730534f16d 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -362,7 +362,7 @@ def clear_event_handlers(self) -> None: class Input: """WebDriver BiDi input module.""" - EVENT_CONFIGS = {} + EVENT_CONFIGS: dict[str, EventConfig] = {} def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index fd712b7c9a8ab..488f0740a40b5 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -256,7 +256,7 @@ def clear_event_handlers(self) -> None: class Log: """WebDriver BiDi log module.""" - EVENT_CONFIGS = {} + EVENT_CONFIGS: dict[str, EventConfig] = {} def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 74951031c597f..30de3306ff001 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -368,7 +368,7 @@ class BytesValue: TYPE_STRING = "string" TYPE_BASE64 = "base64" - def __init__(self, type: str, value: str) -> None: + def __init__(self, type: Any | None, value: Any | None) -> None: self.type = type self.value = value @@ -555,12 +555,12 @@ def clear_event_handlers(self) -> None: class Network: """WebDriver BiDi network module.""" - EVENT_CONFIGS = {} + EVENT_CONFIGS: dict[str, EventConfig] = {} def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - self.intercepts = [] - self._handler_intercepts: dict = {} + self.intercepts: list[Any] = [] + self._handler_intercepts: dict[str, Any] = {} def add_data_collector( self, diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 6c2e4298a2dce..221b5963e8ec1 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -776,7 +776,7 @@ def clear_event_handlers(self) -> None: class Script: """WebDriver BiDi script module.""" - EVENT_CONFIGS = {} + EVENT_CONFIGS: dict[str, EventConfig] = {} def __init__(self, conn, driver=None) -> None: self._conn = conn self._driver = driver diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 089cee2c4fbdf..a2606526f3856 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -76,7 +76,7 @@ class BytesValue: TYPE_STRING = "string" TYPE_BASE64 = "base64" - def __init__(self, type: str, value: str) -> None: + def __init__(self, type: Any | None, value: Any | None) -> None: self.type = type self.value = value @@ -110,7 +110,7 @@ def from_bidi_dict(cls, raw: dict) -> StorageCookie: """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): - value = BytesValue(value_raw.get("type"), value_raw.get("value")) + value: Any = BytesValue(value_raw.get("type"), value_raw.get("value")) else: value = value_raw return cls( diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index b1bc09452bc63..70a21d7fd5e5e 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -100,6 +100,7 @@ def install( elif archive_path is not None: extension_data = {"type": "archivePath", "path": archive_path} else: + assert base64_value is not None extension_data = {"type": "base64", "value": base64_value} params = {"extensionData": extension_data} cmd = command_builder("webExtension.install", params) @@ -116,11 +117,13 @@ def uninstall(self, extension: str | dict): ValueError: If extension is not provided or is None. """ if isinstance(extension, dict): - extension = extension.get("extension") + extension_id: Any = extension.get("extension") + else: + extension_id = extension - if extension is None: + if extension_id is None: raise ValueError("extension parameter is required") - - params = {"extension": extension} + + params = {"extension": extension_id} cmd = command_builder("webExtension.uninstall", params) return self._conn.execute(cmd) From aa893f5fad49b03bf1456d3de1719e8394842755 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 11 Mar 2026 19:43:50 +0000 Subject: [PATCH 19/37] fix linting --- py/generate_bidi.py | 12 +++++------- py/private/bidi_enhancements_manifest.py | 2 +- py/selenium/webdriver/common/bidi/common.py | 3 ++- py/selenium/webdriver/common/bidi/webextension.py | 2 +- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index ce29235456e48..de41855954651 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -952,9 +952,7 @@ def clear_event_handlers(self) -> None: # Add EVENT_CONFIGS dict if there are events if self.events: - code += ( - " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined - ) + code += " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined if self.name == "script": code += " def __init__(self, conn, driver=None) -> None:\n" @@ -1105,13 +1103,13 @@ def clear_event_handlers(self) -> None: dataclass_import_pattern, "from dataclasses import dataclass\nfrom dataclasses import field\n", code, - count=1 + count=1, ) elif "from dataclasses import" not in code: # If there's no dataclasses import yet, add field import after typing code = code.replace( "from typing import Any\n", - "from typing import Any\nfrom dataclasses import field\n" + "from typing import Any\nfrom dataclasses import field\n", ) return code @@ -1655,12 +1653,12 @@ def generate_common_file(output_path: Path) -> None: "\n" "from __future__ import annotations\n" "\n" - "from typing import Any\n" + "from typing import Any, Generator\n" "\n" "\n" "def command_builder(\n" " method: str, params: dict[str, Any]\n" - ") -> dict[str, Any]:\n" + ") -> Generator[dict[str, Any], Any, Any]:\n" ' """Build a BiDi command generator.\n' "\n" " Args:\n" diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 647dd7bcfd892..d9923531b0293 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -1402,7 +1402,7 @@ def to_bidi_dict(self) -> dict: if extension_id is None: raise ValueError("extension parameter is required") - + params = {"extension": extension_id} cmd = command_builder("webExtension.uninstall", params) return self._conn.execute(cmd)''', diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index 168f748d5501b..59e8afd93ab2e 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -19,12 +19,13 @@ from __future__ import annotations +from collections.abc import Generator from typing import Any def command_builder( method: str, params: dict[str, Any] -) -> dict[str, Any]: +) -> Generator[dict[str, Any], Any, Any]: """Build a BiDi command generator. Args: diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 70a21d7fd5e5e..b5881d01e0bea 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -123,7 +123,7 @@ def uninstall(self, extension: str | dict): if extension_id is None: raise ValueError("extension parameter is required") - + params = {"extension": extension_id} cmd = command_builder("webExtension.uninstall", params) return self._conn.execute(cmd) From faf8c70ccfe28eaa2192b9648927ddd6da7789a9 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Fri, 13 Mar 2026 12:12:19 +0000 Subject: [PATCH 20/37] Fix auth tests --- py/generate_bidi.py | 12 +++++++----- py/private/bidi_enhancements_manifest.py | 20 +++++++++++++++++--- py/selenium/webdriver/common/bidi/network.py | 18 ++++++++++++++++-- 3 files changed, 40 insertions(+), 10 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index de41855954651..ce29235456e48 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -952,7 +952,9 @@ def clear_event_handlers(self) -> None: # Add EVENT_CONFIGS dict if there are events if self.events: - code += " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined + code += ( + " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined + ) if self.name == "script": code += " def __init__(self, conn, driver=None) -> None:\n" @@ -1103,13 +1105,13 @@ def clear_event_handlers(self) -> None: dataclass_import_pattern, "from dataclasses import dataclass\nfrom dataclasses import field\n", code, - count=1, + count=1 ) elif "from dataclasses import" not in code: # If there's no dataclasses import yet, add field import after typing code = code.replace( "from typing import Any\n", - "from typing import Any\nfrom dataclasses import field\n", + "from typing import Any\nfrom dataclasses import field\n" ) return code @@ -1653,12 +1655,12 @@ def generate_common_file(output_path: Path) -> None: "\n" "from __future__ import annotations\n" "\n" - "from typing import Any, Generator\n" + "from typing import Any\n" "\n" "\n" "def command_builder(\n" " method: str, params: dict[str, Any]\n" - ") -> Generator[dict[str, Any], Any, Any]:\n" + ") -> dict[str, Any]:\n" ' """Build a BiDi command generator.\n' "\n" " Args:\n" diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index d9923531b0293..d617f7468c034 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -1033,6 +1033,10 @@ def _request_callback(params): """ from selenium.webdriver.common.bidi.common import command_builder as _cb + # Set up network intercept for authRequired phase + intercept_result = self._add_intercept(phases=["authRequired"]) + intercept_id = intercept_result.get("intercept") if intercept_result else None + def _auth_callback(params): raw = ( params @@ -1060,10 +1064,20 @@ def _auth_callback(params): ) ) - return self.add_event_handler("auth_required", _auth_callback)''', + callback_id = self.add_event_handler("auth_required", _auth_callback) + if intercept_id: + self._handler_intercepts[callback_id] = intercept_id + return callback_id''', ''' def remove_auth_handler(self, callback_id): - """Remove an auth handler by callback ID.""" - self.remove_event_handler("auth_required", callback_id)''', + """Remove an auth handler by callback ID and its associated network intercept. + + Args: + callback_id: The handler ID returned by add_auth_handler. + """ + self.remove_event_handler("auth_required", callback_id) + intercept_id = self._handler_intercepts.pop(callback_id, None) + if intercept_id: + self._remove_intercept(intercept_id)''', ], }, "storage": { diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 30de3306ff001..1dd2f5a476049 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -975,6 +975,10 @@ def add_auth_handler(self, username, password): """ from selenium.webdriver.common.bidi.common import command_builder as _cb + # Set up network intercept for authRequired phase + intercept_result = self._add_intercept(phases=["authRequired"]) + intercept_id = intercept_result.get("intercept") if intercept_result else None + def _auth_callback(params): raw = ( params @@ -1002,10 +1006,20 @@ def _auth_callback(params): ) ) - return self.add_event_handler("auth_required", _auth_callback) + callback_id = self.add_event_handler("auth_required", _auth_callback) + if intercept_id: + self._handler_intercepts[callback_id] = intercept_id + return callback_id def remove_auth_handler(self, callback_id): - """Remove an auth handler by callback ID.""" + """Remove an auth handler by callback ID and its associated network intercept. + + Args: + callback_id: The handler ID returned by add_auth_handler. + """ self.remove_event_handler("auth_required", callback_id) + intercept_id = self._handler_intercepts.pop(callback_id, None) + if intercept_id: + self._remove_intercept(intercept_id) def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: """Add an event handler. From 6b277dd10b907635054c4ad45d9abbae8b403826 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Fri, 13 Mar 2026 12:30:15 +0000 Subject: [PATCH 21/37] sort spacing --- py/generate_bidi.py | 3 ++- py/private/bidi_enhancements_manifest.py | 10 ++++++++++ py/selenium/webdriver/common/bidi/browser.py | 4 ++-- .../webdriver/common/bidi/browsing_context.py | 13 ++++++------- py/selenium/webdriver/common/bidi/emulation.py | 4 ++-- py/selenium/webdriver/common/bidi/input.py | 11 +++++------ py/selenium/webdriver/common/bidi/log.py | 9 ++++----- py/selenium/webdriver/common/bidi/network.py | 9 ++++----- py/selenium/webdriver/common/bidi/script.py | 15 ++++++--------- py/selenium/webdriver/common/bidi/session.py | 4 ++-- py/selenium/webdriver/common/bidi/storage.py | 6 +++--- py/selenium/webdriver/common/bidi/webextension.py | 4 ++-- 12 files changed, 48 insertions(+), 44 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index ce29235456e48..32d19ec83cec9 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -1655,12 +1655,13 @@ def generate_common_file(output_path: Path) -> None: "\n" "from __future__ import annotations\n" "\n" + "from collections.abc import Generator\n" "from typing import Any\n" "\n" "\n" "def command_builder(\n" " method: str, params: dict[str, Any]\n" - ") -> dict[str, Any]:\n" + ") -> Generator[dict[str, Any], Any, Any]:\n" ' """Build a BiDi command generator.\n' "\n" " Args:\n" diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index d617f7468c034..dcf464f425e9d 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -37,6 +37,7 @@ # ============================================================================ ENHANCEMENTS: dict[str, dict[str, Any]] = { + "browser": { # Dataclass custom methods "__dataclass_methods__": { @@ -170,6 +171,7 @@ return self._conn.execute(cmd)''', ], }, + "browsingContext": { # Method enhancements "create": { @@ -254,6 +256,7 @@ def from_json(cls, params: dict) -> "DownloadEndParams": ], # Download events are now in the CDDL spec, so no extra_events needed }, + "log": { # Make LogLevel an alias for Level so existing code using LogLevel works "aliases": {"LogLevel": "Level"}, @@ -317,6 +320,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": "entry_added": "Entry", }, }, + "emulation": { "extra_methods": [ ''' def set_geolocation_override( @@ -515,6 +519,7 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": return self._conn.execute(cmd)''', ], }, + "script": { "extra_methods": [ ''' def execute(self, function_declaration: str, *args, context_id: str | None = None) -> Any: @@ -890,6 +895,7 @@ def from_json(self2, p): self._unsubscribe_log_entry(callback_id)''', ], }, + "network": { # Initialize intercepts tracking list and per-handler intercept map "extra_init_code": [ @@ -1080,6 +1086,7 @@ def _auth_callback(params): self._remove_intercept(intercept_id)''', ], }, + "storage": { # Exclude auto-generated dataclasses that need custom to_bidi_dict() # for JSON-over-WebSocket serialization, or custom constructors. @@ -1319,6 +1326,7 @@ def to_bidi_dict(self) -> dict: return result''', ], }, + "session": { # Override UserPromptHandler to add to_bidi_dict() for JSON serialization "exclude_types": ["UserPromptHandler"], @@ -1352,6 +1360,7 @@ def to_bidi_dict(self) -> dict: return result''', ], }, + "webExtension": { # Suppress the raw generated stubs; hand-written versions follow below "exclude_methods": ["install", "uninstall"], @@ -1422,6 +1431,7 @@ def to_bidi_dict(self) -> dict: return self._conn.execute(cmd)''', ], }, + "input": { # FileDialogInfo needs from_json for event deserialization "exclude_types": ["FileDialogInfo", "PointerMoveAction", "PointerDownAction"], diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index a8fb60c98178d..a4ec770fbb135 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: browser from __future__ import annotations -from dataclasses import dataclass, field from typing import Any - from .common import command_builder +from dataclasses import dataclass +from dataclasses import field def transform_download_params( diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 5b1a67ce93f11..c5489ce865180 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -6,15 +6,14 @@ # WebDriver BiDi module: browsingContext from __future__ import annotations +from typing import Any +from .common import command_builder +from dataclasses import dataclass +from dataclasses import field import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class ReadinessState: """ReadinessState.""" @@ -375,10 +374,10 @@ class DownloadParams: class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" - download_params: DownloadParams | None = None + download_params: "DownloadParams | None" = None @classmethod - def from_json(cls, params: dict) -> DownloadEndParams: + def from_json(cls, params: dict) -> "DownloadEndParams": """Deserialize from BiDi wire-level params dict.""" dp = DownloadParams( status=params.get("status"), diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 3dcf8e58881e4..03347a0a85c04 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: emulation from __future__ import annotations -from dataclasses import dataclass, field from typing import Any - from .common import command_builder +from dataclasses import dataclass +from dataclasses import field class ForcedColorsModeTheme: diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 1d4730534f16d..44fd3c82c3407 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -6,15 +6,14 @@ # WebDriver BiDi module: input from __future__ import annotations +from typing import Any +from .common import command_builder +from dataclasses import dataclass +from dataclasses import field import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class PointerType: """PointerType.""" @@ -174,7 +173,7 @@ class FileDialogInfo: multiple: bool | None = None @classmethod - def from_json(cls, params: dict) -> FileDialogInfo: + def from_json(cls, params: dict) -> "FileDialogInfo": """Deserialize event params into FileDialogInfo.""" return cls( context=params.get("context"), diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 488f0740a40b5..3c6a95d74f6d1 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -6,11 +6,10 @@ # WebDriver BiDi module: log from __future__ import annotations +from typing import Any +from dataclasses import dataclass import threading from collections.abc import Callable -from dataclasses import dataclass -from typing import Any - from selenium.webdriver.common.bidi.session import Session @@ -57,7 +56,7 @@ class ConsoleLogEntry: stack_trace: Any | None = None @classmethod - def from_json(cls, params: dict) -> ConsoleLogEntry: + def from_json(cls, params: dict) -> "ConsoleLogEntry": """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -82,7 +81,7 @@ class JavascriptLogEntry: stacktrace: Any | None = None @classmethod - def from_json(cls, params: dict) -> JavascriptLogEntry: + def from_json(cls, params: dict) -> "JavascriptLogEntry": """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 1dd2f5a476049..6a0edf0b2b5e7 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -6,15 +6,14 @@ # WebDriver BiDi module: network from __future__ import annotations +from typing import Any +from .common import command_builder +from dataclasses import dataclass +from dataclasses import field import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class SameSite: """SameSite.""" diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 221b5963e8ec1..5a7d2792a1221 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -6,15 +6,14 @@ # WebDriver BiDi module: script from __future__ import annotations +from typing import Any +from .common import command_builder +from dataclasses import dataclass +from dataclasses import field import threading from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Any - from selenium.webdriver.common.bidi.session import Session -from .common import command_builder - class SpecialNumber: """SpecialNumber.""" @@ -945,9 +944,8 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import datetime as _datetime import math as _math - + import datetime as _datetime from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -1188,9 +1186,8 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - - from selenium.webdriver.common.bidi import log as _log_mod from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod bidi_event = "log.entryAdded" diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index fcb42a4ad86fc..177421eca5ee8 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: session from __future__ import annotations -from dataclasses import dataclass, field from typing import Any - from .common import command_builder +from dataclasses import dataclass +from dataclasses import field class UserPromptHandlerType: diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index a2606526f3856..fef35106c33b0 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: storage from __future__ import annotations -from dataclasses import dataclass, field from typing import Any - from .common import command_builder +from dataclasses import dataclass +from dataclasses import field @dataclass @@ -106,7 +106,7 @@ class StorageCookie: expiry: Any | None = None @classmethod - def from_bidi_dict(cls, raw: dict) -> StorageCookie: + def from_bidi_dict(cls, raw: dict) -> "StorageCookie": """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index b5881d01e0bea..1c5b342c070d5 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: webExtension from __future__ import annotations -from dataclasses import dataclass, field from typing import Any - from .common import command_builder +from dataclasses import dataclass +from dataclasses import field @dataclass From 3ebb065853e6d520dcfd8f11e2bbee81ead15682 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Tue, 17 Mar 2026 09:12:20 +0000 Subject: [PATCH 22/37] linting --- py/selenium/webdriver/common/bidi/browser.py | 4 ++-- .../webdriver/common/bidi/browsing_context.py | 13 +++++++------ py/selenium/webdriver/common/bidi/emulation.py | 4 ++-- py/selenium/webdriver/common/bidi/input.py | 11 ++++++----- py/selenium/webdriver/common/bidi/log.py | 9 +++++---- py/selenium/webdriver/common/bidi/network.py | 9 +++++---- py/selenium/webdriver/common/bidi/script.py | 15 +++++++++------ py/selenium/webdriver/common/bidi/session.py | 4 ++-- py/selenium/webdriver/common/bidi/storage.py | 6 +++--- py/selenium/webdriver/common/bidi/webextension.py | 4 ++-- 10 files changed, 43 insertions(+), 36 deletions(-) diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index a4ec770fbb135..a8fb60c98178d 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: browser from __future__ import annotations +from dataclasses import dataclass, field from typing import Any + from .common import command_builder -from dataclasses import dataclass -from dataclasses import field def transform_download_params( diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index c5489ce865180..5b1a67ce93f11 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -6,14 +6,15 @@ # WebDriver BiDi module: browsingContext from __future__ import annotations -from typing import Any -from .common import command_builder -from dataclasses import dataclass -from dataclasses import field import threading from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class ReadinessState: """ReadinessState.""" @@ -374,10 +375,10 @@ class DownloadParams: class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" - download_params: "DownloadParams | None" = None + download_params: DownloadParams | None = None @classmethod - def from_json(cls, params: dict) -> "DownloadEndParams": + def from_json(cls, params: dict) -> DownloadEndParams: """Deserialize from BiDi wire-level params dict.""" dp = DownloadParams( status=params.get("status"), diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 03347a0a85c04..3dcf8e58881e4 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: emulation from __future__ import annotations +from dataclasses import dataclass, field from typing import Any + from .common import command_builder -from dataclasses import dataclass -from dataclasses import field class ForcedColorsModeTheme: diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 44fd3c82c3407..1d4730534f16d 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -6,14 +6,15 @@ # WebDriver BiDi module: input from __future__ import annotations -from typing import Any -from .common import command_builder -from dataclasses import dataclass -from dataclasses import field import threading from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class PointerType: """PointerType.""" @@ -173,7 +174,7 @@ class FileDialogInfo: multiple: bool | None = None @classmethod - def from_json(cls, params: dict) -> "FileDialogInfo": + def from_json(cls, params: dict) -> FileDialogInfo: """Deserialize event params into FileDialogInfo.""" return cls( context=params.get("context"), diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 3c6a95d74f6d1..488f0740a40b5 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -6,10 +6,11 @@ # WebDriver BiDi module: log from __future__ import annotations -from typing import Any -from dataclasses import dataclass import threading from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + from selenium.webdriver.common.bidi.session import Session @@ -56,7 +57,7 @@ class ConsoleLogEntry: stack_trace: Any | None = None @classmethod - def from_json(cls, params: dict) -> "ConsoleLogEntry": + def from_json(cls, params: dict) -> ConsoleLogEntry: """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -81,7 +82,7 @@ class JavascriptLogEntry: stacktrace: Any | None = None @classmethod - def from_json(cls, params: dict) -> "JavascriptLogEntry": + def from_json(cls, params: dict) -> JavascriptLogEntry: """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 6a0edf0b2b5e7..1dd2f5a476049 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -6,14 +6,15 @@ # WebDriver BiDi module: network from __future__ import annotations -from typing import Any -from .common import command_builder -from dataclasses import dataclass -from dataclasses import field import threading from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class SameSite: """SameSite.""" diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 5a7d2792a1221..221b5963e8ec1 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -6,14 +6,15 @@ # WebDriver BiDi module: script from __future__ import annotations -from typing import Any -from .common import command_builder -from dataclasses import dataclass -from dataclasses import field import threading from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + from selenium.webdriver.common.bidi.session import Session +from .common import command_builder + class SpecialNumber: """SpecialNumber.""" @@ -944,8 +945,9 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import math as _math import datetime as _datetime + import math as _math + from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -1186,8 +1188,9 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod + from selenium.webdriver.common.bidi.session import Session as _Session bidi_event = "log.entryAdded" diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index 177421eca5ee8..fcb42a4ad86fc 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: session from __future__ import annotations +from dataclasses import dataclass, field from typing import Any + from .common import command_builder -from dataclasses import dataclass -from dataclasses import field class UserPromptHandlerType: diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index fef35106c33b0..a2606526f3856 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: storage from __future__ import annotations +from dataclasses import dataclass, field from typing import Any + from .common import command_builder -from dataclasses import dataclass -from dataclasses import field @dataclass @@ -106,7 +106,7 @@ class StorageCookie: expiry: Any | None = None @classmethod - def from_bidi_dict(cls, raw: dict) -> "StorageCookie": + def from_bidi_dict(cls, raw: dict) -> StorageCookie: """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 1c5b342c070d5..b5881d01e0bea 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -6,10 +6,10 @@ # WebDriver BiDi module: webExtension from __future__ import annotations +from dataclasses import dataclass, field from typing import Any + from .common import command_builder -from dataclasses import dataclass -from dataclasses import field @dataclass From eddf58c44924ec16729ea03540a893672ba1e3ed Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Tue, 17 Mar 2026 12:24:24 +0000 Subject: [PATCH 23/37] Loosen viewport size check as window managers don't guarantee putting the window to the size we want --- .../selenium/webdriver/common/bidi_browsing_context_tests.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py index c5c9284418ba2..f26472e7a8d54 100644 --- a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py +++ b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py @@ -423,8 +423,9 @@ def test_set_viewport_back_to_default(driver, pages): viewport_size = driver.execute_script("return [window.innerWidth, window.innerHeight];") device_pixel_ratio = driver.execute_script("return window.devicePixelRatio") - assert viewport_size[0] == default_viewport_size[0] - assert viewport_size[1] == default_viewport_size[1] + # Allow some tolerance since some window managers might not put it to the exact value + assert abs(viewport_size[0] - default_viewport_size[0]) <= 5 + assert abs(viewport_size[1] - default_viewport_size[1]) <= 5 assert device_pixel_ratio == default_device_pixel_ratio finally: driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) From 763e881b1d18dbea9591b029dfee7ea31493d968 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Fri, 27 Mar 2026 09:44:29 +0000 Subject: [PATCH 24/37] handle comments --- py/generate_bidi.py | 220 +++-------------- py/private/bidi_enhancements_manifest.py | 124 ++++++++-- py/selenium/webdriver/common/bidi/__init__.py | 20 +- py/selenium/webdriver/common/bidi/browser.py | 24 +- .../webdriver/common/bidi/browsing_context.py | 185 ++------------ py/selenium/webdriver/common/bidi/common.py | 7 +- .../webdriver/common/bidi/emulation.py | 33 +-- py/selenium/webdriver/common/bidi/input.py | 171 +------------ py/selenium/webdriver/common/bidi/log.py | 156 +----------- py/selenium/webdriver/common/bidi/network.py | 232 +++--------------- .../webdriver/common/bidi/permissions.py | 2 +- py/selenium/webdriver/common/bidi/py.typed | 0 py/selenium/webdriver/common/bidi/script.py | 185 ++------------ py/selenium/webdriver/common/bidi/session.py | 10 +- py/selenium/webdriver/common/bidi/storage.py | 44 +++- .../webdriver/common/bidi/webextension.py | 14 +- py/selenium/webdriver/remote/webdriver.py | 12 +- .../webdriver/remote/websocket_connection.py | 5 +- .../webdriver/common/bidi_browser_tests.py | 2 + rake_tasks/python.rake | 13 + 20 files changed, 367 insertions(+), 1092 deletions(-) mode change 100755 => 100644 py/selenium/webdriver/common/bidi/py.typed diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 32d19ec83cec9..745c0f00ed890 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -42,7 +42,6 @@ # WebDriver BiDi module: {{}} from __future__ import annotations -from typing import Any """ @@ -198,8 +197,9 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: if param_name in self.required_params: body += f" if {snake_param} is None:\n" msg = f"{method_snake}() missing required argument:" + error_message = f"{msg} {snake_param!r}" body += ( - f' raise TypeError("{msg} {{{{snake_param!r}}}}")\n' + f" raise TypeError({error_message!r})\n" ) body += "\n" @@ -591,23 +591,32 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: # Collect needed imports to avoid duplicates needs_command_builder = bool(self.commands) needs_dataclass = self.commands or self.types or self.events - needs_threading = self.events needs_callable = self.events - needs_session = self.events + + stdlib_imports = [] + local_imports = [] # Add imports (field import will be added conditionally after code generation) - if needs_command_builder: - code += "from .common import command_builder\n" - if needs_dataclass: - code += "from dataclasses import dataclass\n" - if needs_threading: - code += "import threading\n" if needs_callable: - code += "from collections.abc import Callable\n" - if needs_session: - code += "from selenium.webdriver.common.bidi.session import Session\n" + stdlib_imports.append("from collections.abc import Callable") + if needs_dataclass: + stdlib_imports.append("from dataclasses import dataclass") + stdlib_imports.append("from typing import Any") + + if needs_command_builder: + local_imports.append( + "from selenium.webdriver.common.bidi.common import command_builder" + ) + if self.events: + local_imports.append( + "from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager" + ) + + code += "\n".join(stdlib_imports) + "\n" + if local_imports: + code += "\n" + "\n".join(local_imports) + "\n" - code += "\n\n" + code += "\n" # Add helper function definitions from enhancements # Collect all referenced helper functions (validate, transform) @@ -784,165 +793,11 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ code += "\n\n" - # Generate EventConfig and _EventManager for modules with events - if self.events: - # Generate EventConfig dataclass - code += """@dataclass -class EventConfig: - \"\"\"Configuration for a BiDi event.\"\"\" - event_key: str - bidi_event: str - event_class: type - - -""" - - # Generate _EventManager class - code += """class _EventWrapper: - \"\"\"Wrapper to provide event_class attribute for WebSocketConnection callbacks.\"\"\" - def __init__(self, bidi_event: str, event_class: type): - self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class - self._python_class = event_class # Keep reference to Python dataclass for deserialization - - def from_json(self, params: dict) -> Any: - \"\"\"Deserialize event params into the wrapped Python dataclass. - - Args: - params: Raw BiDi event params with camelCase keys. - - Returns: - An instance of the dataclass, or the raw dict on failure. - \"\"\" - if self._python_class is None or self._python_class is dict: - return params - try: - # Delegate to a classmethod from_json if the class defines one - if hasattr(self._python_class, \"from_json\") and callable( - self._python_class.from_json - ): - return self._python_class.from_json(params) - import dataclasses as dc - - snake_params = {self._camel_to_snake(k): v for k, v in params.items()} - if dc.is_dataclass(self._python_class): - valid_fields = {f.name for f in dc.fields(self._python_class)} - filtered = {k: v for k, v in snake_params.items() if k in valid_fields} - return self._python_class(**filtered) - return self._python_class(**snake_params) - except Exception: - return params - - @staticmethod - def _camel_to_snake(name: str) -> str: - result = [name[0].lower()] - for char in name[1:]: - if char.isupper(): - result.extend([\"_\", char.lower()]) - else: - result.append(char) - return \"\".join(result) - - -class _EventManager: - \"\"\"Manages event subscriptions and callbacks.\"\"\" - - def __init__(self, conn, event_configs: dict[str, EventConfig]): - self.conn = conn - self.event_configs = event_configs - self.subscriptions: dict = {} - self._event_wrappers = {} # Cache of _EventWrapper objects - self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} - self._available_events = ", ".join(sorted(event_configs.keys())) - self._subscription_lock = threading.Lock() - - # Create event wrappers for each event - for config in event_configs.values(): - wrapper = _EventWrapper(config.bidi_event, config.event_class) - self._event_wrappers[config.bidi_event] = wrapper - - def validate_event(self, event: str) -> EventConfig: - event_config = self.event_configs.get(event) - if not event_config: - raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") - return event_config - - def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: - \"\"\"Subscribe to a BiDi event if not already subscribed.\"\"\" - with self._subscription_lock: - if bidi_event not in self.subscriptions: - session = Session(self.conn) - result = session.subscribe([bidi_event], contexts=contexts) - sub_id = ( - result.get(\"subscription\") if isinstance(result, dict) else None - ) - self.subscriptions[bidi_event] = { - \"callbacks\": [], - \"subscription_id\": sub_id, - } - - def unsubscribe_from_event(self, bidi_event: str) -> None: - \"\"\"Unsubscribe from a BiDi event if no more callbacks exist.\"\"\" - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry is not None and not entry[\"callbacks\"]: - session = Session(self.conn) - sub_id = entry.get(\"subscription_id\") - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - del self.subscriptions[bidi_event] - - def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - self.subscriptions[bidi_event][\"callbacks\"].append(callback_id) - - def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry and callback_id in entry[\"callbacks\"]: - entry[\"callbacks\"].remove(callback_id) - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - event_config = self.validate_event(event) - # Use the event wrapper for add_callback - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - callback_id = self.conn.add_callback(event_wrapper, callback) - self.subscribe_to_event(event_config.bidi_event, contexts) - self.add_callback_to_tracking(event_config.bidi_event, callback_id) - return callback_id - - def remove_event_handler(self, event: str, callback_id: int) -> None: - event_config = self.validate_event(event) - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - self.conn.remove_callback(event_wrapper, callback_id) - self.remove_callback_from_tracking(event_config.bidi_event, callback_id) - self.unsubscribe_from_event(event_config.bidi_event) - - def clear_event_handlers(self) -> None: - \"\"\"Clear all event handlers.\"\"\" - with self._subscription_lock: - if not self.subscriptions: - return - session = Session(self.conn) - for bidi_event, entry in list(self.subscriptions.items()): - event_wrapper = self._event_wrappers.get(bidi_event) - callbacks = entry[\"callbacks\"] if isinstance(entry, dict) else entry - if event_wrapper: - for callback_id in callbacks: - self.conn.remove_callback(event_wrapper, callback_id) - sub_id = ( - entry.get(\"subscription_id\") if isinstance(entry, dict) else None - ) - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - self.subscriptions.clear() - - -""" - code += "\n\n" + # EventConfig, _EventWrapper, and _EventManager are imported from + # selenium.webdriver.common.bidi._event_manager (see import section above) + # rather than being duplicated inline in every generated module. + if False: # placeholder to preserve indentation structure + pass # Generate class # Convert module name (camelCase or snake_case) to proper class name (PascalCase) @@ -1103,15 +958,15 @@ def clear_event_handlers(self) -> None: if re.search(dataclass_import_pattern, code): code = re.sub( dataclass_import_pattern, - "from dataclasses import dataclass\nfrom dataclasses import field\n", + "from dataclasses import dataclass, field\n", code, - count=1 + count=1, ) elif "from dataclasses import" not in code: # If there's no dataclasses import yet, add field import after typing code = code.replace( "from typing import Any\n", - "from typing import Any\nfrom dataclasses import field\n" + "from dataclasses import field\nfrom typing import Any\n", ) return code @@ -1615,7 +1470,9 @@ def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> Non for module_name in sorted(modules.keys()): class_name = module_name_to_class_name(module_name) filename = module_name_to_filename(module_name) - code += f"from .{filename} import {class_name}\n" + code += ( + f"from selenium.webdriver.common.bidi.{filename} import {class_name}\n" + ) code += "\n__all__ = [\n" for module_name in sorted(modules.keys()): @@ -1660,13 +1517,14 @@ def generate_common_file(output_path: Path) -> None: "\n" "\n" "def command_builder(\n" - " method: str, params: dict[str, Any]\n" + " method: str, params: dict[str, Any] | None = None\n" ") -> Generator[dict[str, Any], Any, Any]:\n" ' """Build a BiDi command generator.\n' "\n" " Args:\n" ' method: The BiDi method name (e.g., "session.status", "browser.close")\n' - " params: The parameters for the command\n" + " params: The parameters for the command. If omitted, an empty\n" + " dictionary is sent.\n" "\n" " Yields:\n" " A dictionary representing the BiDi command\n" @@ -1674,6 +1532,8 @@ def generate_common_file(output_path: Path) -> None: " Returns:\n" " The result from the BiDi command execution\n" ' """\n' + " if params is None:\n" + " params = {}\n" ' result = yield {"method": method, "params": params}\n' " return result\n" ) @@ -1750,7 +1610,7 @@ def generate_permissions_file(output_path: Path) -> None: "from enum import Enum\n" "from typing import Any\n" "\n" - "from .common import command_builder\n" + "from selenium.webdriver.common.bidi.common import command_builder\n" "\n" '_VALID_PERMISSION_STATES = {"granted", "denied", "prompt"}\n' "\n" diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index dcf464f425e9d..f8ade8b9b3ad8 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -86,7 +86,7 @@ # convenience NORMAL constant. In the BiDi spec "normal" is the state # represented by ClientWindowRectState, but exposing it here keeps the # Python API consistent with the old ClientWindowState enum. - "exclude_types": ["ClientWindowNamedState"], + "exclude_types": ["ClientWindowNamedState", "SetClientWindowStateParameters"], "extra_dataclasses": [ '''class ClientWindowNamedState: """Named states for a browser client window.""" @@ -95,6 +95,18 @@ MAXIMIZED = "maximized" MINIMIZED = "minimized" NORMAL = "normal"''', + '''@dataclass +class SetClientWindowStateParameters: + """SetClientWindowStateParameters. + + The ``state`` field is required and must be either a named-state string + (e.g. ``ClientWindowNamedState.MAXIMIZED``) or a + :class:`ClientWindowRectState` instance. ``client_window`` is the ID of + the window to affect. + """ + + client_window: Any | None = None + state: Any | None = None''', ], # Override the generator-produced set_download_behavior so that # downloadBehavior is never stripped by the generic None filter. @@ -239,10 +251,10 @@ class DownloadParams: class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" - download_params: "DownloadParams | None" = None + download_params: DownloadParams | None = None @classmethod - def from_json(cls, params: dict) -> "DownloadEndParams": + def from_json(cls, params: dict) -> DownloadEndParams: """Deserialize from BiDi wire-level params dict.""" dp = DownloadParams( status=params.get("status"), @@ -277,7 +289,7 @@ class ConsoleLogEntry: stack_trace: Any | None = None @classmethod - def from_json(cls, params: dict) -> "ConsoleLogEntry": + def from_json(cls, params: dict) -> ConsoleLogEntry: """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -301,7 +313,7 @@ class JavascriptLogEntry: stacktrace: Any | None = None @classmethod - def from_json(cls, params: dict) -> "JavascriptLogEntry": + def from_json(cls, params: dict) -> JavascriptLogEntry: """Deserialize from BiDi params dict.""" return cls( type_=params.get("type"), @@ -322,6 +334,20 @@ def from_json(cls, params: dict) -> "JavascriptLogEntry": }, "emulation": { + "exclude_types": ["setNetworkConditionsParameters"], + "extra_dataclasses": [ + '''@dataclass +class SetNetworkConditionsParameters: + """SetNetworkConditionsParameters.""" + + network_conditions: Any | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) + + +# Backward-compatible alias for existing imports +setNetworkConditionsParameters = SetNetworkConditionsParameters''', + ], "extra_methods": [ ''' def set_geolocation_override( self, @@ -897,6 +923,7 @@ def from_json(self2, p): }, "network": { + "exclude_types": ["disownDataParameters"], # Initialize intercepts tracking list and per-handler intercept map "extra_init_code": [ "self.intercepts: list[Any] = []", @@ -904,6 +931,17 @@ def from_json(self2, p): ], # Request class wraps a beforeRequestSent event params and provides actions "extra_dataclasses": [ + '''@dataclass +class DisownDataParameters: + """DisownDataParameters.""" + + data_type: Any | None = None + collector: Any | None = None + request: Any | None = None + + +# Backward-compatible alias for existing imports +disownDataParameters = DisownDataParameters''', '''class BytesValue: """A string or base64-encoded bytes value used in cookie operations. @@ -1115,7 +1153,11 @@ def __init__(self, type: Any | None, value: Any | None) -> None: self.value = value def to_bidi_dict(self) -> dict: - return {"type": self.type, "value": self.value}''', + return {"type": self.type, "value": self.value} + + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict()''', '''class SameSite: """SameSite cookie attribute values.""" @@ -1139,7 +1181,7 @@ class StorageCookie: expiry: Any | None = None @classmethod - def from_bidi_dict(cls, raw: dict) -> "StorageCookie": + def from_bidi_dict(cls, raw: dict) -> StorageCookie: """Deserialize a wire-level cookie dict to a StorageCookie.""" value_raw = raw.get("value") if isinstance(value_raw, dict): @@ -1193,7 +1235,11 @@ def to_bidi_dict(self) -> dict: result["sameSite"] = self.same_site if self.expiry is not None: result["expiry"] = self.expiry - return result''', + return result + + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict()''', # Custom PartialCookie with camelCase serialization '''@dataclass class PartialCookie: @@ -1227,7 +1273,11 @@ def to_bidi_dict(self) -> dict: result["sameSite"] = self.same_site if self.expiry is not None: result["expiry"] = self.expiry - return result''', + return result + + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict()''', # BrowsingContextPartitionDescriptor: first positional arg is *context* # (the auto-generated dataclass had `type` first, breaking positional # usage like BrowsingContextPartitionDescriptor(driver.current_window_handle)) @@ -1244,7 +1294,12 @@ def __init__(self, context: Any = None, type: str = "context") -> None: self.type = type def to_bidi_dict(self) -> dict: - return {"type": "context", "context": self.context}''', + return {"type": "context", "context": self.context} + + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict()''', + # StorageKeyPartitionDescriptor with camelCase serialization '''@dataclass class StorageKeyPartitionDescriptor: @@ -1261,7 +1316,11 @@ def to_bidi_dict(self) -> dict: result["userContext"] = self.user_context if self.source_origin is not None: result["sourceOrigin"] = self.source_origin - return result''', + return result + + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict()''', ], # Override the generated Storage class methods (Python's last-definition- # wins semantics means these extra_methods shadow the generated ones). @@ -1309,7 +1368,19 @@ def to_bidi_dict(self) -> dict: params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("storage.setCookie", params) result = self._conn.execute(cmd) + if isinstance(result, dict): + pk_raw = result.get("partitionKey") + pk = ( + PartitionKey( + user_context=pk_raw.get("userContext"), + source_origin=pk_raw.get("sourceOrigin"), + ) + if isinstance(pk_raw, dict) + else None + ) + return SetCookieResult(partition_key=pk) return result''', + ''' def delete_cookies(self, filter=None, partition=None): """Execute storage.deleteCookies.""" if filter and hasattr(filter, "to_bidi_dict"): @@ -1323,6 +1394,17 @@ def to_bidi_dict(self) -> dict: params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("storage.deleteCookies", params) result = self._conn.execute(cmd) + if isinstance(result, dict): + pk_raw = result.get("partitionKey") + pk = ( + PartitionKey( + user_context=pk_raw.get("userContext"), + source_origin=pk_raw.get("sourceOrigin"), + ) + if isinstance(pk_raw, dict) + else None + ) + return DeleteCookiesResult(partition_key=pk) return result''', ], }, @@ -1357,7 +1439,11 @@ def to_bidi_dict(self) -> dict: result["file"] = self.file if self.prompt is not None: result["prompt"] = self.prompt - return result''', + return result + + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict()''', ], }, @@ -1406,7 +1492,17 @@ def to_bidi_dict(self) -> dict: extension_data = {"type": "base64", "value": base64_value} params = {"extensionData": extension_data} cmd = command_builder("webExtension.install", params) - return self._conn.execute(cmd)''', + try: + return self._conn.execute(cmd) + except Exception as e: + if "Method not available" in str(e): + raise RuntimeError( + "webExtension.install failed with 'Method not available'. " + "This likely means that web extension support is disabled. " + "Enable unsafe extension debugging and/or set options.enable_webextensions " + "in your WebDriver configuration." + ) from e + raise''', ''' def uninstall(self, extension: str | dict): """Uninstall a web extension. @@ -1445,7 +1541,7 @@ class FileDialogInfo: multiple: bool | None = None @classmethod - def from_json(cls, params: dict) -> "FileDialogInfo": + def from_json(cls, params: dict) -> FileDialogInfo: """Deserialize event params into FileDialogInfo.""" return cls( context=params.get("context"), diff --git a/py/selenium/webdriver/common/bidi/__init__.py b/py/selenium/webdriver/common/bidi/__init__.py index 7be7bd4f73856..79ba3dbf2f86f 100644 --- a/py/selenium/webdriver/common/bidi/__init__.py +++ b/py/selenium/webdriver/common/bidi/__init__.py @@ -5,16 +5,16 @@ from __future__ import annotations -from .browser import Browser -from .browsing_context import BrowsingContext -from .emulation import Emulation -from .input import Input -from .log import Log -from .network import Network -from .script import Script -from .session import Session -from .storage import Storage -from .webextension import WebExtension +from selenium.webdriver.common.bidi.browser import Browser +from selenium.webdriver.common.bidi.browsing_context import BrowsingContext +from selenium.webdriver.common.bidi.emulation import Emulation +from selenium.webdriver.common.bidi.input import Input +from selenium.webdriver.common.bidi.log import Log +from selenium.webdriver.common.bidi.network import Network +from selenium.webdriver.common.bidi.script import Script +from selenium.webdriver.common.bidi.session import Session +from selenium.webdriver.common.bidi.storage import Storage +from selenium.webdriver.common.bidi.webextension import WebExtension __all__ = [ "Browser", diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index a8fb60c98178d..3811a2a2e97b7 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field from typing import Any -from .common import command_builder +from selenium.webdriver.common.bidi.common import command_builder def transform_download_params( @@ -139,13 +139,6 @@ class RemoveUserContextParameters: user_context: Any | None = None -@dataclass -class SetClientWindowStateParameters: - """SetClientWindowStateParameters.""" - - client_window: Any | None = None - - @dataclass class ClientWindowRectState: """ClientWindowRectState.""" @@ -188,6 +181,19 @@ class ClientWindowNamedState: MINIMIZED = "minimized" NORMAL = "normal" +@dataclass +class SetClientWindowStateParameters: + """SetClientWindowStateParameters. + + The ``state`` field is required and must be either a named-state string + (e.g. ``ClientWindowNamedState.MAXIMIZED``) or a + :class:`ClientWindowRectState` instance. ``client_window`` is the ID of + the window to affect. + """ + + client_window: Any | None = None + state: Any | None = None + class Browser: """WebDriver BiDi browser module.""" @@ -272,7 +278,7 @@ def get_user_contexts(self): def remove_user_context(self, user_context: Any | None = None): """Execute browser.removeUserContext.""" if user_context is None: - raise TypeError("remove_user_context() missing required argument: {{snake_param!r}}") + raise TypeError("remove_user_context() missing required argument: 'user_context'") params = { "userContext": user_context, diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 5b1a67ce93f11..fcee27df8488e 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -6,14 +6,12 @@ # WebDriver BiDi module: browsingContext from __future__ import annotations -import threading from collections.abc import Callable from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi.session import Session - -from .common import command_builder +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager +from selenium.webdriver.common.bidi.common import command_builder class ReadinessState: @@ -442,159 +440,6 @@ def _deserialize_info_list(items: list) -> list | None: -@dataclass -class EventConfig: - """Configuration for a BiDi event.""" - event_key: str - bidi_event: str - event_class: type - - -class _EventWrapper: - """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" - def __init__(self, bidi_event: str, event_class: type): - self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class - self._python_class = event_class # Keep reference to Python dataclass for deserialization - - def from_json(self, params: dict) -> Any: - """Deserialize event params into the wrapped Python dataclass. - - Args: - params: Raw BiDi event params with camelCase keys. - - Returns: - An instance of the dataclass, or the raw dict on failure. - """ - if self._python_class is None or self._python_class is dict: - return params - try: - # Delegate to a classmethod from_json if the class defines one - if hasattr(self._python_class, "from_json") and callable( - self._python_class.from_json - ): - return self._python_class.from_json(params) - import dataclasses as dc - - snake_params = {self._camel_to_snake(k): v for k, v in params.items()} - if dc.is_dataclass(self._python_class): - valid_fields = {f.name for f in dc.fields(self._python_class)} - filtered = {k: v for k, v in snake_params.items() if k in valid_fields} - return self._python_class(**filtered) - return self._python_class(**snake_params) - except Exception: - return params - - @staticmethod - def _camel_to_snake(name: str) -> str: - result = [name[0].lower()] - for char in name[1:]: - if char.isupper(): - result.extend(["_", char.lower()]) - else: - result.append(char) - return "".join(result) - - -class _EventManager: - """Manages event subscriptions and callbacks.""" - - def __init__(self, conn, event_configs: dict[str, EventConfig]): - self.conn = conn - self.event_configs = event_configs - self.subscriptions: dict = {} - self._event_wrappers = {} # Cache of _EventWrapper objects - self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} - self._available_events = ", ".join(sorted(event_configs.keys())) - self._subscription_lock = threading.Lock() - - # Create event wrappers for each event - for config in event_configs.values(): - wrapper = _EventWrapper(config.bidi_event, config.event_class) - self._event_wrappers[config.bidi_event] = wrapper - - def validate_event(self, event: str) -> EventConfig: - event_config = self.event_configs.get(event) - if not event_config: - raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") - return event_config - - def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: - """Subscribe to a BiDi event if not already subscribed.""" - with self._subscription_lock: - if bidi_event not in self.subscriptions: - session = Session(self.conn) - result = session.subscribe([bidi_event], contexts=contexts) - sub_id = ( - result.get("subscription") if isinstance(result, dict) else None - ) - self.subscriptions[bidi_event] = { - "callbacks": [], - "subscription_id": sub_id, - } - - def unsubscribe_from_event(self, bidi_event: str) -> None: - """Unsubscribe from a BiDi event if no more callbacks exist.""" - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry is not None and not entry["callbacks"]: - session = Session(self.conn) - sub_id = entry.get("subscription_id") - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - del self.subscriptions[bidi_event] - - def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - self.subscriptions[bidi_event]["callbacks"].append(callback_id) - - def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry and callback_id in entry["callbacks"]: - entry["callbacks"].remove(callback_id) - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - event_config = self.validate_event(event) - # Use the event wrapper for add_callback - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - callback_id = self.conn.add_callback(event_wrapper, callback) - self.subscribe_to_event(event_config.bidi_event, contexts) - self.add_callback_to_tracking(event_config.bidi_event, callback_id) - return callback_id - - def remove_event_handler(self, event: str, callback_id: int) -> None: - event_config = self.validate_event(event) - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - self.conn.remove_callback(event_wrapper, callback_id) - self.remove_callback_from_tracking(event_config.bidi_event, callback_id) - self.unsubscribe_from_event(event_config.bidi_event) - - def clear_event_handlers(self) -> None: - """Clear all event handlers.""" - with self._subscription_lock: - if not self.subscriptions: - return - session = Session(self.conn) - for bidi_event, entry in list(self.subscriptions.items()): - event_wrapper = self._event_wrappers.get(bidi_event) - callbacks = entry["callbacks"] if isinstance(entry, dict) else entry - if event_wrapper: - for callback_id in callbacks: - self.conn.remove_callback(event_wrapper, callback_id) - sub_id = ( - entry.get("subscription_id") if isinstance(entry, dict) else None - ) - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - self.subscriptions.clear() - - - - class BrowsingContext: """WebDriver BiDi browsingContext module.""" @@ -606,7 +451,7 @@ def __init__(self, conn) -> None: def activate(self, context: Any | None = None): """Execute browsingContext.activate.""" if context is None: - raise TypeError("activate() missing required argument: {{snake_param!r}}") + raise TypeError("activate() missing required argument: 'context'") params = { "context": context, @@ -625,7 +470,7 @@ def capture_screenshot( ): """Execute browsingContext.captureScreenshot.""" if context is None: - raise TypeError("capture_screenshot() missing required argument: {{snake_param!r}}") + raise TypeError("capture_screenshot() missing required argument: 'context'") params = { "context": context, @@ -644,7 +489,7 @@ def capture_screenshot( def close(self, context: Any | None = None, prompt_unload: bool | None = None): """Execute browsingContext.close.""" if context is None: - raise TypeError("close() missing required argument: {{snake_param!r}}") + raise TypeError("close() missing required argument: 'context'") params = { "context": context, @@ -664,7 +509,7 @@ def create( ): """Execute browsingContext.create.""" if type is None: - raise TypeError("create() missing required argument: {{snake_param!r}}") + raise TypeError("create() missing required argument: 'type'") params = { "type": type, @@ -709,7 +554,7 @@ def get_tree(self, max_depth: Any | None = None, root: Any | None = None): def handle_user_prompt(self, context: Any | None = None, accept: bool | None = None, user_text: Any | None = None): """Execute browsingContext.handleUserPrompt.""" if context is None: - raise TypeError("handle_user_prompt() missing required argument: {{snake_param!r}}") + raise TypeError("handle_user_prompt() missing required argument: 'context'") params = { "context": context, @@ -731,9 +576,9 @@ def locate_nodes( ): """Execute browsingContext.locateNodes.""" if context is None: - raise TypeError("locate_nodes() missing required argument: {{snake_param!r}}") + raise TypeError("locate_nodes() missing required argument: 'context'") if locator is None: - raise TypeError("locate_nodes() missing required argument: {{snake_param!r}}") + raise TypeError("locate_nodes() missing required argument: 'locator'") params = { "context": context, @@ -753,9 +598,9 @@ def locate_nodes( def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any | None = None): """Execute browsingContext.navigate.""" if context is None: - raise TypeError("navigate() missing required argument: {{snake_param!r}}") + raise TypeError("navigate() missing required argument: 'context'") if url is None: - raise TypeError("navigate() missing required argument: {{snake_param!r}}") + raise TypeError("navigate() missing required argument: 'url'") params = { "context": context, @@ -778,7 +623,7 @@ def print( ): """Execute browsingContext.print.""" if context is None: - raise TypeError("print() missing required argument: {{snake_param!r}}") + raise TypeError("print() missing required argument: 'context'") params = { "context": context, @@ -799,7 +644,7 @@ def print( def reload(self, context: Any | None = None, ignore_cache: bool | None = None, wait: Any | None = None): """Execute browsingContext.reload.""" if context is None: - raise TypeError("reload() missing required argument: {{snake_param!r}}") + raise TypeError("reload() missing required argument: 'context'") params = { "context": context, @@ -833,9 +678,9 @@ def set_viewport( def traverse_history(self, context: Any | None = None, delta: Any | None = None): """Execute browsingContext.traverseHistory.""" if context is None: - raise TypeError("traverse_history() missing required argument: {{snake_param!r}}") + raise TypeError("traverse_history() missing required argument: 'context'") if delta is None: - raise TypeError("traverse_history() missing required argument: {{snake_param!r}}") + raise TypeError("traverse_history() missing required argument: 'delta'") params = { "context": context, diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index 59e8afd93ab2e..fc75caa282a45 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -24,13 +24,14 @@ def command_builder( - method: str, params: dict[str, Any] + method: str, params: dict[str, Any] | None = None ) -> Generator[dict[str, Any], Any, Any]: """Build a BiDi command generator. Args: method: The BiDi method name (e.g., "session.status", "browser.close") - params: The parameters for the command + params: The parameters for the command. If omitted, an empty + dictionary is sent. Yields: A dictionary representing the BiDi command @@ -38,5 +39,7 @@ def command_builder( Returns: The result from the BiDi command execution """ + if params is None: + params = {} result = yield {"method": method, "params": params} return result diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 3dcf8e58881e4..44babb6777616 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field from typing import Any -from .common import command_builder +from selenium.webdriver.common.bidi.common import command_builder class ForcedColorsModeTheme: @@ -81,15 +81,6 @@ class SetLocaleOverrideParameters: user_contexts: list[Any] = field(default_factory=list) -@dataclass -class setNetworkConditionsParameters: - """setNetworkConditionsParameters.""" - - network_conditions: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - @dataclass class NetworkConditionsOffline: """NetworkConditionsOffline.""" @@ -184,6 +175,18 @@ class SetTouchOverrideParameters: user_contexts: list[Any] = field(default_factory=list) +@dataclass +class SetNetworkConditionsParameters: + """SetNetworkConditionsParameters.""" + + network_conditions: Any | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) + + +# Backward-compatible alias for existing imports +setNetworkConditionsParameters = SetNetworkConditionsParameters + class Emulation: """WebDriver BiDi emulation module.""" @@ -198,7 +201,7 @@ def set_forced_colors_mode_theme_override( ): """Execute emulation.setForcedColorsModeThemeOverride.""" if theme is None: - raise TypeError("set_forced_colors_mode_theme_override() missing required argument: {{snake_param!r}}") + raise TypeError("set_forced_colors_mode_theme_override() missing required argument: 'theme'") params = { "theme": theme, @@ -218,7 +221,7 @@ def set_locale_override( ): """Execute emulation.setLocaleOverride.""" if locale is None: - raise TypeError("set_locale_override() missing required argument: {{snake_param!r}}") + raise TypeError("set_locale_override() missing required argument: 'locale'") params = { "locale": locale, @@ -238,7 +241,7 @@ def set_screen_settings_override( ): """Execute emulation.setScreenSettingsOverride.""" if screen_area is None: - raise TypeError("set_screen_settings_override() missing required argument: {{snake_param!r}}") + raise TypeError("set_screen_settings_override() missing required argument: 'screen_area'") params = { "screenArea": screen_area, @@ -258,7 +261,7 @@ def set_viewport_meta_override( ): """Execute emulation.setViewportMetaOverride.""" if viewport_meta is None: - raise TypeError("set_viewport_meta_override() missing required argument: {{snake_param!r}}") + raise TypeError("set_viewport_meta_override() missing required argument: 'viewport_meta'") params = { "viewportMeta": viewport_meta, @@ -278,7 +281,7 @@ def set_scrollbar_type_override( ): """Execute emulation.setScrollbarTypeOverride.""" if scrollbar_type is None: - raise TypeError("set_scrollbar_type_override() missing required argument: {{snake_param!r}}") + raise TypeError("set_scrollbar_type_override() missing required argument: 'scrollbar_type'") params = { "scrollbarType": scrollbar_type, diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 1d4730534f16d..346ead5e49841 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -6,14 +6,12 @@ # WebDriver BiDi module: input from __future__ import annotations -import threading from collections.abc import Callable from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi.session import Session - -from .common import command_builder +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager +from selenium.webdriver.common.bidi.common import command_builder class PointerType: @@ -206,159 +204,6 @@ class PointerDownAction: "file_dialog_opened": "input.fileDialogOpened", } -@dataclass -class EventConfig: - """Configuration for a BiDi event.""" - event_key: str - bidi_event: str - event_class: type - - -class _EventWrapper: - """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" - def __init__(self, bidi_event: str, event_class: type): - self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class - self._python_class = event_class # Keep reference to Python dataclass for deserialization - - def from_json(self, params: dict) -> Any: - """Deserialize event params into the wrapped Python dataclass. - - Args: - params: Raw BiDi event params with camelCase keys. - - Returns: - An instance of the dataclass, or the raw dict on failure. - """ - if self._python_class is None or self._python_class is dict: - return params - try: - # Delegate to a classmethod from_json if the class defines one - if hasattr(self._python_class, "from_json") and callable( - self._python_class.from_json - ): - return self._python_class.from_json(params) - import dataclasses as dc - - snake_params = {self._camel_to_snake(k): v for k, v in params.items()} - if dc.is_dataclass(self._python_class): - valid_fields = {f.name for f in dc.fields(self._python_class)} - filtered = {k: v for k, v in snake_params.items() if k in valid_fields} - return self._python_class(**filtered) - return self._python_class(**snake_params) - except Exception: - return params - - @staticmethod - def _camel_to_snake(name: str) -> str: - result = [name[0].lower()] - for char in name[1:]: - if char.isupper(): - result.extend(["_", char.lower()]) - else: - result.append(char) - return "".join(result) - - -class _EventManager: - """Manages event subscriptions and callbacks.""" - - def __init__(self, conn, event_configs: dict[str, EventConfig]): - self.conn = conn - self.event_configs = event_configs - self.subscriptions: dict = {} - self._event_wrappers = {} # Cache of _EventWrapper objects - self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} - self._available_events = ", ".join(sorted(event_configs.keys())) - self._subscription_lock = threading.Lock() - - # Create event wrappers for each event - for config in event_configs.values(): - wrapper = _EventWrapper(config.bidi_event, config.event_class) - self._event_wrappers[config.bidi_event] = wrapper - - def validate_event(self, event: str) -> EventConfig: - event_config = self.event_configs.get(event) - if not event_config: - raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") - return event_config - - def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: - """Subscribe to a BiDi event if not already subscribed.""" - with self._subscription_lock: - if bidi_event not in self.subscriptions: - session = Session(self.conn) - result = session.subscribe([bidi_event], contexts=contexts) - sub_id = ( - result.get("subscription") if isinstance(result, dict) else None - ) - self.subscriptions[bidi_event] = { - "callbacks": [], - "subscription_id": sub_id, - } - - def unsubscribe_from_event(self, bidi_event: str) -> None: - """Unsubscribe from a BiDi event if no more callbacks exist.""" - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry is not None and not entry["callbacks"]: - session = Session(self.conn) - sub_id = entry.get("subscription_id") - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - del self.subscriptions[bidi_event] - - def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - self.subscriptions[bidi_event]["callbacks"].append(callback_id) - - def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry and callback_id in entry["callbacks"]: - entry["callbacks"].remove(callback_id) - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - event_config = self.validate_event(event) - # Use the event wrapper for add_callback - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - callback_id = self.conn.add_callback(event_wrapper, callback) - self.subscribe_to_event(event_config.bidi_event, contexts) - self.add_callback_to_tracking(event_config.bidi_event, callback_id) - return callback_id - - def remove_event_handler(self, event: str, callback_id: int) -> None: - event_config = self.validate_event(event) - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - self.conn.remove_callback(event_wrapper, callback_id) - self.remove_callback_from_tracking(event_config.bidi_event, callback_id) - self.unsubscribe_from_event(event_config.bidi_event) - - def clear_event_handlers(self) -> None: - """Clear all event handlers.""" - with self._subscription_lock: - if not self.subscriptions: - return - session = Session(self.conn) - for bidi_event, entry in list(self.subscriptions.items()): - event_wrapper = self._event_wrappers.get(bidi_event) - callbacks = entry["callbacks"] if isinstance(entry, dict) else entry - if event_wrapper: - for callback_id in callbacks: - self.conn.remove_callback(event_wrapper, callback_id) - sub_id = ( - entry.get("subscription_id") if isinstance(entry, dict) else None - ) - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - self.subscriptions.clear() - - - - class Input: """WebDriver BiDi input module.""" @@ -370,9 +215,9 @@ def __init__(self, conn) -> None: def perform_actions(self, context: Any | None = None, actions: list[Any] | None = None): """Execute input.performActions.""" if context is None: - raise TypeError("perform_actions() missing required argument: {{snake_param!r}}") + raise TypeError("perform_actions() missing required argument: 'context'") if actions is None: - raise TypeError("perform_actions() missing required argument: {{snake_param!r}}") + raise TypeError("perform_actions() missing required argument: 'actions'") params = { "context": context, @@ -386,7 +231,7 @@ def perform_actions(self, context: Any | None = None, actions: list[Any] | None def release_actions(self, context: Any | None = None): """Execute input.releaseActions.""" if context is None: - raise TypeError("release_actions() missing required argument: {{snake_param!r}}") + raise TypeError("release_actions() missing required argument: 'context'") params = { "context": context, @@ -399,11 +244,11 @@ def release_actions(self, context: Any | None = None): def set_files(self, context: Any | None = None, element: Any | None = None, files: list[Any] | None = None): """Execute input.setFiles.""" if context is None: - raise TypeError("set_files() missing required argument: {{snake_param!r}}") + raise TypeError("set_files() missing required argument: 'context'") if element is None: - raise TypeError("set_files() missing required argument: {{snake_param!r}}") + raise TypeError("set_files() missing required argument: 'element'") if files is None: - raise TypeError("set_files() missing required argument: {{snake_param!r}}") + raise TypeError("set_files() missing required argument: 'files'") params = { "context": context, diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 488f0740a40b5..ca24d6e78d532 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -6,12 +6,11 @@ # WebDriver BiDi module: log from __future__ import annotations -import threading from collections.abc import Callable from dataclasses import dataclass from typing import Any -from selenium.webdriver.common.bidi.session import Session +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager class Level: @@ -100,159 +99,6 @@ def from_json(cls, params: dict) -> JavascriptLogEntry: "entry_added": "log.entryAdded", } -@dataclass -class EventConfig: - """Configuration for a BiDi event.""" - event_key: str - bidi_event: str - event_class: type - - -class _EventWrapper: - """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" - def __init__(self, bidi_event: str, event_class: type): - self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class - self._python_class = event_class # Keep reference to Python dataclass for deserialization - - def from_json(self, params: dict) -> Any: - """Deserialize event params into the wrapped Python dataclass. - - Args: - params: Raw BiDi event params with camelCase keys. - - Returns: - An instance of the dataclass, or the raw dict on failure. - """ - if self._python_class is None or self._python_class is dict: - return params - try: - # Delegate to a classmethod from_json if the class defines one - if hasattr(self._python_class, "from_json") and callable( - self._python_class.from_json - ): - return self._python_class.from_json(params) - import dataclasses as dc - - snake_params = {self._camel_to_snake(k): v for k, v in params.items()} - if dc.is_dataclass(self._python_class): - valid_fields = {f.name for f in dc.fields(self._python_class)} - filtered = {k: v for k, v in snake_params.items() if k in valid_fields} - return self._python_class(**filtered) - return self._python_class(**snake_params) - except Exception: - return params - - @staticmethod - def _camel_to_snake(name: str) -> str: - result = [name[0].lower()] - for char in name[1:]: - if char.isupper(): - result.extend(["_", char.lower()]) - else: - result.append(char) - return "".join(result) - - -class _EventManager: - """Manages event subscriptions and callbacks.""" - - def __init__(self, conn, event_configs: dict[str, EventConfig]): - self.conn = conn - self.event_configs = event_configs - self.subscriptions: dict = {} - self._event_wrappers = {} # Cache of _EventWrapper objects - self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} - self._available_events = ", ".join(sorted(event_configs.keys())) - self._subscription_lock = threading.Lock() - - # Create event wrappers for each event - for config in event_configs.values(): - wrapper = _EventWrapper(config.bidi_event, config.event_class) - self._event_wrappers[config.bidi_event] = wrapper - - def validate_event(self, event: str) -> EventConfig: - event_config = self.event_configs.get(event) - if not event_config: - raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") - return event_config - - def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: - """Subscribe to a BiDi event if not already subscribed.""" - with self._subscription_lock: - if bidi_event not in self.subscriptions: - session = Session(self.conn) - result = session.subscribe([bidi_event], contexts=contexts) - sub_id = ( - result.get("subscription") if isinstance(result, dict) else None - ) - self.subscriptions[bidi_event] = { - "callbacks": [], - "subscription_id": sub_id, - } - - def unsubscribe_from_event(self, bidi_event: str) -> None: - """Unsubscribe from a BiDi event if no more callbacks exist.""" - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry is not None and not entry["callbacks"]: - session = Session(self.conn) - sub_id = entry.get("subscription_id") - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - del self.subscriptions[bidi_event] - - def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - self.subscriptions[bidi_event]["callbacks"].append(callback_id) - - def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry and callback_id in entry["callbacks"]: - entry["callbacks"].remove(callback_id) - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - event_config = self.validate_event(event) - # Use the event wrapper for add_callback - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - callback_id = self.conn.add_callback(event_wrapper, callback) - self.subscribe_to_event(event_config.bidi_event, contexts) - self.add_callback_to_tracking(event_config.bidi_event, callback_id) - return callback_id - - def remove_event_handler(self, event: str, callback_id: int) -> None: - event_config = self.validate_event(event) - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - self.conn.remove_callback(event_wrapper, callback_id) - self.remove_callback_from_tracking(event_config.bidi_event, callback_id) - self.unsubscribe_from_event(event_config.bidi_event) - - def clear_event_handlers(self) -> None: - """Clear all event handlers.""" - with self._subscription_lock: - if not self.subscriptions: - return - session = Session(self.conn) - for bidi_event, entry in list(self.subscriptions.items()): - event_wrapper = self._event_wrappers.get(bidi_event) - callbacks = entry["callbacks"] if isinstance(entry, dict) else entry - if event_wrapper: - for callback_id in callbacks: - self.conn.remove_callback(event_wrapper, callback_id) - sub_id = ( - entry.get("subscription_id") if isinstance(entry, dict) else None - ) - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - self.subscriptions.clear() - - - - class Log: """WebDriver BiDi log module.""" diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 1dd2f5a476049..343b6d960c017 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -6,14 +6,12 @@ # WebDriver BiDi module: network from __future__ import annotations -import threading from collections.abc import Callable from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi.session import Session - -from .common import command_builder +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager +from selenium.webdriver.common.bidi.common import command_builder class SameSite: @@ -275,15 +273,6 @@ class ContinueWithAuthCredentials: credentials: Any | None = None -@dataclass -class disownDataParameters: - """disownDataParameters.""" - - data_type: Any | None = None - collector: Any | None = None - request: Any | None = None - - @dataclass class FailRequestParameters: """FailRequestParameters.""" @@ -358,6 +347,18 @@ class ResponseStartedParameters: response: Any | None = None +@dataclass +class DisownDataParameters: + """DisownDataParameters.""" + + data_type: Any | None = None + collector: Any | None = None + request: Any | None = None + + +# Backward-compatible alias for existing imports +disownDataParameters = DisownDataParameters + class BytesValue: """A string or base64-encoded bytes value used in cookie operations. @@ -399,159 +400,6 @@ def continue_request(self, **kwargs): "before_request": "network.beforeRequestSent", } -@dataclass -class EventConfig: - """Configuration for a BiDi event.""" - event_key: str - bidi_event: str - event_class: type - - -class _EventWrapper: - """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" - def __init__(self, bidi_event: str, event_class: type): - self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class - self._python_class = event_class # Keep reference to Python dataclass for deserialization - - def from_json(self, params: dict) -> Any: - """Deserialize event params into the wrapped Python dataclass. - - Args: - params: Raw BiDi event params with camelCase keys. - - Returns: - An instance of the dataclass, or the raw dict on failure. - """ - if self._python_class is None or self._python_class is dict: - return params - try: - # Delegate to a classmethod from_json if the class defines one - if hasattr(self._python_class, "from_json") and callable( - self._python_class.from_json - ): - return self._python_class.from_json(params) - import dataclasses as dc - - snake_params = {self._camel_to_snake(k): v for k, v in params.items()} - if dc.is_dataclass(self._python_class): - valid_fields = {f.name for f in dc.fields(self._python_class)} - filtered = {k: v for k, v in snake_params.items() if k in valid_fields} - return self._python_class(**filtered) - return self._python_class(**snake_params) - except Exception: - return params - - @staticmethod - def _camel_to_snake(name: str) -> str: - result = [name[0].lower()] - for char in name[1:]: - if char.isupper(): - result.extend(["_", char.lower()]) - else: - result.append(char) - return "".join(result) - - -class _EventManager: - """Manages event subscriptions and callbacks.""" - - def __init__(self, conn, event_configs: dict[str, EventConfig]): - self.conn = conn - self.event_configs = event_configs - self.subscriptions: dict = {} - self._event_wrappers = {} # Cache of _EventWrapper objects - self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} - self._available_events = ", ".join(sorted(event_configs.keys())) - self._subscription_lock = threading.Lock() - - # Create event wrappers for each event - for config in event_configs.values(): - wrapper = _EventWrapper(config.bidi_event, config.event_class) - self._event_wrappers[config.bidi_event] = wrapper - - def validate_event(self, event: str) -> EventConfig: - event_config = self.event_configs.get(event) - if not event_config: - raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") - return event_config - - def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: - """Subscribe to a BiDi event if not already subscribed.""" - with self._subscription_lock: - if bidi_event not in self.subscriptions: - session = Session(self.conn) - result = session.subscribe([bidi_event], contexts=contexts) - sub_id = ( - result.get("subscription") if isinstance(result, dict) else None - ) - self.subscriptions[bidi_event] = { - "callbacks": [], - "subscription_id": sub_id, - } - - def unsubscribe_from_event(self, bidi_event: str) -> None: - """Unsubscribe from a BiDi event if no more callbacks exist.""" - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry is not None and not entry["callbacks"]: - session = Session(self.conn) - sub_id = entry.get("subscription_id") - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - del self.subscriptions[bidi_event] - - def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - self.subscriptions[bidi_event]["callbacks"].append(callback_id) - - def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry and callback_id in entry["callbacks"]: - entry["callbacks"].remove(callback_id) - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - event_config = self.validate_event(event) - # Use the event wrapper for add_callback - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - callback_id = self.conn.add_callback(event_wrapper, callback) - self.subscribe_to_event(event_config.bidi_event, contexts) - self.add_callback_to_tracking(event_config.bidi_event, callback_id) - return callback_id - - def remove_event_handler(self, event: str, callback_id: int) -> None: - event_config = self.validate_event(event) - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - self.conn.remove_callback(event_wrapper, callback_id) - self.remove_callback_from_tracking(event_config.bidi_event, callback_id) - self.unsubscribe_from_event(event_config.bidi_event) - - def clear_event_handlers(self) -> None: - """Clear all event handlers.""" - with self._subscription_lock: - if not self.subscriptions: - return - session = Session(self.conn) - for bidi_event, entry in list(self.subscriptions.items()): - event_wrapper = self._event_wrappers.get(bidi_event) - callbacks = entry["callbacks"] if isinstance(entry, dict) else entry - if event_wrapper: - for callback_id in callbacks: - self.conn.remove_callback(event_wrapper, callback_id) - sub_id = ( - entry.get("subscription_id") if isinstance(entry, dict) else None - ) - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - self.subscriptions.clear() - - - - class Network: """WebDriver BiDi network module.""" @@ -572,9 +420,9 @@ def add_data_collector( ): """Execute network.addDataCollector.""" if data_types is None: - raise TypeError("add_data_collector() missing required argument: {{snake_param!r}}") + raise TypeError("add_data_collector() missing required argument: 'data_types'") if max_encoded_data_size is None: - raise TypeError("add_data_collector() missing required argument: {{snake_param!r}}") + raise TypeError("add_data_collector() missing required argument: 'max_encoded_data_size'") params = { "dataTypes": data_types, @@ -596,7 +444,7 @@ def add_intercept( ): """Execute network.addIntercept.""" if phases is None: - raise TypeError("add_intercept() missing required argument: {{snake_param!r}}") + raise TypeError("add_intercept() missing required argument: 'phases'") params = { "phases": phases, @@ -619,7 +467,7 @@ def continue_request( ): """Execute network.continueRequest.""" if request is None: - raise TypeError("continue_request() missing required argument: {{snake_param!r}}") + raise TypeError("continue_request() missing required argument: 'request'") params = { "request": request, @@ -645,7 +493,7 @@ def continue_response( ): """Execute network.continueResponse.""" if request is None: - raise TypeError("continue_response() missing required argument: {{snake_param!r}}") + raise TypeError("continue_response() missing required argument: 'request'") params = { "request": request, @@ -663,7 +511,7 @@ def continue_response( def continue_with_auth(self, request: Any | None = None): """Execute network.continueWithAuth.""" if request is None: - raise TypeError("continue_with_auth() missing required argument: {{snake_param!r}}") + raise TypeError("continue_with_auth() missing required argument: 'request'") params = { "request": request, @@ -676,11 +524,11 @@ def continue_with_auth(self, request: Any | None = None): def disown_data(self, data_type: Any | None = None, collector: Any | None = None, request: Any | None = None): """Execute network.disownData.""" if data_type is None: - raise TypeError("disown_data() missing required argument: {{snake_param!r}}") + raise TypeError("disown_data() missing required argument: 'data_type'") if collector is None: - raise TypeError("disown_data() missing required argument: {{snake_param!r}}") + raise TypeError("disown_data() missing required argument: 'collector'") if request is None: - raise TypeError("disown_data() missing required argument: {{snake_param!r}}") + raise TypeError("disown_data() missing required argument: 'request'") params = { "dataType": data_type, @@ -695,7 +543,7 @@ def disown_data(self, data_type: Any | None = None, collector: Any | None = None def fail_request(self, request: Any | None = None): """Execute network.failRequest.""" if request is None: - raise TypeError("fail_request() missing required argument: {{snake_param!r}}") + raise TypeError("fail_request() missing required argument: 'request'") params = { "request": request, @@ -714,9 +562,9 @@ def get_data( ): """Execute network.getData.""" if data_type is None: - raise TypeError("get_data() missing required argument: {{snake_param!r}}") + raise TypeError("get_data() missing required argument: 'data_type'") if request is None: - raise TypeError("get_data() missing required argument: {{snake_param!r}}") + raise TypeError("get_data() missing required argument: 'request'") params = { "dataType": data_type, @@ -740,7 +588,7 @@ def provide_response( ): """Execute network.provideResponse.""" if request is None: - raise TypeError("provide_response() missing required argument: {{snake_param!r}}") + raise TypeError("provide_response() missing required argument: 'request'") params = { "request": request, @@ -758,7 +606,7 @@ def provide_response( def remove_data_collector(self, collector: Any | None = None): """Execute network.removeDataCollector.""" if collector is None: - raise TypeError("remove_data_collector() missing required argument: {{snake_param!r}}") + raise TypeError("remove_data_collector() missing required argument: 'collector'") params = { "collector": collector, @@ -771,7 +619,7 @@ def remove_data_collector(self, collector: Any | None = None): def remove_intercept(self, intercept: Any | None = None): """Execute network.removeIntercept.""" if intercept is None: - raise TypeError("remove_intercept() missing required argument: {{snake_param!r}}") + raise TypeError("remove_intercept() missing required argument: 'intercept'") params = { "intercept": intercept, @@ -784,7 +632,7 @@ def remove_intercept(self, intercept: Any | None = None): def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: list[Any] | None = None): """Execute network.setCacheBehavior.""" if cache_behavior is None: - raise TypeError("set_cache_behavior() missing required argument: {{snake_param!r}}") + raise TypeError("set_cache_behavior() missing required argument: 'cache_behavior'") params = { "cacheBehavior": cache_behavior, @@ -803,7 +651,7 @@ def set_extra_headers( ): """Execute network.setExtraHeaders.""" if headers is None: - raise TypeError("set_extra_headers() missing required argument: {{snake_param!r}}") + raise TypeError("set_extra_headers() missing required argument: 'headers'") params = { "headers": headers, @@ -818,9 +666,9 @@ def set_extra_headers( def before_request_sent(self, initiator: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.beforeRequestSent.""" if method is None: - raise TypeError("before_request_sent() missing required argument: {{snake_param!r}}") + raise TypeError("before_request_sent() missing required argument: 'method'") if params is None: - raise TypeError("before_request_sent() missing required argument: {{snake_param!r}}") + raise TypeError("before_request_sent() missing required argument: 'params'") params = { "initiator": initiator, @@ -835,11 +683,11 @@ def before_request_sent(self, initiator: Any | None = None, method: Any | None = def fetch_error(self, error_text: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.fetchError.""" if error_text is None: - raise TypeError("fetch_error() missing required argument: {{snake_param!r}}") + raise TypeError("fetch_error() missing required argument: 'error_text'") if method is None: - raise TypeError("fetch_error() missing required argument: {{snake_param!r}}") + raise TypeError("fetch_error() missing required argument: 'method'") if params is None: - raise TypeError("fetch_error() missing required argument: {{snake_param!r}}") + raise TypeError("fetch_error() missing required argument: 'params'") params = { "errorText": error_text, @@ -854,11 +702,11 @@ def fetch_error(self, error_text: Any | None = None, method: Any | None = None, def response_completed(self, response: Any | None = None, method: Any | None = None, params: Any | None = None): """Execute network.responseCompleted.""" if response is None: - raise TypeError("response_completed() missing required argument: {{snake_param!r}}") + raise TypeError("response_completed() missing required argument: 'response'") if method is None: - raise TypeError("response_completed() missing required argument: {{snake_param!r}}") + raise TypeError("response_completed() missing required argument: 'method'") if params is None: - raise TypeError("response_completed() missing required argument: {{snake_param!r}}") + raise TypeError("response_completed() missing required argument: 'params'") params = { "response": response, @@ -873,7 +721,7 @@ def response_completed(self, response: Any | None = None, method: Any | None = N def response_started(self, response: Any | None = None): """Execute network.responseStarted.""" if response is None: - raise TypeError("response_started() missing required argument: {{snake_param!r}}") + raise TypeError("response_started() missing required argument: 'response'") params = { "response": response, diff --git a/py/selenium/webdriver/common/bidi/permissions.py b/py/selenium/webdriver/common/bidi/permissions.py index 6dd138da17309..acb8bdf65f0ef 100644 --- a/py/selenium/webdriver/common/bidi/permissions.py +++ b/py/selenium/webdriver/common/bidi/permissions.py @@ -22,7 +22,7 @@ from enum import Enum from typing import Any -from .common import command_builder +from selenium.webdriver.common.bidi.common import command_builder _VALID_PERMISSION_STATES = {"granted", "denied", "prompt"} diff --git a/py/selenium/webdriver/common/bidi/py.typed b/py/selenium/webdriver/common/bidi/py.typed old mode 100755 new mode 100644 diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 221b5963e8ec1..d6877de623d14 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -6,14 +6,12 @@ # WebDriver BiDi module: script from __future__ import annotations -import threading from collections.abc import Callable from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi.session import Session - -from .common import command_builder +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager +from selenium.webdriver.common.bidi.common import command_builder class SpecialNumber: @@ -620,159 +618,6 @@ class RealmDestroyedParameters: "realm_destroyed": "script.realmDestroyed", } -@dataclass -class EventConfig: - """Configuration for a BiDi event.""" - event_key: str - bidi_event: str - event_class: type - - -class _EventWrapper: - """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" - def __init__(self, bidi_event: str, event_class: type): - self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class - self._python_class = event_class # Keep reference to Python dataclass for deserialization - - def from_json(self, params: dict) -> Any: - """Deserialize event params into the wrapped Python dataclass. - - Args: - params: Raw BiDi event params with camelCase keys. - - Returns: - An instance of the dataclass, or the raw dict on failure. - """ - if self._python_class is None or self._python_class is dict: - return params - try: - # Delegate to a classmethod from_json if the class defines one - if hasattr(self._python_class, "from_json") and callable( - self._python_class.from_json - ): - return self._python_class.from_json(params) - import dataclasses as dc - - snake_params = {self._camel_to_snake(k): v for k, v in params.items()} - if dc.is_dataclass(self._python_class): - valid_fields = {f.name for f in dc.fields(self._python_class)} - filtered = {k: v for k, v in snake_params.items() if k in valid_fields} - return self._python_class(**filtered) - return self._python_class(**snake_params) - except Exception: - return params - - @staticmethod - def _camel_to_snake(name: str) -> str: - result = [name[0].lower()] - for char in name[1:]: - if char.isupper(): - result.extend(["_", char.lower()]) - else: - result.append(char) - return "".join(result) - - -class _EventManager: - """Manages event subscriptions and callbacks.""" - - def __init__(self, conn, event_configs: dict[str, EventConfig]): - self.conn = conn - self.event_configs = event_configs - self.subscriptions: dict = {} - self._event_wrappers = {} # Cache of _EventWrapper objects - self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} - self._available_events = ", ".join(sorted(event_configs.keys())) - self._subscription_lock = threading.Lock() - - # Create event wrappers for each event - for config in event_configs.values(): - wrapper = _EventWrapper(config.bidi_event, config.event_class) - self._event_wrappers[config.bidi_event] = wrapper - - def validate_event(self, event: str) -> EventConfig: - event_config = self.event_configs.get(event) - if not event_config: - raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") - return event_config - - def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: - """Subscribe to a BiDi event if not already subscribed.""" - with self._subscription_lock: - if bidi_event not in self.subscriptions: - session = Session(self.conn) - result = session.subscribe([bidi_event], contexts=contexts) - sub_id = ( - result.get("subscription") if isinstance(result, dict) else None - ) - self.subscriptions[bidi_event] = { - "callbacks": [], - "subscription_id": sub_id, - } - - def unsubscribe_from_event(self, bidi_event: str) -> None: - """Unsubscribe from a BiDi event if no more callbacks exist.""" - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry is not None and not entry["callbacks"]: - session = Session(self.conn) - sub_id = entry.get("subscription_id") - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - del self.subscriptions[bidi_event] - - def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - self.subscriptions[bidi_event]["callbacks"].append(callback_id) - - def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: - with self._subscription_lock: - entry = self.subscriptions.get(bidi_event) - if entry and callback_id in entry["callbacks"]: - entry["callbacks"].remove(callback_id) - - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - event_config = self.validate_event(event) - # Use the event wrapper for add_callback - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - callback_id = self.conn.add_callback(event_wrapper, callback) - self.subscribe_to_event(event_config.bidi_event, contexts) - self.add_callback_to_tracking(event_config.bidi_event, callback_id) - return callback_id - - def remove_event_handler(self, event: str, callback_id: int) -> None: - event_config = self.validate_event(event) - event_wrapper = self._event_wrappers.get(event_config.bidi_event) - self.conn.remove_callback(event_wrapper, callback_id) - self.remove_callback_from_tracking(event_config.bidi_event, callback_id) - self.unsubscribe_from_event(event_config.bidi_event) - - def clear_event_handlers(self) -> None: - """Clear all event handlers.""" - with self._subscription_lock: - if not self.subscriptions: - return - session = Session(self.conn) - for bidi_event, entry in list(self.subscriptions.items()): - event_wrapper = self._event_wrappers.get(bidi_event) - callbacks = entry["callbacks"] if isinstance(entry, dict) else entry - if event_wrapper: - for callback_id in callbacks: - self.conn.remove_callback(event_wrapper, callback_id) - sub_id = ( - entry.get("subscription_id") if isinstance(entry, dict) else None - ) - if sub_id: - session.unsubscribe(subscriptions=[sub_id]) - else: - session.unsubscribe(events=[bidi_event]) - self.subscriptions.clear() - - - - class Script: """WebDriver BiDi script module.""" @@ -792,7 +637,7 @@ def add_preload_script( ): """Execute script.addPreloadScript.""" if function_declaration is None: - raise TypeError("add_preload_script() missing required argument: {{snake_param!r}}") + raise TypeError("add_preload_script() missing required argument: 'function_declaration'") params = { "functionDeclaration": function_declaration, @@ -809,9 +654,9 @@ def add_preload_script( def disown(self, handles: list[Any] | None = None, target: Any | None = None): """Execute script.disown.""" if handles is None: - raise TypeError("disown() missing required argument: {{snake_param!r}}") + raise TypeError("disown() missing required argument: 'handles'") if target is None: - raise TypeError("disown() missing required argument: {{snake_param!r}}") + raise TypeError("disown() missing required argument: 'target'") params = { "handles": handles, @@ -835,11 +680,11 @@ def call_function( ): """Execute script.callFunction.""" if function_declaration is None: - raise TypeError("call_function() missing required argument: {{snake_param!r}}") + raise TypeError("call_function() missing required argument: 'function_declaration'") if await_promise is None: - raise TypeError("call_function() missing required argument: {{snake_param!r}}") + raise TypeError("call_function() missing required argument: 'await_promise'") if target is None: - raise TypeError("call_function() missing required argument: {{snake_param!r}}") + raise TypeError("call_function() missing required argument: 'target'") params = { "functionDeclaration": function_declaration, @@ -867,11 +712,11 @@ def evaluate( ): """Execute script.evaluate.""" if expression is None: - raise TypeError("evaluate() missing required argument: {{snake_param!r}}") + raise TypeError("evaluate() missing required argument: 'expression'") if target is None: - raise TypeError("evaluate() missing required argument: {{snake_param!r}}") + raise TypeError("evaluate() missing required argument: 'target'") if await_promise is None: - raise TypeError("evaluate() missing required argument: {{snake_param!r}}") + raise TypeError("evaluate() missing required argument: 'await_promise'") params = { "expression": expression, @@ -900,7 +745,7 @@ def get_realms(self, context: Any | None = None, type: Any | None = None): def remove_preload_script(self, script: Any | None = None): """Execute script.removePreloadScript.""" if script is None: - raise TypeError("remove_preload_script() missing required argument: {{snake_param!r}}") + raise TypeError("remove_preload_script() missing required argument: 'script'") params = { "script": script, @@ -913,11 +758,11 @@ def remove_preload_script(self, script: Any | None = None): def message(self, channel: Any | None = None, data: Any | None = None, source: Any | None = None): """Execute script.message.""" if channel is None: - raise TypeError("message() missing required argument: {{snake_param!r}}") + raise TypeError("message() missing required argument: 'channel'") if data is None: - raise TypeError("message() missing required argument: {{snake_param!r}}") + raise TypeError("message() missing required argument: 'data'") if source is None: - raise TypeError("message() missing required argument: {{snake_param!r}}") + raise TypeError("message() missing required argument: 'source'") params = { "channel": channel, diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index fcb42a4ad86fc..e04d897e25deb 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field from typing import Any -from .common import command_builder +from selenium.webdriver.common.bidi.common import command_builder class UserPromptHandlerType: @@ -176,6 +176,10 @@ def to_bidi_dict(self) -> dict: result["prompt"] = self.prompt return result + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict() + class Session: """WebDriver BiDi session module.""" @@ -194,7 +198,7 @@ def status(self): def new(self, capabilities: Any | None = None): """Execute session.new.""" if capabilities is None: - raise TypeError("new() missing required argument: {{snake_param!r}}") + raise TypeError("new() missing required argument: 'capabilities'") params = { "capabilities": capabilities, @@ -221,7 +225,7 @@ def subscribe( ): """Execute session.subscribe.""" if events is None: - raise TypeError("subscribe() missing required argument: {{snake_param!r}}") + raise TypeError("subscribe() missing required argument: 'events'") params = { "events": events, diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index a2606526f3856..5ae8bf5aeb2d0 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field from typing import Any -from .common import command_builder +from selenium.webdriver.common.bidi.common import command_builder @dataclass @@ -83,6 +83,10 @@ def __init__(self, type: Any | None, value: Any | None) -> None: def to_bidi_dict(self) -> dict: return {"type": self.type, "value": self.value} + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict() + class SameSite: """SameSite cookie attribute values.""" @@ -162,6 +166,10 @@ def to_bidi_dict(self) -> dict: result["expiry"] = self.expiry return result + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict() + @dataclass class PartialCookie: """PartialCookie.""" @@ -196,6 +204,10 @@ def to_bidi_dict(self) -> dict: result["expiry"] = self.expiry return result + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict() + class BrowsingContextPartitionDescriptor: """BrowsingContextPartitionDescriptor. @@ -211,6 +223,10 @@ def __init__(self, context: Any = None, type: str = "context") -> None: def to_bidi_dict(self) -> dict: return {"type": "context", "context": self.context} + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict() + @dataclass class StorageKeyPartitionDescriptor: """StorageKeyPartitionDescriptor.""" @@ -228,6 +244,10 @@ def to_bidi_dict(self) -> dict: result["sourceOrigin"] = self.source_origin return result + def to_dict(self) -> dict: + """Backward-compatible alias for to_bidi_dict().""" + return self.to_bidi_dict() + class Storage: """WebDriver BiDi storage module.""" @@ -277,6 +297,17 @@ def set_cookie(self, cookie=None, partition=None): params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("storage.setCookie", params) result = self._conn.execute(cmd) + if isinstance(result, dict): + pk_raw = result.get("partitionKey") + pk = ( + PartitionKey( + user_context=pk_raw.get("userContext"), + source_origin=pk_raw.get("sourceOrigin"), + ) + if isinstance(pk_raw, dict) + else None + ) + return SetCookieResult(partition_key=pk) return result def delete_cookies(self, filter=None, partition=None): """Execute storage.deleteCookies.""" @@ -291,4 +322,15 @@ def delete_cookies(self, filter=None, partition=None): params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("storage.deleteCookies", params) result = self._conn.execute(cmd) + if isinstance(result, dict): + pk_raw = result.get("partitionKey") + pk = ( + PartitionKey( + user_context=pk_raw.get("userContext"), + source_origin=pk_raw.get("sourceOrigin"), + ) + if isinstance(pk_raw, dict) + else None + ) + return DeleteCookiesResult(partition_key=pk) return result diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index b5881d01e0bea..0a28843e339f1 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field from typing import Any -from .common import command_builder +from selenium.webdriver.common.bidi.common import command_builder @dataclass @@ -104,7 +104,17 @@ def install( extension_data = {"type": "base64", "value": base64_value} params = {"extensionData": extension_data} cmd = command_builder("webExtension.install", params) - return self._conn.execute(cmd) + try: + return self._conn.execute(cmd) + except Exception as e: + if "Method not available" in str(e): + raise RuntimeError( + "webExtension.install failed with 'Method not available'. " + "This likely means that web extension support is disabled. " + "Enable unsafe extension debugging and/or set options.enable_webextensions " + "in your WebDriver configuration." + ) from e + raise def uninstall(self, extension: str | dict): """Uninstall a web extension. diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index 38013a56f7b7d..2c41897878075 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -29,6 +29,7 @@ import zipfile from abc import ABCMeta from base64 import b64decode, urlsafe_b64encode +from collections.abc import Generator from contextlib import asynccontextmanager, contextmanager from importlib import import_module from typing import Any, cast @@ -436,14 +437,17 @@ def execute_cdp_cmd(self, cmd: str, cmd_args: dict): ] def execute( - self, driver_command: str, params: dict[str, Any] | None = None - ) -> dict[str, Any]: + self, + driver_command: str | Generator[dict[str, Any], Any, Any], + params: dict[str, Any] | None = None, + ) -> Any: """Sends a command to be executed by a command.CommandExecutor. Args: - driver_command: The name of the command to execute as a string. Can also be a generator - for BiDi protocol commands. + driver_command: The name of the command to execute as a string. + Can also be a BiDi protocol command generator. params: A dictionary of named parameters to send with the command. + Ignored when ``driver_command`` is a BiDi generator. Returns: The command's JSON response loaded into a dictionary object. diff --git a/py/selenium/webdriver/remote/websocket_connection.py b/py/selenium/webdriver/remote/websocket_connection.py index b4cac118df033..44cb2adef7a0b 100644 --- a/py/selenium/webdriver/remote/websocket_connection.py +++ b/py/selenium/webdriver/remote/websocket_connection.py @@ -56,7 +56,10 @@ def default(self, o): result = {} for f in dataclasses.fields(o): value = getattr(o, f.name) - if value is None: + # Skip None values unless the field is explicitly marked + # retain_none=True in its metadata (e.g. for required-but-nullable + # BiDi fields that must be sent as JSON null rather than omitted). + if value is None and not f.metadata.get("retain_none"): continue camel_key = _snake_to_camel(f.name) # Flatten PointerCommonProperties fields inline into the parent diff --git a/py/test/selenium/webdriver/common/bidi_browser_tests.py b/py/test/selenium/webdriver/common/bidi_browser_tests.py index b9e042403dcba..d0d08ea33f649 100644 --- a/py/test/selenium/webdriver/common/bidi_browser_tests.py +++ b/py/test/selenium/webdriver/common/bidi_browser_tests.py @@ -101,8 +101,10 @@ def test_raises_exception_when_removing_default_user_context(driver): def test_client_window_state_constants(driver): """Test ClientWindowNamedState constants.""" + assert ClientWindowNamedState.FULLSCREEN == "fullscreen" assert ClientWindowNamedState.MAXIMIZED == "maximized" assert ClientWindowNamedState.MINIMIZED == "minimized" + assert ClientWindowNamedState.NORMAL == "normal" def test_create_user_context_with_accept_insecure_certs(driver): diff --git a/rake_tasks/python.rake b/rake_tasks/python.rake index 0dfeefed09082..a3825b8285610 100644 --- a/rake_tasks/python.rake +++ b/rake_tasks/python.rake @@ -60,6 +60,19 @@ task :local_dev, [:all] do |_task, arguments| FileUtils.rm_rf("#{lib_path}/common/devtools") FileUtils.cp_r("#{bazel_bin}/.", lib_path, remove_destination: true) else + bidi_src = "#{bazel_bin}/common/bidi" + bidi_dest = "#{lib_path}/common/bidi" + if Dir.exist?(bidi_src) + FileUtils.mkdir_p(bidi_dest) + Dir.children(bidi_src).sort.each do |entry| + src = File.join(bidi_src, entry) + next unless File.file?(src) || File.symlink?(src) + + resolved_src = File.symlink?(src) ? File.realpath(src) : src + FileUtils.cp(resolved_src, File.join(bidi_dest, entry)) + end + end + %w[common/devtools common/linux common/mac common/windows].each do |dir| src = "#{bazel_bin}/#{dir}" dest = "#{lib_path}/#{dir}" From 126ce121049bb8b87b7783b6e3c7808059aa4461 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Tue, 31 Mar 2026 12:33:34 +0100 Subject: [PATCH 25/37] update ruff excludes --- py/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/py/pyproject.toml b/py/pyproject.toml index e081f7f17a2b4..0022442ab8431 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -140,6 +140,7 @@ warn_unreachable = false [tool.ruff] extend-exclude = [ "selenium/webdriver/common/devtools/", + "selenium/webdriver/common/bidi/", ] line-length = 120 respect-gitignore = true From a0cfa911bcf8ae3f55124e300b3d77612cae9fbc Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 1 Apr 2026 13:36:02 +0100 Subject: [PATCH 26/37] fix tests --- py/BUILD.bazel | 5 +- py/private/BUILD.bazel | 1 + py/private/_event_manager.py | 186 ++++++++++++++++++ py/private/cdp.py | 8 +- .../webdriver/common/bidi/_event_manager.py | 8 +- py/selenium/webdriver/common/bidi/browser.py | 1 - .../webdriver/common/bidi/browsing_context.py | 3 +- py/selenium/webdriver/common/bidi/cdp.py | 8 +- .../webdriver/common/bidi/emulation.py | 1 - py/selenium/webdriver/common/bidi/input.py | 3 +- py/selenium/webdriver/common/bidi/log.py | 3 +- py/selenium/webdriver/common/bidi/network.py | 3 +- py/selenium/webdriver/common/bidi/script.py | 9 +- py/selenium/webdriver/common/bidi/session.py | 1 - py/selenium/webdriver/common/bidi/storage.py | 1 - .../webdriver/common/bidi/webextension.py | 1 - 16 files changed, 216 insertions(+), 26 deletions(-) create mode 100644 py/private/_event_manager.py diff --git a/py/BUILD.bazel b/py/BUILD.bazel index d43994c15c531..292cde4981d74 100644 --- a/py/BUILD.bazel +++ b/py/BUILD.bazel @@ -622,7 +622,10 @@ generate_bidi( name = "create-bidi-src", cddl_file = "//common/bidi/spec:all.cddl", enhancements_manifest = "//py/private:bidi_enhancements_manifest.py", - extra_srcs = ["//py/private:cdp.py"], + extra_srcs = [ + "//py/private:_event_manager.py", + "//py/private:cdp.py", + ], generator = ":generate_bidi", module_name = "selenium/webdriver/common/bidi", spec_version = "1.0", diff --git a/py/private/BUILD.bazel b/py/private/BUILD.bazel index 88acc9d2aba11..d2ea587fd8101 100644 --- a/py/private/BUILD.bazel +++ b/py/private/BUILD.bazel @@ -1,6 +1,7 @@ load("@rules_python//python:defs.bzl", "py_binary") exports_files([ + "_event_manager.py", "bidi_enhancements_manifest.py", "cdp.py", ]) diff --git a/py/private/_event_manager.py b/py/private/_event_manager.py new file mode 100644 index 0000000000000..84b4446d190c5 --- /dev/null +++ b/py/private/_event_manager.py @@ -0,0 +1,186 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Shared event management helpers for generated WebDriver BiDi modules. + +``EventConfig``, ``_EventWrapper``, and ``_EventManager`` are emitted +identically into every generated module that exposes events. Rather than +duplicating this logic across those modules, they are defined once here and +copied into generated outputs by Bazel. +""" + +from __future__ import annotations + +import threading +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +from selenium.webdriver.common.bidi.session import Session + + +@dataclass +class EventConfig: + """Configuration for a BiDi event.""" + + event_key: str + bidi_event: str + event_class: type + + +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + """ + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) + + +class _EventManager: + """Manages event subscriptions and callbacks.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + """Subscribe to a BiDi event if not already subscribed.""" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist.""" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: + session = Session(self.conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event]["callbacks"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() \ No newline at end of file diff --git a/py/private/cdp.py b/py/private/cdp.py index b097762fe50cd..ba4a73298ee0a 100644 --- a/py/private/cdp.py +++ b/py/private/cdp.py @@ -60,7 +60,13 @@ def import_devtools(ver): # because cdp has been updated but selenium python has not been released yet. devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir()) - latest = max(int(x[1:]) for x in versions) + available_versions = tuple( + x for x in versions if x == "latest" or (x.startswith("v") and x[1:].isdigit()) + ) + numeric_versions = tuple(x[1:] for x in available_versions if x.startswith("v")) + if not numeric_versions: + raise + latest = max(numeric_versions, key=int) selenium_logger = logging.getLogger(__name__) selenium_logger.debug("Falling back to loading `devtools`: v%s", latest) devtools = importlib.import_module(f"{base}{latest}") diff --git a/py/selenium/webdriver/common/bidi/_event_manager.py b/py/selenium/webdriver/common/bidi/_event_manager.py index 216a5b8eccb70..84b4446d190c5 100644 --- a/py/selenium/webdriver/common/bidi/_event_manager.py +++ b/py/selenium/webdriver/common/bidi/_event_manager.py @@ -18,9 +18,9 @@ """Shared event management helpers for generated WebDriver BiDi modules. ``EventConfig``, ``_EventWrapper``, and ``_EventManager`` are emitted -identically into every generated module that exposes events. Rather than -duplicating ~160 lines of code across all of those modules, they are defined -once here and imported by the generated files. +identically into every generated module that exposes events. Rather than +duplicating this logic across those modules, they are defined once here and +copied into generated outputs by Bazel. """ from __future__ import annotations @@ -183,4 +183,4 @@ def clear_event_handlers(self) -> None: session.unsubscribe(subscriptions=[sub_id]) else: session.unsubscribe(events=[bidi_event]) - self.subscriptions.clear() + self.subscriptions.clear() \ No newline at end of file diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 3811a2a2e97b7..94dd0094e9173 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -11,7 +11,6 @@ from selenium.webdriver.common.bidi.common import command_builder - def transform_download_params( allowed: bool | None, destination_folder: str | None, diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index fcee27df8488e..86075dc166256 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -10,9 +10,8 @@ from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class ReadinessState: """ReadinessState.""" diff --git a/py/selenium/webdriver/common/bidi/cdp.py b/py/selenium/webdriver/common/bidi/cdp.py index b097762fe50cd..ba4a73298ee0a 100644 --- a/py/selenium/webdriver/common/bidi/cdp.py +++ b/py/selenium/webdriver/common/bidi/cdp.py @@ -60,7 +60,13 @@ def import_devtools(ver): # because cdp has been updated but selenium python has not been released yet. devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir()) - latest = max(int(x[1:]) for x in versions) + available_versions = tuple( + x for x in versions if x == "latest" or (x.startswith("v") and x[1:].isdigit()) + ) + numeric_versions = tuple(x[1:] for x in available_versions if x.startswith("v")) + if not numeric_versions: + raise + latest = max(numeric_versions, key=int) selenium_logger = logging.getLogger(__name__) selenium_logger.debug("Falling back to loading `devtools`: v%s", latest) devtools = importlib.import_module(f"{base}{latest}") diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 44babb6777616..9791aba5e08a6 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -11,7 +11,6 @@ from selenium.webdriver.common.bidi.common import command_builder - class ForcedColorsModeTheme: """ForcedColorsModeTheme.""" diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 346ead5e49841..d2508fea5ca64 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -10,9 +10,8 @@ from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class PointerType: """PointerType.""" diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index ca24d6e78d532..04c5a53c04510 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -10,8 +10,7 @@ from dataclasses import dataclass from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class Level: """Level.""" diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 343b6d960c017..c0302bdec186b 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -10,9 +10,8 @@ from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class SameSite: """SameSite.""" diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index d6877de623d14..d5b15ff4d983c 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -10,9 +10,8 @@ from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class SpecialNumber: """SpecialNumber.""" @@ -790,9 +789,8 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import datetime as _datetime import math as _math - + import datetime as _datetime from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -1033,9 +1031,8 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - - from selenium.webdriver.common.bidi import log as _log_mod from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod bidi_event = "log.entryAdded" diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index e04d897e25deb..a54e196aa86d9 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -11,7 +11,6 @@ from selenium.webdriver.common.bidi.common import command_builder - class UserPromptHandlerType: """UserPromptHandlerType.""" diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 5ae8bf5aeb2d0..d922390a08699 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -11,7 +11,6 @@ from selenium.webdriver.common.bidi.common import command_builder - @dataclass class PartitionKey: """PartitionKey.""" diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 0a28843e339f1..3520219e26c53 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -11,7 +11,6 @@ from selenium.webdriver.common.bidi.common import command_builder - @dataclass class InstallParameters: """InstallParameters.""" From ea374f48355414c4b5f237f37a0a90b4ebcc02d7 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Tue, 7 Apr 2026 12:14:08 +0100 Subject: [PATCH 27/37] ruffs updates --- py/selenium/webdriver/common/bidi/browser.py | 1 + py/selenium/webdriver/common/bidi/browsing_context.py | 3 ++- py/selenium/webdriver/common/bidi/emulation.py | 1 + py/selenium/webdriver/common/bidi/input.py | 3 ++- py/selenium/webdriver/common/bidi/log.py | 3 ++- py/selenium/webdriver/common/bidi/network.py | 3 ++- py/selenium/webdriver/common/bidi/script.py | 9 ++++++--- py/selenium/webdriver/common/bidi/session.py | 1 + py/selenium/webdriver/common/bidi/storage.py | 1 + py/selenium/webdriver/common/bidi/webextension.py | 1 + 10 files changed, 19 insertions(+), 7 deletions(-) diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 94dd0094e9173..3811a2a2e97b7 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -11,6 +11,7 @@ from selenium.webdriver.common.bidi.common import command_builder + def transform_download_params( allowed: bool | None, destination_folder: str | None, diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 86075dc166256..fcee27df8488e 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -10,8 +10,9 @@ from dataclasses import dataclass, field from typing import Any +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager + class ReadinessState: """ReadinessState.""" diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 9791aba5e08a6..44babb6777616 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -11,6 +11,7 @@ from selenium.webdriver.common.bidi.common import command_builder + class ForcedColorsModeTheme: """ForcedColorsModeTheme.""" diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index d2508fea5ca64..346ead5e49841 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -10,8 +10,9 @@ from dataclasses import dataclass, field from typing import Any +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager + class PointerType: """PointerType.""" diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 04c5a53c04510..ca24d6e78d532 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -10,7 +10,8 @@ from dataclasses import dataclass from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager + class Level: """Level.""" diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index c0302bdec186b..343b6d960c017 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -10,8 +10,9 @@ from dataclasses import dataclass, field from typing import Any +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager + class SameSite: """SameSite.""" diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index d5b15ff4d983c..d6877de623d14 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -10,8 +10,9 @@ from dataclasses import dataclass, field from typing import Any +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager + class SpecialNumber: """SpecialNumber.""" @@ -789,8 +790,9 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import math as _math import datetime as _datetime + import math as _math + from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -1031,8 +1033,9 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod + from selenium.webdriver.common.bidi.session import Session as _Session bidi_event = "log.entryAdded" diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index a54e196aa86d9..e04d897e25deb 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -11,6 +11,7 @@ from selenium.webdriver.common.bidi.common import command_builder + class UserPromptHandlerType: """UserPromptHandlerType.""" diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index d922390a08699..5ae8bf5aeb2d0 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -11,6 +11,7 @@ from selenium.webdriver.common.bidi.common import command_builder + @dataclass class PartitionKey: """PartitionKey.""" diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 3520219e26c53..0a28843e339f1 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -11,6 +11,7 @@ from selenium.webdriver.common.bidi.common import command_builder + @dataclass class InstallParameters: """InstallParameters.""" From c214d32c1b6cec0311433ab3fe9db10e6ab336f4 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Tue, 7 Apr 2026 12:34:15 +0100 Subject: [PATCH 28/37] Update CDDL files and regenerate python files --- common/bidi/spec/all.cddl | 71 +++++++++++-------- common/bidi/spec/local.cddl | 29 ++++++-- common/bidi/spec/remote.cddl | 65 +++++++---------- py/selenium/webdriver/common/bidi/browser.py | 1 - .../webdriver/common/bidi/browsing_context.py | 37 +++++++++- .../webdriver/common/bidi/emulation.py | 30 -------- py/selenium/webdriver/common/bidi/input.py | 3 +- py/selenium/webdriver/common/bidi/log.py | 3 +- py/selenium/webdriver/common/bidi/network.py | 4 +- py/selenium/webdriver/common/bidi/script.py | 11 ++- py/selenium/webdriver/common/bidi/session.py | 1 - py/selenium/webdriver/common/bidi/storage.py | 1 - .../webdriver/common/bidi/webextension.py | 1 - 13 files changed, 130 insertions(+), 127 deletions(-) diff --git a/common/bidi/spec/all.cddl b/common/bidi/spec/all.cddl index 85c4536a2cd10..e10b42723b0f5 100644 --- a/common/bidi/spec/all.cddl +++ b/common/bidi/spec/all.cddl @@ -420,6 +420,7 @@ BrowsingContextCommand = ( browsingContext.Navigate // browsingContext.Print // browsingContext.Reload // + browsingContext.SetBypassCSP // browsingContext.SetViewport // browsingContext.TraverseHistory ) @@ -435,6 +436,7 @@ BrowsingContextResult = ( browsingContext.NavigateResult / browsingContext.PrintResult / browsingContext.ReloadResult / + browsingContext.SetBypassCSPResult / browsingContext.SetViewportResult / browsingContext.TraverseHistoryResult ) @@ -518,6 +520,7 @@ browsingContext.BaseNavigationInfo = ( navigation: browsingContext.Navigation / null, timestamp: js-uint, url: text, + ? userContext: browser.UserContext, ) browsingContext.NavigationInfo = { @@ -605,7 +608,8 @@ browsingContext.CreateParameters = { } browsingContext.CreateResult = { - context: browsingContext.BrowsingContext + context: browsingContext.BrowsingContext, + ? userContext: browser.UserContext } browsingContext.GetTree = ( @@ -715,6 +719,19 @@ browsingContext.ReloadParameters = { browsingContext.ReloadResult = browsingContext.NavigateResult +browsingContext.SetBypassCSP = ( + method: "browsingContext.setBypassCSP", + params: browsingContext.SetBypassCSPParameters +) + +browsingContext.SetBypassCSPParameters = { + bypass: true / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +browsingContext.SetBypassCSPResult = EmptyResult + browsingContext.SetViewport = ( method: "browsingContext.setViewport", params: browsingContext.SetViewportParameters @@ -774,7 +791,8 @@ browsingContext.HistoryUpdated = ( browsingContext.HistoryUpdatedParameters = { context: browsingContext.BrowsingContext, timestamp: js-uint, - url: text + url: text, + ? userContext: browser.UserContext } browsingContext.DomContentLoaded = ( @@ -844,6 +862,7 @@ browsingContext.UserPromptClosedParameters = { context: browsingContext.BrowsingContext, accepted: bool, type: browsingContext.UserPromptType, + ? userContext: browser.UserContext, ? userText: text } @@ -857,6 +876,7 @@ browsingContext.UserPromptOpenedParameters = { handler: session.UserPromptHandlerType, message: text, type: browsingContext.UserPromptType, + ? userContext: browser.UserContext, ? defaultValue: text } @@ -871,8 +891,7 @@ EmulationCommand = ( emulation.SetScrollbarTypeOverride // emulation.SetTimezoneOverride // emulation.SetTouchOverride // - emulation.SetUserAgentOverride // - emulation.SetViewportMetaOverride + emulation.SetUserAgentOverride ) @@ -885,8 +904,7 @@ EmulationResult = ( emulation.SetScrollbarTypeOverrideResult / emulation.SetTimezoneOverrideResult / emulation.SetTouchOverrideResult / - emulation.SetUserAgentOverrideResult / - emulation.SetViewportMetaOverrideResult + emulation.SetUserAgentOverrideResult ) emulation.SetForcedColorsModeThemeOverride = ( @@ -949,10 +967,10 @@ emulation.SetLocaleOverrideResult = EmptyResult emulation.SetNetworkConditions = ( method: "emulation.setNetworkConditions", - params: emulation.setNetworkConditionsParameters + params: emulation.SetNetworkConditionsParameters ) -emulation.setNetworkConditionsParameters = { +emulation.SetNetworkConditionsParameters = { networkConditions: emulation.NetworkConditions / null, ? contexts: [+browsingContext.BrowsingContext], ? userContexts: [+browser.UserContext], @@ -1018,19 +1036,6 @@ emulation.SetUserAgentOverrideParameters = { emulation.SetUserAgentOverrideResult = EmptyResult -emulation.SetViewportMetaOverride = ( - method: "emulation.setViewportMetaOverride", - params: emulation.SetViewportMetaOverrideParameters -) - -emulation.SetViewportMetaOverrideParameters = { - viewportMeta: true / null, - ? contexts: [+browsingContext.BrowsingContext], - ? userContexts: [+browser.UserContext], -} - -emulation.SetViewportMetaOverrideResult = EmptyResult - emulation.SetScriptingEnabled = ( method: "emulation.setScriptingEnabled", params: emulation.SetScriptingEnabledParameters @@ -1145,6 +1150,7 @@ network.BaseParameters = ( redirectCount: js-uint, request: network.RequestData, timestamp: js-uint, + ? userContext: browser.UserContext / null, ? intercepts: [+network.Intercept] ) @@ -1379,10 +1385,10 @@ network.ContinueWithAuthResult = EmptyResult network.DisownData = ( method: "network.disownData", - params: network.disownDataParameters + params: network.DisownDataParameters ) -network.disownDataParameters = { +network.DisownDataParameters = { dataType: network.DataType, collector: network.Collector, request: network.Request, @@ -1710,6 +1716,7 @@ script.WindowRealmInfo = { script.BaseRealmInfo, type: "window", context: browsingContext.BrowsingContext, + ? userContext: browser.UserContext, ? sandbox: text } @@ -1969,7 +1976,8 @@ script.StackTrace = { script.Source = { realm: script.Realm, - ? context: browsingContext.BrowsingContext + ? context: browsingContext.BrowsingContext, + ? userContext: browser.UserContext } script.RealmTarget = { @@ -2381,15 +2389,15 @@ input.WheelScrollAction = { } input.PointerCommonProperties = ( - ? width: js-uint .default 1, - ? height: js-uint .default 1, - ? pressure: float .default 0.0, - ? tangentialPressure: float .default 0.0, - ? twist: (0..359) .default 0, + ? width: js-uint, + ? height: js-uint, + ? pressure: (0.0..1.0), + ? tangentialPressure: (-1.0..1.0), + ? twist: (0..359), ; 0 .. Math.PI / 2 - ? altitudeAngle: (0.0..1.5707963267948966) .default 0.0, + ? altitudeAngle: (0.0..1.5707963267948966), ; 0 .. 2 * Math.PI - ? azimuthAngle: (0.0..6.283185307179586) .default 0.0, + ? azimuthAngle: (0.0..6.283185307179586), ) input.Origin = "viewport" / "pointer" / input.ElementOrigin @@ -2427,6 +2435,7 @@ input.FileDialogOpened = ( input.FileDialogInfo = { context: browsingContext.BrowsingContext, + ? userContext: browser.UserContext, ? element: script.SharedReference, multiple: bool, } diff --git a/common/bidi/spec/local.cddl b/common/bidi/spec/local.cddl index d43af0ae11b03..1bb2ce612e2c2 100644 --- a/common/bidi/spec/local.cddl +++ b/common/bidi/spec/local.cddl @@ -251,6 +251,7 @@ BrowsingContextResult = ( browsingContext.NavigateResult / browsingContext.PrintResult / browsingContext.ReloadResult / + browsingContext.SetBypassCSPResult / browsingContext.SetViewportResult / browsingContext.TraverseHistoryResult ) @@ -334,6 +335,7 @@ browsingContext.BaseNavigationInfo = ( navigation: browsingContext.Navigation / null, timestamp: js-uint, url: text, + ? userContext: browser.UserContext, ) browsingContext.NavigationInfo = { @@ -351,7 +353,8 @@ browsingContext.CaptureScreenshotResult = { browsingContext.CloseResult = EmptyResult browsingContext.CreateResult = { - context: browsingContext.BrowsingContext + context: browsingContext.BrowsingContext, + ? userContext: browser.UserContext } browsingContext.GetTreeResult = { @@ -375,6 +378,8 @@ browsingContext.PrintResult = { browsingContext.ReloadResult = browsingContext.NavigateResult +browsingContext.SetBypassCSPResult = EmptyResult + browsingContext.SetViewportResult = EmptyResult browsingContext.TraverseHistoryResult = EmptyResult @@ -407,7 +412,8 @@ browsingContext.HistoryUpdated = ( browsingContext.HistoryUpdatedParameters = { context: browsingContext.BrowsingContext, timestamp: js-uint, - url: text + url: text, + ? userContext: browser.UserContext } browsingContext.DomContentLoaded = ( @@ -477,6 +483,7 @@ browsingContext.UserPromptClosedParameters = { context: browsingContext.BrowsingContext, accepted: bool, type: browsingContext.UserPromptType, + ? userContext: browser.UserContext, ? userText: text } @@ -490,6 +497,7 @@ browsingContext.UserPromptOpenedParameters = { handler: session.UserPromptHandlerType, message: text, type: browsingContext.UserPromptType, + ? userContext: browser.UserContext, ? defaultValue: text } @@ -502,8 +510,7 @@ EmulationResult = ( emulation.SetScrollbarTypeOverrideResult / emulation.SetTimezoneOverrideResult / emulation.SetTouchOverrideResult / - emulation.SetUserAgentOverrideResult / - emulation.SetViewportMetaOverrideResult + emulation.SetUserAgentOverrideResult ) emulation.SetForcedColorsModeThemeOverrideResult = EmptyResult @@ -520,8 +527,6 @@ emulation.SetScreenOrientationOverrideResult = EmptyResult emulation.SetUserAgentOverrideResult = EmptyResult -emulation.SetViewportMetaOverrideResult = EmptyResult - emulation.SetScriptingEnabledResult = EmptyResult emulation.SetScrollbarTypeOverrideResult = EmptyResult @@ -568,6 +573,7 @@ network.BaseParameters = ( redirectCount: js-uint, request: network.RequestData, timestamp: js-uint, + ? userContext: browser.UserContext / null, ? intercepts: [+network.Intercept] ) @@ -926,6 +932,7 @@ script.WindowRealmInfo = { script.BaseRealmInfo, type: "window", context: browsingContext.BrowsingContext, + ? userContext: browser.UserContext, ? sandbox: text } @@ -1185,7 +1192,8 @@ script.StackTrace = { script.Source = { realm: script.Realm, - ? context: browsingContext.BrowsingContext + ? context: browsingContext.BrowsingContext, + ? userContext: browser.UserContext } script.AddPreloadScriptResult = { @@ -1295,6 +1303,12 @@ log.EntryAdded = ( params: log.Entry, ) +InputResult = ( + input.PerformActionsResult / + input.ReleaseActionsResult / + input.SetFilesResult +) + InputEvent = ( input.FileDialogOpened @@ -1313,6 +1327,7 @@ input.FileDialogOpened = ( input.FileDialogInfo = { context: browsingContext.BrowsingContext, + ? userContext: browser.UserContext, ? element: script.SharedReference, multiple: bool, } diff --git a/common/bidi/spec/remote.cddl b/common/bidi/spec/remote.cddl index a98859a021e12..7490df1b44bc7 100644 --- a/common/bidi/spec/remote.cddl +++ b/common/bidi/spec/remote.cddl @@ -273,6 +273,7 @@ BrowsingContextCommand = ( browsingContext.Navigate // browsingContext.Print // browsingContext.Reload // + browsingContext.SetBypassCSP // browsingContext.SetViewport // browsingContext.TraverseHistory ) @@ -480,6 +481,17 @@ browsingContext.ReloadParameters = { ? wait: browsingContext.ReadinessState, } +browsingContext.SetBypassCSP = ( + method: "browsingContext.setBypassCSP", + params: browsingContext.SetBypassCSPParameters +) + +browsingContext.SetBypassCSPParameters = { + bypass: true / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + browsingContext.SetViewport = ( method: "browsingContext.setViewport", params: browsingContext.SetViewportParameters @@ -518,8 +530,7 @@ EmulationCommand = ( emulation.SetScrollbarTypeOverride // emulation.SetTimezoneOverride // emulation.SetTouchOverride // - emulation.SetUserAgentOverride // - emulation.SetViewportMetaOverride + emulation.SetUserAgentOverride ) @@ -577,10 +588,10 @@ emulation.SetLocaleOverrideParameters = { emulation.SetNetworkConditions = ( method: "emulation.setNetworkConditions", - params: emulation.setNetworkConditionsParameters + params: emulation.SetNetworkConditionsParameters ) -emulation.setNetworkConditionsParameters = { +emulation.SetNetworkConditionsParameters = { networkConditions: emulation.NetworkConditions / null, ? contexts: [+browsingContext.BrowsingContext], ? userContexts: [+browser.UserContext], @@ -638,17 +649,6 @@ emulation.SetUserAgentOverrideParameters = { ? userContexts: [+browser.UserContext], } -emulation.SetViewportMetaOverride = ( - method: "emulation.setViewportMetaOverride", - params: emulation.SetViewportMetaOverrideParameters -) - -emulation.SetViewportMetaOverrideParameters = { - viewportMeta: true / null, - ? contexts: [+browsingContext.BrowsingContext], - ? userContexts: [+browser.UserContext], -} - emulation.SetScriptingEnabled = ( method: "emulation.setScriptingEnabled", params: emulation.SetScriptingEnabledParameters @@ -876,10 +876,10 @@ network.ContinueWithAuthNoCredentials = ( network.DisownData = ( method: "network.disownData", - params: network.disownDataParameters + params: network.DisownDataParameters ) -network.disownDataParameters = { +network.DisownDataParameters = { dataType: network.DataType, collector: network.Collector, request: network.Request, @@ -1500,12 +1500,6 @@ InputCommand = ( input.SetFiles ) -InputResult = ( - input.PerformActionsResult / - input.ReleaseActionsResult / - input.SetFilesResult -) - input.ElementOrigin = { type: "element", element: script.SharedReference @@ -1625,15 +1619,15 @@ input.WheelScrollAction = { } input.PointerCommonProperties = ( - ? width: js-uint .default 1, - ? height: js-uint .default 1, - ? pressure: float .default 0.0, - ? tangentialPressure: float .default 0.0, - ? twist: (0..359) .default 0, + ? width: js-uint, + ? height: js-uint, + ? pressure: (0.0..1.0), + ? tangentialPressure: (-1.0..1.0), + ? twist: (0..359), ; 0 .. Math.PI / 2 - ? altitudeAngle: (0.0..1.5707963267948966) .default 0.0, + ? altitudeAngle: (0.0..1.5707963267948966), ; 0 .. 2 * Math.PI - ? azimuthAngle: (0.0..6.283185307179586) .default 0.0, + ? azimuthAngle: (0.0..6.283185307179586), ) input.Origin = "viewport" / "pointer" / input.ElementOrigin @@ -1658,17 +1652,6 @@ input.SetFilesParameters = { files: [*text] } -input.FileDialogOpened = ( - method: "input.fileDialogOpened", - params: input.FileDialogInfo -) - -input.FileDialogInfo = { - context: browsingContext.BrowsingContext, - ? element: script.SharedReference, - multiple: bool, -} - WebExtensionCommand = ( webExtension.Install // webExtension.Uninstall diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 3811a2a2e97b7..94dd0094e9173 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -11,7 +11,6 @@ from selenium.webdriver.common.bidi.common import command_builder - def transform_download_params( allowed: bool | None, destination_folder: str | None, diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index fcee27df8488e..50f9d61d487f6 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -10,9 +10,8 @@ from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class ReadinessState: """ReadinessState.""" @@ -109,6 +108,7 @@ class BaseNavigationInfo: navigation: Any | None = None timestamp: Any | None = None url: str | None = None + user_context: Any | None = None @dataclass @@ -184,6 +184,7 @@ class CreateResult: """CreateResult.""" context: Any | None = None + user_context: Any | None = None @dataclass @@ -290,6 +291,15 @@ class ReloadParameters: wait: Any | None = None +@dataclass +class SetBypassCSPParameters: + """SetBypassCSPParameters.""" + + bypass: Any | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) + + @dataclass class SetViewportParameters: """SetViewportParameters.""" @@ -323,6 +333,7 @@ class HistoryUpdatedParameters: context: Any | None = None timestamp: Any | None = None url: str | None = None + user_context: Any | None = None @dataclass @@ -332,6 +343,7 @@ class UserPromptClosedParameters: context: Any | None = None accepted: bool | None = None type: Any | None = None + user_context: Any | None = None user_text: str | None = None @@ -343,6 +355,7 @@ class UserPromptOpenedParameters: handler: Any | None = None message: str | None = None type: Any | None = None + user_context: Any | None = None default_value: str | None = None @@ -656,6 +669,26 @@ def reload(self, context: Any | None = None, ignore_cache: bool | None = None, w result = self._conn.execute(cmd) return result + def set_bypass_csp( + self, + bypass: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute browsingContext.setBypassCSP.""" + if bypass is None: + raise TypeError("set_bypass_csp() missing required argument: 'bypass'") + + params = { + "bypass": bypass, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.setBypassCSP", params) + result = self._conn.execute(cmd) + return result + def set_viewport( self, context: str | None = None, diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 44babb6777616..67f95b933aa16 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -11,7 +11,6 @@ from selenium.webdriver.common.bidi.common import command_builder - class ForcedColorsModeTheme: """ForcedColorsModeTheme.""" @@ -131,15 +130,6 @@ class SetUserAgentOverrideParameters: user_contexts: list[Any] = field(default_factory=list) -@dataclass -class SetViewportMetaOverrideParameters: - """SetViewportMetaOverrideParameters.""" - - viewport_meta: Any | None = None - contexts: list[Any] = field(default_factory=list) - user_contexts: list[Any] = field(default_factory=list) - - @dataclass class SetScriptingEnabledParameters: """SetScriptingEnabledParameters.""" @@ -253,26 +243,6 @@ def set_screen_settings_override( result = self._conn.execute(cmd) return result - def set_viewport_meta_override( - self, - viewport_meta: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute emulation.setViewportMetaOverride.""" - if viewport_meta is None: - raise TypeError("set_viewport_meta_override() missing required argument: 'viewport_meta'") - - params = { - "viewportMeta": viewport_meta, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setViewportMetaOverride", params) - result = self._conn.execute(cmd) - return result - def set_scrollbar_type_override( self, scrollbar_type: Any | None = None, diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 346ead5e49841..d2508fea5ca64 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -10,9 +10,8 @@ from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class PointerType: """PointerType.""" diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index ca24d6e78d532..04c5a53c04510 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -10,8 +10,7 @@ from dataclasses import dataclass from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class Level: """Level.""" diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 343b6d960c017..e13befc582f08 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -10,9 +10,8 @@ from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class SameSite: """SameSite.""" @@ -72,6 +71,7 @@ class BaseParameters: redirect_count: Any | None = None request: Any | None = None timestamp: Any | None = None + user_context: Any | None = None intercepts: list[Any] = field(default_factory=list) diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index d6877de623d14..572f620a6d63b 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -10,9 +10,8 @@ from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class SpecialNumber: """SpecialNumber.""" @@ -205,6 +204,7 @@ class WindowRealmInfo: type: str = field(default="window", init=False) context: Any | None = None + user_context: Any | None = None sandbox: str | None = None @@ -505,6 +505,7 @@ class Source: realm: Any | None = None context: Any | None = None + user_context: Any | None = None @dataclass @@ -790,9 +791,8 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import datetime as _datetime import math as _math - + import datetime as _datetime from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -1033,9 +1033,8 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - - from selenium.webdriver.common.bidi import log as _log_mod from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod bidi_event = "log.entryAdded" diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index e04d897e25deb..a54e196aa86d9 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -11,7 +11,6 @@ from selenium.webdriver.common.bidi.common import command_builder - class UserPromptHandlerType: """UserPromptHandlerType.""" diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 5ae8bf5aeb2d0..d922390a08699 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -11,7 +11,6 @@ from selenium.webdriver.common.bidi.common import command_builder - @dataclass class PartitionKey: """PartitionKey.""" diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 0a28843e339f1..3520219e26c53 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -11,7 +11,6 @@ from selenium.webdriver.common.bidi.common import command_builder - @dataclass class InstallParameters: """InstallParameters.""" From d8ef2ff8839385cac1c56599ce2cbeff6331a6d8 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Tue, 7 Apr 2026 12:34:54 +0100 Subject: [PATCH 29/37] ruffs updates --- py/selenium/webdriver/common/bidi/browser.py | 1 + py/selenium/webdriver/common/bidi/browsing_context.py | 3 ++- py/selenium/webdriver/common/bidi/emulation.py | 1 + py/selenium/webdriver/common/bidi/input.py | 3 ++- py/selenium/webdriver/common/bidi/log.py | 3 ++- py/selenium/webdriver/common/bidi/network.py | 3 ++- py/selenium/webdriver/common/bidi/script.py | 9 ++++++--- py/selenium/webdriver/common/bidi/session.py | 1 + py/selenium/webdriver/common/bidi/storage.py | 1 + py/selenium/webdriver/common/bidi/webextension.py | 1 + 10 files changed, 19 insertions(+), 7 deletions(-) diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 94dd0094e9173..3811a2a2e97b7 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -11,6 +11,7 @@ from selenium.webdriver.common.bidi.common import command_builder + def transform_download_params( allowed: bool | None, destination_folder: str | None, diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 50f9d61d487f6..175c511393098 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -10,8 +10,9 @@ from dataclasses import dataclass, field from typing import Any +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager + class ReadinessState: """ReadinessState.""" diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 67f95b933aa16..1c48100cc343b 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -11,6 +11,7 @@ from selenium.webdriver.common.bidi.common import command_builder + class ForcedColorsModeTheme: """ForcedColorsModeTheme.""" diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index d2508fea5ca64..346ead5e49841 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -10,8 +10,9 @@ from dataclasses import dataclass, field from typing import Any +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager + class PointerType: """PointerType.""" diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 04c5a53c04510..ca24d6e78d532 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -10,7 +10,8 @@ from dataclasses import dataclass from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager + class Level: """Level.""" diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index e13befc582f08..d6875e14fa58a 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -10,8 +10,9 @@ from dataclasses import dataclass, field from typing import Any +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager + class SameSite: """SameSite.""" diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 572f620a6d63b..ecc2a75e0922d 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -10,8 +10,9 @@ from dataclasses import dataclass, field from typing import Any +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager + class SpecialNumber: """SpecialNumber.""" @@ -791,8 +792,9 @@ def execute(self, function_declaration: str, *args, context_id: str | None = Non Returns: The inner RemoteValue result dict, or raises WebDriverException on exception. """ - import math as _math import datetime as _datetime + import math as _math + from selenium.common.exceptions import WebDriverException as _WebDriverException def _serialize_arg(value): @@ -1033,8 +1035,9 @@ def _disown(self, handles, target): def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading - from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod + from selenium.webdriver.common.bidi.session import Session as _Session bidi_event = "log.entryAdded" diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index a54e196aa86d9..e04d897e25deb 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -11,6 +11,7 @@ from selenium.webdriver.common.bidi.common import command_builder + class UserPromptHandlerType: """UserPromptHandlerType.""" diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index d922390a08699..5ae8bf5aeb2d0 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -11,6 +11,7 @@ from selenium.webdriver.common.bidi.common import command_builder + @dataclass class PartitionKey: """PartitionKey.""" diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 3520219e26c53..0a28843e339f1 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -11,6 +11,7 @@ from selenium.webdriver.common.bidi.common import command_builder + @dataclass class InstallParameters: """InstallParameters.""" From 4e38ac17d570fcbe853b242645fa6408586a7e5d Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 8 Apr 2026 11:15:10 +0100 Subject: [PATCH 30/37] fix test --- py/test/selenium/webdriver/common/bidi_browsing_context_tests.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py index f26472e7a8d54..eb5981d2282b6 100644 --- a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py +++ b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py @@ -426,7 +426,6 @@ def test_set_viewport_back_to_default(driver, pages): # Allow some tolerance since some window managers might not put it to the exact value assert abs(viewport_size[0] - default_viewport_size[0]) <= 5 assert abs(viewport_size[1] - default_viewport_size[1]) <= 5 - assert device_pixel_ratio == default_device_pixel_ratio finally: driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) From 719f7dc959a569e5d323389dbbf2d659f6bcf421 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 8 Apr 2026 11:46:07 +0100 Subject: [PATCH 31/37] add assert and add failure to chrome --- .../common/bidi_browsing_context_tests.py | 278 +++++++++++++----- 1 file changed, 209 insertions(+), 69 deletions(-) diff --git a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py index eb5981d2282b6..8038ee826aa74 100644 --- a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py +++ b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py @@ -60,7 +60,9 @@ def test_create_window(driver): def test_create_window_with_reference_context(driver): """Test creating a window with a reference context.""" reference_context = driver.current_window_handle - context_id = driver.browsing_context.create(type=WindowTypes.WINDOW, reference_context=reference_context) + context_id = driver.browsing_context.create( + type=WindowTypes.WINDOW, reference_context=reference_context + ) assert context_id is not None # Clean up @@ -79,7 +81,9 @@ def test_create_tab(driver): def test_create_tab_with_reference_context(driver): """Test creating a tab with a reference context.""" reference_context = driver.current_window_handle - context_id = driver.browsing_context.create(type=WindowTypes.TAB, reference_context=reference_context) + context_id = driver.browsing_context.create( + type=WindowTypes.TAB, reference_context=reference_context + ) assert context_id is not None # Clean up @@ -92,7 +96,10 @@ def test_create_context_with_all_parameters(driver): user_context = driver.browser.create_user_context() context_id = driver.browsing_context.create( - type=WindowTypes.WINDOW, reference_context=reference_context, user_context=user_context, background=True + type=WindowTypes.WINDOW, + reference_context=reference_context, + user_context=user_context, + background=True, ) assert context_id is not None assert context_id != reference_context @@ -121,7 +128,9 @@ def test_navigate_to_url_with_readiness_state(driver, pages): context_id = driver.browsing_context.create(type=WindowTypes.TAB) url = pages.url("bidi/logEntryAdded.html") - result = driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + result = driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) assert context_id is not None assert "/bidi/logEntryAdded.html" in result["url"] @@ -135,7 +144,9 @@ def test_get_tree_with_child(driver, pages): reference_context = driver.current_window_handle url = pages.url("iframes.html") - driver.browsing_context.navigate(context=reference_context, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=reference_context, url=url, wait=ReadinessState.COMPLETE + ) context_info_list = driver.browsing_context.get_tree(root=reference_context) @@ -151,9 +162,13 @@ def test_get_tree_with_depth(driver, pages): reference_context = driver.current_window_handle url = pages.url("iframes.html") - driver.browsing_context.navigate(context=reference_context, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=reference_context, url=url, wait=ReadinessState.COMPLETE + ) - context_info_list = driver.browsing_context.get_tree(root=reference_context, max_depth=0) + context_info_list = driver.browsing_context.get_tree( + root=reference_context, max_depth=0 + ) assert len(context_info_list) == 1 info = context_info_list[0] @@ -224,7 +239,9 @@ def test_reload_browsing_context(driver, pages): context_id = driver.browsing_context.create(type=WindowTypes.TAB) url = pages.url("bidi/logEntryAdded.html") - driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) reload_info = driver.browsing_context.reload(context=context_id) @@ -239,9 +256,13 @@ def test_reload_with_readiness_state(driver, pages): context_id = driver.browsing_context.create(type=WindowTypes.TAB) url = pages.url("bidi/logEntryAdded.html") - driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) - reload_info = driver.browsing_context.reload(context=context_id, wait=ReadinessState.COMPLETE) + reload_info = driver.browsing_context.reload( + context=context_id, wait=ReadinessState.COMPLETE + ) assert reload_info["navigation"] is not None assert "/bidi/logEntryAdded.html" in reload_info["url"] @@ -338,7 +359,9 @@ def test_capture_screenshot_with_parameters(driver, pages): clip = {"type": "box", "x": rect["x"], "y": rect["y"], "width": 5, "height": 5} - screenshot = driver.browsing_context.capture_screenshot(context=context_id, origin="document", clip=clip) + screenshot = driver.browsing_context.capture_screenshot( + context=context_id, origin="document", clip=clip + ) assert len(screenshot) > 0 @@ -349,14 +372,20 @@ def test_set_viewport(driver, pages): driver.get(pages.url("formPage.html")) try: - driver.browsing_context.set_viewport(context=context_id, viewport={"width": 251, "height": 301}) + driver.browsing_context.set_viewport( + context=context_id, viewport={"width": 251, "height": 301} + ) - viewport_size = driver.execute_script("return [window.innerWidth, window.innerHeight];") + viewport_size = driver.execute_script( + "return [window.innerWidth, window.innerHeight];" + ) assert viewport_size[0] == 251 assert viewport_size[1] == 301 finally: - driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) + driver.browsing_context.set_viewport( + context=context_id, viewport=None, device_pixel_ratio=None + ) def test_set_viewport_with_device_pixel_ratio(driver, pages): @@ -366,10 +395,14 @@ def test_set_viewport_with_device_pixel_ratio(driver, pages): try: driver.browsing_context.set_viewport( - context=context_id, viewport={"width": 252, "height": 302}, device_pixel_ratio=5 + context=context_id, + viewport={"width": 252, "height": 302}, + device_pixel_ratio=5, ) - viewport_size = driver.execute_script("return [window.innerWidth, window.innerHeight];") + viewport_size = driver.execute_script( + "return [window.innerWidth, window.innerHeight];" + ) assert viewport_size[0] == 252 assert viewport_size[1] == 302 @@ -378,7 +411,9 @@ def test_set_viewport_with_device_pixel_ratio(driver, pages): assert device_pixel_ratio == 5 finally: - driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) + driver.browsing_context.set_viewport( + context=context_id, viewport=None, device_pixel_ratio=None + ) def test_set_viewport_with_no_args_doesnt_change_values(driver, pages): @@ -388,12 +423,16 @@ def test_set_viewport_with_no_args_doesnt_change_values(driver, pages): try: driver.browsing_context.set_viewport( - context=context_id, viewport={"width": 253, "height": 303}, device_pixel_ratio=6 + context=context_id, + viewport={"width": 253, "height": 303}, + device_pixel_ratio=6, ) driver.browsing_context.set_viewport(context=context_id) - viewport_size = driver.execute_script("return [window.innerWidth, window.innerHeight];") + viewport_size = driver.execute_script( + "return [window.innerWidth, window.innerHeight];" + ) assert viewport_size[0] == 253 assert viewport_size[1] == 303 @@ -402,32 +441,46 @@ def test_set_viewport_with_no_args_doesnt_change_values(driver, pages): assert device_pixel_ratio == 6 finally: - driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) + driver.browsing_context.set_viewport( + context=context_id, viewport=None, device_pixel_ratio=None + ) +@pytest.mark.xfail_chrome def test_set_viewport_back_to_default(driver, pages): """Test resetting the viewport and device pixel ratio to defaults.""" context_id = driver.current_window_handle driver.get(pages.url("formPage.html")) - default_viewport_size = driver.execute_script("return [window.innerWidth, window.innerHeight];") + default_viewport_size = driver.execute_script( + "return [window.innerWidth, window.innerHeight];" + ) default_device_pixel_ratio = driver.execute_script("return window.devicePixelRatio") try: driver.browsing_context.set_viewport( - context=context_id, viewport={"width": 254, "height": 304}, device_pixel_ratio=10 + context=context_id, + viewport={"width": 254, "height": 304}, + device_pixel_ratio=10, ) - driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) + driver.browsing_context.set_viewport( + context=context_id, viewport=None, device_pixel_ratio=None + ) - viewport_size = driver.execute_script("return [window.innerWidth, window.innerHeight];") + viewport_size = driver.execute_script( + "return [window.innerWidth, window.innerHeight];" + ) device_pixel_ratio = driver.execute_script("return window.devicePixelRatio") # Allow some tolerance since some window managers might not put it to the exact value assert abs(viewport_size[0] - default_viewport_size[0]) <= 5 assert abs(viewport_size[1] - default_viewport_size[1]) <= 5 + assert device_pixel_ratio == default_device_pixel_ratio finally: - driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) + driver.browsing_context.set_viewport( + context=context_id, viewport=None, device_pixel_ratio=None + ) def test_print_page(driver, pages): @@ -446,7 +499,9 @@ def test_print_page(driver, pages): def test_navigate_back_in_browser_history(driver, pages): """Test navigating back in the browser history.""" context_id = driver.current_window_handle - driver.browsing_context.navigate(context=context_id, url=pages.url("formPage.html"), wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=pages.url("formPage.html"), wait=ReadinessState.COMPLETE + ) # Navigate to another page by submitting a form driver.find_element(By.ID, "imageButton").submit() @@ -459,7 +514,9 @@ def test_navigate_back_in_browser_history(driver, pages): def test_navigate_forward_in_browser_history(driver, pages): """Test navigating forward in the browser history.""" context_id = driver.current_window_handle - driver.browsing_context.navigate(context=context_id, url=pages.url("formPage.html"), wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=pages.url("formPage.html"), wait=ReadinessState.COMPLETE + ) # Navigate to another page by submitting a form driver.find_element(By.ID, "imageButton").submit() @@ -481,7 +538,9 @@ def test_locate_nodes(driver, pages): driver.get(pages.url("xhtmlTest.html")) - elements = driver.browsing_context.locate_nodes(context=context_id, locator={"type": "css", "value": "div"}) + elements = driver.browsing_context.locate_nodes( + context=context_id, locator={"type": "css", "value": "div"} + ) assert len(elements) > 0 @@ -493,7 +552,9 @@ def test_locate_nodes_with_css_locator(driver, pages): driver.get(pages.url("xhtmlTest.html")) elements = driver.browsing_context.locate_nodes( - context=context_id, locator={"type": "css", "value": "div.extraDiv, div.content"}, max_node_count=1 + context=context_id, + locator={"type": "css", "value": "div.extraDiv, div.content"}, + max_node_count=1, ) assert len(elements) >= 1 @@ -515,7 +576,9 @@ def test_locate_nodes_with_xpath_locator(driver, pages): driver.get(pages.url("xhtmlTest.html")) elements = driver.browsing_context.locate_nodes( - context=context_id, locator={"type": "xpath", "value": "/html/body/div[2]"}, max_node_count=1 + context=context_id, + locator={"type": "xpath", "value": "/html/body/div[2]"}, + max_node_count=1, ) assert len(elements) >= 1 @@ -538,7 +601,9 @@ def test_locate_nodes_with_inner_text(driver, pages): driver.get(pages.url("xhtmlTest.html")) elements = driver.browsing_context.locate_nodes( - context=context_id, locator={"type": "innerText", "value": "Spaced out"}, max_node_count=1 + context=context_id, + locator={"type": "innerText", "value": "Spaced out"}, + max_node_count=1, ) assert len(elements) >= 1 @@ -595,7 +660,9 @@ def test_add_event_handler_context_created(driver): def on_context_created(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("context_created", on_context_created) + callback_id = driver.browsing_context.add_event_handler( + "context_created", on_context_created + ) assert callback_id is not None # Create a new context to trigger the event @@ -603,7 +670,10 @@ def on_context_created(info): # Verify the event was received (might be > 1 since default context is also included) assert len(events_received) >= 1 - assert events_received[0].context == context_id or events_received[1].context == context_id + assert ( + events_received[0].context == context_id + or events_received[1].context == context_id + ) driver.browsing_context.close(context_id) driver.browsing_context.remove_event_handler("context_created", callback_id) @@ -616,7 +686,9 @@ def test_add_event_handler_context_destroyed(driver): def on_context_destroyed(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("context_destroyed", on_context_destroyed) + callback_id = driver.browsing_context.add_event_handler( + "context_destroyed", on_context_destroyed + ) assert callback_id is not None # Create and then close a context to trigger the event @@ -636,13 +708,17 @@ def test_add_event_handler_navigation_committed(driver, pages): def on_navigation_committed(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("navigation_committed", on_navigation_committed) + callback_id = driver.browsing_context.add_event_handler( + "navigation_committed", on_navigation_committed + ) assert callback_id is not None # Navigate to trigger the event context_id = driver.current_window_handle url = pages.url("simpleTest.html") - driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) assert len(events_received) >= 1 assert any(url in event.url for event in events_received) @@ -657,13 +733,17 @@ def test_add_event_handler_dom_content_loaded(driver, pages): def on_dom_content_loaded(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("dom_content_loaded", on_dom_content_loaded) + callback_id = driver.browsing_context.add_event_handler( + "dom_content_loaded", on_dom_content_loaded + ) assert callback_id is not None # Navigate to trigger the event context_id = driver.current_window_handle url = pages.url("simpleTest.html") - driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) assert len(events_received) == 1 assert any("simpleTest" in event.url for event in events_received) @@ -683,7 +763,9 @@ def on_load(info): context_id = driver.current_window_handle url = pages.url("simpleTest.html") - driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) assert len(events_received) == 1 assert any("simpleTest" in event.url for event in events_received) @@ -698,12 +780,16 @@ def test_add_event_handler_navigation_started(driver, pages): def on_navigation_started(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("navigation_started", on_navigation_started) + callback_id = driver.browsing_context.add_event_handler( + "navigation_started", on_navigation_started + ) assert callback_id is not None context_id = driver.current_window_handle url = pages.url("simpleTest.html") - driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) assert len(events_received) == 1 assert any("simpleTest" in event.url for event in events_received) @@ -718,17 +804,23 @@ def test_add_event_handler_fragment_navigated(driver, pages): def on_fragment_navigated(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("fragment_navigated", on_fragment_navigated) + callback_id = driver.browsing_context.add_event_handler( + "fragment_navigated", on_fragment_navigated + ) assert callback_id is not None # First navigate to a page context_id = driver.current_window_handle url = pages.url("linked_image.html") - driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) # Then navigate to the same page with a fragment to trigger the event fragment_url = url + "#link" - driver.browsing_context.navigate(context=context_id, url=fragment_url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=fragment_url, wait=ReadinessState.COMPLETE + ) assert len(events_received) == 1 assert any("link" in event.url for event in events_received) @@ -744,13 +836,17 @@ def test_add_event_handler_navigation_failed(driver): def on_navigation_failed(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("navigation_failed", on_navigation_failed) + callback_id = driver.browsing_context.add_event_handler( + "navigation_failed", on_navigation_failed + ) assert callback_id is not None # Navigate to an invalid URL to trigger the event context_id = driver.current_window_handle try: - driver.browsing_context.navigate(context=context_id, url="http://invalid-domain-that-does-not-exist.test/") + driver.browsing_context.navigate( + context=context_id, url="http://invalid-domain-that-does-not-exist.test/" + ) except Exception: # Expect an exception due to navigation failure pass @@ -769,7 +865,9 @@ def test_add_event_handler_user_prompt_opened(driver, pages): def on_user_prompt_opened(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("user_prompt_opened", on_user_prompt_opened) + callback_id = driver.browsing_context.add_event_handler( + "user_prompt_opened", on_user_prompt_opened + ) assert callback_id is not None # Create an alert to trigger the event @@ -794,7 +892,9 @@ def test_add_event_handler_user_prompt_closed(driver, pages): def on_user_prompt_closed(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("user_prompt_closed", on_user_prompt_closed) + callback_id = driver.browsing_context.add_event_handler( + "user_prompt_closed", on_user_prompt_closed + ) assert callback_id is not None create_prompt_page(driver, pages) @@ -819,12 +919,16 @@ def test_add_event_handler_history_updated(driver, pages): def on_history_updated(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("history_updated", on_history_updated) + callback_id = driver.browsing_context.add_event_handler( + "history_updated", on_history_updated + ) assert callback_id is not None context_id = driver.current_window_handle url = pages.url("simpleTest.html") - driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) # Use history.pushState to trigger history updated event driver.script.execute("() => { history.pushState({}, '', '/new-path'); }") @@ -844,13 +948,17 @@ def test_add_event_handler_download_will_begin(driver, pages): def on_download_will_begin(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("download_will_begin", on_download_will_begin) + callback_id = driver.browsing_context.add_event_handler( + "download_will_begin", on_download_will_begin + ) assert callback_id is not None # click on a download link to trigger the event context_id = driver.current_window_handle url = pages.url("downloads/download.html") - driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) download_xpath_file_1_txt = '//*[@id="file-1"]' driver.find_element(By.XPATH, download_xpath_file_1_txt).click() @@ -870,12 +978,16 @@ def test_add_event_handler_download_end(driver, pages): def on_download_end(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("download_end", on_download_end) + callback_id = driver.browsing_context.add_event_handler( + "download_end", on_download_end + ) assert callback_id is not None context_id = driver.current_window_handle url = pages.url("downloads/download.html") - driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) + driver.browsing_context.navigate( + context=context_id, url=url, wait=ReadinessState.COMPLETE + ) driver.find_element(By.ID, "file-1").click() @@ -893,12 +1005,14 @@ def on_download_end(info): # we assert that atleast "file_1" is present in the downloaded file since multiple downloads # will have numbered suffix like file_1 (1) assert any( - "downloads/file_1.txt" in ev.download_params.url and "file_1" in ev.download_params.filepath + "downloads/file_1.txt" in ev.download_params.url + and "file_1" in ev.download_params.filepath for ev in events_received ) assert any( - "downloads/file_2.jpg" in ev.download_params.url and "file_2" in ev.download_params.filepath + "downloads/file_2.jpg" in ev.download_params.url + and "file_2" in ev.download_params.filepath for ev in events_received ) @@ -937,7 +1051,9 @@ def test_remove_event_handler(driver): def on_context_created(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler("context_created", on_context_created) + callback_id = driver.browsing_context.add_event_handler( + "context_created", on_context_created + ) # Create a context to trigger the event context_id_1 = driver.browsing_context.create(type=WindowTypes.TAB) @@ -969,8 +1085,12 @@ def on_context_created_2(info): events_received_2.append(info) # Add multiple event handlers for the same event - callback_id_1 = driver.browsing_context.add_event_handler("context_created", on_context_created_1) - callback_id_2 = driver.browsing_context.add_event_handler("context_created", on_context_created_2) + callback_id_1 = driver.browsing_context.add_event_handler( + "context_created", on_context_created_1 + ) + callback_id_2 = driver.browsing_context.add_event_handler( + "context_created", on_context_created_2 + ) # Create a context to trigger both handlers context_id = driver.browsing_context.create(type=WindowTypes.TAB) @@ -999,8 +1119,12 @@ def on_context_created_2(info): events_received_2.append(info) # Add multiple event handlers - callback_id_1 = driver.browsing_context.add_event_handler("context_created", on_context_created_1) - callback_id_2 = driver.browsing_context.add_event_handler("context_created", on_context_created_2) + callback_id_1 = driver.browsing_context.add_event_handler( + "context_created", on_context_created_1 + ) + callback_id_2 = driver.browsing_context.add_event_handler( + "context_created", on_context_created_2 + ) # Create a context to trigger both handlers context_id_1 = driver.browsing_context.create(type=WindowTypes.TAB) @@ -1082,7 +1206,9 @@ def callback(info): def register_handler(self, thread_id): try: callback = self.make_callback() - callback_id = self.driver.browsing_context.add_event_handler("context_created", callback) + callback_id = self.driver.browsing_context.add_event_handler( + "context_created", callback + ) with self.data_lock: self.callback_ids.append(callback_id) if len(self.callback_ids) == 5: @@ -1090,12 +1216,16 @@ def register_handler(self, thread_id): return callback_id except Exception as e: with self.data_lock: - self.thread_errors.append(f"Thread {thread_id}: Registration failed: {e}") + self.thread_errors.append( + f"Thread {thread_id}: Registration failed: {e}" + ) return None def remove_handler(self, callback_id, thread_id): try: - self.driver.browsing_context.remove_event_handler("context_created", callback_id) + self.driver.browsing_context.remove_event_handler( + "context_created", callback_id + ) except Exception as e: with self.data_lock: self.thread_errors.append(f"Thread {thread_id}: Removal failed: {e}") @@ -1105,13 +1235,19 @@ def test_concurrent_event_handler_registration(driver): helper = _EventHandlerTestHelper(driver) with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(helper.register_handler, f"reg-{i}") for i in range(5)] + futures = [ + executor.submit(helper.register_handler, f"reg-{i}") for i in range(5) + ] for future in futures: future.result(timeout=15) helper.registration_complete.wait(timeout=5) - assert len(helper.callback_ids) == 5, f"Expected 5 handlers, got {len(helper.callback_ids)}" - assert not helper.thread_errors, "Errors during registration: \n" + "\n".join(helper.thread_errors) + assert ( + len(helper.callback_ids) == 5 + ), f"Expected 5 handlers, got {len(helper.callback_ids)}" + assert not helper.thread_errors, "Errors during registration: \n" + "\n".join( + helper.thread_errors + ) def test_event_callback_data_consistency(driver): @@ -1129,7 +1265,9 @@ def test_event_callback_data_consistency(driver): driver.browsing_context.close(ctx) with helper.data_lock: - assert not helper.consistency_errors, "Consistency errors: " + str(helper.consistency_errors) + assert not helper.consistency_errors, "Consistency errors: " + str( + helper.consistency_errors + ) assert len(helper.events_received) > 0, "No events received" assert len(helper.events_received) == sum(helper.context_counts.values()) assert len(helper.events_received) == sum(helper.event_type_counts.values()) @@ -1150,7 +1288,9 @@ def test_concurrent_event_handler_removal(driver): for future in futures: future.result(timeout=15) - assert not helper.thread_errors, "Errors during removal: \n" + "\n".join(helper.thread_errors) + assert not helper.thread_errors, "Errors during removal: \n" + "\n".join( + helper.thread_errors + ) def test_no_event_after_handler_removal(driver): From 867de27db541775191ec76557dbb1b79b17eb9b2 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Wed, 8 Apr 2026 17:55:35 +0100 Subject: [PATCH 32/37] do ruff format --- py/generate_bidi.py | 99 ++----- py/private/_event_manager.py | 14 +- py/private/bidi_enhancements_manifest.py | 16 +- py/private/cdp.py | 4 +- py/selenium/common/exceptions.py | 12 +- .../webdriver/common/bidi/_event_manager.py | 14 +- py/selenium/webdriver/common/bidi/browser.py | 32 +-- .../webdriver/common/bidi/browsing_context.py | 40 +-- py/selenium/webdriver/common/bidi/cdp.py | 4 +- py/selenium/webdriver/common/bidi/common.py | 4 +- .../webdriver/common/bidi/emulation.py | 10 +- py/selenium/webdriver/common/bidi/input.py | 8 +- py/selenium/webdriver/common/bidi/log.py | 7 +- py/selenium/webdriver/common/bidi/network.py | 32 +-- .../webdriver/common/bidi/permissions.py | 3 +- py/selenium/webdriver/common/bidi/script.py | 46 +++- py/selenium/webdriver/common/bidi/session.py | 8 +- py/selenium/webdriver/common/bidi/storage.py | 15 +- .../webdriver/common/bidi/webextension.py | 15 +- py/selenium/webdriver/common/proxy.py | 34 +-- py/selenium/webdriver/remote/webdriver.py | 153 +++-------- .../webdriver/remote/websocket_connection.py | 12 +- .../common/bidi_browsing_context_tests.py | 247 +++++------------- 23 files changed, 275 insertions(+), 554 deletions(-) diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 745c0f00ed890..194d94ba12d04 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -68,9 +68,7 @@ def load_enhancements_manifest(manifest_path: str | None) -> dict[str, Any]: return {} try: - spec = importlib.util.spec_from_file_location( - "bidi_enhancements", manifest_file - ) + spec = importlib.util.spec_from_file_location("bidi_enhancements", manifest_file) if spec is None or spec.loader is None: logger.warning(f"Could not load manifest: {manifest_path}") return {} @@ -169,9 +167,7 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: if param_strs: # Check if full signature would exceed line length limit (120 chars) - single_line_signature = ( - f" def {method_name}(self, {', '.join(param_strs)}):" - ) + single_line_signature = f" def {method_name}(self, {', '.join(param_strs)}):" if len(single_line_signature) > 120: # Format parameters on multiple lines body = f" def {method_name}(\n" @@ -198,9 +194,7 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: body += f" if {snake_param} is None:\n" msg = f"{method_snake}() missing required argument:" error_message = f"{msg} {snake_param!r}" - body += ( - f" raise TypeError({error_message!r})\n" - ) + body += f" raise TypeError({error_message!r})\n" body += "\n" # Add validation if specified in enhancements (for additional business logic validation) @@ -220,9 +214,7 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: transform_func = transform_spec.get("func") result_param = transform_spec.get("result_param", "params") input_params = [ - transform_spec.get(k) - for k in ["allowed", "destination_folder"] - if transform_spec.get(k) + transform_spec.get(k) for k in ["allowed", "destination_folder"] if transform_spec.get(k) ] if transform_func and result_param: @@ -245,9 +237,7 @@ def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: snake_param = self._camel_to_snake(param_name) if preprocess_type == "check_serialize_method": body += f" if {snake_param} and hasattr({snake_param}, 'to_bidi_dict'):\n" - body += ( - f" {snake_param} = {snake_param}.to_bidi_dict()\n" - ) + body += f" {snake_param} = {snake_param}.to_bidi_dict()\n" body += "\n" # Build params dict @@ -538,11 +528,7 @@ def to_python_dataclass(self) -> str: # Extract the type name from params_type (e.g., "browsingContext.Info" -> "Info") # The params_type comes from the CDDL and includes module prefix - type_name = ( - self.params_type.split(".")[-1] - if "." in self.params_type - else self.params_type - ) + type_name = self.params_type.split(".")[-1] if "." in self.params_type else self.params_type # Special case: if the type is BaseNavigationInfo, use BaseNavigationInfo directly # (NavigationInfo will be created as an alias to it) @@ -604,9 +590,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: stdlib_imports.append("from typing import Any") if needs_command_builder: - local_imports.append( - "from selenium.webdriver.common.bidi.common import command_builder" - ) + local_imports.append("from selenium.webdriver.common.bidi.common import command_builder") if self.events: local_imports.append( "from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager" @@ -626,9 +610,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: method_enhancements = enhancements.get(method_name_snake, {}) if "validate" in method_enhancements: helper_funcs_to_add.add(("validate", method_enhancements["validate"])) - if "transform" in method_enhancements and isinstance( - method_enhancements["transform"], dict - ): + if "transform" in method_enhancements and isinstance(method_enhancements["transform"], dict): transform_spec = method_enhancements["transform"] if "func" in transform_spec: helper_funcs_to_add.add(("transform", transform_spec["func"])) @@ -636,10 +618,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: # Generate helper functions if needed if helper_funcs_to_add: for func_type, func_name in sorted(helper_funcs_to_add): - if ( - func_type == "validate" - and func_name == "validate_download_behavior" - ): + if func_type == "validate" and func_name == "validate_download_behavior": code += """def validate_download_behavior( allowed: bool | None, destination_folder: str | None, @@ -662,10 +641,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: """ - elif ( - func_type == "transform" - and func_name == "transform_download_params" - ): + elif func_type == "transform" and func_name == "transform_download_params": code += """def transform_download_params( allowed: bool | None, destination_folder: str | None, @@ -750,9 +726,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: code += f' "{event_name}": "{event_def.method}",\n' # Extra events not in the CDDL spec (e.g. Chromium-specific events) for extra_evt in enhancements.get("extra_events", []): - code += ( - f' "{extra_evt["event_key"]}": "{extra_evt["bidi_event"]}",\n' - ) + code += f' "{extra_evt["event_key"]}": "{extra_evt["bidi_event"]}",\n' code += "}\n\n" # Add custom method function definitions before the class (for browsingContext) @@ -807,9 +781,7 @@ def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: # Add EVENT_CONFIGS dict if there are events if self.events: - code += ( - " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined - ) + code += " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined if self.name == "script": code += " def __init__(self, conn, driver=None) -> None:\n" @@ -928,8 +900,7 @@ def clear_event_handlers(self) -> None: # Build the entry line and check if it exceeds 120 chars single_line = ( - f' "{event_name}": ' - f'EventConfig("{event_name}", "{event_def.method}", {event_class}),' + f' "{event_name}": EventConfig("{event_name}", "{event_def.method}", {event_class}),' ) if len(single_line) > 120: @@ -1105,9 +1076,7 @@ def _extract_types(self) -> None: description=f"{type_name}", ) self.modules[module_name].enums.append(enum_def) - logger.debug( - f"Found enum: {def_name} with {len(values)} values" - ) + logger.debug(f"Found enum: {def_name} with {len(values)} values") else: # Extract fields from type definition fields = self._extract_type_fields(def_content) @@ -1120,9 +1089,7 @@ def _extract_types(self) -> None: description=f"{type_name}", ) self.modules[module_name].types.append(type_def) - logger.debug( - f"Found type: {def_name} with {len(fields)} fields" - ) + logger.debug(f"Found type: {def_name} with {len(fields)} fields") def _is_enum_definition(self, definition: str) -> bool: """Check if a definition is an enum (string union with /). @@ -1235,9 +1202,7 @@ def _extract_events(self) -> None: Event pattern: module.EventName = (method: "module.eventName", params: module.ParamType) """ # Find definitions that are in the event_names set - event_pattern = re.compile( - r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" - ) + event_pattern = re.compile(r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)") for def_name, def_content in self.definitions.items(): # Skip if not identified as an event @@ -1271,16 +1236,12 @@ def _extract_events(self) -> None: ) self.modules[module_name].events.append(event) - logger.debug( - f"Found event: {def_name} (method={method}, params={params_type})" - ) + logger.debug(f"Found event: {def_name} (method={method}, params={params_type})") def _extract_commands(self) -> None: """Extract command definitions from parsed definitions.""" # Find command definitions that follow pattern: module.Command = (method: "...", params: ...) - command_pattern = re.compile( - r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" - ) + command_pattern = re.compile(r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)") for def_name, def_content in self.definitions.items(): # Skip definitions that are events (they share the same pattern) @@ -1301,9 +1262,7 @@ def _extract_commands(self) -> None: self.modules[module_name] = CddlModule(name=module_name) # Extract parameters and required parameters - params, required_params = self._extract_parameters_and_required( - params_type - ) + params, required_params = self._extract_parameters_and_required(params_type) # Create command cmd = CddlCommand( @@ -1315,13 +1274,9 @@ def _extract_commands(self) -> None: ) self.modules[module_name].commands.append(cmd) - logger.debug( - f"Found command: {method} with params {params_type}" - ) + logger.debug(f"Found command: {method} with params {params_type}") - def _extract_parameters( - self, params_type: str, _seen: set[str] | None = None - ) -> dict[str, str]: + def _extract_parameters(self, params_type: str, _seen: set[str] | None = None) -> dict[str, str]: """Extract parameters from a parameter type definition. Handles both struct types ({...}) and top-level union types (TypeA / TypeB), @@ -1366,9 +1321,7 @@ def _extract_parameters_and_required( # For union types, collect parameters from all alternatives # but treat them as optional since the caller only needs to pass one alternative for alt_type in alternatives: - alt_params, _ = self._extract_parameters_and_required( - alt_type, _seen - ) + alt_params, _ = self._extract_parameters_and_required(alt_type, _seen) params.update(alt_params) # Note: We intentionally DON'T add to required, since these are union alternatives return params, required @@ -1470,9 +1423,7 @@ def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> Non for module_name in sorted(modules.keys()): class_name = module_name_to_class_name(module_name) filename = module_name_to_filename(module_name) - code += ( - f"from selenium.webdriver.common.bidi.{filename} import {class_name}\n" - ) + code += f"from selenium.webdriver.common.bidi.{filename} import {class_name}\n" code += "\n__all__ = [\n" for module_name in sorted(modules.keys()): @@ -1777,9 +1728,7 @@ def main( if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Generate Python WebDriver BiDi modules from CDDL specification" - ) + parser = argparse.ArgumentParser(description="Generate Python WebDriver BiDi modules from CDDL specification") parser.add_argument( "cddl_file", help="Path to CDDL specification file", diff --git a/py/private/_event_manager.py b/py/private/_event_manager.py index 84b4446d190c5..1dcc8288ce683 100644 --- a/py/private/_event_manager.py +++ b/py/private/_event_manager.py @@ -62,9 +62,7 @@ def from_json(self, params: dict) -> Any: return params try: # Delegate to a classmethod from_json if the class defines one - if hasattr(self._python_class, "from_json") and callable( - self._python_class.from_json - ): + if hasattr(self._python_class, "from_json") and callable(self._python_class.from_json): return self._python_class.from_json(params) import dataclasses as dc @@ -117,9 +115,7 @@ def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) if bidi_event not in self.subscriptions: session = Session(self.conn) result = session.subscribe([bidi_event], contexts=contexts) - sub_id = ( - result.get("subscription") if isinstance(result, dict) else None - ) + sub_id = result.get("subscription") if isinstance(result, dict) else None self.subscriptions[bidi_event] = { "callbacks": [], "subscription_id": sub_id, @@ -176,11 +172,9 @@ def clear_event_handlers(self) -> None: if event_wrapper: for callback_id in callbacks: self.conn.remove_callback(event_wrapper, callback_id) - sub_id = ( - entry.get("subscription_id") if isinstance(entry, dict) else None - ) + sub_id = entry.get("subscription_id") if isinstance(entry, dict) else None if sub_id: session.unsubscribe(subscriptions=[sub_id]) else: session.unsubscribe(events=[bidi_event]) - self.subscriptions.clear() \ No newline at end of file + self.subscriptions.clear() diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index f8ade8b9b3ad8..06c0573db9083 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -37,7 +37,6 @@ # ============================================================================ ENHANCEMENTS: dict[str, dict[str, Any]] = { - "browser": { # Dataclass custom methods "__dataclass_methods__": { @@ -183,7 +182,6 @@ class SetClientWindowStateParameters: return self._conn.execute(cmd)''', ], }, - "browsingContext": { # Method enhancements "create": { @@ -268,7 +266,6 @@ def from_json(cls, params: dict) -> DownloadEndParams: ], # Download events are now in the CDDL spec, so no extra_events needed }, - "log": { # Make LogLevel an alias for Level so existing code using LogLevel works "aliases": {"LogLevel": "Level"}, @@ -332,7 +329,6 @@ def from_json(cls, params: dict) -> JavascriptLogEntry: "entry_added": "Entry", }, }, - "emulation": { "exclude_types": ["setNetworkConditionsParameters"], "extra_dataclasses": [ @@ -545,7 +541,6 @@ class SetNetworkConditionsParameters: return self._conn.execute(cmd)''', ], }, - "script": { "extra_methods": [ ''' def execute(self, function_declaration: str, *args, context_id: str | None = None) -> Any: @@ -921,7 +916,6 @@ def from_json(self2, p): self._unsubscribe_log_entry(callback_id)''', ], }, - "network": { "exclude_types": ["disownDataParameters"], # Initialize intercepts tracking list and per-handler intercept map @@ -1124,7 +1118,6 @@ def _auth_callback(params): self._remove_intercept(intercept_id)''', ], }, - "storage": { # Exclude auto-generated dataclasses that need custom to_bidi_dict() # for JSON-over-WebSocket serialization, or custom constructors. @@ -1299,7 +1292,6 @@ def to_bidi_dict(self) -> dict: def to_dict(self) -> dict: """Backward-compatible alias for to_bidi_dict().""" return self.to_bidi_dict()''', - # StorageKeyPartitionDescriptor with camelCase serialization '''@dataclass class StorageKeyPartitionDescriptor: @@ -1380,7 +1372,6 @@ def to_dict(self) -> dict: ) return SetCookieResult(partition_key=pk) return result''', - ''' def delete_cookies(self, filter=None, partition=None): """Execute storage.deleteCookies.""" if filter and hasattr(filter, "to_bidi_dict"): @@ -1408,7 +1399,6 @@ def to_dict(self) -> dict: return result''', ], }, - "session": { # Override UserPromptHandler to add to_bidi_dict() for JSON serialization "exclude_types": ["UserPromptHandler"], @@ -1446,7 +1436,6 @@ def to_dict(self) -> dict: return self.to_bidi_dict()''', ], }, - "webExtension": { # Suppress the raw generated stubs; hand-written versions follow below "exclude_methods": ["install", "uninstall"], @@ -1527,7 +1516,6 @@ def to_dict(self) -> dict: return self._conn.execute(cmd)''', ], }, - "input": { # FileDialogInfo needs from_json for event deserialization "exclude_types": ["FileDialogInfo", "PointerMoveAction", "PointerDownAction"], @@ -1651,9 +1639,7 @@ def transform_download_params( "type": "allowed", # Convert pathlib.Path (or any path-like) to str so the BiDi # protocol always receives a plain JSON string. - "destinationFolder": ( - str(destination_folder) if destination_folder is not None else None - ), + "destinationFolder": (str(destination_folder) if destination_folder is not None else None), } elif allowed is False: return {"type": "denied"} diff --git a/py/private/cdp.py b/py/private/cdp.py index ba4a73298ee0a..bac00765f43ca 100644 --- a/py/private/cdp.py +++ b/py/private/cdp.py @@ -60,9 +60,7 @@ def import_devtools(ver): # because cdp has been updated but selenium python has not been released yet. devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir()) - available_versions = tuple( - x for x in versions if x == "latest" or (x.startswith("v") and x[1:].isdigit()) - ) + available_versions = tuple(x for x in versions if x == "latest" or (x.startswith("v") and x[1:].isdigit())) numeric_versions = tuple(x[1:] for x in available_versions if x.startswith("v")) if not numeric_versions: raise diff --git a/py/selenium/common/exceptions.py b/py/selenium/common/exceptions.py index 7ec809eb20b18..92526c3a701be 100644 --- a/py/selenium/common/exceptions.py +++ b/py/selenium/common/exceptions.py @@ -122,9 +122,7 @@ def __init__( screen: str | None = None, stacktrace: Sequence[str] | None = None, ) -> None: - with_support = ( - f"{msg}; {SUPPORT_MSG} {ERROR_URL}#staleelementreferenceexception" - ) + with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#staleelementreferenceexception" super().__init__(with_support, screen, stacktrace) @@ -191,9 +189,7 @@ def __init__( screen: str | None = None, stacktrace: Sequence[str] | None = None, ) -> None: - with_support = ( - f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementnotinteractableexception" - ) + with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementnotinteractableexception" super().__init__(with_support, screen, stacktrace) @@ -279,9 +275,7 @@ def __init__( screen: str | None = None, stacktrace: Sequence[str] | None = None, ) -> None: - with_support = ( - f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementclickinterceptedexception" - ) + with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementclickinterceptedexception" super().__init__(with_support, screen, stacktrace) diff --git a/py/selenium/webdriver/common/bidi/_event_manager.py b/py/selenium/webdriver/common/bidi/_event_manager.py index 84b4446d190c5..1dcc8288ce683 100644 --- a/py/selenium/webdriver/common/bidi/_event_manager.py +++ b/py/selenium/webdriver/common/bidi/_event_manager.py @@ -62,9 +62,7 @@ def from_json(self, params: dict) -> Any: return params try: # Delegate to a classmethod from_json if the class defines one - if hasattr(self._python_class, "from_json") and callable( - self._python_class.from_json - ): + if hasattr(self._python_class, "from_json") and callable(self._python_class.from_json): return self._python_class.from_json(params) import dataclasses as dc @@ -117,9 +115,7 @@ def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) if bidi_event not in self.subscriptions: session = Session(self.conn) result = session.subscribe([bidi_event], contexts=contexts) - sub_id = ( - result.get("subscription") if isinstance(result, dict) else None - ) + sub_id = result.get("subscription") if isinstance(result, dict) else None self.subscriptions[bidi_event] = { "callbacks": [], "subscription_id": sub_id, @@ -176,11 +172,9 @@ def clear_event_handlers(self) -> None: if event_wrapper: for callback_id in callbacks: self.conn.remove_callback(event_wrapper, callback_id) - sub_id = ( - entry.get("subscription_id") if isinstance(entry, dict) else None - ) + sub_id = entry.get("subscription_id") if isinstance(entry, dict) else None if sub_id: session.unsubscribe(subscriptions=[sub_id]) else: session.unsubscribe(events=[bidi_event]) - self.subscriptions.clear() \ No newline at end of file + self.subscriptions.clear() diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 3811a2a2e97b7..6310f2e18c2ce 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -101,7 +101,6 @@ def get_y(self): return self.y - @dataclass class UserContextInfo: """UserContextInfo.""" @@ -181,6 +180,7 @@ class ClientWindowNamedState: MINIMIZED = "minimized" NORMAL = "normal" + @dataclass class SetClientWindowStateParameters: """SetClientWindowStateParameters. @@ -194,6 +194,7 @@ class SetClientWindowStateParameters: client_window: Any | None = None state: Any | None = None + class Browser: """WebDriver BiDi browser module.""" @@ -202,8 +203,7 @@ def __init__(self, conn) -> None: def close(self): """Execute browser.close.""" - params = { - } + params = {} params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("browser.close", params) result = self._conn.execute(cmd) @@ -216,10 +216,10 @@ def create_user_context( unhandled_prompt_behavior: Any | None = None, ): """Execute browser.createUserContext.""" - if proxy and hasattr(proxy, 'to_bidi_dict'): + if proxy and hasattr(proxy, "to_bidi_dict"): proxy = proxy.to_bidi_dict() - if unhandled_prompt_behavior and hasattr(unhandled_prompt_behavior, 'to_bidi_dict'): + if unhandled_prompt_behavior and hasattr(unhandled_prompt_behavior, "to_bidi_dict"): unhandled_prompt_behavior = unhandled_prompt_behavior.to_bidi_dict() params = { @@ -237,8 +237,7 @@ def create_user_context( def get_client_windows(self): """Execute browser.getClientWindows.""" - params = { - } + params = {} params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("browser.getClientWindows", params) result = self._conn.execute(cmd) @@ -252,7 +251,7 @@ def get_client_windows(self): state=item.get("state"), width=item.get("width"), x=item.get("x"), - y=item.get("y") + y=item.get("y"), ) for item in items if isinstance(item, dict) @@ -261,18 +260,13 @@ def get_client_windows(self): def get_user_contexts(self): """Execute browser.getUserContexts.""" - params = { - } + params = {} params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("browser.getUserContexts", params) result = self._conn.execute(cmd) if result and "userContexts" in result: items = result.get("userContexts", []) - return [ - item.get("userContext") - for item in items - if isinstance(item, dict) - ] + return [item.get("userContext") for item in items if isinstance(item, dict)] return [] def remove_user_context(self, user_context: Any | None = None): @@ -320,6 +314,7 @@ def set_download_behavior( params["userContexts"] = user_contexts cmd = command_builder("browser.setDownloadBehavior", params) return self._conn.execute(cmd) + def set_client_window_state( self, client_window: Any | None = None, @@ -344,12 +339,9 @@ def set_client_window_state( # Serialize ClientWindowRectState if needed state_param = state - if hasattr(state, '__dataclass_fields__'): + if hasattr(state, "__dataclass_fields__"): # It's a dataclass, convert to dict - state_param = { - k: v for k, v in state.__dict__.items() - if v is not None - } + state_param = {k: v for k, v in state.__dict__.items() if v is not None} params = { "clientWindow": client_window, diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 175c511393098..59a9813e58124 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -366,12 +366,14 @@ class DownloadWillBeginParams: suggested_filename: str | None = None + @dataclass class DownloadCanceledParams: """DownloadCanceledParams.""" status: Any | None = None + @dataclass class DownloadParams: """DownloadParams - fields shared by all download end event variants.""" @@ -383,6 +385,7 @@ class DownloadParams: url: str | None = None filepath: str | None = None + @dataclass class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" @@ -402,6 +405,7 @@ def from_json(cls, params: dict) -> DownloadEndParams: ) return cls(download_params=dp) + # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "context_created": "browsingContext.contextCreated", @@ -420,6 +424,7 @@ def from_json(cls, params: dict) -> DownloadEndParams: "user_prompt_opened": "browsingContext.userPromptOpened", } + def _deserialize_info_list(items: list) -> list | None: """Recursively deserialize a list of dicts to Info objects. @@ -452,12 +457,11 @@ def _deserialize_info_list(items: list) -> list | None: return result if result else None - - class BrowsingContext: """WebDriver BiDi browsingContext module.""" EVENT_CONFIGS: dict[str, EventConfig] = {} + def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) @@ -558,7 +562,7 @@ def get_tree(self, max_depth: Any | None = None, root: Any | None = None): original_opener=item.get("originalOpener"), url=item.get("url"), user_context=item.get("userContext"), - parent=item.get("parent") + parent=item.get("parent"), ) for item in items if isinstance(item, dict) @@ -725,7 +729,6 @@ def traverse_history(self, context: Any | None = None, delta: Any | None = None) result = self._conn.execute(cmd) return result - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: """Add an event handler. @@ -752,48 +755,49 @@ def clear_event_handlers(self) -> None: """Clear all event handlers.""" return self._event_manager.clear_event_handlers() + # Event Info Type Aliases # Event: browsingContext.contextCreated -ContextCreated = globals().get('Info', dict) # Fallback to dict if type not defined +ContextCreated = globals().get("Info", dict) # Fallback to dict if type not defined # Event: browsingContext.contextDestroyed -ContextDestroyed = globals().get('Info', dict) # Fallback to dict if type not defined +ContextDestroyed = globals().get("Info", dict) # Fallback to dict if type not defined # Event: browsingContext.navigationStarted -NavigationStarted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +NavigationStarted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.fragmentNavigated -FragmentNavigated = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +FragmentNavigated = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.historyUpdated -HistoryUpdated = globals().get('HistoryUpdatedParameters', dict) # Fallback to dict if type not defined +HistoryUpdated = globals().get("HistoryUpdatedParameters", dict) # Fallback to dict if type not defined # Event: browsingContext.domContentLoaded -DomContentLoaded = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +DomContentLoaded = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.load -Load = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +Load = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.downloadWillBegin -DownloadWillBegin = globals().get('DownloadWillBeginParams', dict) # Fallback to dict if type not defined +DownloadWillBegin = globals().get("DownloadWillBeginParams", dict) # Fallback to dict if type not defined # Event: browsingContext.downloadEnd -DownloadEnd = globals().get('DownloadEndParams', dict) # Fallback to dict if type not defined +DownloadEnd = globals().get("DownloadEndParams", dict) # Fallback to dict if type not defined # Event: browsingContext.navigationAborted -NavigationAborted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +NavigationAborted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.navigationCommitted -NavigationCommitted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +NavigationCommitted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.navigationFailed -NavigationFailed = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +NavigationFailed = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.userPromptClosed -UserPromptClosed = globals().get('UserPromptClosedParameters', dict) # Fallback to dict if type not defined +UserPromptClosed = globals().get("UserPromptClosedParameters", dict) # Fallback to dict if type not defined # Event: browsingContext.userPromptOpened -UserPromptOpened = globals().get('UserPromptOpenedParameters', dict) # Fallback to dict if type not defined +UserPromptOpened = globals().get("UserPromptOpenedParameters", dict) # Fallback to dict if type not defined # Populate EVENT_CONFIGS with event configuration mappings diff --git a/py/selenium/webdriver/common/bidi/cdp.py b/py/selenium/webdriver/common/bidi/cdp.py index ba4a73298ee0a..bac00765f43ca 100644 --- a/py/selenium/webdriver/common/bidi/cdp.py +++ b/py/selenium/webdriver/common/bidi/cdp.py @@ -60,9 +60,7 @@ def import_devtools(ver): # because cdp has been updated but selenium python has not been released yet. devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir()) - available_versions = tuple( - x for x in versions if x == "latest" or (x.startswith("v") and x[1:].isdigit()) - ) + available_versions = tuple(x for x in versions if x == "latest" or (x.startswith("v") and x[1:].isdigit())) numeric_versions = tuple(x[1:] for x in available_versions if x.startswith("v")) if not numeric_versions: raise diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index fc75caa282a45..ff67b56622c35 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -23,9 +23,7 @@ from typing import Any -def command_builder( - method: str, params: dict[str, Any] | None = None -) -> Generator[dict[str, Any], Any, Any]: +def command_builder(method: str, params: dict[str, Any] | None = None) -> Generator[dict[str, Any], Any, Any]: """Build a BiDi command generator. Args: diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 1c48100cc343b..0860890abf41b 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -178,6 +178,7 @@ class SetNetworkConditionsParameters: # Backward-compatible alias for existing imports setNetworkConditionsParameters = SetNetworkConditionsParameters + class Emulation: """WebDriver BiDi emulation module.""" @@ -319,9 +320,7 @@ def set_geolocation_override( if isinstance(error, dict): params["error"] = error else: - params["error"] = { - "type": error.type if error.type is not None else "positionUnavailable" - } + params["error"] = {"type": error.type if error.type is not None else "positionUnavailable"} if contexts is not None: params["contexts"] = contexts if user_contexts is not None: @@ -329,6 +328,7 @@ def set_geolocation_override( cmd = command_builder("emulation.setGeolocationOverride", params) result = self._conn.execute(cmd) return result + def set_timezone_override( self, timezone=None, @@ -353,6 +353,7 @@ def set_timezone_override( params["userContexts"] = user_contexts cmd = command_builder("emulation.setTimezoneOverride", params) return self._conn.execute(cmd) + def set_scripting_enabled( self, enabled=None, @@ -377,6 +378,7 @@ def set_scripting_enabled( params["userContexts"] = user_contexts cmd = command_builder("emulation.setScriptingEnabled", params) return self._conn.execute(cmd) + def set_user_agent_override( self, user_agent=None, @@ -400,6 +402,7 @@ def set_user_agent_override( params["userContexts"] = user_contexts cmd = command_builder("emulation.setUserAgentOverride", params) return self._conn.execute(cmd) + def set_screen_orientation_override( self, screen_orientation=None, @@ -436,6 +439,7 @@ def set_screen_orientation_override( params["userContexts"] = user_contexts cmd = command_builder("emulation.setScreenOrientationOverride", params) return self._conn.execute(cmd) + def set_network_conditions( self, network_conditions=None, diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 346ead5e49841..5d4c670490089 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -180,6 +180,7 @@ def from_json(cls, params: dict) -> FileDialogInfo: multiple=params.get("multiple"), ) + @dataclass class PointerMoveAction: """PointerMoveAction.""" @@ -191,6 +192,7 @@ class PointerMoveAction: origin: Any | None = None properties: Any | None = None + @dataclass class PointerDownAction: """PointerDownAction.""" @@ -199,15 +201,18 @@ class PointerDownAction: button: Any | None = None properties: Any | None = None + # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "file_dialog_opened": "input.fileDialogOpened", } + class Input: """WebDriver BiDi input module.""" EVENT_CONFIGS: dict[str, EventConfig] = {} + def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) @@ -305,9 +310,10 @@ def clear_event_handlers(self) -> None: """Clear all event handlers.""" return self._event_manager.clear_event_handlers() + # Event Info Type Aliases # Event: input.fileDialogOpened -FileDialogOpened = globals().get('FileDialogInfo', dict) # Fallback to dict if type not defined +FileDialogOpened = globals().get("FileDialogInfo", dict) # Fallback to dict if type not defined # Populate EVENT_CONFIGS with event configuration mappings diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index ca24d6e78d532..856d8561e706f 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -24,6 +24,7 @@ class Level: LogLevel = Level + @dataclass class BaseLogEntry: """BaseLogEntry.""" @@ -69,6 +70,7 @@ def from_json(cls, params: dict) -> ConsoleLogEntry: stack_trace=params.get("stackTrace"), ) + @dataclass class JavascriptLogEntry: """JavascriptLogEntry - a JavaScript error log entry from the browser.""" @@ -92,6 +94,7 @@ def from_json(cls, params: dict) -> JavascriptLogEntry: stacktrace=params.get("stackTrace"), ) + Entry = GenericLogEntry | ConsoleLogEntry | JavascriptLogEntry # BiDi Event Name to Parameter Type Mapping @@ -99,15 +102,16 @@ def from_json(cls, params: dict) -> JavascriptLogEntry: "entry_added": "log.entryAdded", } + class Log: """WebDriver BiDi log module.""" EVENT_CONFIGS: dict[str, EventConfig] = {} + def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: """Add an event handler. @@ -134,6 +138,7 @@ def clear_event_handlers(self) -> None: """Clear all event handlers.""" return self._event_manager.clear_event_handlers() + # Event Info Type Aliases # Event: log.entryAdded EntryAdded = Entry diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index d6875e14fa58a..e13fbe0f7a20b 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -360,6 +360,7 @@ class DisownDataParameters: # Backward-compatible alias for existing imports disownDataParameters = DisownDataParameters + class BytesValue: """A string or base64-encoded bytes value used in cookie operations. @@ -377,6 +378,7 @@ def __init__(self, type: Any | None, value: Any | None) -> None: def to_bidi_dict(self) -> dict: return {"type": self.type, "value": self.value} + class Request: """Wraps a BiDi network request event params and provides request action methods.""" @@ -395,16 +397,19 @@ def continue_request(self, **kwargs): params.update(kwargs) self._conn.execute(_cb("network.continueRequest", params)) + # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "auth_required": "network.authRequired", "before_request": "network.beforeRequestSent", } + class Network: """WebDriver BiDi network module.""" EVENT_CONFIGS: dict[str, EventConfig] = {} + def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) @@ -755,6 +760,7 @@ def _add_intercept(self, phases=None, url_patterns=None): if intercept_id and intercept_id not in self.intercepts: self.intercepts.append(intercept_id) return result + def _remove_intercept(self, intercept_id): """Remove a low-level network intercept.""" from selenium.webdriver.common.bidi.common import command_builder as _cb @@ -762,6 +768,7 @@ def _remove_intercept(self, intercept_id): self._conn.execute(_cb("network.removeIntercept", {"intercept": intercept_id})) if intercept_id in self.intercepts: self.intercepts.remove(intercept_id) + def add_request_handler(self, event, callback, url_patterns=None): """Add a handler for network requests at the specified phase. @@ -784,11 +791,7 @@ def add_request_handler(self, event, callback, url_patterns=None): intercept_id = intercept_result.get("intercept") if intercept_result else None def _request_callback(params): - raw = ( - params - if isinstance(params, dict) - else (params.__dict__ if hasattr(params, "__dict__") else {}) - ) + raw = params if isinstance(params, dict) else (params.__dict__ if hasattr(params, "__dict__") else {}) request = Request(self._conn, raw) callback(request) @@ -796,6 +799,7 @@ def _request_callback(params): if intercept_id: self._handler_intercepts[callback_id] = intercept_id return callback_id + def remove_request_handler(self, event, callback_id): """Remove a network request handler and its associated network intercept. @@ -807,11 +811,13 @@ def remove_request_handler(self, event, callback_id): intercept_id = self._handler_intercepts.pop(callback_id, None) if intercept_id: self._remove_intercept(intercept_id) + def clear_request_handlers(self): """Clear all request handlers and remove all tracked intercepts.""" self.clear_event_handlers() for intercept_id in list(self.intercepts): self._remove_intercept(intercept_id) + def add_auth_handler(self, username, password): """Add an auth handler that automatically provides credentials. @@ -829,16 +835,8 @@ def add_auth_handler(self, username, password): intercept_id = intercept_result.get("intercept") if intercept_result else None def _auth_callback(params): - raw = ( - params - if isinstance(params, dict) - else (params.__dict__ if hasattr(params, "__dict__") else {}) - ) - request_id = ( - raw.get("request", {}).get("request") - if isinstance(raw, dict) - else None - ) + raw = params if isinstance(params, dict) else (params.__dict__ if hasattr(params, "__dict__") else {}) + request_id = raw.get("request", {}).get("request") if isinstance(raw, dict) else None if request_id: self._conn.execute( _cb( @@ -859,6 +857,7 @@ def _auth_callback(params): if intercept_id: self._handler_intercepts[callback_id] = intercept_id return callback_id + def remove_auth_handler(self, callback_id): """Remove an auth handler by callback ID and its associated network intercept. @@ -896,9 +895,10 @@ def clear_event_handlers(self) -> None: """Clear all event handlers.""" return self._event_manager.clear_event_handlers() + # Event Info Type Aliases # Event: network.authRequired -AuthRequired = globals().get('AuthRequiredParameters', dict) # Fallback to dict if type not defined +AuthRequired = globals().get("AuthRequiredParameters", dict) # Fallback to dict if type not defined # Populate EVENT_CONFIGS with event configuration mappings diff --git a/py/selenium/webdriver/common/bidi/permissions.py b/py/selenium/webdriver/common/bidi/permissions.py index acb8bdf65f0ef..98e25a1d2f856 100644 --- a/py/selenium/webdriver/common/bidi/permissions.py +++ b/py/selenium/webdriver/common/bidi/permissions.py @@ -82,8 +82,7 @@ def set_permission( state_value = state.value if isinstance(state, PermissionState) else state if state_value not in _VALID_PERMISSION_STATES: raise ValueError( - f"Invalid permission state: {state_value!r}. " - f"Must be one of {sorted(_VALID_PERMISSION_STATES)}" + f"Invalid permission state: {state_value!r}. Must be one of {sorted(_VALID_PERMISSION_STATES)}" ) if isinstance(descriptor, str): diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index ecc2a75e0922d..38e43a6677470 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -620,10 +620,12 @@ class RealmDestroyedParameters: "realm_destroyed": "script.realmDestroyed", } + class Script: """WebDriver BiDi script module.""" EVENT_CONFIGS: dict[str, EventConfig] = {} + def __init__(self, conn, driver=None) -> None: self._conn = conn self._driver = driver @@ -845,6 +847,7 @@ def _serialize_arg(value): if raw.get("type") == "success": return raw.get("result") return raw + def _add_preload_script( self, function_declaration, @@ -880,6 +883,7 @@ def _add_preload_script( if isinstance(result, dict): return result.get("script") return result + def _remove_preload_script(self, script_id): """Remove a preload script by ID. @@ -887,6 +891,7 @@ def _remove_preload_script(self, script_id): script_id: The ID of the preload script to remove. """ return self.remove_preload_script(script=script_id) + def pin(self, function_declaration): """Pin (add) a preload script that runs on every page load. @@ -897,6 +902,7 @@ def pin(self, function_declaration): script_id: The ID of the pinned script (str). """ return self._add_preload_script(function_declaration) + def unpin(self, script_id): """Unpin (remove) a previously pinned preload script. @@ -904,6 +910,7 @@ def unpin(self, script_id): script_id: The ID returned by pin(). """ return self._remove_preload_script(script_id=script_id) + def _evaluate( self, expression, @@ -926,6 +933,7 @@ def _evaluate( Returns: An object with .realm, .result (dict or None), and .exception_details (or None). """ + class _EvalResult: def __init__(self2, realm, result, exception_details): self2.realm = realm @@ -947,6 +955,7 @@ def __init__(self2, realm, result, exception_details): return _EvalResult(realm=realm, result=None, exception_details=exc) return _EvalResult(realm=realm, result=raw.get("result"), exception_details=None) return _EvalResult(realm=None, result=raw, exception_details=None) + def _call_function( self, function_declaration, @@ -973,6 +982,7 @@ def _call_function( Returns: An object with .result (dict or None) and .exception_details (or None). """ + class _CallResult: def __init__(self2, result, exception_details): self2.result = result @@ -995,6 +1005,7 @@ def __init__(self2, result, exception_details): if raw.get("type") == "success": return _CallResult(result=raw.get("result"), exception_details=None) return _CallResult(result=raw, exception_details=None) + def _get_realms(self, context=None, type=None): """Get all realms, optionally filtered by context and type. @@ -1005,6 +1016,7 @@ def _get_realms(self, context=None, type=None): Returns: List of realm info objects with .realm, .origin, .type, .context attributes. """ + class _RealmInfo: def __init__(self2, realm, origin, type_, context): self2.realm = realm @@ -1017,13 +1029,16 @@ def __init__(self2, realm, origin, type_, context): result = [] for r in realms_list: if isinstance(r, dict): - result.append(_RealmInfo( - realm=r.get("realm"), - origin=r.get("origin"), - type_=r.get("type"), - context=r.get("context"), - )) + result.append( + _RealmInfo( + realm=r.get("realm"), + origin=r.get("origin"), + type_=r.get("type"), + context=r.get("context"), + ) + ) return result + def _disown(self, handles, target): """Disown handles in a browsing context. @@ -1032,6 +1047,7 @@ def _disown(self, handles, target): target: A dict like {"context": }. """ return self.disown(handles=handles, target=target) + def _subscribe_log_entry(self, callback, entry_type_filter=None): """Subscribe to log.entryAdded BiDi events with optional type filtering.""" import threading as _threading @@ -1068,9 +1084,7 @@ def _wrapped(raw): if entry_type_filter is None: callback(entry) else: - t = getattr(entry, "type_", None) or ( - entry.get("type") if isinstance(entry, dict) else None - ) + t = getattr(entry, "type_", None) or (entry.get("type") if isinstance(entry, dict) else None) if t == entry_type_filter: callback(entry) @@ -1086,15 +1100,14 @@ def from_json(self2, p): if bidi_event not in self._log_subscriptions: session = _Session(self._conn) result = session.subscribe([bidi_event]) - sub_id = ( - result.get("subscription") if isinstance(result, dict) else None - ) + sub_id = result.get("subscription") if isinstance(result, dict) else None self._log_subscriptions[bidi_event] = { "callbacks": [], "subscription_id": sub_id, } self._log_subscriptions[bidi_event]["callbacks"].append(callback_id) return callback_id + def _unsubscribe_log_entry(self, callback_id): """Unsubscribe a log entry callback by ID.""" from selenium.webdriver.common.bidi.session import Session as _Session @@ -1123,6 +1136,7 @@ def from_json(self2, p): else: session.unsubscribe(events=[bidi_event]) del self._log_subscriptions[bidi_event] + def add_console_message_handler(self, callback: Callable) -> int: """Add a handler for console log messages (log.entryAdded type=console). @@ -1133,9 +1147,11 @@ def add_console_message_handler(self, callback: Callable) -> int: callback_id for use with remove_console_message_handler. """ return self._subscribe_log_entry(callback, entry_type_filter="console") + def remove_console_message_handler(self, callback_id: int) -> None: """Remove a console message handler by callback ID.""" self._unsubscribe_log_entry(callback_id) + def add_javascript_error_handler(self, callback: Callable) -> int: """Add a handler for JavaScript error log messages (log.entryAdded type=javascript). @@ -1146,6 +1162,7 @@ def add_javascript_error_handler(self, callback: Callable) -> int: callback_id for use with remove_javascript_error_handler. """ return self._subscribe_log_entry(callback, entry_type_filter="javascript") + def remove_javascript_error_handler(self, callback_id: int) -> None: """Remove a JavaScript error handler by callback ID.""" self._unsubscribe_log_entry(callback_id) @@ -1176,12 +1193,13 @@ def clear_event_handlers(self) -> None: """Clear all event handlers.""" return self._event_manager.clear_event_handlers() + # Event Info Type Aliases # Event: script.realmCreated -RealmCreated = globals().get('RealmInfo', dict) # Fallback to dict if type not defined +RealmCreated = globals().get("RealmInfo", dict) # Fallback to dict if type not defined # Event: script.realmDestroyed -RealmDestroyed = globals().get('RealmDestroyedParameters', dict) # Fallback to dict if type not defined +RealmDestroyed = globals().get("RealmDestroyedParameters", dict) # Fallback to dict if type not defined # Populate EVENT_CONFIGS with event configuration mappings diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index e04d897e25deb..741faeb42bc43 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -180,6 +180,7 @@ def to_dict(self) -> dict: """Backward-compatible alias for to_bidi_dict().""" return self.to_bidi_dict() + class Session: """WebDriver BiDi session module.""" @@ -188,8 +189,7 @@ def __init__(self, conn) -> None: def status(self): """Execute session.status.""" - params = { - } + params = {} params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("session.status", params) result = self._conn.execute(cmd) @@ -210,8 +210,7 @@ def new(self, capabilities: Any | None = None): def end(self): """Execute session.end.""" - params = { - } + params = {} params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("session.end", params) result = self._conn.execute(cmd) @@ -247,4 +246,3 @@ def unsubscribe(self, events: list[Any] | None = None, subscriptions: list[Any] cmd = command_builder("session.unsubscribe", params) result = self._conn.execute(cmd) return result - diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 5ae8bf5aeb2d0..9825407c2eaf8 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -87,6 +87,7 @@ def to_dict(self) -> dict: """Backward-compatible alias for to_bidi_dict().""" return self.to_bidi_dict() + class SameSite: """SameSite cookie attribute values.""" @@ -95,6 +96,7 @@ class SameSite: NONE = "none" DEFAULT = "default" + @dataclass class StorageCookie: """A cookie object returned by storage.getCookies.""" @@ -129,6 +131,7 @@ def from_bidi_dict(cls, raw: dict) -> StorageCookie: expiry=raw.get("expiry"), ) + @dataclass class CookieFilter: """CookieFilter.""" @@ -170,6 +173,7 @@ def to_dict(self) -> dict: """Backward-compatible alias for to_bidi_dict().""" return self.to_bidi_dict() + @dataclass class PartialCookie: """PartialCookie.""" @@ -208,6 +212,7 @@ def to_dict(self) -> dict: """Backward-compatible alias for to_bidi_dict().""" return self.to_bidi_dict() + class BrowsingContextPartitionDescriptor: """BrowsingContextPartitionDescriptor. @@ -227,6 +232,7 @@ def to_dict(self) -> dict: """Backward-compatible alias for to_bidi_dict().""" return self.to_bidi_dict() + @dataclass class StorageKeyPartitionDescriptor: """StorageKeyPartitionDescriptor.""" @@ -248,6 +254,7 @@ def to_dict(self) -> dict: """Backward-compatible alias for to_bidi_dict().""" return self.to_bidi_dict() + class Storage: """WebDriver BiDi storage module.""" @@ -268,11 +275,7 @@ def get_cookies(self, filter=None, partition=None): cmd = command_builder("storage.getCookies", params) result = self._conn.execute(cmd) if result and "cookies" in result: - cookies = [ - StorageCookie.from_bidi_dict(c) - for c in result.get("cookies", []) - if isinstance(c, dict) - ] + cookies = [StorageCookie.from_bidi_dict(c) for c in result.get("cookies", []) if isinstance(c, dict)] pk_raw = result.get("partitionKey") pk = ( PartitionKey( @@ -284,6 +287,7 @@ def get_cookies(self, filter=None, partition=None): ) return GetCookiesResult(cookies=cookies, partition_key=pk) return GetCookiesResult(cookies=[], partition_key=None) + def set_cookie(self, cookie=None, partition=None): """Execute storage.setCookie.""" if cookie and hasattr(cookie, "to_bidi_dict"): @@ -309,6 +313,7 @@ def set_cookie(self, cookie=None, partition=None): ) return SetCookieResult(partition_key=pk) return result + def delete_cookies(self, filter=None, partition=None): """Execute storage.deleteCookies.""" if filter and hasattr(filter, "to_bidi_dict"): diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 0a28843e339f1..03fedab62e174 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -87,14 +87,16 @@ def install( ValueError: If more than one, or none, of the arguments is provided. """ provided = [ - k for k, v in { - "path": path, "archive_path": archive_path, "base64_value": base64_value, - }.items() if v is not None + k + for k, v in { + "path": path, + "archive_path": archive_path, + "base64_value": base64_value, + }.items() + if v is not None ] if len(provided) != 1: - raise ValueError( - f"Exactly one of path, archive_path, or base64_value must be provided; got: {provided}" - ) + raise ValueError(f"Exactly one of path, archive_path, or base64_value must be provided; got: {provided}") if path is not None: extension_data = {"type": "path", "path": path} elif archive_path is not None: @@ -115,6 +117,7 @@ def install( "in your WebDriver configuration." ) from e raise + def uninstall(self, extension: str | dict): """Uninstall a web extension. diff --git a/py/selenium/webdriver/common/proxy.py b/py/selenium/webdriver/common/proxy.py index 28de19afa5742..eadf1d069709f 100644 --- a/py/selenium/webdriver/common/proxy.py +++ b/py/selenium/webdriver/common/proxy.py @@ -35,23 +35,13 @@ class ProxyType: profile preference, 'string' is id of proxy type. """ - DIRECT = ProxyTypeFactory.make( - 0, "DIRECT" - ) # Direct connection, no proxy (default on Windows). - MANUAL = ProxyTypeFactory.make( - 1, "MANUAL" - ) # Manual proxy settings (e.g., for httpProxy). + DIRECT = ProxyTypeFactory.make(0, "DIRECT") # Direct connection, no proxy (default on Windows). + MANUAL = ProxyTypeFactory.make(1, "MANUAL") # Manual proxy settings (e.g., for httpProxy). PAC = ProxyTypeFactory.make(2, "PAC") # Proxy autoconfiguration from URL. RESERVED_1 = ProxyTypeFactory.make(3, "RESERVED1") # Never used. - AUTODETECT = ProxyTypeFactory.make( - 4, "AUTODETECT" - ) # Proxy autodetection (presumably with WPAD). - SYSTEM = ProxyTypeFactory.make( - 5, "SYSTEM" - ) # Use system settings (default on Linux). - UNSPECIFIED = ProxyTypeFactory.make( - 6, "UNSPECIFIED" - ) # Not initialized (for internal use). + AUTODETECT = ProxyTypeFactory.make(4, "AUTODETECT") # Proxy autodetection (presumably with WPAD). + SYSTEM = ProxyTypeFactory.make(5, "SYSTEM") # Use system settings (default on Linux). + UNSPECIFIED = ProxyTypeFactory.make(6, "UNSPECIFIED") # Not initialized (for internal use). @classmethod def load(cls, value): @@ -60,11 +50,7 @@ def load(cls, value): value = str(value).upper() for attr in dir(cls): attr_value = getattr(cls, attr) - if ( - isinstance(attr_value, dict) - and "string" in attr_value - and attr_value["string"] == value - ): + if isinstance(attr_value, dict) and "string" in attr_value and attr_value["string"] == value: return attr_value raise Exception(f"No proxy type is found for {value}") @@ -219,17 +205,13 @@ def to_bidi_dict(self) -> dict: if self.noProxy: # Convert comma-separated string to list if isinstance(self.noProxy, str): - result["noProxy"] = [ - host.strip() for host in self.noProxy.split(",") if host.strip() - ] + result["noProxy"] = [host.strip() for host in self.noProxy.split(",") if host.strip()] elif isinstance(self.noProxy, list): if not all(isinstance(h, str) for h in self.noProxy): raise TypeError("no_proxy list must contain only strings") result["noProxy"] = self.noProxy else: - raise TypeError( - "no_proxy must be a comma-separated string or a list of strings" - ) + raise TypeError("no_proxy must be a comma-separated string or a list of strings") elif proxy_type == "pac": if self.proxyAutoconfigUrl: diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index 2c41897878075..4e426090883d4 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -116,9 +116,7 @@ def get_remote_connection( client_config: ClientConfig | None = None, ) -> RemoteConnection: if isinstance(command_executor, str): - client_config = client_config or ClientConfig( - remote_server_addr=command_executor - ) + client_config = client_config or ClientConfig(remote_server_addr=command_executor) client_config.remote_server_addr = command_executor command_executor = RemoteConnection(client_config=client_config) @@ -400,13 +398,9 @@ def create_web_element(self, element_id: str) -> WebElement: def _unwrap_value(self, value): if isinstance(value, dict): if "element-6066-11e4-a52e-4f735466cecf" in value: - return self.create_web_element( - value["element-6066-11e4-a52e-4f735466cecf"] - ) + return self.create_web_element(value["element-6066-11e4-a52e-4f735466cecf"]) if "shadow-6066-11e4-a52e-4f735466cecf" in value: - return self._shadowroot_cls( - self, value["shadow-6066-11e4-a52e-4f735466cecf"] - ) + return self._shadowroot_cls(self, value["shadow-6066-11e4-a52e-4f735466cecf"]) for key, val in value.items(): value[key] = self._unwrap_value(val) return value @@ -432,9 +426,7 @@ def execute_cdp_cmd(self, cmd: str, cmd_args: dict): Example: `driver.execute_cdp_cmd("Network.getResponseBody", {"requestId": requestId})` """ - return self.execute("executeCdpCommand", {"cmd": cmd, "params": cmd_args})[ - "value" - ] + return self.execute("executeCdpCommand", {"cmd": cmd, "params": cmd_args})["value"] def execute( self, @@ -470,9 +462,7 @@ def execute( elif "sessionId" not in params: params["sessionId"] = self.session_id - response = cast(RemoteConnection, self.command_executor).execute( - driver_command, params - ) + response = cast(RemoteConnection, self.command_executor).execute(driver_command, params) if response: self.error_handler.check_response(response) @@ -528,9 +518,7 @@ def unpin(self, script_key: ScriptKey) -> None: try: self.pinned_scripts.pop(script_key.id) except KeyError: - raise KeyError( - f"No script with key: {script_key} existed in {self.pinned_scripts}" - ) from None + raise KeyError(f"No script with key: {script_key} existed in {self.pinned_scripts}") from None def get_pinned_scripts(self) -> list[str]: """Return a list of all pinned scripts. @@ -563,9 +551,7 @@ def execute_script(self, script: str, *args) -> Any: converted_args = list(args) command = Command.W3C_EXECUTE_SCRIPT - return self.execute(command, {"script": script, "args": converted_args})[ - "value" - ] + return self.execute(command, {"script": script, "args": converted_args})["value"] def execute_async_script(self, script: str, *args) -> Any: """Asynchronously Executes JavaScript in the current window/frame. @@ -584,9 +570,7 @@ def execute_async_script(self, script: str, *args) -> Any: converted_args = list(args) command = Command.W3C_EXECUTE_SCRIPT_ASYNC - return self.execute(command, {"script": script, "args": converted_args})[ - "value" - ] + return self.execute(command, {"script": script, "args": converted_args})["value"] @property def current_url(self) -> str: @@ -763,9 +747,7 @@ def implicitly_wait(self, time_to_wait: float) -> None: Example: `driver.implicitly_wait(30)` """ - self.execute( - Command.SET_TIMEOUTS, {"implicit": int(float(time_to_wait) * 1000)} - ) + self.execute(Command.SET_TIMEOUTS, {"implicit": int(float(time_to_wait) * 1000)}) def set_script_timeout(self, time_to_wait: float) -> None: """Set the timeout for asynchronous script execution. @@ -794,9 +776,7 @@ def set_page_load_timeout(self, time_to_wait: float) -> None: `driver.set_page_load_timeout(30)` """ try: - self.execute( - Command.SET_TIMEOUTS, {"pageLoad": int(float(time_to_wait) * 1000)} - ) + self.execute(Command.SET_TIMEOUTS, {"pageLoad": int(float(time_to_wait) * 1000)}) except WebDriverException: self.execute( Command.SET_TIMEOUTS, @@ -837,9 +817,7 @@ def timeouts(self, timeouts) -> None: """ _ = self.execute(Command.SET_TIMEOUTS, timeouts._to_json())["value"] - def find_element( - self, by: str | RelativeBy = By.ID, value: str | None = None - ) -> WebElement: + def find_element(self, by: str | RelativeBy = By.ID, value: str | None = None) -> WebElement: """Find an element given a By strategy and locator. Args: @@ -860,18 +838,12 @@ def find_element( if isinstance(by, RelativeBy): elements = self.find_elements(by=by, value=value) if not elements: - raise NoSuchElementException( - f"Cannot locate relative element with: {by.root}" - ) + raise NoSuchElementException(f"Cannot locate relative element with: {by.root}") return elements[0] - return self.execute(Command.FIND_ELEMENT, {"using": by, "value": value})[ - "value" - ] + return self.execute(Command.FIND_ELEMENT, {"using": by, "value": value})["value"] - def find_elements( - self, by: str | RelativeBy = By.ID, value: str | None = None - ) -> list[WebElement]: + def find_elements(self, by: str | RelativeBy = By.ID, value: str | None = None) -> list[WebElement]: """Find elements given a By strategy and locator. Args: @@ -893,21 +865,14 @@ def find_elements( _pkg = ".".join(__name__.split(".")[:-1]) raw_data = pkgutil.get_data(_pkg, "findElements.js") if raw_data is None: - raise FileNotFoundError( - f"Could not find findElements.js in package {_pkg}" - ) + raise FileNotFoundError(f"Could not find findElements.js in package {_pkg}") raw_function = raw_data.decode("utf8") - find_element_js = ( - f"/* findElements */return ({raw_function}).apply(null, arguments);" - ) + find_element_js = f"/* findElements */return ({raw_function}).apply(null, arguments);" return self.execute_script(find_element_js, by.to_dict()) # Return empty list if driver returns null # See https://github.com/SeleniumHQ/selenium/issues/4555 - return ( - self.execute(Command.FIND_ELEMENTS, {"using": by, "value": value})["value"] - or [] - ) + return self.execute(Command.FIND_ELEMENTS, {"using": by, "value": value})["value"] or [] @property def capabilities(self) -> dict: @@ -1004,9 +969,7 @@ def get_window_size(self, windowHandle: str = "current") -> dict: return {k: size[k] for k in ("width", "height")} - def set_window_position( - self, x: float, y: float, windowHandle: str = "current" - ) -> dict: + def set_window_position(self, x: float, y: float, windowHandle: str = "current") -> dict: """Sets the x,y position of the current window. Args: @@ -1065,9 +1028,7 @@ def set_window_rect(self, x=None, y=None, width=None, height=None) -> dict: if (x is None and y is None) and (not height and not width): raise InvalidArgumentException("x and y or height and width need values") - return self.execute( - Command.SET_WINDOW_RECT, {"x": x, "y": y, "width": width, "height": height} - )["value"] + return self.execute(Command.SET_WINDOW_RECT, {"x": x, "y": y, "width": width, "height": height})["value"] @property def file_detector(self) -> FileDetector: @@ -1112,9 +1073,7 @@ def orientation(self, value) -> None: if value.upper() in allowed_values: self.execute(Command.SET_SCREEN_ORIENTATION, {"orientation": value}) else: - raise WebDriverException( - "You can only set the orientation to 'LANDSCAPE' and 'PORTRAIT'" - ) + raise WebDriverException("You can only set the orientation to 'LANDSCAPE' and 'PORTRAIT'") def start_devtools(self) -> tuple[Any, WebSocketConnection]: global cdp @@ -1129,9 +1088,7 @@ def start_devtools(self) -> tuple[Any, WebSocketConnection]: version, ws_url = self._get_cdp_details() if not ws_url: - raise WebDriverException( - "Unable to find url to connect to from capabilities" - ) + raise WebDriverException("Unable to find url to connect to from capabilities") if cdp is None: raise WebDriverException("CDP module not loaded") @@ -1140,28 +1097,20 @@ def start_devtools(self) -> tuple[Any, WebSocketConnection]: if self._websocket_connection: return self._devtools, self._websocket_connection if self.caps["browserName"].lower() == "firefox": - raise RuntimeError( - "CDP support for Firefox has been removed. Please switch to WebDriver BiDi." - ) + raise RuntimeError("CDP support for Firefox has been removed. Please switch to WebDriver BiDi.") if not isinstance(self.command_executor, RemoteConnection): - raise WebDriverException( - "command_executor must be a RemoteConnection instance for CDP support" - ) + raise WebDriverException("command_executor must be a RemoteConnection instance for CDP support") self._websocket_connection = WebSocketConnection( ws_url, self.command_executor.client_config.websocket_timeout, self.command_executor.client_config.websocket_interval, ) - targets = self._websocket_connection.execute( - self._devtools.target.get_targets() - ) + targets = self._websocket_connection.execute(self._devtools.target.get_targets()) for target in targets: if target.target_id == self.current_window_handle: target_id = target.target_id break - session = self._websocket_connection.execute( - self._devtools.target.attach_to_target(target_id, True) - ) + session = self._websocket_connection.execute(self._devtools.target.attach_to_target(target_id, True)) self._websocket_connection.session_id = session return self._devtools, self._websocket_connection @@ -1176,9 +1125,7 @@ async def bidi_connection(self): version, ws_url = self._get_cdp_details() if not ws_url: - raise WebDriverException( - "Unable to find url to connect to from capabilities" - ) + raise WebDriverException("Unable to find url to connect to from capabilities") devtools = cdp.import_devtools(version) async with cdp.open_cdp(ws_url) as conn: @@ -1204,14 +1151,10 @@ def _start_bidi(self) -> None: if self.caps.get("webSocketUrl"): ws_url = self.caps.get("webSocketUrl") else: - raise WebDriverException( - "Unable to find url to connect to from capabilities" - ) + raise WebDriverException("Unable to find url to connect to from capabilities") if not isinstance(self.command_executor, RemoteConnection): - raise WebDriverException( - "command_executor must be a RemoteConnection instance for BiDi support" - ) + raise WebDriverException("command_executor must be a RemoteConnection instance for BiDi support") self._websocket_connection = WebSocketConnection( ws_url, @@ -1427,13 +1370,9 @@ def _get_cdp_details(self): http = urllib3.PoolManager() try: if self.caps.get("browserName") == "chrome": - debugger_address = self.caps.get("goog:chromeOptions").get( - "debuggerAddress" - ) + debugger_address = self.caps.get("goog:chromeOptions").get("debuggerAddress") elif self.caps.get("browserName") in ("MicrosoftEdge", "webview2"): - debugger_address = self.caps.get("ms:edgeOptions").get( - "debuggerAddress" - ) + debugger_address = self.caps.get("ms:edgeOptions").get("debuggerAddress") except AttributeError: raise WebDriverException("Can't get debugger address.") @@ -1461,9 +1400,7 @@ def add_virtual_authenticator(self, options: VirtualAuthenticatorOptions) -> Non driver.add_virtual_authenticator(options) ``` """ - self._authenticator_id = self.execute( - Command.ADD_VIRTUAL_AUTHENTICATOR, options.to_dict() - )["value"] + self._authenticator_id = self.execute(Command.ADD_VIRTUAL_AUTHENTICATOR, options.to_dict())["value"] @property def virtual_authenticator_id(self) -> str | None: @@ -1503,12 +1440,8 @@ def add_credential(self, credential: Credential) -> None: @required_virtual_authenticator def get_credentials(self) -> list[Credential]: """Returns the list of credentials owned by the authenticator.""" - credential_data = self.execute( - Command.GET_CREDENTIALS, {"authenticatorId": self._authenticator_id} - ) - return [ - Credential.from_dict(credential) for credential in credential_data["value"] - ] + credential_data = self.execute(Command.GET_CREDENTIALS, {"authenticatorId": self._authenticator_id}) + return [Credential.from_dict(credential) for credential in credential_data["value"]] @required_virtual_authenticator def remove_credential(self, credential_id: str | bytearray) -> None: @@ -1530,9 +1463,7 @@ def remove_credential(self, credential_id: str | bytearray) -> None: @required_virtual_authenticator def remove_all_credentials(self) -> None: """Removes all credentials from the authenticator.""" - self.execute( - Command.REMOVE_ALL_CREDENTIALS, {"authenticatorId": self._authenticator_id} - ) + self.execute(Command.REMOVE_ALL_CREDENTIALS, {"authenticatorId": self._authenticator_id}) @required_virtual_authenticator def set_user_verified(self, verified: bool) -> None: @@ -1553,9 +1484,7 @@ def set_user_verified(self, verified: bool) -> None: def get_downloadable_files(self) -> list: """Retrieves the downloadable files as a list of file names.""" if "se:downloadsEnabled" not in self.capabilities: - raise WebDriverException( - "You must enable downloads in order to work with downloadable files." - ) + raise WebDriverException("You must enable downloads in order to work with downloadable files.") return self.execute(Command.GET_DOWNLOADABLE_FILES)["value"]["names"] @@ -1570,16 +1499,12 @@ def download_file(self, file_name: str, target_directory: str) -> None: `driver.download_file("example.zip", "/path/to/directory")` """ if "se:downloadsEnabled" not in self.capabilities: - raise WebDriverException( - "You must enable downloads in order to work with downloadable files." - ) + raise WebDriverException("You must enable downloads in order to work with downloadable files.") if not os.path.exists(target_directory): os.makedirs(target_directory) - contents = self.execute(Command.DOWNLOAD_FILE, {"name": file_name})["value"][ - "contents" - ] + contents = self.execute(Command.DOWNLOAD_FILE, {"name": file_name})["value"]["contents"] with tempfile.TemporaryDirectory() as tmp_dir: zip_file = os.path.join(tmp_dir, file_name + ".zip") @@ -1592,9 +1517,7 @@ def download_file(self, file_name: str, target_directory: str) -> None: def delete_downloadable_files(self) -> None: """Deletes all downloadable files.""" if "se:downloadsEnabled" not in self.capabilities: - raise WebDriverException( - "You must enable downloads in order to work with downloadable files." - ) + raise WebDriverException("You must enable downloads in order to work with downloadable files.") self.execute(Command.DELETE_DOWNLOADABLE_FILES) diff --git a/py/selenium/webdriver/remote/websocket_connection.py b/py/selenium/webdriver/remote/websocket_connection.py index 44cb2adef7a0b..cd34c35db3696 100644 --- a/py/selenium/webdriver/remote/websocket_connection.py +++ b/py/selenium/webdriver/remote/websocket_connection.py @@ -158,9 +158,7 @@ def _serialize_command(self, command): def _deserialize_result(self, result, command): try: _ = command.send(result) - raise WebDriverException( - "The command's generator function did not exit when expected!" - ) + raise WebDriverException("The command's generator function did not exit when expected!") except StopIteration as exit: return exit.value @@ -177,15 +175,11 @@ def on_error(ws, error): def run_socket(): if self.url.startswith("wss://"): - self._ws.run_forever( - sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True - ) + self._ws.run_forever(sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True) else: self._ws.run_forever(suppress_origin=True) - self._ws = WebSocketApp( - self.url, on_open=on_open, on_message=on_message, on_error=on_error - ) + self._ws = WebSocketApp(self.url, on_open=on_open, on_message=on_message, on_error=on_error) self._ws_thread = Thread(target=run_socket, daemon=True) self._ws_thread.start() diff --git a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py index 8038ee826aa74..86e3d11af0341 100644 --- a/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py +++ b/py/test/selenium/webdriver/common/bidi_browsing_context_tests.py @@ -60,9 +60,7 @@ def test_create_window(driver): def test_create_window_with_reference_context(driver): """Test creating a window with a reference context.""" reference_context = driver.current_window_handle - context_id = driver.browsing_context.create( - type=WindowTypes.WINDOW, reference_context=reference_context - ) + context_id = driver.browsing_context.create(type=WindowTypes.WINDOW, reference_context=reference_context) assert context_id is not None # Clean up @@ -81,9 +79,7 @@ def test_create_tab(driver): def test_create_tab_with_reference_context(driver): """Test creating a tab with a reference context.""" reference_context = driver.current_window_handle - context_id = driver.browsing_context.create( - type=WindowTypes.TAB, reference_context=reference_context - ) + context_id = driver.browsing_context.create(type=WindowTypes.TAB, reference_context=reference_context) assert context_id is not None # Clean up @@ -128,9 +124,7 @@ def test_navigate_to_url_with_readiness_state(driver, pages): context_id = driver.browsing_context.create(type=WindowTypes.TAB) url = pages.url("bidi/logEntryAdded.html") - result = driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + result = driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) assert context_id is not None assert "/bidi/logEntryAdded.html" in result["url"] @@ -144,9 +138,7 @@ def test_get_tree_with_child(driver, pages): reference_context = driver.current_window_handle url = pages.url("iframes.html") - driver.browsing_context.navigate( - context=reference_context, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=reference_context, url=url, wait=ReadinessState.COMPLETE) context_info_list = driver.browsing_context.get_tree(root=reference_context) @@ -162,13 +154,9 @@ def test_get_tree_with_depth(driver, pages): reference_context = driver.current_window_handle url = pages.url("iframes.html") - driver.browsing_context.navigate( - context=reference_context, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=reference_context, url=url, wait=ReadinessState.COMPLETE) - context_info_list = driver.browsing_context.get_tree( - root=reference_context, max_depth=0 - ) + context_info_list = driver.browsing_context.get_tree(root=reference_context, max_depth=0) assert len(context_info_list) == 1 info = context_info_list[0] @@ -239,9 +227,7 @@ def test_reload_browsing_context(driver, pages): context_id = driver.browsing_context.create(type=WindowTypes.TAB) url = pages.url("bidi/logEntryAdded.html") - driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) reload_info = driver.browsing_context.reload(context=context_id) @@ -256,13 +242,9 @@ def test_reload_with_readiness_state(driver, pages): context_id = driver.browsing_context.create(type=WindowTypes.TAB) url = pages.url("bidi/logEntryAdded.html") - driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) - reload_info = driver.browsing_context.reload( - context=context_id, wait=ReadinessState.COMPLETE - ) + reload_info = driver.browsing_context.reload(context=context_id, wait=ReadinessState.COMPLETE) assert reload_info["navigation"] is not None assert "/bidi/logEntryAdded.html" in reload_info["url"] @@ -359,9 +341,7 @@ def test_capture_screenshot_with_parameters(driver, pages): clip = {"type": "box", "x": rect["x"], "y": rect["y"], "width": 5, "height": 5} - screenshot = driver.browsing_context.capture_screenshot( - context=context_id, origin="document", clip=clip - ) + screenshot = driver.browsing_context.capture_screenshot(context=context_id, origin="document", clip=clip) assert len(screenshot) > 0 @@ -372,20 +352,14 @@ def test_set_viewport(driver, pages): driver.get(pages.url("formPage.html")) try: - driver.browsing_context.set_viewport( - context=context_id, viewport={"width": 251, "height": 301} - ) + driver.browsing_context.set_viewport(context=context_id, viewport={"width": 251, "height": 301}) - viewport_size = driver.execute_script( - "return [window.innerWidth, window.innerHeight];" - ) + viewport_size = driver.execute_script("return [window.innerWidth, window.innerHeight];") assert viewport_size[0] == 251 assert viewport_size[1] == 301 finally: - driver.browsing_context.set_viewport( - context=context_id, viewport=None, device_pixel_ratio=None - ) + driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) def test_set_viewport_with_device_pixel_ratio(driver, pages): @@ -400,9 +374,7 @@ def test_set_viewport_with_device_pixel_ratio(driver, pages): device_pixel_ratio=5, ) - viewport_size = driver.execute_script( - "return [window.innerWidth, window.innerHeight];" - ) + viewport_size = driver.execute_script("return [window.innerWidth, window.innerHeight];") assert viewport_size[0] == 252 assert viewport_size[1] == 302 @@ -411,9 +383,7 @@ def test_set_viewport_with_device_pixel_ratio(driver, pages): assert device_pixel_ratio == 5 finally: - driver.browsing_context.set_viewport( - context=context_id, viewport=None, device_pixel_ratio=None - ) + driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) def test_set_viewport_with_no_args_doesnt_change_values(driver, pages): @@ -430,9 +400,7 @@ def test_set_viewport_with_no_args_doesnt_change_values(driver, pages): driver.browsing_context.set_viewport(context=context_id) - viewport_size = driver.execute_script( - "return [window.innerWidth, window.innerHeight];" - ) + viewport_size = driver.execute_script("return [window.innerWidth, window.innerHeight];") assert viewport_size[0] == 253 assert viewport_size[1] == 303 @@ -441,9 +409,7 @@ def test_set_viewport_with_no_args_doesnt_change_values(driver, pages): assert device_pixel_ratio == 6 finally: - driver.browsing_context.set_viewport( - context=context_id, viewport=None, device_pixel_ratio=None - ) + driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) @pytest.mark.xfail_chrome @@ -452,9 +418,7 @@ def test_set_viewport_back_to_default(driver, pages): context_id = driver.current_window_handle driver.get(pages.url("formPage.html")) - default_viewport_size = driver.execute_script( - "return [window.innerWidth, window.innerHeight];" - ) + default_viewport_size = driver.execute_script("return [window.innerWidth, window.innerHeight];") default_device_pixel_ratio = driver.execute_script("return window.devicePixelRatio") try: @@ -464,13 +428,9 @@ def test_set_viewport_back_to_default(driver, pages): device_pixel_ratio=10, ) - driver.browsing_context.set_viewport( - context=context_id, viewport=None, device_pixel_ratio=None - ) + driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) - viewport_size = driver.execute_script( - "return [window.innerWidth, window.innerHeight];" - ) + viewport_size = driver.execute_script("return [window.innerWidth, window.innerHeight];") device_pixel_ratio = driver.execute_script("return window.devicePixelRatio") # Allow some tolerance since some window managers might not put it to the exact value @@ -478,9 +438,7 @@ def test_set_viewport_back_to_default(driver, pages): assert abs(viewport_size[1] - default_viewport_size[1]) <= 5 assert device_pixel_ratio == default_device_pixel_ratio finally: - driver.browsing_context.set_viewport( - context=context_id, viewport=None, device_pixel_ratio=None - ) + driver.browsing_context.set_viewport(context=context_id, viewport=None, device_pixel_ratio=None) def test_print_page(driver, pages): @@ -499,9 +457,7 @@ def test_print_page(driver, pages): def test_navigate_back_in_browser_history(driver, pages): """Test navigating back in the browser history.""" context_id = driver.current_window_handle - driver.browsing_context.navigate( - context=context_id, url=pages.url("formPage.html"), wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=pages.url("formPage.html"), wait=ReadinessState.COMPLETE) # Navigate to another page by submitting a form driver.find_element(By.ID, "imageButton").submit() @@ -514,9 +470,7 @@ def test_navigate_back_in_browser_history(driver, pages): def test_navigate_forward_in_browser_history(driver, pages): """Test navigating forward in the browser history.""" context_id = driver.current_window_handle - driver.browsing_context.navigate( - context=context_id, url=pages.url("formPage.html"), wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=pages.url("formPage.html"), wait=ReadinessState.COMPLETE) # Navigate to another page by submitting a form driver.find_element(By.ID, "imageButton").submit() @@ -538,9 +492,7 @@ def test_locate_nodes(driver, pages): driver.get(pages.url("xhtmlTest.html")) - elements = driver.browsing_context.locate_nodes( - context=context_id, locator={"type": "css", "value": "div"} - ) + elements = driver.browsing_context.locate_nodes(context=context_id, locator={"type": "css", "value": "div"}) assert len(elements) > 0 @@ -660,9 +612,7 @@ def test_add_event_handler_context_created(driver): def on_context_created(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "context_created", on_context_created - ) + callback_id = driver.browsing_context.add_event_handler("context_created", on_context_created) assert callback_id is not None # Create a new context to trigger the event @@ -670,10 +620,7 @@ def on_context_created(info): # Verify the event was received (might be > 1 since default context is also included) assert len(events_received) >= 1 - assert ( - events_received[0].context == context_id - or events_received[1].context == context_id - ) + assert events_received[0].context == context_id or events_received[1].context == context_id driver.browsing_context.close(context_id) driver.browsing_context.remove_event_handler("context_created", callback_id) @@ -686,9 +633,7 @@ def test_add_event_handler_context_destroyed(driver): def on_context_destroyed(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "context_destroyed", on_context_destroyed - ) + callback_id = driver.browsing_context.add_event_handler("context_destroyed", on_context_destroyed) assert callback_id is not None # Create and then close a context to trigger the event @@ -708,17 +653,13 @@ def test_add_event_handler_navigation_committed(driver, pages): def on_navigation_committed(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "navigation_committed", on_navigation_committed - ) + callback_id = driver.browsing_context.add_event_handler("navigation_committed", on_navigation_committed) assert callback_id is not None # Navigate to trigger the event context_id = driver.current_window_handle url = pages.url("simpleTest.html") - driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) assert len(events_received) >= 1 assert any(url in event.url for event in events_received) @@ -733,17 +674,13 @@ def test_add_event_handler_dom_content_loaded(driver, pages): def on_dom_content_loaded(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "dom_content_loaded", on_dom_content_loaded - ) + callback_id = driver.browsing_context.add_event_handler("dom_content_loaded", on_dom_content_loaded) assert callback_id is not None # Navigate to trigger the event context_id = driver.current_window_handle url = pages.url("simpleTest.html") - driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) assert len(events_received) == 1 assert any("simpleTest" in event.url for event in events_received) @@ -763,9 +700,7 @@ def on_load(info): context_id = driver.current_window_handle url = pages.url("simpleTest.html") - driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) assert len(events_received) == 1 assert any("simpleTest" in event.url for event in events_received) @@ -780,16 +715,12 @@ def test_add_event_handler_navigation_started(driver, pages): def on_navigation_started(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "navigation_started", on_navigation_started - ) + callback_id = driver.browsing_context.add_event_handler("navigation_started", on_navigation_started) assert callback_id is not None context_id = driver.current_window_handle url = pages.url("simpleTest.html") - driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) assert len(events_received) == 1 assert any("simpleTest" in event.url for event in events_received) @@ -804,23 +735,17 @@ def test_add_event_handler_fragment_navigated(driver, pages): def on_fragment_navigated(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "fragment_navigated", on_fragment_navigated - ) + callback_id = driver.browsing_context.add_event_handler("fragment_navigated", on_fragment_navigated) assert callback_id is not None # First navigate to a page context_id = driver.current_window_handle url = pages.url("linked_image.html") - driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) # Then navigate to the same page with a fragment to trigger the event fragment_url = url + "#link" - driver.browsing_context.navigate( - context=context_id, url=fragment_url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=fragment_url, wait=ReadinessState.COMPLETE) assert len(events_received) == 1 assert any("link" in event.url for event in events_received) @@ -836,17 +761,13 @@ def test_add_event_handler_navigation_failed(driver): def on_navigation_failed(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "navigation_failed", on_navigation_failed - ) + callback_id = driver.browsing_context.add_event_handler("navigation_failed", on_navigation_failed) assert callback_id is not None # Navigate to an invalid URL to trigger the event context_id = driver.current_window_handle try: - driver.browsing_context.navigate( - context=context_id, url="http://invalid-domain-that-does-not-exist.test/" - ) + driver.browsing_context.navigate(context=context_id, url="http://invalid-domain-that-does-not-exist.test/") except Exception: # Expect an exception due to navigation failure pass @@ -865,9 +786,7 @@ def test_add_event_handler_user_prompt_opened(driver, pages): def on_user_prompt_opened(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "user_prompt_opened", on_user_prompt_opened - ) + callback_id = driver.browsing_context.add_event_handler("user_prompt_opened", on_user_prompt_opened) assert callback_id is not None # Create an alert to trigger the event @@ -892,9 +811,7 @@ def test_add_event_handler_user_prompt_closed(driver, pages): def on_user_prompt_closed(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "user_prompt_closed", on_user_prompt_closed - ) + callback_id = driver.browsing_context.add_event_handler("user_prompt_closed", on_user_prompt_closed) assert callback_id is not None create_prompt_page(driver, pages) @@ -919,16 +836,12 @@ def test_add_event_handler_history_updated(driver, pages): def on_history_updated(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "history_updated", on_history_updated - ) + callback_id = driver.browsing_context.add_event_handler("history_updated", on_history_updated) assert callback_id is not None context_id = driver.current_window_handle url = pages.url("simpleTest.html") - driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) # Use history.pushState to trigger history updated event driver.script.execute("() => { history.pushState({}, '', '/new-path'); }") @@ -948,17 +861,13 @@ def test_add_event_handler_download_will_begin(driver, pages): def on_download_will_begin(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "download_will_begin", on_download_will_begin - ) + callback_id = driver.browsing_context.add_event_handler("download_will_begin", on_download_will_begin) assert callback_id is not None # click on a download link to trigger the event context_id = driver.current_window_handle url = pages.url("downloads/download.html") - driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) download_xpath_file_1_txt = '//*[@id="file-1"]' driver.find_element(By.XPATH, download_xpath_file_1_txt).click() @@ -978,16 +887,12 @@ def test_add_event_handler_download_end(driver, pages): def on_download_end(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "download_end", on_download_end - ) + callback_id = driver.browsing_context.add_event_handler("download_end", on_download_end) assert callback_id is not None context_id = driver.current_window_handle url = pages.url("downloads/download.html") - driver.browsing_context.navigate( - context=context_id, url=url, wait=ReadinessState.COMPLETE - ) + driver.browsing_context.navigate(context=context_id, url=url, wait=ReadinessState.COMPLETE) driver.find_element(By.ID, "file-1").click() @@ -1005,14 +910,12 @@ def on_download_end(info): # we assert that atleast "file_1" is present in the downloaded file since multiple downloads # will have numbered suffix like file_1 (1) assert any( - "downloads/file_1.txt" in ev.download_params.url - and "file_1" in ev.download_params.filepath + "downloads/file_1.txt" in ev.download_params.url and "file_1" in ev.download_params.filepath for ev in events_received ) assert any( - "downloads/file_2.jpg" in ev.download_params.url - and "file_2" in ev.download_params.filepath + "downloads/file_2.jpg" in ev.download_params.url and "file_2" in ev.download_params.filepath for ev in events_received ) @@ -1051,9 +954,7 @@ def test_remove_event_handler(driver): def on_context_created(info): events_received.append(info) - callback_id = driver.browsing_context.add_event_handler( - "context_created", on_context_created - ) + callback_id = driver.browsing_context.add_event_handler("context_created", on_context_created) # Create a context to trigger the event context_id_1 = driver.browsing_context.create(type=WindowTypes.TAB) @@ -1085,12 +986,8 @@ def on_context_created_2(info): events_received_2.append(info) # Add multiple event handlers for the same event - callback_id_1 = driver.browsing_context.add_event_handler( - "context_created", on_context_created_1 - ) - callback_id_2 = driver.browsing_context.add_event_handler( - "context_created", on_context_created_2 - ) + callback_id_1 = driver.browsing_context.add_event_handler("context_created", on_context_created_1) + callback_id_2 = driver.browsing_context.add_event_handler("context_created", on_context_created_2) # Create a context to trigger both handlers context_id = driver.browsing_context.create(type=WindowTypes.TAB) @@ -1119,12 +1016,8 @@ def on_context_created_2(info): events_received_2.append(info) # Add multiple event handlers - callback_id_1 = driver.browsing_context.add_event_handler( - "context_created", on_context_created_1 - ) - callback_id_2 = driver.browsing_context.add_event_handler( - "context_created", on_context_created_2 - ) + callback_id_1 = driver.browsing_context.add_event_handler("context_created", on_context_created_1) + callback_id_2 = driver.browsing_context.add_event_handler("context_created", on_context_created_2) # Create a context to trigger both handlers context_id_1 = driver.browsing_context.create(type=WindowTypes.TAB) @@ -1206,9 +1099,7 @@ def callback(info): def register_handler(self, thread_id): try: callback = self.make_callback() - callback_id = self.driver.browsing_context.add_event_handler( - "context_created", callback - ) + callback_id = self.driver.browsing_context.add_event_handler("context_created", callback) with self.data_lock: self.callback_ids.append(callback_id) if len(self.callback_ids) == 5: @@ -1216,16 +1107,12 @@ def register_handler(self, thread_id): return callback_id except Exception as e: with self.data_lock: - self.thread_errors.append( - f"Thread {thread_id}: Registration failed: {e}" - ) + self.thread_errors.append(f"Thread {thread_id}: Registration failed: {e}") return None def remove_handler(self, callback_id, thread_id): try: - self.driver.browsing_context.remove_event_handler( - "context_created", callback_id - ) + self.driver.browsing_context.remove_event_handler("context_created", callback_id) except Exception as e: with self.data_lock: self.thread_errors.append(f"Thread {thread_id}: Removal failed: {e}") @@ -1235,19 +1122,13 @@ def test_concurrent_event_handler_registration(driver): helper = _EventHandlerTestHelper(driver) with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: - futures = [ - executor.submit(helper.register_handler, f"reg-{i}") for i in range(5) - ] + futures = [executor.submit(helper.register_handler, f"reg-{i}") for i in range(5)] for future in futures: future.result(timeout=15) helper.registration_complete.wait(timeout=5) - assert ( - len(helper.callback_ids) == 5 - ), f"Expected 5 handlers, got {len(helper.callback_ids)}" - assert not helper.thread_errors, "Errors during registration: \n" + "\n".join( - helper.thread_errors - ) + assert len(helper.callback_ids) == 5, f"Expected 5 handlers, got {len(helper.callback_ids)}" + assert not helper.thread_errors, "Errors during registration: \n" + "\n".join(helper.thread_errors) def test_event_callback_data_consistency(driver): @@ -1265,9 +1146,7 @@ def test_event_callback_data_consistency(driver): driver.browsing_context.close(ctx) with helper.data_lock: - assert not helper.consistency_errors, "Consistency errors: " + str( - helper.consistency_errors - ) + assert not helper.consistency_errors, "Consistency errors: " + str(helper.consistency_errors) assert len(helper.events_received) > 0, "No events received" assert len(helper.events_received) == sum(helper.context_counts.values()) assert len(helper.events_received) == sum(helper.event_type_counts.values()) @@ -1288,9 +1167,7 @@ def test_concurrent_event_handler_removal(driver): for future in futures: future.result(timeout=15) - assert not helper.thread_errors, "Errors during removal: \n" + "\n".join( - helper.thread_errors - ) + assert not helper.thread_errors, "Errors during removal: \n" + "\n".join(helper.thread_errors) def test_no_event_after_handler_removal(driver): From 52c43d2b6eca522a08ea7ebe1976fea9ed293217 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Thu, 9 Apr 2026 16:06:25 +0100 Subject: [PATCH 33/37] Use sentinel pattern for set viewport --- py/private/bidi_enhancements_manifest.py | 28 ++++++ .../webdriver/common/bidi/browsing_context.py | 87 ++++++++++--------- 2 files changed, 72 insertions(+), 43 deletions(-) diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 06c0573db9083..57e3d4e35f0dc 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -184,6 +184,7 @@ class SetClientWindowStateParameters: }, "browsingContext": { # Method enhancements + "exclude_methods": ["set_viewport"], "create": { "extract_field": "context", }, @@ -223,6 +224,33 @@ class SetClientWindowStateParameters: "devicePixelRatio": "float", }, }, + "extra_methods": [ + ''' def set_viewport( + self, + context: str | None = None, + viewport: Any = ..., + user_contexts: Any | None = None, + device_pixel_ratio: Any = ..., + ): + """Execute browsingContext.setViewport. + + Uses sentinel defaults so explicit None is serialized for viewport/devicePixelRatio, + while omitted arguments are not sent. + """ + params = {} + if context is not None: + params["context"] = context + if user_contexts is not None: + params["userContexts"] = user_contexts + if viewport is not ...: + params["viewport"] = viewport + if device_pixel_ratio is not ...: + params["devicePixelRatio"] = device_pixel_ratio + + cmd = command_builder("browsingContext.setViewport", params) + result = self._conn.execute(cmd) + return result''', + ], # Non-CDDL download event dataclasses (Chromium-specific) "extra_dataclasses": [ '''@dataclass diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 59a9813e58124..177d727c97949 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -10,9 +10,8 @@ from dataclasses import dataclass, field from typing import Any -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder - +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager class ReadinessState: """ReadinessState.""" @@ -366,14 +365,12 @@ class DownloadWillBeginParams: suggested_filename: str | None = None - @dataclass class DownloadCanceledParams: """DownloadCanceledParams.""" status: Any | None = None - @dataclass class DownloadParams: """DownloadParams - fields shared by all download end event variants.""" @@ -385,7 +382,6 @@ class DownloadParams: url: str | None = None filepath: str | None = None - @dataclass class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" @@ -405,7 +401,6 @@ def from_json(cls, params: dict) -> DownloadEndParams: ) return cls(download_params=dp) - # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "context_created": "browsingContext.contextCreated", @@ -424,7 +419,6 @@ def from_json(cls, params: dict) -> DownloadEndParams: "user_prompt_opened": "browsingContext.userPromptOpened", } - def _deserialize_info_list(items: list) -> list | None: """Recursively deserialize a list of dicts to Info objects. @@ -457,11 +451,12 @@ def _deserialize_info_list(items: list) -> list | None: return result if result else None + + class BrowsingContext: """WebDriver BiDi browsingContext module.""" EVENT_CONFIGS: dict[str, EventConfig] = {} - def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) @@ -562,7 +557,7 @@ def get_tree(self, max_depth: Any | None = None, root: Any | None = None): original_opener=item.get("originalOpener"), url=item.get("url"), user_context=item.get("userContext"), - parent=item.get("parent"), + parent=item.get("parent") ) for item in items if isinstance(item, dict) @@ -694,25 +689,6 @@ def set_bypass_csp( result = self._conn.execute(cmd) return result - def set_viewport( - self, - context: str | None = None, - viewport: Any | None = None, - user_contexts: Any | None = None, - device_pixel_ratio: Any | None = None, - ): - """Execute browsingContext.setViewport.""" - params = { - "context": context, - "viewport": viewport, - "userContexts": user_contexts, - "devicePixelRatio": device_pixel_ratio, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("browsingContext.setViewport", params) - result = self._conn.execute(cmd) - return result - def traverse_history(self, context: Any | None = None, delta: Any | None = None): """Execute browsingContext.traverseHistory.""" if context is None: @@ -729,6 +705,32 @@ def traverse_history(self, context: Any | None = None, delta: Any | None = None) result = self._conn.execute(cmd) return result + def set_viewport( + self, + context: str | None = None, + viewport: Any = ..., + user_contexts: Any | None = None, + device_pixel_ratio: Any = ..., + ): + """Execute browsingContext.setViewport. + + Uses sentinel defaults so explicit None is serialized for viewport/devicePixelRatio, + while omitted arguments are not sent. + """ + params = {} + if context is not None: + params["context"] = context + if user_contexts is not None: + params["userContexts"] = user_contexts + if viewport is not ...: + params["viewport"] = viewport + if device_pixel_ratio is not ...: + params["devicePixelRatio"] = device_pixel_ratio + + cmd = command_builder("browsingContext.setViewport", params) + result = self._conn.execute(cmd) + return result + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: """Add an event handler. @@ -755,49 +757,48 @@ def clear_event_handlers(self) -> None: """Clear all event handlers.""" return self._event_manager.clear_event_handlers() - # Event Info Type Aliases # Event: browsingContext.contextCreated -ContextCreated = globals().get("Info", dict) # Fallback to dict if type not defined +ContextCreated = globals().get('Info', dict) # Fallback to dict if type not defined # Event: browsingContext.contextDestroyed -ContextDestroyed = globals().get("Info", dict) # Fallback to dict if type not defined +ContextDestroyed = globals().get('Info', dict) # Fallback to dict if type not defined # Event: browsingContext.navigationStarted -NavigationStarted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined +NavigationStarted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined # Event: browsingContext.fragmentNavigated -FragmentNavigated = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined +FragmentNavigated = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined # Event: browsingContext.historyUpdated -HistoryUpdated = globals().get("HistoryUpdatedParameters", dict) # Fallback to dict if type not defined +HistoryUpdated = globals().get('HistoryUpdatedParameters', dict) # Fallback to dict if type not defined # Event: browsingContext.domContentLoaded -DomContentLoaded = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined +DomContentLoaded = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined # Event: browsingContext.load -Load = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined +Load = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined # Event: browsingContext.downloadWillBegin -DownloadWillBegin = globals().get("DownloadWillBeginParams", dict) # Fallback to dict if type not defined +DownloadWillBegin = globals().get('DownloadWillBeginParams', dict) # Fallback to dict if type not defined # Event: browsingContext.downloadEnd -DownloadEnd = globals().get("DownloadEndParams", dict) # Fallback to dict if type not defined +DownloadEnd = globals().get('DownloadEndParams', dict) # Fallback to dict if type not defined # Event: browsingContext.navigationAborted -NavigationAborted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined +NavigationAborted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined # Event: browsingContext.navigationCommitted -NavigationCommitted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined +NavigationCommitted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined # Event: browsingContext.navigationFailed -NavigationFailed = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined +NavigationFailed = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined # Event: browsingContext.userPromptClosed -UserPromptClosed = globals().get("UserPromptClosedParameters", dict) # Fallback to dict if type not defined +UserPromptClosed = globals().get('UserPromptClosedParameters', dict) # Fallback to dict if type not defined # Event: browsingContext.userPromptOpened -UserPromptOpened = globals().get("UserPromptOpenedParameters", dict) # Fallback to dict if type not defined +UserPromptOpened = globals().get('UserPromptOpenedParameters', dict) # Fallback to dict if type not defined # Populate EVENT_CONFIGS with event configuration mappings From fecbb4e195ff7297c543c4c540db44184c756c82 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Thu, 9 Apr 2026 16:28:11 +0100 Subject: [PATCH 34/37] formatting sigh --- .../webdriver/common/bidi/browsing_context.py | 42 +++++++++++-------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 177d727c97949..5491e157a87c2 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -10,8 +10,9 @@ from dataclasses import dataclass, field from typing import Any +from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventManager from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager + class ReadinessState: """ReadinessState.""" @@ -365,12 +366,14 @@ class DownloadWillBeginParams: suggested_filename: str | None = None + @dataclass class DownloadCanceledParams: """DownloadCanceledParams.""" status: Any | None = None + @dataclass class DownloadParams: """DownloadParams - fields shared by all download end event variants.""" @@ -382,6 +385,7 @@ class DownloadParams: url: str | None = None filepath: str | None = None + @dataclass class DownloadEndParams: """DownloadEndParams - params for browsingContext.downloadEnd event.""" @@ -401,6 +405,7 @@ def from_json(cls, params: dict) -> DownloadEndParams: ) return cls(download_params=dp) + # BiDi Event Name to Parameter Type Mapping EVENT_NAME_MAPPING = { "context_created": "browsingContext.contextCreated", @@ -419,6 +424,7 @@ def from_json(cls, params: dict) -> DownloadEndParams: "user_prompt_opened": "browsingContext.userPromptOpened", } + def _deserialize_info_list(items: list) -> list | None: """Recursively deserialize a list of dicts to Info objects. @@ -451,12 +457,11 @@ def _deserialize_info_list(items: list) -> list | None: return result if result else None - - class BrowsingContext: """WebDriver BiDi browsingContext module.""" EVENT_CONFIGS: dict[str, EventConfig] = {} + def __init__(self, conn) -> None: self._conn = conn self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) @@ -557,7 +562,7 @@ def get_tree(self, max_depth: Any | None = None, root: Any | None = None): original_opener=item.get("originalOpener"), url=item.get("url"), user_context=item.get("userContext"), - parent=item.get("parent") + parent=item.get("parent"), ) for item in items if isinstance(item, dict) @@ -757,48 +762,49 @@ def clear_event_handlers(self) -> None: """Clear all event handlers.""" return self._event_manager.clear_event_handlers() + # Event Info Type Aliases # Event: browsingContext.contextCreated -ContextCreated = globals().get('Info', dict) # Fallback to dict if type not defined +ContextCreated = globals().get("Info", dict) # Fallback to dict if type not defined # Event: browsingContext.contextDestroyed -ContextDestroyed = globals().get('Info', dict) # Fallback to dict if type not defined +ContextDestroyed = globals().get("Info", dict) # Fallback to dict if type not defined # Event: browsingContext.navigationStarted -NavigationStarted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +NavigationStarted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.fragmentNavigated -FragmentNavigated = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +FragmentNavigated = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.historyUpdated -HistoryUpdated = globals().get('HistoryUpdatedParameters', dict) # Fallback to dict if type not defined +HistoryUpdated = globals().get("HistoryUpdatedParameters", dict) # Fallback to dict if type not defined # Event: browsingContext.domContentLoaded -DomContentLoaded = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +DomContentLoaded = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.load -Load = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +Load = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.downloadWillBegin -DownloadWillBegin = globals().get('DownloadWillBeginParams', dict) # Fallback to dict if type not defined +DownloadWillBegin = globals().get("DownloadWillBeginParams", dict) # Fallback to dict if type not defined # Event: browsingContext.downloadEnd -DownloadEnd = globals().get('DownloadEndParams', dict) # Fallback to dict if type not defined +DownloadEnd = globals().get("DownloadEndParams", dict) # Fallback to dict if type not defined # Event: browsingContext.navigationAborted -NavigationAborted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +NavigationAborted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.navigationCommitted -NavigationCommitted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +NavigationCommitted = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.navigationFailed -NavigationFailed = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined +NavigationFailed = globals().get("BaseNavigationInfo", dict) # Fallback to dict if type not defined # Event: browsingContext.userPromptClosed -UserPromptClosed = globals().get('UserPromptClosedParameters', dict) # Fallback to dict if type not defined +UserPromptClosed = globals().get("UserPromptClosedParameters", dict) # Fallback to dict if type not defined # Event: browsingContext.userPromptOpened -UserPromptOpened = globals().get('UserPromptOpenedParameters', dict) # Fallback to dict if type not defined +UserPromptOpened = globals().get("UserPromptOpenedParameters", dict) # Fallback to dict if type not defined # Populate EVENT_CONFIGS with event configuration mappings From 5d7287317818c77404337b1fd8d107429d241aca Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Thu, 9 Apr 2026 16:47:58 +0100 Subject: [PATCH 35/37] more formatting because ruff format isn't enough --- py/BUILD.bazel | 1 - py/generate_bidi.py | 19 +++++++++- py/private/bidi_enhancements_manifest.py | 18 ++++++++++ py/private/cdp.py | 36 ++++++++----------- py/private/generate_bidi.bzl | 1 - py/selenium/webdriver/common/bidi/__init__.py | 19 ++++++++-- py/selenium/webdriver/common/bidi/browser.py | 20 ++++++++--- .../webdriver/common/bidi/browsing_context.py | 20 ++++++++--- .../webdriver/common/bidi/emulation.py | 20 ++++++++--- py/selenium/webdriver/common/bidi/input.py | 20 ++++++++--- py/selenium/webdriver/common/bidi/log.py | 20 ++++++++--- py/selenium/webdriver/common/bidi/network.py | 20 ++++++++--- py/selenium/webdriver/common/bidi/script.py | 20 ++++++++--- py/selenium/webdriver/common/bidi/session.py | 20 ++++++++--- py/selenium/webdriver/common/bidi/storage.py | 20 ++++++++--- .../webdriver/common/bidi/webextension.py | 20 ++++++++--- 16 files changed, 227 insertions(+), 67 deletions(-) diff --git a/py/BUILD.bazel b/py/BUILD.bazel index 292cde4981d74..186324560aade 100644 --- a/py/BUILD.bazel +++ b/py/BUILD.bazel @@ -810,7 +810,6 @@ BROWSER_TESTS = { ] ] - test_suite( name = "test-remote", tags = ["remote"], diff --git a/py/generate_bidi.py b/py/generate_bidi.py index 194d94ba12d04..5b301d3ec7e40 100755 --- a/py/generate_bidi.py +++ b/py/generate_bidi.py @@ -1,4 +1,21 @@ -#!/usr/bin/env python3.10 +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + """ Generate Python WebDriver BiDi command modules from CDDL specification. diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 57e3d4e35f0dc..4b25688ed47c4 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -1,3 +1,21 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + """ Enhancement manifest for BiDi code generation. diff --git a/py/private/cdp.py b/py/private/cdp.py index bac00765f43ca..d94f0dac2e32b 100644 --- a/py/private/cdp.py +++ b/py/private/cdp.py @@ -1,26 +1,20 @@ -# The MIT License(MIT) +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Copyright(c) 2018 Hyperion Gray +# http://www.apache.org/licenses/LICENSE-2.0 # -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files(the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -# -# This code comes from https://github.com/HyperionGray/trio-chrome-devtools-protocol/tree/master/trio_cdp +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import contextvars import importlib diff --git a/py/private/generate_bidi.bzl b/py/private/generate_bidi.bzl index e072279f85e94..8b4cc4e3e648f 100644 --- a/py/private/generate_bidi.bzl +++ b/py/private/generate_bidi.bzl @@ -72,7 +72,6 @@ def _generate_bidi_impl(ctx): return [DefaultInfo(files = depset(outputs))] - generate_bidi = rule( implementation = _generate_bidi_impl, attrs = { diff --git a/py/selenium/webdriver/common/bidi/__init__.py b/py/selenium/webdriver/common/bidi/__init__.py index 79ba3dbf2f86f..b37319da3651b 100644 --- a/py/selenium/webdriver/common/bidi/__init__.py +++ b/py/selenium/webdriver/common/bidi/__init__.py @@ -1,7 +1,20 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from __future__ import annotations diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 6310f2e18c2ce..440f13ed00072 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: browser +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from dataclasses import dataclass, field diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index 5491e157a87c2..b5e14f19c6864 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: browsingContext +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from collections.abc import Callable diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index 0860890abf41b..f1bc0c9efeb0a 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: emulation +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from dataclasses import dataclass, field diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 5d4c670490089..6c06fc4e7deaa 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: input +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from collections.abc import Callable diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 856d8561e706f..597936402f99c 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: log +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from collections.abc import Callable diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index e13fbe0f7a20b..6c24e399b0e54 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: network +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from collections.abc import Callable diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index 38e43a6677470..ee6eb4f4a437a 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: script +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from collections.abc import Callable diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index 741faeb42bc43..b00544d286546 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: session +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from dataclasses import dataclass, field diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 9825407c2eaf8..90e65ac3d5ffb 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: storage +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from dataclasses import dataclass, field diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 03fedab62e174..62f2dec130308 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: webExtension +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from dataclasses import dataclass, field From 887195630bb2ac696085f6a5f1d857380da05a66 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Thu, 9 Apr 2026 20:23:08 +0100 Subject: [PATCH 36/37] correct signature --- py/private/bidi_enhancements_manifest.py | 32 ++++++++ .../webdriver/common/bidi/emulation.py | 73 ++++++++++--------- 2 files changed, 69 insertions(+), 36 deletions(-) diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py index 4b25688ed47c4..8cec1f9da245f 100644 --- a/py/private/bidi_enhancements_manifest.py +++ b/py/private/bidi_enhancements_manifest.py @@ -584,6 +584,38 @@ class SetNetworkConditionsParameters: if user_contexts is not None: params["userContexts"] = user_contexts cmd = command_builder("emulation.setNetworkConditions", params) + return self._conn.execute(cmd)''', + ''' def set_screen_settings_override( + self, + width: int | None = None, + height: int | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute emulation.setScreenSettingsOverride. + + Sets or clears the screen settings override for specified browsing or user + contexts. + + Args: + width: The screen width in pixels, or ``None`` to clear the override. + height: The screen height in pixels, or ``None`` to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + screen_area = None + if width is not None or height is not None: + screen_area = {} + if width is not None: + screen_area["width"] = width + if height is not None: + screen_area["height"] = height + params: dict[str, Any] = {"screenArea": screen_area} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setScreenSettingsOverride", params) return self._conn.execute(cmd)''', ], }, diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index f1bc0c9efeb0a..c03d602f25670 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -1,21 +1,9 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - +# WebDriver BiDi module: emulation from __future__ import annotations from dataclasses import dataclass, field @@ -237,26 +225,6 @@ def set_locale_override( result = self._conn.execute(cmd) return result - def set_screen_settings_override( - self, - screen_area: Any | None = None, - contexts: list[Any] | None = None, - user_contexts: list[Any] | None = None, - ): - """Execute emulation.setScreenSettingsOverride.""" - if screen_area is None: - raise TypeError("set_screen_settings_override() missing required argument: 'screen_area'") - - params = { - "screenArea": screen_area, - "contexts": contexts, - "userContexts": user_contexts, - } - params = {k: v for k, v in params.items() if v is not None} - cmd = command_builder("emulation.setScreenSettingsOverride", params) - result = self._conn.execute(cmd) - return result - def set_scrollbar_type_override( self, scrollbar_type: Any | None = None, @@ -485,3 +453,36 @@ def set_network_conditions( params["userContexts"] = user_contexts cmd = command_builder("emulation.setNetworkConditions", params) return self._conn.execute(cmd) + + def set_screen_settings_override( + self, + width: int | None = None, + height: int | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute emulation.setScreenSettingsOverride. + + Sets or clears the screen settings override for specified browsing or user + contexts. + + Args: + width: The screen width in pixels, or ``None`` to clear the override. + height: The screen height in pixels, or ``None`` to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + screen_area = None + if width is not None or height is not None: + screen_area = {} + if width is not None: + screen_area["width"] = width + if height is not None: + screen_area["height"] = height + params: dict[str, Any] = {"screenArea": screen_area} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setScreenSettingsOverride", params) + return self._conn.execute(cmd) From 559b9d30fcd94e843d401bec8548d41637c96455 Mon Sep 17 00:00:00 2001 From: AutomatedTester Date: Thu, 9 Apr 2026 20:39:04 +0100 Subject: [PATCH 37/37] more formatting because ruff and ./go format do different things and hate people writing code --- .../webdriver/common/bidi/emulation.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index c03d602f25670..a3e6b4b6c4ddb 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -1,9 +1,21 @@ -# DO NOT EDIT THIS FILE! +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# This file is generated from the WebDriver BiDi specification. If you need to make -# changes, edit the generator and regenerate all of the modules. +# http://www.apache.org/licenses/LICENSE-2.0 # -# WebDriver BiDi module: emulation +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + from __future__ import annotations from dataclasses import dataclass, field