Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions internal/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,23 @@ func New(options *types.Options) (*Runner, error) {
os.Exit(0)
}

tmpDir, err := os.MkdirTemp("", "nuclei-tmp-*")
if err != nil {
return nil, errors.Wrap(err, "could not create temporary directory")
}
runner.tmpDir = tmpDir

// Cleanup tmpDir only if initialization fails
// On successful initialization, Close() method will handle cleanup
cleanupOnError := true
defer func() {
if cleanupOnError && runner.tmpDir != "" {
_ = os.RemoveAll(runner.tmpDir)
}
}()

// create the input provider and load the inputs
inputProvider, err := provider.NewInputProvider(provider.InputOptions{Options: options})
inputProvider, err := provider.NewInputProvider(provider.InputOptions{Options: options, TempDir: runner.tmpDir})
if err != nil {
return nil, errors.Wrap(err, "could not create input provider")
}
Expand Down Expand Up @@ -386,10 +401,8 @@ func New(options *types.Options) (*Runner, error) {
}
runner.rateLimiter = utils.GetRateLimiter(context.Background(), options.RateLimit, options.RateLimitDuration)

if tmpDir, err := os.MkdirTemp("", "nuclei-tmp-*"); err == nil {
runner.tmpDir = tmpDir
}

// Initialization successful, disable cleanup on error
cleanupOnError = false
return runner, nil
}

Expand Down
11 changes: 11 additions & 0 deletions pkg/input/formats/formats.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"

"github.com/projectdiscovery/nuclei/v3/pkg/input/types"
"github.com/projectdiscovery/retryablehttp-go"
fileutil "github.com/projectdiscovery/utils/file"
"gopkg.in/yaml.v3"
)
Expand Down Expand Up @@ -47,6 +48,16 @@ type Format interface {
SetOptions(options InputFormatOptions)
}

// SpecDownloader is an interface for downloading API specifications from URLs
type SpecDownloader interface {
// Download downloads the spec from the given URL and saves it to tmpDir
// Returns the path to the downloaded file
// httpClient is a retryablehttp.Client instance (can be nil for fallback)
Download(url, tmpDir string, httpClient *retryablehttp.Client) (string, error)
// SupportedExtensions returns the list of supported file extensions
SupportedExtensions() []string
}

var (
DefaultVarDumpFileName = "required_openapi_params.yaml"
ErrNoVarsDumpFile = errors.New("no required params file found")
Expand Down
136 changes: 136 additions & 0 deletions pkg/input/formats/openapi/downloader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package openapi

import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"

"github.com/pkg/errors"
"github.com/projectdiscovery/nuclei/v3/pkg/input/formats"
"github.com/projectdiscovery/retryablehttp-go"
)

// OpenAPIDownloader implements the SpecDownloader interface for OpenAPI 3.0 specs
type OpenAPIDownloader struct{}

// NewDownloader creates a new OpenAPI downloader
func NewDownloader() formats.SpecDownloader {
return &OpenAPIDownloader{}
}

// This function downloads an OpenAPI 3.0 spec from the given URL and saves it to tmpDir
func (d *OpenAPIDownloader) Download(urlStr, tmpDir string, httpClient *retryablehttp.Client) (string, error) {
// Validate URL format, OpenAPI 3.0 specs are typically JSON
if !strings.HasSuffix(urlStr, ".json") {
return "", fmt.Errorf("URL does not appear to be an OpenAPI JSON spec")
}

const maxSpecSizeBytes = 10 * 1024 * 1024 // 10MB

// Use provided httpClient or create a fallback
var client *http.Client
if httpClient != nil {
client = httpClient.HTTPClient
} else {
// Fallback to simple client if no httpClient provided
client = &http.Client{Timeout: 30 * time.Second}
}

resp, err := client.Get(urlStr)
if err != nil {
return "", errors.Wrap(err, "failed to download OpenAPI spec")
}

defer func() {
_ = resp.Body.Close()
}()

if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HTTP %d when downloading OpenAPI spec", resp.StatusCode)
}

bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, maxSpecSizeBytes))
if err != nil {
return "", errors.Wrap(err, "failed to read response body")
}

// Validate it's a valid JSON and has OpenAPI structure
var spec map[string]interface{}
if err := json.Unmarshal(bodyBytes, &spec); err != nil {
return "", fmt.Errorf("downloaded content is not valid JSON: %w", err)
}

// Check if it's an OpenAPI 3.0 spec
if openapi, exists := spec["openapi"]; exists {
if openapiStr, ok := openapi.(string); ok && strings.HasPrefix(openapiStr, "3.") {
// Valid OpenAPI 3.0 spec
} else {
return "", fmt.Errorf("not a valid OpenAPI 3.0 spec (found version: %v)", openapi)
}
} else {
return "", fmt.Errorf("not an OpenAPI spec (missing 'openapi' field)")
}

// Extract host from URL for server configuration
parsedURL, err := url.Parse(urlStr)
if err != nil {
return "", errors.Wrap(err, "failed to parse URL")
}
host := parsedURL.Host
scheme := parsedURL.Scheme
if scheme == "" {
scheme = "https"
}

// Add servers section if missing or empty
servers, exists := spec["servers"]
if !exists || servers == nil {
spec["servers"] = []map[string]interface{}{{"url": scheme + "://" + host}}
} else if serverList, ok := servers.([]interface{}); ok && len(serverList) == 0 {
spec["servers"] = []map[string]interface{}{{"url": scheme + "://" + host}}
}

// Marshal back to JSON
modifiedJSON, err := json.Marshal(spec)
if err != nil {
return "", errors.Wrap(err, "failed to marshal modified spec")
}

// Create output directory
openapiDir := filepath.Join(tmpDir, "openapi")
if err := os.MkdirAll(openapiDir, 0755); err != nil {
return "", errors.Wrap(err, "failed to create openapi directory")
}

// Generate filename
filename := fmt.Sprintf("openapi-spec-%d.json", time.Now().Unix())
filePath := filepath.Join(openapiDir, filename)

// Write file
file, err := os.Create(filePath)
if err != nil {
return "", fmt.Errorf("failed to create file: %w", err)
}

defer func() {
_ = file.Close()
}()

if _, writeErr := file.Write(modifiedJSON); writeErr != nil {
_ = os.Remove(filePath)
return "", errors.Wrap(writeErr, "failed to write OpenAPI spec to file")
}

return filePath, nil
}

// SupportedExtensions returns the list of supported file extensions for OpenAPI
func (d *OpenAPIDownloader) SupportedExtensions() []string {
return []string{".json"}
}
Loading
Loading