Skip to content

Commit

Permalink
feat(db,server): add GetCveIDs (#358)
Browse files Browse the repository at this point in the history
* feat(db,server): add GetCveIDs

* chore(db/rdb): remove tautological error check

* chore(db/redis): fix log message
  • Loading branch information
MaineK00n authored Apr 24, 2024
1 parent 0382630 commit 7d9560e
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 4 deletions.
1 change: 1 addition & 0 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type DB interface {

Get(string) (*models.CveDetail, error)
GetMulti([]string) (map[string]models.CveDetail, error)
GetCveIDs() ([]string, error)
GetCveIDsByCpeURI(string) ([]string, []string, []string, error)
GetByCpeURI(string) ([]models.CveDetail, error)
InsertJvn([]string) error
Expand Down
36 changes: 33 additions & 3 deletions db/rdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/knqyf263/go-cpe/common"
"github.com/knqyf263/go-cpe/naming"
"github.com/spf13/viper"
"golang.org/x/exp/maps"
"golang.org/x/xerrors"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
Expand Down Expand Up @@ -188,9 +189,7 @@ func (r *RDBDriver) MigrateDB() error {
}
}
case dialectMysql, dialectPostgreSQL:
if err != nil {
return xerrors.Errorf("Failed to migrate. err: %w", err)
}
return xerrors.Errorf("Failed to migrate. err: %w", err)
default:
return xerrors.Errorf("Not Supported DB dialects. r.name: %s", r.name)
}
Expand Down Expand Up @@ -308,6 +307,37 @@ func (r *RDBDriver) Get(cveID string) (*models.CveDetail, error) {
return &detail, nil
}

// GetCveIDs select all cve ids
func (r *RDBDriver) GetCveIDs() ([]string, error) {
cveIDs := map[string]struct{}{}

var nvds []string
if err := r.conn.Model(&models.Nvd{}).Pluck("cve_id", &nvds).Error; err != nil {
return nil, err
}
for _, cveID := range nvds {
cveIDs[cveID] = struct{}{}
}

var jvns []string
if err := r.conn.Model(&models.Jvn{}).Distinct().Pluck("cve_id", &jvns).Error; err != nil {
return nil, err
}
for _, cveID := range jvns {
cveIDs[cveID] = struct{}{}
}

var fortinets []string
if err := r.conn.Model(&models.Fortinet{}).Distinct().Pluck("cve_id", &fortinets).Error; err != nil {
return nil, err
}
for _, cveID := range fortinets {
cveIDs[cveID] = struct{}{}
}

return maps.Keys(cveIDs), nil
}

func (r *RDBDriver) getCveIDsByPartVendorProduct(uri string) ([]string, error) {
specified, err := naming.UnbindURI(uri)
if err != nil {
Expand Down
31 changes: 30 additions & 1 deletion db/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,35 @@ func (r *RedisDriver) GetMulti(cveIDs []string) (map[string]models.CveDetail, er
return cveDetails, nil
}

// GetCveIDs select all cve ids
func (r *RedisDriver) GetCveIDs() ([]string, error) {
ctx := context.Background()

dbsize, err := r.conn.DBSize(ctx).Result()
if err != nil {
return nil, xerrors.Errorf("Failed to DBSize. err: %w", err)
}

cveIDs := []string{}
var cursor uint64
for {
var keys []string
var err error
keys, cursor, err = r.conn.Scan(ctx, cursor, fmt.Sprintf(cveKeyFormat, "*"), dbsize/5).Result()
if err != nil {
return nil, xerrors.Errorf("Failed to Scan. err: %w", err)
}

cveIDs = append(cveIDs, keys...)

if cursor == 0 {
break
}
}

return cveIDs, nil
}

// GetCveIDsByCpeURI Select Cve Ids by by pseudo-CPE
func (r *RedisDriver) GetCveIDsByCpeURI(uri string) ([]string, []string, []string, error) {
specified, err := naming.UnbindURI(uri)
Expand Down Expand Up @@ -504,7 +533,7 @@ func (r *RedisDriver) InsertNvd(years []string) error {
}
var err error

log.Infof("Fetching CVE information from NVD(recent, modified).")
log.Infof("Fetching CVE information from NVD.")
cacheDir, err := nvd.Fetch()
if err != nil {
return xerrors.Errorf("Failed to Fetch. err: %w", err)
Expand Down
30 changes: 30 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ func Start(logToFile bool, logDir string, driver db.DB) error {
// Routes
e.GET("/health", health())
e.GET("/cves/:id", getCve(driver))
e.POST("/cves", getCveMulti(driver))
e.GET("/cves/ids", getCveIDs(driver))
e.POST("/cpes", getCveByCpeName(driver))
e.POST("/cpes/ids", getCveIDsByCpeName(driver))

Expand Down Expand Up @@ -67,6 +69,23 @@ func getCve(driver db.DB) echo.HandlerFunc {
}
}

func getCveMulti(driver db.DB) echo.HandlerFunc {
return func(c echo.Context) error {
var cveIDs []string
if err := c.Bind(&cveIDs); err != nil {
log.Errorf("%s", err)
return err
}

cveDetails, err := driver.GetMulti(cveIDs)
if err != nil {
log.Errorf("%s", err)
return err
}
return c.JSON(http.StatusOK, &cveDetails)
}
}

type cpeName struct {
Name string `form:"name"`
}
Expand All @@ -92,6 +111,17 @@ func getCveByCpeName(driver db.DB) echo.HandlerFunc {
}
}

func getCveIDs(driver db.DB) echo.HandlerFunc {
return func(c echo.Context) error {
cveIDs, err := driver.GetCveIDs()
if err != nil {
log.Errorf("%s", err)
return err
}
return c.JSON(http.StatusOK, &cveIDs)
}
}

func getCveIDsByCpeName(driver db.DB) echo.HandlerFunc {
return func(c echo.Context) error {
cpe := cpeName{}
Expand Down

0 comments on commit 7d9560e

Please sign in to comment.