Skip to content
5 changes: 5 additions & 0 deletions cli/azd/cmd/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,10 @@ func authActions(root *actions.ActionDescriptor) *actions.ActionDescriptor {
ActionResolver: newLogoutAction,
})

group.Add("serve", &actions.ActionDescriptorOptions{
Command: newServeCmd("auth"),
ActionResolver: newServeAction,
})

return group
}
174 changes: 174 additions & 0 deletions cli/azd/cmd/auth_serve.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package cmd

import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"os/signal"
"syscall"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/azure/azure-dev/cli/azd/cmd/actions"
"github.com/azure/azure-dev/cli/azd/pkg/auth"
"github.com/azure/azure-dev/cli/azd/pkg/input"
"github.com/azure/azure-dev/cli/azd/pkg/output"
"github.com/spf13/cobra"
)

// TokenResponse defines the structure of the token response.
type TokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresOn int64 `json:"expires_on"`
}

// tokenHandler handles token requests.
func (serve *serveAction) tokenHandler(w http.ResponseWriter, r *http.Request) {
clientIP := r.Header.Get("X-Forwarded-For")
if clientIP == "" {
clientIP, _, _ = net.SplitHostPort(r.RemoteAddr)
}

fmt.Printf("Client IP: %s\n", clientIP)

// Only allow requests from 127.0.0.1 or host.docker.internal
allowedIPs := []string{"::1", "127.0.0.1", "host.docker.internal"}

// Check if the request comes from an allowed IP address
ipAllowed := false
for _, allowedIP := range allowedIPs {
if clientIP == allowedIP {
ipAllowed = true
break
}
}

if !ipAllowed {
http.Error(w, "Forbidden: Requests are only allowed from 127.0.0.1 or host.docker.internal", http.StatusForbidden)
return
}

resource := r.URL.Query().Get("resource")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to ensure anything about the verb that was used (I thought that MSI uses POST, not GET, for some reason?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should ensure that the Metadata header is present, as outlined on the linked page and reject other verbs, then.

if resource == "" {
resource = r.FormValue("resource")
}
if resource == "" {
resource = "https://management.azure.com/"
}

fmt.Printf("Received request for resource: %s from IP: %s\n", resource, clientIP)

ctx := context.Background()
var cred azcore.TokenCredential

cred, err := serve.credentialProvider(ctx, &auth.CredentialForCurrentUserOptions{
NoPrompt: true,
TenantID: "",
})
if err != nil {
fmt.Printf("credentialProvider: %v", err)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For these error cases, we likely still want to call something like http.Error instead of just returning here which I think would just close the connection without an HTTP response at all. Using log would also be preferable vs pushing to the console.

In addition - if the user is never able to provide a tenant ID, perhaps we call the credentialProvider when we construct the serve action and save it on the object, so we don't have to do it per request?

http.Error(w, "Failed to get credentials: "+err.Error(), http.StatusInternalServerError)
return
}

token, err := cred.GetToken(ctx, policy.TokenRequestOptions{
Scopes: []string{resource + "/.default"},
})
if err != nil {
fmt.Printf("fetching token: %v", err)
http.Error(w, "Failed to fetch token: "+err.Error(), http.StatusInternalServerError)
return
}

res := TokenResponse{
AccessToken: token.Token,
ExpiresOn: token.ExpiresOn.Unix(),
}

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(res); err != nil {
http.Error(w, "Failed to encode response: "+err.Error(), http.StatusInternalServerError)
}
}

func (serve *serveAction) start(port string) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: It would be better if this was an int not a string perhaps? If we pass in a non int I assume ListenAndServe() is going to return an error.

http.HandleFunc("/MSI/token", serve.tokenHandler)
http.HandleFunc("/metadata/identity/oauth2/token", serve.tokenHandler)

srv := &http.Server{
Addr: ":" + port,
WriteTimeout: 15 * time.Second,
ReadTimeout: 15 * time.Second,
}

go func() {
fmt.Printf("Server started on port %s\n", port)
fmt.Printf("MSI endpoint for local development: http://localhost:%s/MSI/token\n", port)
fmt.Printf("MSI endpoint for Docker: http://host.docker.internal:%s/MSI/token\n", port)
fmt.Println("Set the MSI_ENDPOINT environment variable to the appropriate URL above.")
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
fmt.Printf("Server stopped: %s\n", err)
}
}()

c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
<-c

ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
log.Printf("Server shutdown failed: %s\n", err)
}

log.Println("Shutting down")
os.Exit(0)
}

func newServeCmd(parent string) *cobra.Command {
return &cobra.Command{
Use: "serve",
Short: "Starts a local Managed Identity endpoint for development purposes.",
Annotations: map[string]string{
loginCmdParentAnnotation: parent,
},
}
}

type serveAction struct {
console input.Console
credentialProvider CredentialProviderFn
formatter output.Formatter
writer io.Writer
}

func newServeAction(
console input.Console,
credentialProvider CredentialProviderFn,
formatter output.Formatter,
writer io.Writer) actions.Action {
return &serveAction{
console: console,
credentialProvider: credentialProvider,
formatter: formatter,
writer: writer,
}
}

func (serve *serveAction) Run(ctx context.Context) (*actions.ActionResult, error) {
port := os.Getenv("AZD_AUTH_SERVER_PORT")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not add a --port to the command instead of controlling this via an env-var?

if port == "" {
port = "53028"
}
serve.start(port)
return nil, nil
}
16 changes: 16 additions & 0 deletions cli/azd/cmd/testdata/TestUsage-azd-auth-serve.snap
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

Starts a local Managed Identity endpoint for development purposes.

Usage
azd auth serve [flags]

Global Flags
-C, --cwd string : Sets the current working directory.
--debug : Enables debugging and diagnostics logging.
--docs : Opens the documentation for azd auth serve in your web browser.
-h, --help : Gets help for serve.
--no-prompt : Accepts the default value instead of prompting, or it fails if there is no default.

Find a bug? Want to let us know how we're doing? Fill out this brief survey: https://aka.ms/azure-dev/hats.


1 change: 1 addition & 0 deletions cli/azd/cmd/testdata/TestUsage-azd-auth.snap
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Usage
Available Commands
login : Log in to Azure.
logout : Log out of Azure.
serve : Starts a local Managed Identity endpoint for development purposes.

Global Flags
-C, --cwd string : Sets the current working directory.
Expand Down