diff --git a/README.md b/README.md index c653ce8..09cacf3 100644 --- a/README.md +++ b/README.md @@ -259,6 +259,18 @@ groups := df.GroupBy("key1", "key2") // Group by column "key1", and column "key2 aggre := groups.Aggregation([]AggregationType{Aggregation_MAX, Aggregation_MIN}, []string{"values", "values2"}) // Maximum value in column "values", Minimum value in column "values2" ``` +#### Pivot + +```go +pivot := df.Pivot( + []string{"A", "B"}, // rows + []string{"C", "D"}, // columns + []PivotValue{ // values + {Colname: "E", AggregationType: Aggregation_SUM}, + {Colname: "F", AggregationType: Aggregation_COUNT}, +}) +``` + #### Arrange With Arrange a DataFrame can be sorted by the given column names: diff --git a/dataframe/dataframe.go b/dataframe/dataframe.go index 51d38f6..ba73e6b 100644 --- a/dataframe/dataframe.go +++ b/dataframe/dataframe.go @@ -499,7 +499,7 @@ func (gps Groups) Aggregation(typs []AggregationType, colnames []string) DataFra return DataFrame{Err: fmt.Errorf("Aggregation: this method %s not found", typs[i])} } - curMap[fmt.Sprintf("%s_%s", c, typs[i])] = value + curMap[buildAggregatedColname(c, typs[i])] = value } dfMaps = append(dfMaps, curMap) @@ -2293,6 +2293,10 @@ func inIntSlice(i int, is []int) bool { return false } +func buildAggregatedColname(c string, typ AggregationType) string { + return fmt.Sprintf("%s_%s", c, typ) +} + // Matrix is an interface which is compatible with gonum's mat.Matrix interface type Matrix interface { Dims() (r, c int) @@ -2357,3 +2361,268 @@ func (df DataFrame) Describe() DataFrame { ddf := New(ss...) return ddf } + +type PivotValue struct { + Colname string + AggregationType AggregationType +} + +// Pivot Create a dataframe like spreadsheet-style pivot table +func (df DataFrame) Pivot(rows []string, columns []string, values []PivotValue) DataFrame { + err := df.checkPivotParams(rows, columns, values) + if err != nil { + return DataFrame{Err: err} + } + + aggregatedDF := df.aggregateByRowsAndColumns(rows, columns, values) + if aggregatedDF.Err != nil { + return aggregatedDF + } + + generatedColnames, generatedColtyps := df.buildGeneratedCols(aggregatedDF, columns, values) + + var rowGroups map[string]DataFrame + if len(rows) == 0 { + rowGroups = map[string]DataFrame{"": aggregatedDF} + } else { + rowGroups = aggregatedDF.GroupBy(rows...).groups + } + rowGroupsKeys := make([]string, 0, len(rowGroups)) + for key := range rowGroups { + rowGroupsKeys = append(rowGroupsKeys, key) + } + sort.Strings(rowGroupsKeys) + newColnames, newColElements := df.buildNewCols(rows, generatedColnames, len(rowGroupsKeys)) + + rowIdx := 0 + for key, rowGroupDF := range rowGroups { + rowIdx = strIndexInStrSlice(rowGroupsKeys, key) + + // fill row + for colIdx, colname := range rows { + newColElements[colIdx][rowIdx] = rowGroupDF.Col(colname).Elem(0) + } + // set default value for columns + for colIdx := range generatedColnames { + newColElements[colIdx+len(rows)][rowIdx] = getDefaultElem(generatedColtyps[colIdx]) + } + + // update value of columns + for i := 0; i < rowGroupDF.Nrow(); i++ { + colNames := make([]string, 0, len(columns)) + for _, col := range columns { + colNames = append(colNames, rowGroupDF.Col(col).Elem(i).String()) + } + + for _, valueColumn := range values { + aggregatedColname := buildAggregatedColname(valueColumn.Colname, valueColumn.AggregationType) + newColNames := append(colNames, aggregatedColname) + newColname := strings.Join(newColNames, "_") + colIdx := strIndexInStrSlice(generatedColnames, newColname) + newColElements[len(rows)+colIdx][rowIdx] = rowGroupDF.Col(aggregatedColname).Elem(i) + } + } + rowIdx++ + } + + newColumnSlice := make([]series.Series, 0, len(newColnames)) + for i, colname := range newColnames { + var typ series.Type + if i < len(rows) { + typ = df.Col(colname).Type() + } else { + typ = generatedColtyps[i-len(rows)] + } + newColumnSlice = append(newColumnSlice, series.New(newColElements[i], typ, colname)) + } + + return New(newColumnSlice...) +} + +func (df *DataFrame) checkPivotParams(rows []string, columns []string, values []PivotValue) error { + if len(values) == 0 { + return fmt.Errorf("values cannot be empty") + } + + usedColumnNames := make(map[string]bool, len(rows)+len(columns)+len(values)) + dfNames := df.Names() + for _, colName := range rows { + err := df.isValidColumnParam(usedColumnNames, colName, dfNames) + if err != nil { + return err + } + } + for _, colName := range columns { + err := df.isValidColumnParam(usedColumnNames, colName, dfNames) + if err != nil { + return err + } + } + for _, col := range values { + err := df.isValidColumnParam(usedColumnNames, col.Colname, dfNames) + if err != nil { + return err + } + } + + for _, value := range values { + switch df.Col(value.Colname).Type() { + case series.Int, series.Float: + // only support numbers + continue + default: + return fmt.Errorf("series cannot aggregate") + } + } + return nil +} + +func (df *DataFrame) isValidColumnParam(usedColumnNames map[string]bool, colName string, dfNames []string) error { + if _, ok := usedColumnNames[colName]; ok { + return fmt.Errorf("column %s cannot be used more than once", colName) + } + usedColumnNames[colName] = true + if !isStrInStrSlice(dfNames, colName) { + return fmt.Errorf("column %s not exist", colName) + } + return nil +} + +func (df DataFrame) aggregateByRowsAndColumns(rows []string, columns []string, values []PivotValue) DataFrame { + valueColnames := make([]string, 0, len(values)) + aggregationTypes := make([]AggregationType, 0, len(values)) + for _, value := range values { + valueColnames = append(valueColnames, value.Colname) + if value.AggregationType == 0 { + // default AggregationType is Aggregation_SUM + aggregationTypes = append(aggregationTypes, Aggregation_SUM) + } else { + aggregationTypes = append(aggregationTypes, value.AggregationType) + } + } + + var selectedColnames []string + if len(rows) > 0 { + selectedColnames = append(selectedColnames, rows...) + } + if len(columns) > 0 { + selectedColnames = append(selectedColnames, columns...) + } + if len(selectedColnames) == 0 { + t := Groups{groups: map[string]DataFrame{"": df}, colnames: valueColnames} + return t.Aggregation(aggregationTypes, valueColnames) + } + + groups := df.GroupBy(selectedColnames...) + if groups.Err != nil { + return DataFrame{Err: groups.Err} + } + return groups.Aggregation(aggregationTypes, valueColnames) +} + +func (df DataFrame) buildGeneratedCols(aggregatedDF DataFrame, columns []string, values []PivotValue) ([]string, []series.Type) { + if len(columns) == 0 { + generatedColnames := make([]string, 0, len(values)) + generatedColtyps := make([]series.Type, 0, len(values)) + for _, value := range values { + aggregatedValueColname := buildAggregatedColname(value.Colname, value.AggregationType) + generatedColnames = append(generatedColnames, aggregatedValueColname) + generatedColtyps = append(generatedColtyps, df.Col(value.Colname).Type()) + } + return generatedColnames, generatedColtyps + } + + columnGroups := aggregatedDF.GroupBy(columns...).groups + generatedColElemsList := make([][]series.Element, 0, len(columnGroups)) + for _, columnGroupDf := range columnGroups { + columnStrValues := make([]string, 0, len(columns)) + columnElems := make([]series.Element, 0, len(columns)) + for _, column := range columns { + columnStrValues = append(columnStrValues, columnGroupDf.Col(column).Elem(0).String()) + columnElems = append(columnElems, columnGroupDf.Col(column).Elem(0)) + } + generatedColElemsList = append(generatedColElemsList, columnElems) + } + + // sort generatedColElemsList by elements + sort.Slice(generatedColElemsList, func(i, j int) bool { + generatedColElemsI := generatedColElemsList[i] + generatedColElemsJ := generatedColElemsList[j] + + for idx := range generatedColElemsI { + if generatedColElemsI[idx].Less(generatedColElemsJ[idx]) { + return true + } else if generatedColElemsI[idx].Greater(generatedColElemsJ[idx]) { + return false + } else { + continue + } + } + // all elements are equal + return false + }) + + generatedColnames := make([]string, 0, len(generatedColElemsList)*len(values)) + generatedColtyps := make([]series.Type, 0, len(generatedColElemsList)*len(values)) + for _, generatedColElems := range generatedColElemsList { + tmpColnames := make([]string, 0, len(generatedColElems)) + for _, elem := range generatedColElems { + tmpColnames = append(tmpColnames, elem.String()) + } + for _, value := range values { + aggregatedValueColname := buildAggregatedColname(value.Colname, value.AggregationType) + tmpColName := strings.Join(append(tmpColnames, aggregatedValueColname), "_") + generatedColnames = append(generatedColnames, tmpColName) + generatedColtyps = append(generatedColtyps, df.Col(value.Colname).Type()) + } + } + return generatedColnames, generatedColtyps +} + +func (df DataFrame) buildNewCols(rows []string, generatedColnames []string, rowCnt int) ([]string, [][]series.Element) { + newColnames := make([]string, 0, len(rows)+len(generatedColnames)) + if len(rows) > 0 { + newColnames = append(newColnames, rows...) + } + if len(generatedColnames) > 0 { + newColnames = append(newColnames, generatedColnames...) + } + + newColElements := make([][]series.Element, len(newColnames)) + for i := range newColElements { + newColElements[i] = make([]series.Element, rowCnt) + } + return newColnames, newColElements +} + +var defaultIntElem = series.New([]int{0}, series.Int, "").Elem(0) +var defaultStringElem = series.New([]string{""}, series.String, "").Elem(0) +var defaultFloatElem = series.New([]float64{0}, series.Float, "").Elem(0) +var defaultBoolElem = series.New([]bool{false}, series.Bool, "").Elem(0) + +func getDefaultElem(tpe series.Type) series.Element { + switch tpe { + case series.String: + return defaultStringElem + case series.Int: + return defaultIntElem + case series.Float: + return defaultFloatElem + case series.Bool: + return defaultBoolElem + } + return nil +} + +func strIndexInStrSlice(strSlice []string, str string) int { + for i, s := range strSlice { + if s == str { + return i + } + } + return -1 +} + +func isStrInStrSlice(strSlice []string, str string) bool { + return strIndexInStrSlice(strSlice, str) != -1 +} diff --git a/dataframe/dataframe_test.go b/dataframe/dataframe_test.go index 6cb0c2b..8c626a7 100644 --- a/dataframe/dataframe_test.go +++ b/dataframe/dataframe_test.go @@ -3000,3 +3000,82 @@ func TestGroups_GetGroups(t *testing.T) { t.Fatalf("Expected to get 3 groups, got %d", len(groupNames)) } } + +func TestDataFrame_Pivot(t *testing.T) { + // case 1, set rows and values + df := New( + series.New([]string{"a", "b", "b"}, series.String, "A"), + series.New([]int{1, 2, 3}, series.Int, "B"), + ) + newDF := df.Pivot([]string{"A"}, nil, []PivotValue{{Colname: "B", AggregationType: Aggregation_SUM}}) + expectedRecords := [][]string{ + {"A", "B_SUM"}, + {"a", "1"}, + {"b", "5"}, + } + if !reflect.DeepEqual(newDF.Records(), expectedRecords) { + t.Fatalf("unexpced result, result=%v expected=%v", newDF.Records(), expectedRecords) + } + + // case 2, set columns and values + df = New( + series.New([]string{"a", "b", "b"}, series.String, "A"), + series.New([]int{1, 2, 3}, series.Int, "B"), + ) + newDF = df.Pivot(nil, []string{"A"}, []PivotValue{{Colname: "B", AggregationType: Aggregation_SUM}}) + expectedRecords = [][]string{ + {"a_B_SUM", "b_B_SUM"}, + {"1", "5"}, + } + if !reflect.DeepEqual(newDF.Records(), expectedRecords) { + t.Fatalf("unexpced result, result=%v expected=%v", newDF.Records(), expectedRecords) + } + + // case 3, only set values + df = New( + series.New([]string{"a", "b"}, series.String, "A"), + series.New([]int{1, 2}, series.Int, "B"), + ) + newDF = df.Pivot(nil, nil, []PivotValue{{Colname: "B", AggregationType: Aggregation_SUM}}) + expectedRecords = [][]string{ + {"B_SUM"}, + {"3"}, + } + if !reflect.DeepEqual(newDF.Records(), expectedRecords) { + t.Fatalf("unexpced result, result=%v expected=%v", newDF.Records(), expectedRecords) + } + + // case4, set all parameters + newDF = New( + series.New([]string{"A1", "A1", "A1", "A1", "A1", "A1", "A1"}, series.String, "A"), + series.New([]string{"B1", "B1", "B1", "B2", "B2", "B2", "B2"}, series.String, "B"), + series.New([]string{"C1", "C1", "C2", "C1", "C1", "C2", "C2"}, series.String, "C"), + series.New([]string{"D1", "D2", "D1", "D1", "D2", "D1", "D1"}, series.String, "D"), + series.New([]int{1, 2, 3, 4, 5, 6, 7}, series.Int, "E"), + series.New([]int{8, 9, 10, 11, 12, 13, 14}, series.Int, "F"), + ).Pivot( + []string{"A", "B"}, + []string{"C", "D"}, + []PivotValue{ + {Colname: "E", AggregationType: Aggregation_SUM}, + {Colname: "F", AggregationType: Aggregation_COUNT}, + }) + expectedRecords = [][]string{ + {"A", "B", "C1_D1_E_SUM", "C1_D1_F_COUNT", "C1_D2_E_SUM", "C1_D2_F_COUNT", "C2_D1_E_SUM", "C2_D1_F_COUNT"}, + {"A1", "B1", "1", "1", "2", "1", "3", "1"}, + {"A1", "B2", "4", "1", "5", "1", "13", "2"}, + } + if !reflect.DeepEqual(newDF.Records(), expectedRecords) { + t.Fatalf("unexpced result, result=%v expected=%v", newDF.Records(), expectedRecords) + } + + // case 5, invalid params + df = New( + series.New([]string{"a", "b"}, series.String, "A"), + series.New([]int{1, 2}, series.Int, "B"), + ) + newDF = df.Pivot(nil, nil, nil) + if newDF.Err == nil { + t.Fatalf("expect param error") + } +}