Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement dataframe.Pivot (#92) #202

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,18 @@ groups := df.GroupBy("key1", "key2") // Group by column "key1", and column "key2
aggre := groups.Aggregation([]AggregationType{Aggregation_MAX, Aggregation_MIN}, []string{"values", "values2"}) // Maximum value in column "values", Minimum value in column "values2"
```

#### Pivot

```go
pivot := df.Pivot(
[]string{"A", "B"}, // rows
[]string{"C", "D"}, // columns
[]PivotValue{ // values
{Colname: "E", AggregationType: Aggregation_SUM},
{Colname: "F", AggregationType: Aggregation_COUNT},
})
```

#### Arrange

With Arrange a DataFrame can be sorted by the given column names:
Expand Down
271 changes: 270 additions & 1 deletion dataframe/dataframe.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ func (gps Groups) Aggregation(typs []AggregationType, colnames []string) DataFra
return DataFrame{Err: fmt.Errorf("Aggregation: this method %s not found", typs[i])}

}
curMap[fmt.Sprintf("%s_%s", c, typs[i])] = value
curMap[buildAggregatedColname(c, typs[i])] = value
}
dfMaps = append(dfMaps, curMap)

Expand Down Expand Up @@ -2293,6 +2293,10 @@ func inIntSlice(i int, is []int) bool {
return false
}

func buildAggregatedColname(c string, typ AggregationType) string {
return fmt.Sprintf("%s_%s", c, typ)
}

// Matrix is an interface which is compatible with gonum's mat.Matrix interface
type Matrix interface {
Dims() (r, c int)
Expand Down Expand Up @@ -2357,3 +2361,268 @@ func (df DataFrame) Describe() DataFrame {
ddf := New(ss...)
return ddf
}

type PivotValue struct {
Colname string
AggregationType AggregationType
}

// Pivot Create a dataframe like spreadsheet-style pivot table
func (df DataFrame) Pivot(rows []string, columns []string, values []PivotValue) DataFrame {
err := df.checkPivotParams(rows, columns, values)
if err != nil {
return DataFrame{Err: err}
}

aggregatedDF := df.aggregateByRowsAndColumns(rows, columns, values)
if aggregatedDF.Err != nil {
return aggregatedDF
}

generatedColnames, generatedColtyps := df.buildGeneratedCols(aggregatedDF, columns, values)

var rowGroups map[string]DataFrame
if len(rows) == 0 {
rowGroups = map[string]DataFrame{"": aggregatedDF}
} else {
rowGroups = aggregatedDF.GroupBy(rows...).groups
}
rowGroupsKeys := make([]string, 0, len(rowGroups))
for key := range rowGroups {
rowGroupsKeys = append(rowGroupsKeys, key)
}
sort.Strings(rowGroupsKeys)
newColnames, newColElements := df.buildNewCols(rows, generatedColnames, len(rowGroupsKeys))

rowIdx := 0
for key, rowGroupDF := range rowGroups {
rowIdx = strIndexInStrSlice(rowGroupsKeys, key)

// fill row
for colIdx, colname := range rows {
newColElements[colIdx][rowIdx] = rowGroupDF.Col(colname).Elem(0)
}
// set default value for columns
for colIdx := range generatedColnames {
newColElements[colIdx+len(rows)][rowIdx] = getDefaultElem(generatedColtyps[colIdx])
}

// update value of columns
for i := 0; i < rowGroupDF.Nrow(); i++ {
colNames := make([]string, 0, len(columns))
for _, col := range columns {
colNames = append(colNames, rowGroupDF.Col(col).Elem(i).String())
}

for _, valueColumn := range values {
aggregatedColname := buildAggregatedColname(valueColumn.Colname, valueColumn.AggregationType)
newColNames := append(colNames, aggregatedColname)
newColname := strings.Join(newColNames, "_")
colIdx := strIndexInStrSlice(generatedColnames, newColname)
newColElements[len(rows)+colIdx][rowIdx] = rowGroupDF.Col(aggregatedColname).Elem(i)
}
}
rowIdx++
}

newColumnSlice := make([]series.Series, 0, len(newColnames))
for i, colname := range newColnames {
var typ series.Type
if i < len(rows) {
typ = df.Col(colname).Type()
} else {
typ = generatedColtyps[i-len(rows)]
}
newColumnSlice = append(newColumnSlice, series.New(newColElements[i], typ, colname))
}

return New(newColumnSlice...)
}

func (df *DataFrame) checkPivotParams(rows []string, columns []string, values []PivotValue) error {
if len(values) == 0 {
return fmt.Errorf("values cannot be empty")
}

usedColumnNames := make(map[string]bool, len(rows)+len(columns)+len(values))
dfNames := df.Names()
for _, colName := range rows {
err := df.isValidColumnParam(usedColumnNames, colName, dfNames)
if err != nil {
return err
}
}
for _, colName := range columns {
err := df.isValidColumnParam(usedColumnNames, colName, dfNames)
if err != nil {
return err
}
}
for _, col := range values {
err := df.isValidColumnParam(usedColumnNames, col.Colname, dfNames)
if err != nil {
return err
}
}

for _, value := range values {
switch df.Col(value.Colname).Type() {
case series.Int, series.Float:
// only support numbers
continue
default:
return fmt.Errorf("series cannot aggregate")
}
}
return nil
}

func (df *DataFrame) isValidColumnParam(usedColumnNames map[string]bool, colName string, dfNames []string) error {
if _, ok := usedColumnNames[colName]; ok {
return fmt.Errorf("column %s cannot be used more than once", colName)
}
usedColumnNames[colName] = true
if !isStrInStrSlice(dfNames, colName) {
return fmt.Errorf("column %s not exist", colName)
}
return nil
}

func (df DataFrame) aggregateByRowsAndColumns(rows []string, columns []string, values []PivotValue) DataFrame {
valueColnames := make([]string, 0, len(values))
aggregationTypes := make([]AggregationType, 0, len(values))
for _, value := range values {
valueColnames = append(valueColnames, value.Colname)
if value.AggregationType == 0 {
// default AggregationType is Aggregation_SUM
aggregationTypes = append(aggregationTypes, Aggregation_SUM)
} else {
aggregationTypes = append(aggregationTypes, value.AggregationType)
}
}

var selectedColnames []string
if len(rows) > 0 {
selectedColnames = append(selectedColnames, rows...)
}
if len(columns) > 0 {
selectedColnames = append(selectedColnames, columns...)
}
if len(selectedColnames) == 0 {
t := Groups{groups: map[string]DataFrame{"": df}, colnames: valueColnames}
return t.Aggregation(aggregationTypes, valueColnames)
}

groups := df.GroupBy(selectedColnames...)
if groups.Err != nil {
return DataFrame{Err: groups.Err}
}
return groups.Aggregation(aggregationTypes, valueColnames)
}

func (df DataFrame) buildGeneratedCols(aggregatedDF DataFrame, columns []string, values []PivotValue) ([]string, []series.Type) {
if len(columns) == 0 {
generatedColnames := make([]string, 0, len(values))
generatedColtyps := make([]series.Type, 0, len(values))
for _, value := range values {
aggregatedValueColname := buildAggregatedColname(value.Colname, value.AggregationType)
generatedColnames = append(generatedColnames, aggregatedValueColname)
generatedColtyps = append(generatedColtyps, df.Col(value.Colname).Type())
}
return generatedColnames, generatedColtyps
}

columnGroups := aggregatedDF.GroupBy(columns...).groups
generatedColElemsList := make([][]series.Element, 0, len(columnGroups))
for _, columnGroupDf := range columnGroups {
columnStrValues := make([]string, 0, len(columns))
columnElems := make([]series.Element, 0, len(columns))
for _, column := range columns {
columnStrValues = append(columnStrValues, columnGroupDf.Col(column).Elem(0).String())
columnElems = append(columnElems, columnGroupDf.Col(column).Elem(0))
}
generatedColElemsList = append(generatedColElemsList, columnElems)
}

// sort generatedColElemsList by elements
sort.Slice(generatedColElemsList, func(i, j int) bool {
generatedColElemsI := generatedColElemsList[i]
generatedColElemsJ := generatedColElemsList[j]

for idx := range generatedColElemsI {
if generatedColElemsI[idx].Less(generatedColElemsJ[idx]) {
return true
} else if generatedColElemsI[idx].Greater(generatedColElemsJ[idx]) {
return false
} else {
continue
}
}
// all elements are equal
return false
})

generatedColnames := make([]string, 0, len(generatedColElemsList)*len(values))
generatedColtyps := make([]series.Type, 0, len(generatedColElemsList)*len(values))
for _, generatedColElems := range generatedColElemsList {
tmpColnames := make([]string, 0, len(generatedColElems))
for _, elem := range generatedColElems {
tmpColnames = append(tmpColnames, elem.String())
}
for _, value := range values {
aggregatedValueColname := buildAggregatedColname(value.Colname, value.AggregationType)
tmpColName := strings.Join(append(tmpColnames, aggregatedValueColname), "_")
generatedColnames = append(generatedColnames, tmpColName)
generatedColtyps = append(generatedColtyps, df.Col(value.Colname).Type())
}
}
return generatedColnames, generatedColtyps
}

func (df DataFrame) buildNewCols(rows []string, generatedColnames []string, rowCnt int) ([]string, [][]series.Element) {
newColnames := make([]string, 0, len(rows)+len(generatedColnames))
if len(rows) > 0 {
newColnames = append(newColnames, rows...)
}
if len(generatedColnames) > 0 {
newColnames = append(newColnames, generatedColnames...)
}

newColElements := make([][]series.Element, len(newColnames))
for i := range newColElements {
newColElements[i] = make([]series.Element, rowCnt)
}
return newColnames, newColElements
}

var defaultIntElem = series.New([]int{0}, series.Int, "").Elem(0)
var defaultStringElem = series.New([]string{""}, series.String, "").Elem(0)
var defaultFloatElem = series.New([]float64{0}, series.Float, "").Elem(0)
var defaultBoolElem = series.New([]bool{false}, series.Bool, "").Elem(0)

func getDefaultElem(tpe series.Type) series.Element {
switch tpe {
case series.String:
return defaultStringElem
case series.Int:
return defaultIntElem
case series.Float:
return defaultFloatElem
case series.Bool:
return defaultBoolElem
}
return nil
}

func strIndexInStrSlice(strSlice []string, str string) int {
for i, s := range strSlice {
if s == str {
return i
}
}
return -1
}

func isStrInStrSlice(strSlice []string, str string) bool {
return strIndexInStrSlice(strSlice, str) != -1
}
79 changes: 79 additions & 0 deletions dataframe/dataframe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3000,3 +3000,82 @@ func TestGroups_GetGroups(t *testing.T) {
t.Fatalf("Expected to get 3 groups, got %d", len(groupNames))
}
}

func TestDataFrame_Pivot(t *testing.T) {
// case 1, set rows and values
df := New(
series.New([]string{"a", "b", "b"}, series.String, "A"),
series.New([]int{1, 2, 3}, series.Int, "B"),
)
newDF := df.Pivot([]string{"A"}, nil, []PivotValue{{Colname: "B", AggregationType: Aggregation_SUM}})
expectedRecords := [][]string{
{"A", "B_SUM"},
{"a", "1"},
{"b", "5"},
}
if !reflect.DeepEqual(newDF.Records(), expectedRecords) {
t.Fatalf("unexpced result, result=%v expected=%v", newDF.Records(), expectedRecords)
}

// case 2, set columns and values
df = New(
series.New([]string{"a", "b", "b"}, series.String, "A"),
series.New([]int{1, 2, 3}, series.Int, "B"),
)
newDF = df.Pivot(nil, []string{"A"}, []PivotValue{{Colname: "B", AggregationType: Aggregation_SUM}})
expectedRecords = [][]string{
{"a_B_SUM", "b_B_SUM"},
{"1", "5"},
}
if !reflect.DeepEqual(newDF.Records(), expectedRecords) {
t.Fatalf("unexpced result, result=%v expected=%v", newDF.Records(), expectedRecords)
}

// case 3, only set values
df = New(
series.New([]string{"a", "b"}, series.String, "A"),
series.New([]int{1, 2}, series.Int, "B"),
)
newDF = df.Pivot(nil, nil, []PivotValue{{Colname: "B", AggregationType: Aggregation_SUM}})
expectedRecords = [][]string{
{"B_SUM"},
{"3"},
}
if !reflect.DeepEqual(newDF.Records(), expectedRecords) {
t.Fatalf("unexpced result, result=%v expected=%v", newDF.Records(), expectedRecords)
}

// case4, set all parameters
newDF = New(
series.New([]string{"A1", "A1", "A1", "A1", "A1", "A1", "A1"}, series.String, "A"),
series.New([]string{"B1", "B1", "B1", "B2", "B2", "B2", "B2"}, series.String, "B"),
series.New([]string{"C1", "C1", "C2", "C1", "C1", "C2", "C2"}, series.String, "C"),
series.New([]string{"D1", "D2", "D1", "D1", "D2", "D1", "D1"}, series.String, "D"),
series.New([]int{1, 2, 3, 4, 5, 6, 7}, series.Int, "E"),
series.New([]int{8, 9, 10, 11, 12, 13, 14}, series.Int, "F"),
).Pivot(
[]string{"A", "B"},
[]string{"C", "D"},
[]PivotValue{
{Colname: "E", AggregationType: Aggregation_SUM},
{Colname: "F", AggregationType: Aggregation_COUNT},
})
expectedRecords = [][]string{
{"A", "B", "C1_D1_E_SUM", "C1_D1_F_COUNT", "C1_D2_E_SUM", "C1_D2_F_COUNT", "C2_D1_E_SUM", "C2_D1_F_COUNT"},
{"A1", "B1", "1", "1", "2", "1", "3", "1"},
{"A1", "B2", "4", "1", "5", "1", "13", "2"},
}
if !reflect.DeepEqual(newDF.Records(), expectedRecords) {
t.Fatalf("unexpced result, result=%v expected=%v", newDF.Records(), expectedRecords)
}

// case 5, invalid params
df = New(
series.New([]string{"a", "b"}, series.String, "A"),
series.New([]int{1, 2}, series.Int, "B"),
)
newDF = df.Pivot(nil, nil, nil)
if newDF.Err == nil {
t.Fatalf("expect param error")
}
}