Skip to content

Commit 52d622f

Browse files
committed
Update Spago to v1.0.2-0.20231029222829-dea27c85cd66;
Replace `ag.Node` with `mat.Tensor`
1 parent 47c6ce7 commit 52d622f

File tree

104 files changed

+655
-5736
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

104 files changed

+655
-5736
lines changed

.github/workflows/go.yml

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ jobs:
88
- uses: actions/checkout@v3
99
- uses: actions/setup-go@v3
1010
with:
11-
go-version: '1.20.3'
11+
go-version: '1.21.3'
1212
- name: Run tests and generate coverage report
1313
run: go test -coverprofile cover.out -covermode atomic ./...
1414
- name: Upload coverage to Codecov
@@ -23,7 +23,7 @@ jobs:
2323
steps:
2424
- uses: actions/setup-go@v3
2525
with:
26-
go-version: '1.20.3'
26+
go-version: '1.21.3'
2727
- uses: actions/checkout@v3
2828
- name: go vet
2929
run: go vet ./...
@@ -34,7 +34,7 @@ jobs:
3434
steps:
3535
- uses: actions/setup-go@v3
3636
with:
37-
go-version: '1.20.3'
37+
go-version: '1.21.3'
3838
- name: Install gocyclo
3939
run: go install github.com/fzipp/gocyclo/cmd/gocyclo@latest
4040
- uses: actions/checkout@v3
@@ -47,7 +47,7 @@ jobs:
4747
steps:
4848
- uses: actions/setup-go@v3
4949
with:
50-
go-version: '1.20.3'
50+
go-version: '1.21.3'
5151
- name: Install staticcheck
5252
run: go install honnef.co/go/tools/cmd/staticcheck@latest
5353
- uses: actions/checkout@v3

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ Usage of server:
6767
-network value
6868
network type for server listening
6969
-task value
70-
type of inference/computation that the model can fulfill ("text2text"|"zero-shot-classification"|"question-answering"|"text-classification"|"token-classification"|"text-encoding")
70+
type of inference/computation that the model can fulfill ("textgeneration"|"zero-shot-classification"|"question-answering"|"text-classification"|"token-classification"|"text-encoding")
7171
-tls value
7272
whether to enable TLS ("true"|"false")
7373
-tls-cert value
@@ -82,7 +82,7 @@ For example, to run Cybertron in server mode for Machine Translation (e.g. `en`
8282
```console
8383
echo "CYBERTRON_MODEL=Helsinki-NLP/opus-mt-en-it" > .env
8484
echo "CYBERTRON_MODELS_DIR=models" >> .env
85-
echo "CYBERTRON_MODEL_TASK=text2text" >> .env
85+
echo "CYBERTRON_MODEL_TASK=text-generation" >> .env
8686
```
8787

8888
and execute the following command:

cmd/server/config.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import (
2020
type TaskType string
2121

2222
const (
23-
Text2TextTask TaskType = "text2text"
23+
TextGenerationTask TaskType = "text-generation"
2424
ZeroShotClassificationTask TaskType = "zero-shot-classification"
2525
QuestionAnsweringTask TaskType = "question-answering"
2626
TextClassificationTask TaskType = "text-classification"
@@ -31,7 +31,7 @@ const (
3131

3232
// TaskTypeValues is the list of supported task types.
3333
var TaskTypeValues = []TaskType{
34-
Text2TextTask,
34+
TextGenerationTask,
3535
ZeroShotClassificationTask,
3636
QuestionAnsweringTask,
3737
TextClassificationTask,
@@ -124,7 +124,7 @@ func (conf *config) bindFlagSet(fs *flag.FlagSet) {
124124
flagParseFunc(tasks.ParseConversionPolicy, &mm.ConversionPolicy))
125125
fs.Func("model-conversion-precision", `floating-point bits of precision to use if the model is converted ("32"|"64")`,
126126
flagParseFunc(tasks.ParseFloatPrecision, &mm.ConversionPrecision))
127-
fs.Func("task", `type of inference/computation that the model can fulfill ("text2text"|"zero-shot-classification"|"question-answering"|"text-classification"|"token-classification"|"text-encoding"|"language-modeling")`,
127+
fs.Func("task", `type of inference/computation that the model can fulfill ("text-generation"|"zero-shot-classification"|"question-answering"|"text-classification"|"token-classification"|"text-encoding"|"language-modeling")`,
128128
flagParseFunc(ParseTaskType, &conf.task))
129129

130130
s := conf.serverConfig

cmd/server/main.go

+35-3
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,16 @@ import (
1818
"github.com/nlpodyssey/cybertron/pkg/tasks"
1919
"github.com/nlpodyssey/cybertron/pkg/tasks/languagemodeling"
2020
"github.com/nlpodyssey/cybertron/pkg/tasks/questionanswering"
21-
"github.com/nlpodyssey/cybertron/pkg/tasks/text2text"
2221
"github.com/nlpodyssey/cybertron/pkg/tasks/textclassification"
2322
"github.com/nlpodyssey/cybertron/pkg/tasks/textencoding"
23+
"github.com/nlpodyssey/cybertron/pkg/tasks/textgeneration"
2424
"github.com/nlpodyssey/cybertron/pkg/tasks/tokenclassification"
2525
"github.com/nlpodyssey/cybertron/pkg/tasks/zeroshotclassifier"
2626
"github.com/rs/zerolog"
2727
"github.com/rs/zerolog/log"
28+
"github.com/shirou/gopsutil/v3/cpu"
29+
"github.com/shirou/gopsutil/v3/mem"
30+
"github.com/shirou/gopsutil/v3/process"
2831
)
2932

3033
const defaultModelsDir = "models"
@@ -71,6 +74,8 @@ func run() error {
7174
}
7275
defer tasks.Finalize(m)
7376

77+
logMetrics()
78+
7479
requestHandler, err := server.ResolveRequestHandler(m)
7580
if err != nil {
7681
return err
@@ -84,12 +89,39 @@ func run() error {
8489
return s.Start(ctx)
8590
}
8691

92+
func logMetrics() {
93+
// Set up zerolog to print with human-readable timestamps
94+
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
95+
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
96+
97+
// Get total CPU count
98+
totalCpu, _ := cpu.Counts(false)
99+
// Get process CPU percentage
100+
p, _ := process.NewProcess(int32(os.Getpid()))
101+
percent, _ := p.CPUPercent()
102+
103+
log.Info().
104+
Int("total_cpus", totalCpu).
105+
Float64("cpu_used_by_process_percent", percent).
106+
Msg("CPU Metrics")
107+
108+
// Get total available RAM
109+
vmStat, _ := mem.VirtualMemory()
110+
// Get process RAM usage
111+
memInfo, _ := p.MemoryInfo()
112+
113+
log.Info().
114+
Uint64("total_RAM_available", vmStat.Total).
115+
Uint64("RAM_used_by_process", memInfo.RSS).
116+
Msg("RAM Metrics")
117+
}
118+
87119
func loadModelForTask(conf *config) (m any, err error) {
88120
switch conf.task {
89121
case ZeroShotClassificationTask:
90122
return tasks.Load[zeroshotclassifier.Interface](conf.loaderConfig)
91-
case Text2TextTask:
92-
return tasks.Load[text2text.Interface](conf.loaderConfig)
123+
case TextGenerationTask:
124+
return tasks.Load[textgeneration.Interface](conf.loaderConfig)
93125
case QuestionAnsweringTask:
94126
return tasks.Load[questionanswering.Interface](conf.loaderConfig)
95127
case TextClassificationTask:

examples/abstractivequestionasnwering/main.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ import (
1212
//lint:ignore ST1001 allow dot import just to make the example more readable
1313
. "github.com/nlpodyssey/cybertron/examples"
1414
"github.com/nlpodyssey/cybertron/pkg/tasks"
15-
"github.com/nlpodyssey/cybertron/pkg/tasks/text2text"
16-
"github.com/nlpodyssey/cybertron/pkg/tasks/text2text/bart"
15+
"github.com/nlpodyssey/cybertron/pkg/tasks/textgeneration"
16+
"github.com/nlpodyssey/cybertron/pkg/tasks/textgeneration/bart"
1717
"github.com/rs/zerolog"
1818
"github.com/rs/zerolog/log"
1919
)
@@ -35,19 +35,19 @@ func main() {
3535

3636
modelsDir := HasEnvVar("CYBERTRON_MODELS_DIR")
3737

38-
m, err := tasks.Load[*bart.Text2Text](&tasks.Config{
38+
m, err := tasks.Load[*bart.TextGeneration](&tasks.Config{
3939
ModelsDir: modelsDir,
40-
ModelName: text2text.DefaultModelForAbstractiveQuestionAnswering,
40+
ModelName: textgeneration.DefaultModelForAbstractiveQuestionAnswering,
4141
})
4242
if err != nil {
4343
log.Fatal().Err(err).Send()
4444
}
4545
defer tasks.Finalize(m)
4646

47-
opts := text2text.DefaultOptions()
47+
opts := textgeneration.DefaultOptions()
4848

4949
start := time.Now()
50-
result, err := m.Generate(context.Background(), text2text.PrepareInputForAbstractiveQuestionAnswering(query, passages), opts)
50+
result, err := m.Generate(context.Background(), textgeneration.PrepareInputForAbstractiveQuestionAnswering(query, passages), opts)
5151
if err != nil {
5252
panic(err)
5353
}

examples/relationextraction/main.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313
//lint:ignore ST1001 allow dot import just to make the example more readable
1414
. "github.com/nlpodyssey/cybertron/examples"
1515
"github.com/nlpodyssey/cybertron/pkg/tasks"
16-
"github.com/nlpodyssey/cybertron/pkg/tasks/text2text"
16+
"github.com/nlpodyssey/cybertron/pkg/tasks/textgeneration"
1717
"github.com/rs/zerolog"
1818
"github.com/rs/zerolog/log"
1919
)
@@ -38,7 +38,7 @@ func main() {
3838
}
3939
defer tasks.Finalize(m)
4040

41-
opts := text2text.DefaultOptions()
41+
opts := textgeneration.DefaultOptions()
4242

4343
fn := func(text string) error {
4444
start := time.Now()

examples/textgeneration/main.go

+64-5
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,39 @@ import (
88
"context"
99
"fmt"
1010
"os"
11+
"runtime"
1112
"time"
1213

1314
//lint:ignore ST1001 allow dot import just to make the example more readable
1415
. "github.com/nlpodyssey/cybertron/examples"
1516
"github.com/nlpodyssey/cybertron/pkg/tasks"
16-
"github.com/nlpodyssey/cybertron/pkg/tasks/text2text"
17+
"github.com/nlpodyssey/cybertron/pkg/tasks/textgeneration"
1718
"github.com/rs/zerolog"
1819
"github.com/rs/zerolog/log"
20+
"github.com/shirou/gopsutil/v3/cpu"
21+
"github.com/shirou/gopsutil/v3/mem"
22+
"github.com/shirou/gopsutil/v3/process"
1923
)
2024

2125
func main() {
2226
zerolog.SetGlobalLevel(zerolog.DebugLevel)
2327
LoadDotenv()
2428

25-
modelsDir := HasEnvVar("CYBERTRON_MODELS_DIR")
26-
modelName := HasEnvVar("CYBERTRON_MODEL")
29+
modelsDir := "/Users/mg/Projects/nlpodyssey/cybertron/models" //HasEnvVar("CYBERTRON_MODELS_DIR")
30+
modelName := "Helsinki-NLP/opus-mt-it-en"
2731

28-
m, err := tasks.Load[text2text.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName})
32+
start := time.Now()
33+
m, err := tasks.Load[textgeneration.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName})
2934
if err != nil {
3035
log.Fatal().Err(err).Send()
3136
}
3237
defer tasks.Finalize(m)
3338

34-
opts := text2text.DefaultOptions()
39+
log.Debug().Msgf("Loaded model %q in %v", modelName, time.Since(start))
40+
41+
logMetrics()
42+
43+
opts := textgeneration.DefaultOptions()
3544

3645
fn := func(text string) error {
3746
start := time.Now()
@@ -41,6 +50,7 @@ func main() {
4150
}
4251
fmt.Println(time.Since(start).Seconds())
4352
fmt.Println(result.Texts[0])
53+
runtime.GC()
4454
return nil
4555
}
4656

@@ -49,3 +59,52 @@ func main() {
4959
log.Fatal().Err(err).Send()
5060
}
5161
}
62+
63+
func logMetrics() error {
64+
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
65+
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
66+
67+
// Get total CPU count
68+
totalCpu, err := cpu.Counts(false)
69+
if err != nil {
70+
return err
71+
}
72+
// Get process CPU percentage
73+
p, err := process.NewProcess(int32(os.Getpid()))
74+
if err != nil {
75+
return err
76+
}
77+
percent, err := p.CPUPercent()
78+
if err != nil {
79+
return err
80+
}
81+
82+
// Log CPU Metrics
83+
log.Info().
84+
Int("total_cpu_cores", totalCpu).
85+
Float64("process_cpu_usage_percent", percent).
86+
Msg("CPU Metrics")
87+
88+
// Get total available RAM
89+
vmStat, err := mem.VirtualMemory()
90+
if err != nil {
91+
return err
92+
}
93+
// Get process RAM usage
94+
memInfo, err := p.MemoryInfo()
95+
if err != nil {
96+
return err
97+
}
98+
99+
// Log RAM Metrics
100+
log.Info().
101+
Float64("total_ram_available_mb", byteToMb(vmStat.Total)).
102+
Float64("process_ram_usage_mb", byteToMb(memInfo.RSS)).
103+
Msg("RAM Metrics")
104+
105+
return nil
106+
}
107+
108+
func byteToMb(b uint64) float64 {
109+
return float64(b) / 1024 / 1024
110+
}

0 commit comments

Comments
 (0)