Skip to content

Commit

Permalink
Update examples with default models
Browse files Browse the repository at this point in the history
  • Loading branch information
matteo-grella committed Oct 30, 2023
1 parent 9d9e067 commit e93f0fc
Show file tree
Hide file tree
Showing 12 changed files with 56 additions and 24 deletions.
2 changes: 1 addition & 1 deletion examples/abstractivequestionasnwering/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func main() {
zerolog.SetGlobalLevel(zerolog.TraceLevel)
LoadDotenv()

modelsDir := HasEnvVar("CYBERTRON_MODELS_DIR")
modelsDir := HasEnvVarOr("CYBERTRON_MODELS_DIR", "models")

m, err := tasks.Load[*bart.TextGeneration](&tasks.Config{
ModelsDir: modelsDir,
Expand Down
4 changes: 2 additions & 2 deletions examples/languagemodeling/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ func main() {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
LoadDotenv()

modelsDir := HasEnvVar("CYBERTRON_MODELS_DIR")
modelName := HasEnvVar("CYBERTRON_MODEL")
modelsDir := HasEnvVarOr("CYBERTRON_MODELS_DIR", "models")
modelName := HasEnvVarOr("CYBERTRON_MODEL", languagemodeling.DefaultModel)

m, err := tasks.Load[languagemodeling.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName})
if err != nil {
Expand Down
12 changes: 7 additions & 5 deletions examples/questionanswering/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ import (
"github.com/rs/zerolog/log"
)

// Example of content to be used as context for the question answering task.
const content = `Cloud computing is a technology that allows individuals and businesses to access computing resources over the Internet. It enables users to utilize hardware and software that are managed by third parties at remote locations. Services provided by cloud computing include storage solutions, databases, and computing power, which can be used on a pay-per-use basis. This model offers flexibility and scalability, reducing the need for large upfront investments in infrastructure. Major providers of cloud computing services include Amazon Web Services (AWS), Microsoft Azure, and Google Cloud Platform (GCP).`

func main() {
zerolog.SetGlobalLevel(zerolog.TraceLevel)
LoadDotenv()

modelsDir := HasEnvVar("CYBERTRON_MODELS_DIR")
modelName := HasEnvVar("CYBERTRON_MODEL")
paragraph := HasEnvVar("CYBERTRON_QA_PARAGRAPH")
modelsDir := HasEnvVarOr("CYBERTRON_MODELS_DIR", "models")
modelName := HasEnvVarOr("CYBERTRON_MODEL", questionanswering.DefaultEnglishModel)

m, err := tasks.Load[questionanswering.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName})
if err != nil {
Expand All @@ -36,7 +38,7 @@ func main() {

fn := func(text string) error {
start := time.Now()
result, err := m.Answer(context.Background(), text, paragraph, opts)
result, err := m.Answer(context.Background(), text, content, opts)
if err != nil {
return err
}
Expand All @@ -45,7 +47,7 @@ func main() {
return nil
}

fmt.Println(paragraph)
fmt.Println(content)

err = ForEachInput(os.Stdin, fn)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion examples/relationextraction/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func main() {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
LoadDotenv()

modelsDir := HasEnvVar("CYBERTRON_MODELS_DIR")
modelsDir := HasEnvVarOr("CYBERTRON_MODELS_DIR", "models")

m, err := tasks.LoadModelForTextGeneration(&tasks.Config{
ModelsDir: modelsDir,
Expand Down
11 changes: 8 additions & 3 deletions examples/textclassification/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ import (
"github.com/rs/zerolog/log"
)

const limit = 5 // number of labels to show

func main() {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
LoadDotenv()

modelsDir := HasEnvVar("CYBERTRON_MODELS_DIR")
modelName := HasEnvVar("CYBERTRON_MODEL")
modelsDir := HasEnvVarOr("CYBERTRON_MODELS_DIR", "models")
modelName := HasEnvVarOr("CYBERTRON_MODEL", textclassification.DefaultModelForGeographicCategorizationMulti)

m, err := tasks.Load[textclassification.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName})
if err != nil {
Expand All @@ -38,7 +40,10 @@ func main() {
return err
}
fmt.Println(time.Since(start).Seconds())
fmt.Println(result)

for i := range result.Labels[:limit] {
fmt.Printf("%s\t%0.3f\n", result.Labels[i], result.Scores[i])
}
return nil
}

Expand Down
6 changes: 3 additions & 3 deletions examples/textencoding/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ import (
"github.com/rs/zerolog/log"
)

const limit = 10
const limit = 10 // number of dimensions to show

func main() {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
LoadDotenv()

modelsDir := HasEnvVar("CYBERTRON_MODELS_DIR")
modelName := HasEnvVar("CYBERTRON_MODEL")
modelsDir := HasEnvVarOr("CYBERTRON_MODELS_DIR", "models")
modelName := HasEnvVarOr("CYBERTRON_MODEL", textencoding.DefaultModelMulti)

m, err := tasks.Load[textencoding.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName})
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions examples/textgeneration/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ func main() {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
LoadDotenv()

modelsDir := "/Users/mg/Projects/nlpodyssey/cybertron/models" //HasEnvVar("CYBERTRON_MODELS_DIR")
modelName := "Helsinki-NLP/opus-mt-it-en"
modelsDir := HasEnvVarOr("CYBERTRON_MODELS_DIR", "models")
modelName := HasEnvVarOr("CYBERTRON_MODEL", textgeneration.DefaultModelForMachineTranslation("en", "it"))

start := time.Now()
m, err := tasks.Load[textgeneration.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName})
Expand Down
4 changes: 2 additions & 2 deletions examples/tokenclassification/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ func main() {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
LoadDotenv()

modelsDir := HasEnvVar("CYBERTRON_MODELS_DIR")
modelName := HasEnvVar("CYBERTRON_MODEL")
modelsDir := HasEnvVarOr("CYBERTRON_MODELS_DIR", "models")
modelName := HasEnvVarOr("CYBERTRON_MODEL", tokenclassification.DefaultEnglishModel)

m, err := tasks.Load[tokenclassification.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName})
if err != nil {
Expand Down
10 changes: 10 additions & 0 deletions examples/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ func HasEnvVar(key string) string {
return value
}

// HasEnvVarOr returns the value of the environment variable with the given key.
// It returns the alternative value if the environment variable is not set.
func HasEnvVarOr(key string, alt string) string {
value := os.Getenv(key)
if value == "" || len(strings.Trim(value, " ")) == 0 {
return alt
}
return value
}

// MarshalJSON returns the JSON string representation of the input data
func MarshalJSON(data any) string {
m, _ := json.MarshalIndent(data, "", " ")
Expand Down
15 changes: 11 additions & 4 deletions examples/zeroshotclassification/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@ func main() {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
LoadDotenv()

modelsDir := HasEnvVar("CYBERTRON_MODELS_DIR")
modelName := HasEnvVar("CYBERTRON_MODEL")
possibleClasses := HasEnvVar("CYBERTRON_ZERO_SHOT_POSSIBLE_CLASSES")
modelsDir := HasEnvVarOr("CYBERTRON_MODELS_DIR", "models")
modelName := HasEnvVarOr("CYBERTRON_MODEL", zeroshotclassifier.DefaultModel)

if len(os.Args) < 2 {
log.Fatal().Msg("missing possible classes (comma separated)")
}
possibleClasses := os.Args[1]

m, err := tasks.Load[zeroshotclassifier.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName})
if err != nil {
Expand All @@ -46,7 +50,10 @@ func main() {
return err
}
fmt.Println(time.Since(start).Seconds())
fmt.Println(result)

for i := range result.Labels {
fmt.Printf("%s\t%0.3f\n", result.Labels[i], result.Scores[i])
}
return nil
}

Expand Down
3 changes: 3 additions & 0 deletions pkg/models/bart/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ func ConfigFromFile(file string) (Config, error) {
if config.MaxLength == 0 {
config.MaxLength = config.MaxPositionEmbeddings
}
if config.NumBeams == 0 {
config.NumBeams = 4 // TODO: check if this is the default value?
}
return config, nil
}

Expand Down
7 changes: 6 additions & 1 deletion pkg/tasks/tokenclassification/tokenclassification.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@ import (

const (
// DefaultEnglishModel is a model for Named Entities Recognition for the English language.
// It supports the following entities (CoNLL-2003 NER dataset):
// LOC, MISC, ORG, PER
DefaultEnglishModel = "dbmdz/bert-large-cased-finetuned-conll03-english"

// DefaultEnglishModelOntonotes is a model for Named Entities Recognition for the English language.
// It supports the following entities:
// CARDINAL, DATE, EVENT, FAC, GPE, LANGUAGE, LAW, LOC, MONEY, NORP, ORDINAL, PERCENT, PERSON, PRODUCT, QUANTITY, TIME, WORK_OF_ART
// Model card: https://huggingface.co/djagatiya/ner-bert-base-cased-ontonotesv5-englishv4
DefaultEnglishModel = "djagatiya/ner-bert-base-cased-ontonotesv5-englishv4"
DefaultEnglishModelOntonotes = "djagatiya/ner-bert-base-cased-ontonotesv5-englishv4"

// DefaultModelMulti is a multilingual model for Named Entities Recognition supporting 9 languages:
// de, en, es, fr, it, nl, pl, pt, ru.
Expand Down

0 comments on commit e93f0fc

Please sign in to comment.