-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.go
156 lines (136 loc) · 3.62 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
// Package main provides the command line utility.
package main
import (
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"github.com/cassierecher/markovian/impl"
)
var (
inFilePath = flag.String("in", "", "A file containing a Markov chain to use. Leave empty to start with a new Markov chain.")
dataFilePath = flag.String("data", "", `A file of data to train the Markov chain on. Alternatively, specify "stdin" to use standard input.`)
outFilePath = flag.String("out", "markov.json", "A file to store the Markov chain in. File will be overwritten if it already exists.")
order = flag.Int("order", 2, `The order of the Markov chain.`)
)
func init() {
flag.Parse()
}
// Implements the "help" command.
func helpCmd() {
fmt.Fprintf(os.Stderr, `Markovian
Synopsis: markovian ARG
Args:
- help: Display this message.
- train: Train a Markov chain. Relevant flags: in, data, out, order.
Flags:
`)
flag.PrintDefaults()
}
// Gets the Markov chain to focus on, either from the specified file, or new if none was specified.
// Returns the Markov chain, or an error if one should occur.
func obtainMarkovChain() (*impl.MarkovChain, error) {
// Must use address of zero value instead of nil pointer due to JSON parsing requirement.
mc := &impl.MarkovChain{}
if *inFilePath == "" {
var err error
mc, err = impl.New(*order)
if err != nil {
return nil, fmt.Errorf("couldn't make new Markov chain: %s", err)
}
} else {
// Get data from file.
b, err := ioutil.ReadFile(*inFilePath)
if err != nil {
return nil, fmt.Errorf("couldn't read input file: %s", err)
}
if err := json.Unmarshal(b, mc); err != nil {
return nil, fmt.Errorf("couldn't read json: %s", err)
}
}
return mc, nil
}
// Implements the "train" command.
// Returns errors, if one should occur.
func trainCmd() error {
// Get the input Markov chain.
mc, err := obtainMarkovChain()
if err != nil {
return fmt.Errorf("couldn't obtain Markov chain: %s", err)
}
// Get the data to read.
var r io.Reader
switch *dataFilePath {
case "":
return errors.New("must provide input file path")
case "stdin":
r = os.Stdin
default:
// Input is possibly very large; use a reader instead of ioutil convenience method.
in, err := os.Open(*dataFilePath)
if err != nil {
return fmt.Errorf("couldn't open file: %s", err)
}
defer in.Close()
r = in
}
// Perform training.
mc.Train(r)
// Marshal the resulting Markov chain to JSON in a file.
b, err := json.Marshal(mc)
if err != nil {
return err
}
return ioutil.WriteFile(*outFilePath, b, 0600)
}
// Implements the "generate" command.
// Returns errors, if one should occur.
func generateCmd() error {
// Get the input Markov chain.
mc, err := obtainMarkovChain()
if err != nil {
return fmt.Errorf("couldn't obtain Markov chain: %s", err)
}
// Generate doesn't work with an untrained Markov chain.
if len(mc.Lessons) < 1 {
return errors.New(`command "generate" requires trained input Markov chain`)
}
fmt.Printf("Not yet implemented\n")
return nil
}
func main() {
// Handle args.
args := flag.Args()
// Validate number of args.
if len(args) < 1 {
fmt.Fprintf(os.Stderr, "Not enough args.\n")
helpCmd()
os.Exit(1)
}
if len(args) > 1 {
fmt.Fprintf(os.Stderr, "Too many args.\n")
helpCmd()
os.Exit(1)
}
switch args[0] {
case "train":
if err := trainCmd(); err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
}
case "generate":
if err := generateCmd(); err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
}
case "help":
helpCmd()
default:
fmt.Fprintf(os.Stderr, "Unrecognized command %q.\n", args[0])
helpCmd()
os.Exit(1)
}
}