Skip to content

Commit

Permalink
Implement dataframe.Pivot (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
toaco committed Dec 1, 2022
1 parent f705409 commit 9720435
Show file tree
Hide file tree
Showing 3 changed files with 361 additions and 1 deletion.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,18 @@ groups := df.GroupBy("key1", "key2") // Group by column "key1", and column "key2
aggre := groups.Aggregation([]AggregationType{Aggregation_MAX, Aggregation_MIN}, []string{"values", "values2"}) // Maximum value in column "values", Minimum value in column "values2"
```

#### Pivot

```go
pivot := df.Pivot(
[]string{"A", "B"}, // rows
[]string{"C", "D"}, // columns
[]PivotValue{ // values
{Colname: "E", AggregationType: Aggregation_SUM},
{Colname: "F", AggregationType: Aggregation_COUNT},
})
```

#### Arrange

With Arrange a DataFrame can be sorted by the given column names:
Expand Down
271 changes: 270 additions & 1 deletion dataframe/dataframe.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ func (gps Groups) Aggregation(typs []AggregationType, colnames []string) DataFra
return DataFrame{Err: fmt.Errorf("Aggregation: this method %s not found", typs[i])}

}
curMap[fmt.Sprintf("%s_%s", c, typs[i])] = value
curMap[buildAggregatedColname(c, typs[i])] = value
}
dfMaps = append(dfMaps, curMap)

Expand Down Expand Up @@ -2293,6 +2293,10 @@ func inIntSlice(i int, is []int) bool {
return false
}

func buildAggregatedColname(c string, typ AggregationType) string {
return fmt.Sprintf("%s_%s", c, typ)
}

// Matrix is an interface which is compatible with gonum's mat.Matrix interface
type Matrix interface {
Dims() (r, c int)
Expand Down Expand Up @@ -2357,3 +2361,268 @@ func (df DataFrame) Describe() DataFrame {
ddf := New(ss...)
return ddf
}

type PivotValue struct {
Colname string
AggregationType AggregationType
}

// Pivot Create a dataframe like spreadsheet-style pivot table
func (df DataFrame) Pivot(rows []string, columns []string, values []PivotValue) DataFrame {
err := df.checkPivotParams(rows, columns, values)
if err != nil {
return DataFrame{Err: err}
}

aggregatedDF := df.aggregateByRowsAndColumns(rows, columns, values)
if aggregatedDF.Err != nil {
return aggregatedDF
}

generatedColnames, generatedColtyps := df.buildGeneratedCols(aggregatedDF, columns, values)

var rowGroups map[string]DataFrame
if len(rows) == 0 {
rowGroups = map[string]DataFrame{"": aggregatedDF}
} else {
rowGroups = aggregatedDF.GroupBy(rows...).groups
}
rowGroupsKeys := make([]string, 0, len(rowGroups))
for key := range rowGroups {
rowGroupsKeys = append(rowGroupsKeys, key)
}
sort.Strings(rowGroupsKeys)
newColnames, newColElements := df.buildNewCols(rows, generatedColnames, len(rowGroupsKeys))

rowIdx := 0
for key, rowGroupDF := range rowGroups {
rowIdx = strIndexInStrSlice(rowGroupsKeys, key)

// fill row
for colIdx, colname := range rows {
newColElements[colIdx][rowIdx] = rowGroupDF.Col(colname).Elem(0)
}
// set default value for columns
for colIdx := range generatedColnames {
newColElements[colIdx+len(rows)][rowIdx] = getDefaultElem(generatedColtyps[colIdx])
}

// update value of columns
for i := 0; i < rowGroupDF.Nrow(); i++ {
colNames := make([]string, 0, len(columns))
for _, col := range columns {
colNames = append(colNames, rowGroupDF.Col(col).Elem(i).String())
}

for _, valueColumn := range values {
aggregatedColname := buildAggregatedColname(valueColumn.Colname, valueColumn.AggregationType)
newColNames := append(colNames, aggregatedColname)
newColname := strings.Join(newColNames, "_")
colIdx := strIndexInStrSlice(generatedColnames, newColname)
newColElements[len(rows)+colIdx][rowIdx] = rowGroupDF.Col(aggregatedColname).Elem(i)
}
}
rowIdx++
}

newColumnSlice := make([]series.Series, 0, len(newColnames))
for i, colname := range newColnames {
var typ series.Type
if i < len(rows) {
typ = df.Col(colname).Type()
} else {
typ = generatedColtyps[i-len(rows)]
}
newColumnSlice = append(newColumnSlice, series.New(newColElements[i], typ, colname))
}

return New(newColumnSlice...)
}

func (df *DataFrame) checkPivotParams(rows []string, columns []string, values []PivotValue) error {
if len(values) == 0 {
return fmt.Errorf("values cannot be empty")
}

usedColumnNames := make(map[string]bool, len(rows)+len(columns)+len(values))
dfNames := df.Names()
for _, colName := range rows {
err := df.isValidColumnParam(usedColumnNames, colName, dfNames)
if err != nil {
return err
}
}
for _, colName := range columns {
err := df.isValidColumnParam(usedColumnNames, colName, dfNames)
if err != nil {
return err
}
}
for _, col := range values {
err := df.isValidColumnParam(usedColumnNames, col.Colname, dfNames)
if err != nil {
return err
}
}

for _, value := range values {
switch df.Col(value.Colname).Type() {
case series.Int, series.Float:
// only support numbers
continue
default:
return fmt.Errorf("series cannot aggregate")
}
}
return nil
}

func (df *DataFrame) isValidColumnParam(usedColumnNames map[string]bool, colName string, dfNames []string) error {
if _, ok := usedColumnNames[colName]; ok {
return fmt.Errorf("column %s cannot be used more than once", colName)
}
usedColumnNames[colName] = true
if !isStrInStrSlice(dfNames, colName) {
return fmt.Errorf("column %s not exist", colName)
}
return nil
}

func (df DataFrame) aggregateByRowsAndColumns(rows []string, columns []string, values []PivotValue) DataFrame {
valueColnames := make([]string, 0, len(values))
aggregationTypes := make([]AggregationType, 0, len(values))
for _, value := range values {
valueColnames = append(valueColnames, value.Colname)
if value.AggregationType == 0 {
// default AggregationType is Aggregation_SUM
aggregationTypes = append(aggregationTypes, Aggregation_SUM)
} else {
aggregationTypes = append(aggregationTypes, value.AggregationType)
}
}

var selectedColnames []string
if len(rows) > 0 {
selectedColnames = append(selectedColnames, rows...)
}
if len(columns) > 0 {
selectedColnames = append(selectedColnames, columns...)
}
if len(selectedColnames) == 0 {
t := Groups{groups: map[string]DataFrame{"": df}, colnames: valueColnames}
return t.Aggregation(aggregationTypes, valueColnames)
}

groups := df.GroupBy(selectedColnames...)
if groups.Err != nil {
return DataFrame{Err: groups.Err}
}
return groups.Aggregation(aggregationTypes, valueColnames)
}

func (df DataFrame) buildGeneratedCols(aggregatedDF DataFrame, columns []string, values []PivotValue) ([]string, []series.Type) {
if len(columns) == 0 {
generatedColnames := make([]string, 0, len(values))
generatedColtyps := make([]series.Type, 0, len(values))
for _, value := range values {
aggregatedValueColname := buildAggregatedColname(value.Colname, value.AggregationType)
generatedColnames = append(generatedColnames, aggregatedValueColname)
generatedColtyps = append(generatedColtyps, df.Col(value.Colname).Type())
}
return generatedColnames, generatedColtyps
}

columnGroups := aggregatedDF.GroupBy(columns...).groups
generatedColElemsList := make([][]series.Element, 0, len(columnGroups))
for _, columnGroupDf := range columnGroups {
columnStrValues := make([]string, 0, len(columns))
columnElems := make([]series.Element, 0, len(columns))
for _, column := range columns {
columnStrValues = append(columnStrValues, columnGroupDf.Col(column).Elem(0).String())
columnElems = append(columnElems, columnGroupDf.Col(column).Elem(0))
}
generatedColElemsList = append(generatedColElemsList, columnElems)
}

// sort generatedColElemsList by elements
sort.Slice(generatedColElemsList, func(i, j int) bool {
generatedColElemsI := generatedColElemsList[i]
generatedColElemsJ := generatedColElemsList[j]

for idx := range generatedColElemsI {
if generatedColElemsI[idx].Less(generatedColElemsJ[idx]) {
return true
} else if generatedColElemsI[idx].Greater(generatedColElemsJ[idx]) {
return false
} else {
continue
}
}
// all elements are equal
return false
})

generatedColnames := make([]string, 0, len(generatedColElemsList)*len(values))
generatedColtyps := make([]series.Type, 0, len(generatedColElemsList)*len(values))
for _, generatedColElems := range generatedColElemsList {
tmpColnames := make([]string, 0, len(generatedColElems))
for _, elem := range generatedColElems {
tmpColnames = append(tmpColnames, elem.String())
}
for _, value := range values {
aggregatedValueColname := buildAggregatedColname(value.Colname, value.AggregationType)
tmpColName := strings.Join(append(tmpColnames, aggregatedValueColname), "_")
generatedColnames = append(generatedColnames, tmpColName)
generatedColtyps = append(generatedColtyps, df.Col(value.Colname).Type())
}
}
return generatedColnames, generatedColtyps
}

func (df DataFrame) buildNewCols(rows []string, generatedColnames []string, rowCnt int) ([]string, [][]series.Element) {
newColnames := make([]string, 0, len(rows)+len(generatedColnames))
if len(rows) > 0 {
newColnames = append(newColnames, rows...)
}
if len(generatedColnames) > 0 {
newColnames = append(newColnames, generatedColnames...)
}

newColElements := make([][]series.Element, len(newColnames))
for i := range newColElements {
newColElements[i] = make([]series.Element, rowCnt)
}
return newColnames, newColElements
}

var defaultIntElem = series.New([]int{0}, series.Int, "").Elem(0)
var defaultStringElem = series.New([]string{""}, series.String, "").Elem(0)
var defaultFloatElem = series.New([]float64{0}, series.Float, "").Elem(0)
var defaultBoolElem = series.New([]bool{false}, series.Bool, "").Elem(0)

func getDefaultElem(tpe series.Type) series.Element {
switch tpe {
case series.String:
return defaultStringElem
case series.Int:
return defaultIntElem
case series.Float:
return defaultFloatElem
case series.Bool:
return defaultBoolElem
}
return nil
}

func strIndexInStrSlice(strSlice []string, str string) int {
for i, s := range strSlice {
if s == str {
return i
}
}
return -1
}

func isStrInStrSlice(strSlice []string, str string) bool {
return strIndexInStrSlice(strSlice, str) != -1
}
79 changes: 79 additions & 0 deletions dataframe/dataframe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3000,3 +3000,82 @@ func TestGroups_GetGroups(t *testing.T) {
t.Fatalf("Expected to get 3 groups, got %d", len(groupNames))
}
}

func TestDataFrame_Pivot(t *testing.T) {
// case 1, set rows and values
df := New(
series.New([]string{"a", "b", "b"}, series.String, "A"),
series.New([]int{1, 2, 3}, series.Int, "B"),
)
newDF := df.Pivot([]string{"A"}, nil, []PivotValue{{Colname: "B", AggregationType: Aggregation_SUM}})
expectedRecords := [][]string{
{"A", "B_SUM"},
{"a", "1"},
{"b", "5"},
}
if !reflect.DeepEqual(newDF.Records(), expectedRecords) {
t.Fatalf("unexpced result, result=%v expected=%v", newDF.Records(), expectedRecords)
}

// case 2, set columns and values
df = New(
series.New([]string{"a", "b", "b"}, series.String, "A"),
series.New([]int{1, 2, 3}, series.Int, "B"),
)
newDF = df.Pivot(nil, []string{"A"}, []PivotValue{{Colname: "B", AggregationType: Aggregation_SUM}})
expectedRecords = [][]string{
{"a_B_SUM", "b_B_SUM"},
{"1", "5"},
}
if !reflect.DeepEqual(newDF.Records(), expectedRecords) {
t.Fatalf("unexpced result, result=%v expected=%v", newDF.Records(), expectedRecords)
}

// case 3, only set values
df = New(
series.New([]string{"a", "b"}, series.String, "A"),
series.New([]int{1, 2}, series.Int, "B"),
)
newDF = df.Pivot(nil, nil, []PivotValue{{Colname: "B", AggregationType: Aggregation_SUM}})
expectedRecords = [][]string{
{"B_SUM"},
{"3"},
}
if !reflect.DeepEqual(newDF.Records(), expectedRecords) {
t.Fatalf("unexpced result, result=%v expected=%v", newDF.Records(), expectedRecords)
}

// case4, set all parameters
newDF = New(
series.New([]string{"A1", "A1", "A1", "A1", "A1", "A1", "A1"}, series.String, "A"),
series.New([]string{"B1", "B1", "B1", "B2", "B2", "B2", "B2"}, series.String, "B"),
series.New([]string{"C1", "C1", "C2", "C1", "C1", "C2", "C2"}, series.String, "C"),
series.New([]string{"D1", "D2", "D1", "D1", "D2", "D1", "D1"}, series.String, "D"),
series.New([]int{1, 2, 3, 4, 5, 6, 7}, series.Int, "E"),
series.New([]int{8, 9, 10, 11, 12, 13, 14}, series.Int, "F"),
).Pivot(
[]string{"A", "B"},
[]string{"C", "D"},
[]PivotValue{
{Colname: "E", AggregationType: Aggregation_SUM},
{Colname: "F", AggregationType: Aggregation_COUNT},
})
expectedRecords = [][]string{
{"A", "B", "C1_D1_E_SUM", "C1_D1_F_COUNT", "C1_D2_E_SUM", "C1_D2_F_COUNT", "C2_D1_E_SUM", "C2_D1_F_COUNT"},
{"A1", "B1", "1", "1", "2", "1", "3", "1"},
{"A1", "B2", "4", "1", "5", "1", "13", "2"},
}
if !reflect.DeepEqual(newDF.Records(), expectedRecords) {
t.Fatalf("unexpced result, result=%v expected=%v", newDF.Records(), expectedRecords)
}

// case 5, invalid params
df = New(
series.New([]string{"a", "b"}, series.String, "A"),
series.New([]int{1, 2}, series.Int, "B"),
)
newDF = df.Pivot(nil, nil, nil)
if newDF.Err == nil {
t.Fatalf("expect param error")
}
}

0 comments on commit 9720435

Please sign in to comment.