1818package sync
1919
2020import (
21+ "context"
2122 "fmt"
2223 "runtime"
2324 "syscall"
24- "time"
2525 "unsafe"
2626
2727 discovery "github.com/arduino/pluggable-discovery-protocol-handler/v2"
@@ -32,11 +32,13 @@ import (
3232
3333//sys getModuleHandle(moduleName *byte) (handle syscall.Handle, err error) = GetModuleHandleA
3434//sys registerClass(wndClass *wndClass) (atom uint16, err error) = user32.RegisterClassA
35+ //sys unregisterClass(className *byte) (err error) = user32.UnregisterClassA
3536//sys defWindowProc(hwnd syscall.Handle, msg uint32, wParam uintptr, lParam uintptr) (lResult uintptr) = user32.DefWindowProcW
3637//sys createWindowEx(exstyle uint32, className *byte, windowText *byte, style uint32, x int32, y int32, width int32, height int32, parent syscall.Handle, menu syscall.Handle, hInstance syscall.Handle, lpParam uintptr) (hwnd syscall.Handle, err error) = user32.CreateWindowExA
38+ //sys destroyWindowEx(hwnd syscall.Handle) (err error) = user32.DestroyWindow
3739//sys registerDeviceNotification(recipient syscall.Handle, filter *devBroadcastDeviceInterface, flags uint32) (devHandle syscall.Handle, err error) = user32.RegisterDeviceNotificationA
38- //sys getMessage(msg *msg, hwnd syscall.Handle, msgFilterMin uint32, msgFilterMax uint32 ) (res int32, err error) = user32.GetMessageA
39- //sys translateMessage (msg *msg) (res bool ) = user32.TranslateMessage
40+ //sys unregisterDeviceNotification(deviceHandle syscall.Handle) (err error) = user32.UnregisterDeviceNotification
41+ //sys getMessage (msg *msg, hwnd syscall.Handle, msgFilterMin uint32, msgFilterMax uint32 ) (err error ) = user32.GetMessageA
4042//sys dispatchMessage(msg *msg) (res int32, err error) = user32.DispatchMessageA
4143
4244type wndClass struct {
@@ -67,13 +69,7 @@ type msg struct {
6769 lPrivate int32
6870}
6971
70- const wsExDlgModalFrame = 0x00000001
7172const wsExTopmost = 0x00000008
72- const wsExTransparent = 0x00000020
73- const wsExMDIChild = 0x00000040
74- const wsExToolWindow = 0x00000080
75- const wsExAppWindow = 0x00040000
76- const wsExLayered = 0x00080000
7773
7874type guid struct {
7975 data1 uint32
@@ -90,30 +86,33 @@ type devBroadcastDeviceInterface struct {
9086 szName uint16
9187}
9288
93- //var usbEventGUID = guid{???} // TODO
89+ // USB devices GUID used to filter notifications
90+ var usbEventGUID guid = guid {
91+ data1 : 0x10bfdca5 ,
92+ data2 : 0x3065 ,
93+ data3 : 0xd211 ,
94+ data4 : [8 ]byte {0x90 , 0x1f , 0x00 , 0xc0 , 0x4f , 0xb9 , 0x51 , 0xed },
95+ }
9496
9597const deviceNotifyWindowHandle = 0
96- const deviceNotifySserviceHandle = 1
9798const deviceNotifyAllInterfaceClasses = 4
98-
9999const dbtDevtypeDeviceInterface = 5
100100
101- func init () {
102- runtime .LockOSThread ()
103- }
101+ type WindowProcCallback func (hwnd syscall.Handle , msg uint32 , wParam uintptr , lParam uintptr ) uintptr
104102
105103// Start the sync process, successful events will be passed to eventCB, errors to errorCB.
106104// Returns a channel used to stop the sync process.
107105// Returns error if sync process can't be started.
108106func Start (eventCB discovery.EventCallback , errorCB discovery.ErrorCallback ) (chan <- bool , error ) {
109- startResult := make (chan error )
110- event := make ( chan bool , 1 )
111- go func () {
112- initAndRunWindowHandler ( startResult , event )
113- }()
114- if err := <- startResult ; err != nil {
115- return nil , err
107+ eventsChan := make (chan bool , 1 )
108+ windowCallback := func ( hwnd syscall. Handle , msg uint32 , wParam uintptr , lParam uintptr ) uintptr {
109+ select {
110+ case eventsChan <- true :
111+ default :
112+ }
113+ return defWindowProc ( hwnd , msg , wParam , lParam )
116114 }
115+
117116 go func () {
118117 current , err := enumerator .GetDetailedPortsList ()
119118 if err != nil {
@@ -125,18 +124,15 @@ func Start(eventCB discovery.EventCallback, errorCB discovery.ErrorCallback) (ch
125124 }
126125
127126 for {
128- <- event
129-
130- // Wait 100 ms to pile up events
131- time .Sleep (100 * time .Millisecond )
132127 select {
133- case <- event :
128+ case ev := <- eventsChan :
134129 // Just one event could be queued because the channel has size 1
135130 // (more events coming after this one are discarded on send)
131+ if ! ev {
132+ return
133+ }
136134 default :
137135 }
138-
139- // Send updates
140136 updates , err := enumerator .GetDetailedPortsList ()
141137 if err != nil {
142138 errorCB (fmt .Sprintf ("Error enumerating serial ports: %s" , err ))
@@ -146,72 +142,127 @@ func Start(eventCB discovery.EventCallback, errorCB discovery.ErrorCallback) (ch
146142 current = updates
147143 }
148144 }()
149- quit := make (chan bool )
145+
146+ // Context used to stop the goroutine that consume the window messages
147+ ctx , cancel := context .WithCancel (context .Background ())
148+
149+ stopper := make (chan bool )
150150 go func () {
151- <- quit
152- // TODO: implement termination channel
151+ // Lock this goroutine to the same OS thread for its whole execution,
152+ // if this is not done destruction of the windows will fail since
153+ // it must be done in the same thread that creates it
154+ runtime .LockOSThread ()
155+ defer close (eventsChan )
156+
157+ // We must create the window used to receive notifications in the same
158+ // thread that destroys it otherwise it would fail
159+ windowHandle , className , err := createWindow (windowCallback )
160+ if err != nil {
161+ errorCB (err .Error ())
162+ return
163+ }
164+ defer func () {
165+ if err := destroyWindow (windowHandle , className ); err != nil {
166+ errorCB (err .Error ())
167+ }
168+ }()
169+
170+ notificationsDevHandle , err := registerNotifications (windowHandle )
171+ if err != nil {
172+ errorCB (err .Error ())
173+ return
174+ }
175+ defer func () {
176+ if err := unregisterNotifications (notificationsDevHandle ); err != nil {
177+ errorCB (err .Error ())
178+ }
179+ }()
180+ defer cancel ()
181+
182+ // To consume messages we need the window handle, so we must start
183+ // this goroutine in here and not outside the one that handles
184+ // creation and destruction of the window used to receive notifications
185+ go func () {
186+ if err := consumeMessages (ctx , windowHandle ); err != nil {
187+ errorCB (err .Error ())
188+ }
189+ }()
190+
191+ <- stopper
153192 }()
154- return quit , nil
193+ return stopper , nil
155194}
156195
157- func initAndRunWindowHandler ( startResult chan <- error , event chan <- bool ) {
158- handle , err := getModuleHandle (nil )
196+ func createWindow ( windowCallback WindowProcCallback ) (syscall. Handle , * byte , error ) {
197+ moduleHandle , err := getModuleHandle (nil )
159198 if err != nil {
160- startResult <- err
161- return
199+ return syscall .InvalidHandle , nil , err
162200 }
163201
164- wndProc := func (hwnd syscall.Handle , msg uint32 , wParam uintptr , lParam uintptr ) uintptr {
165- select {
166- case event <- true :
167- default :
168- }
169- return defWindowProc (hwnd , msg , wParam , lParam )
202+ className , err := syscall .BytePtrFromString ("arduino-serialdiscovery" )
203+ if err != nil {
204+ return syscall .InvalidHandle , nil , err
170205 }
171-
172- className := syscall .StringBytePtr ("serialdiscovery" )
173206 windowClass := & wndClass {
174- instance : handle ,
207+ instance : moduleHandle ,
175208 className : className ,
176- wndProc : syscall .NewCallback (wndProc ),
209+ wndProc : syscall .NewCallback (windowCallback ),
177210 }
178211 if _ , err := registerClass (windowClass ); err != nil {
179- startResult <- fmt .Errorf ("registering new window: %s" , err )
180- return
212+ return syscall .InvalidHandle , nil , fmt .Errorf ("registering new window: %s" , err )
181213 }
182214
183- hwnd , err := createWindowEx (wsExTopmost , className , className , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 )
215+ windowHandle , err := createWindowEx (wsExTopmost , className , className , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 )
184216 if err != nil {
185- startResult <- fmt .Errorf ("creating window: %s" , err )
186- return
217+ return syscall .InvalidHandle , nil , fmt .Errorf ("creating window: %s" , err )
218+ }
219+ return windowHandle , className , nil
220+ }
221+
222+ func destroyWindow (windowHandle syscall.Handle , className * byte ) error {
223+ if err := destroyWindowEx (windowHandle ); err != nil {
224+ return fmt .Errorf ("error destroying window: %s" , err )
187225 }
226+ if err := unregisterClass (className ); err != nil {
227+ return fmt .Errorf ("error unregistering window class: %s" , err )
228+ }
229+ return nil
230+ }
188231
232+ func registerNotifications (windowHandle syscall.Handle ) (syscall.Handle , error ) {
189233 notificationFilter := devBroadcastDeviceInterface {
190234 dwDeviceType : dbtDevtypeDeviceInterface ,
191- // TODO: Filter USB events using the correct GUID
235+ classGUID : usbEventGUID ,
192236 }
193237 notificationFilter .dwSize = uint32 (unsafe .Sizeof (notificationFilter ))
194238
195- if _ , err := registerDeviceNotification (
196- hwnd ,
197- & notificationFilter ,
198- deviceNotifyWindowHandle | deviceNotifyAllInterfaceClasses ); err != nil {
199- startResult <- fmt .Errorf ("registering for devices notification: %s" , err )
200- return
239+ var flags uint32 = deviceNotifyWindowHandle | deviceNotifyAllInterfaceClasses
240+ notificationsDevHandle , err := registerDeviceNotification (windowHandle , & notificationFilter , flags )
241+ if err != nil {
242+ return syscall .InvalidHandle , err
201243 }
202244
203- startResult <- nil
245+ return notificationsDevHandle , nil
246+ }
247+
248+ func unregisterNotifications (notificationsDevHandle syscall.Handle ) error {
249+ if err := unregisterDeviceNotification (notificationsDevHandle ); err != nil {
250+ return fmt .Errorf ("error unregistering device notifications: %s" , err )
251+ }
252+ return nil
253+ }
204254
255+ func consumeMessages (ctx context.Context , windowHandle syscall.Handle ) error {
205256 var m msg
206257 for {
207- if res , err := getMessage (& m , hwnd , 0 , 0 ); res == 0 || res == - 1 {
208- if err != nil {
209- // TODO: send err and stop sync mode.
210- // fmt.Println(err)
211- }
212- break
258+ select {
259+ case <- ctx .Done ():
260+ return nil
261+ default :
262+ }
263+ if err := getMessage (& m , windowHandle , 0 , 0 ); err != nil {
264+ return fmt .Errorf ("error consuming messages: %s" , err )
213265 }
214- translateMessage (& m )
215266 dispatchMessage (& m )
216267 }
217268}
0 commit comments