You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

148 lines
5.1 KiB

  1. // Copyright 2017, The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE.md file.
  4. package cmpopts
  5. import (
  6. "fmt"
  7. "reflect"
  8. "sort"
  9. "github.com/google/go-cmp/cmp"
  10. "github.com/google/go-cmp/cmp/internal/function"
  11. )
  12. // SortSlices returns a Transformer option that sorts all []V.
  13. // The less function must be of the form "func(T, T) bool" which is used to
  14. // sort any slice with element type V that is assignable to T.
  15. //
  16. // The less function must be:
  17. // • Deterministic: less(x, y) == less(x, y)
  18. // • Irreflexive: !less(x, x)
  19. // • Transitive: if !less(x, y) and !less(y, z), then !less(x, z)
  20. //
  21. // The less function does not have to be "total". That is, if !less(x, y) and
  22. // !less(y, x) for two elements x and y, their relative order is maintained.
  23. //
  24. // SortSlices can be used in conjunction with EquateEmpty.
  25. func SortSlices(lessFunc interface{}) cmp.Option {
  26. vf := reflect.ValueOf(lessFunc)
  27. if !function.IsType(vf.Type(), function.Less) || vf.IsNil() {
  28. panic(fmt.Sprintf("invalid less function: %T", lessFunc))
  29. }
  30. ss := sliceSorter{vf.Type().In(0), vf}
  31. return cmp.FilterValues(ss.filter, cmp.Transformer("cmpopts.SortSlices", ss.sort))
  32. }
  33. type sliceSorter struct {
  34. in reflect.Type // T
  35. fnc reflect.Value // func(T, T) bool
  36. }
  37. func (ss sliceSorter) filter(x, y interface{}) bool {
  38. vx, vy := reflect.ValueOf(x), reflect.ValueOf(y)
  39. if !(x != nil && y != nil && vx.Type() == vy.Type()) ||
  40. !(vx.Kind() == reflect.Slice && vx.Type().Elem().AssignableTo(ss.in)) ||
  41. (vx.Len() <= 1 && vy.Len() <= 1) {
  42. return false
  43. }
  44. // Check whether the slices are already sorted to avoid an infinite
  45. // recursion cycle applying the same transform to itself.
  46. ok1 := sort.SliceIsSorted(x, func(i, j int) bool { return ss.less(vx, i, j) })
  47. ok2 := sort.SliceIsSorted(y, func(i, j int) bool { return ss.less(vy, i, j) })
  48. return !ok1 || !ok2
  49. }
  50. func (ss sliceSorter) sort(x interface{}) interface{} {
  51. src := reflect.ValueOf(x)
  52. dst := reflect.MakeSlice(src.Type(), src.Len(), src.Len())
  53. for i := 0; i < src.Len(); i++ {
  54. dst.Index(i).Set(src.Index(i))
  55. }
  56. sort.SliceStable(dst.Interface(), func(i, j int) bool { return ss.less(dst, i, j) })
  57. ss.checkSort(dst)
  58. return dst.Interface()
  59. }
  60. func (ss sliceSorter) checkSort(v reflect.Value) {
  61. start := -1 // Start of a sequence of equal elements.
  62. for i := 1; i < v.Len(); i++ {
  63. if ss.less(v, i-1, i) {
  64. // Check that first and last elements in v[start:i] are equal.
  65. if start >= 0 && (ss.less(v, start, i-1) || ss.less(v, i-1, start)) {
  66. panic(fmt.Sprintf("incomparable values detected: want equal elements: %v", v.Slice(start, i)))
  67. }
  68. start = -1
  69. } else if start == -1 {
  70. start = i
  71. }
  72. }
  73. }
  74. func (ss sliceSorter) less(v reflect.Value, i, j int) bool {
  75. vx, vy := v.Index(i), v.Index(j)
  76. return ss.fnc.Call([]reflect.Value{vx, vy})[0].Bool()
  77. }
  78. // SortMaps returns a Transformer option that flattens map[K]V types to be a
  79. // sorted []struct{K, V}. The less function must be of the form
  80. // "func(T, T) bool" which is used to sort any map with key K that is
  81. // assignable to T.
  82. //
  83. // Flattening the map into a slice has the property that cmp.Equal is able to
  84. // use Comparers on K or the K.Equal method if it exists.
  85. //
  86. // The less function must be:
  87. // • Deterministic: less(x, y) == less(x, y)
  88. // • Irreflexive: !less(x, x)
  89. // • Transitive: if !less(x, y) and !less(y, z), then !less(x, z)
  90. // • Total: if x != y, then either less(x, y) or less(y, x)
  91. //
  92. // SortMaps can be used in conjunction with EquateEmpty.
  93. func SortMaps(lessFunc interface{}) cmp.Option {
  94. vf := reflect.ValueOf(lessFunc)
  95. if !function.IsType(vf.Type(), function.Less) || vf.IsNil() {
  96. panic(fmt.Sprintf("invalid less function: %T", lessFunc))
  97. }
  98. ms := mapSorter{vf.Type().In(0), vf}
  99. return cmp.FilterValues(ms.filter, cmp.Transformer("cmpopts.SortMaps", ms.sort))
  100. }
  101. type mapSorter struct {
  102. in reflect.Type // T
  103. fnc reflect.Value // func(T, T) bool
  104. }
  105. func (ms mapSorter) filter(x, y interface{}) bool {
  106. vx, vy := reflect.ValueOf(x), reflect.ValueOf(y)
  107. return (x != nil && y != nil && vx.Type() == vy.Type()) &&
  108. (vx.Kind() == reflect.Map && vx.Type().Key().AssignableTo(ms.in)) &&
  109. (vx.Len() != 0 || vy.Len() != 0)
  110. }
  111. func (ms mapSorter) sort(x interface{}) interface{} {
  112. src := reflect.ValueOf(x)
  113. outType := reflect.StructOf([]reflect.StructField{
  114. {Name: "K", Type: src.Type().Key()},
  115. {Name: "V", Type: src.Type().Elem()},
  116. })
  117. dst := reflect.MakeSlice(reflect.SliceOf(outType), src.Len(), src.Len())
  118. for i, k := range src.MapKeys() {
  119. v := reflect.New(outType).Elem()
  120. v.Field(0).Set(k)
  121. v.Field(1).Set(src.MapIndex(k))
  122. dst.Index(i).Set(v)
  123. }
  124. sort.Slice(dst.Interface(), func(i, j int) bool { return ms.less(dst, i, j) })
  125. ms.checkSort(dst)
  126. return dst.Interface()
  127. }
  128. func (ms mapSorter) checkSort(v reflect.Value) {
  129. for i := 1; i < v.Len(); i++ {
  130. if !ms.less(v, i-1, i) {
  131. panic(fmt.Sprintf("partial order detected: want %v < %v", v.Index(i-1), v.Index(i)))
  132. }
  133. }
  134. }
  135. func (ms mapSorter) less(v reflect.Value, i, j int) bool {
  136. vx, vy := v.Index(i).Field(0), v.Index(j).Field(0)
  137. return ms.fnc.Call([]reflect.Value{vx, vy})[0].Bool()
  138. }