You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

627 line
17 KiB

  1. package dynamodb
  2. import (
  3. "encoding/base64"
  4. "encoding/json"
  5. "fmt"
  6. "math"
  7. "reflect"
  8. "sort"
  9. "strconv"
  10. "strings"
  11. "sync"
  12. "unicode"
  13. )
  14. func MarshalAttributes(m interface{}) ([]Attribute, error) {
  15. v := reflect.ValueOf(m).Elem()
  16. builder := &attributeBuilder{}
  17. builder.buffer = []Attribute{}
  18. for _, f := range cachedTypeFields(v.Type()) { // loop on each field
  19. fv := fieldByIndex(v, f.index)
  20. if !fv.IsValid() || isEmptyValueToOmit(fv) {
  21. continue
  22. }
  23. err := builder.reflectToDynamoDBAttribute(f.name, fv)
  24. if err != nil {
  25. return builder.buffer, err
  26. }
  27. }
  28. return builder.buffer, nil
  29. }
  30. func UnmarshalAttributes(attributesRef *map[string]*Attribute, m interface{}) error {
  31. rv := reflect.ValueOf(m)
  32. if rv.Kind() != reflect.Ptr || rv.IsNil() {
  33. return fmt.Errorf("InvalidUnmarshalError reflect.ValueOf(v): %#v, m interface{}: %#v", rv, reflect.TypeOf(m))
  34. }
  35. v := reflect.ValueOf(m).Elem()
  36. attributes := *attributesRef
  37. for _, f := range cachedTypeFields(v.Type()) { // loop on each field
  38. fv := fieldByIndex(v, f.index)
  39. correlatedAttribute := attributes[f.name]
  40. if correlatedAttribute == nil {
  41. continue
  42. }
  43. err := unmarshallAttribute(correlatedAttribute, fv)
  44. if err != nil {
  45. return err
  46. }
  47. }
  48. return nil
  49. }
  50. type attributeBuilder struct {
  51. buffer []Attribute
  52. }
  53. func (builder *attributeBuilder) Push(attribute *Attribute) {
  54. builder.buffer = append(builder.buffer, *attribute)
  55. }
  56. func unmarshallAttribute(a *Attribute, v reflect.Value) error {
  57. switch v.Kind() {
  58. case reflect.Bool:
  59. n, err := strconv.ParseInt(a.Value, 10, 64)
  60. if err != nil {
  61. return fmt.Errorf("UnmarshalTypeError (bool) %#v: %#v", a.Value, err)
  62. }
  63. v.SetBool(n != 0)
  64. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  65. n, err := strconv.ParseInt(a.Value, 10, 64)
  66. if err != nil || v.OverflowInt(n) {
  67. return fmt.Errorf("UnmarshalTypeError (number) %#v: %#v", a.Value, err)
  68. }
  69. v.SetInt(n)
  70. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
  71. n, err := strconv.ParseUint(a.Value, 10, 64)
  72. if err != nil || v.OverflowUint(n) {
  73. return fmt.Errorf("UnmarshalTypeError (number) %#v: %#v", a.Value, err)
  74. }
  75. v.SetUint(n)
  76. case reflect.Float32, reflect.Float64:
  77. n, err := strconv.ParseFloat(a.Value, v.Type().Bits())
  78. if err != nil || v.OverflowFloat(n) {
  79. return fmt.Errorf("UnmarshalTypeError (number) %#v: %#v", a.Value, err)
  80. }
  81. v.SetFloat(n)
  82. case reflect.String:
  83. v.SetString(a.Value)
  84. case reflect.Slice:
  85. if v.Type().Elem().Kind() == reflect.Uint8 { // byte arrays are a special case
  86. b := make([]byte, base64.StdEncoding.DecodedLen(len(a.Value)))
  87. n, err := base64.StdEncoding.Decode(b, []byte(a.Value))
  88. if err != nil {
  89. return fmt.Errorf("UnmarshalTypeError (byte) %#v: %#v", a.Value, err)
  90. }
  91. v.Set(reflect.ValueOf(b[0:n]))
  92. break
  93. }
  94. if a.SetType() { // Special NS and SS types should be correctly handled
  95. nativeSetCreated := false
  96. switch v.Type().Elem().Kind() {
  97. case reflect.Bool:
  98. nativeSetCreated = true
  99. arry := reflect.MakeSlice(v.Type(), len(a.SetValues), len(a.SetValues))
  100. for i, aval := range a.SetValues {
  101. n, err := strconv.ParseInt(aval, 10, 64)
  102. if err != nil {
  103. return fmt.Errorf("UnmarshalSetTypeError (bool) %#v: %#v", aval, err)
  104. }
  105. arry.Index(i).SetBool(n != 0)
  106. }
  107. v.Set(arry)
  108. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  109. nativeSetCreated = true
  110. arry := reflect.MakeSlice(v.Type(), len(a.SetValues), len(a.SetValues))
  111. for i, aval := range a.SetValues {
  112. n, err := strconv.ParseInt(aval, 10, 64)
  113. if err != nil || arry.Index(i).OverflowInt(n) {
  114. return fmt.Errorf("UnmarshalSetTypeError (number) %#v: %#v", aval, err)
  115. }
  116. arry.Index(i).SetInt(n)
  117. }
  118. v.Set(arry)
  119. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
  120. nativeSetCreated = true
  121. arry := reflect.MakeSlice(v.Type(), len(a.SetValues), len(a.SetValues))
  122. for i, aval := range a.SetValues {
  123. n, err := strconv.ParseUint(aval, 10, 64)
  124. if err != nil || arry.Index(i).OverflowUint(n) {
  125. return fmt.Errorf("UnmarshalSetTypeError (number) %#v: %#v", aval, err)
  126. }
  127. arry.Index(i).SetUint(n)
  128. }
  129. v.Set(arry)
  130. case reflect.Float32, reflect.Float64:
  131. nativeSetCreated = true
  132. arry := reflect.MakeSlice(v.Type(), len(a.SetValues), len(a.SetValues))
  133. for i, aval := range a.SetValues {
  134. n, err := strconv.ParseFloat(aval, arry.Index(i).Type().Bits())
  135. if err != nil || arry.Index(i).OverflowFloat(n) {
  136. return fmt.Errorf("UnmarshalSetTypeError (number) %#v: %#v", aval, err)
  137. }
  138. arry.Index(i).SetFloat(n)
  139. }
  140. v.Set(arry)
  141. case reflect.String:
  142. nativeSetCreated = true
  143. arry := reflect.MakeSlice(v.Type(), len(a.SetValues), len(a.SetValues))
  144. for i, aval := range a.SetValues {
  145. arry.Index(i).SetString(aval)
  146. }
  147. v.Set(arry)
  148. }
  149. if nativeSetCreated {
  150. break
  151. }
  152. }
  153. // Slices can be marshalled as nil, but otherwise are handled
  154. // as arrays.
  155. fallthrough
  156. case reflect.Array, reflect.Struct, reflect.Map, reflect.Interface, reflect.Ptr:
  157. unmarshalled := reflect.New(v.Type())
  158. err := json.Unmarshal([]byte(a.Value), unmarshalled.Interface())
  159. if err != nil {
  160. return err
  161. }
  162. v.Set(unmarshalled.Elem())
  163. default:
  164. return fmt.Errorf("UnsupportedTypeError %#v", v.Type())
  165. }
  166. return nil
  167. }
  168. // reflectValueQuoted writes the value in v to the output.
  169. // If quoted is true, the serialization is wrapped in a JSON string.
  170. func (e *attributeBuilder) reflectToDynamoDBAttribute(name string, v reflect.Value) error {
  171. if !v.IsValid() {
  172. return nil
  173. } // don't build
  174. switch v.Kind() {
  175. case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64:
  176. rv, err := numericReflectedValueString(v)
  177. if err != nil {
  178. return err
  179. }
  180. e.Push(NewNumericAttribute(name, rv))
  181. case reflect.String:
  182. e.Push(NewStringAttribute(name, v.String()))
  183. case reflect.Slice:
  184. if v.IsNil() {
  185. break
  186. }
  187. if v.Type().Elem().Kind() == reflect.Uint8 {
  188. // Byte slices are treated as errors
  189. s := v.Bytes()
  190. dst := make([]byte, base64.StdEncoding.EncodedLen(len(s)))
  191. base64.StdEncoding.Encode(dst, s)
  192. e.Push(NewStringAttribute(name, string(dst)))
  193. break
  194. }
  195. // Special NS and SS types should be correctly handled
  196. nativeSetCreated := false
  197. switch v.Type().Elem().Kind() {
  198. case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64:
  199. nativeSetCreated = true
  200. arrystrings := make([]string, v.Len())
  201. for i, _ := range arrystrings {
  202. var err error
  203. arrystrings[i], err = numericReflectedValueString(v.Index(i))
  204. if err != nil {
  205. return err
  206. }
  207. }
  208. e.Push(NewNumericSetAttribute(name, arrystrings))
  209. case reflect.String: // simple copy will suffice
  210. nativeSetCreated = true
  211. arrystrings := make([]string, v.Len())
  212. for i, _ := range arrystrings {
  213. arrystrings[i] = v.Index(i).String()
  214. }
  215. e.Push(NewStringSetAttribute(name, arrystrings))
  216. }
  217. if nativeSetCreated {
  218. break
  219. }
  220. // Slices can be marshalled as nil, but otherwise are handled
  221. // as arrays.
  222. fallthrough
  223. case reflect.Array, reflect.Struct, reflect.Map, reflect.Interface, reflect.Ptr:
  224. jsonVersion, err := json.Marshal(v.Interface())
  225. if err != nil {
  226. return err
  227. }
  228. escapedJson := `"` + string(jsonVersion) + `"` // strconv.Quote not required because the entire string is escaped from json Marshall
  229. e.Push(NewStringAttribute(name, escapedJson[1:len(escapedJson)-1]))
  230. default:
  231. return fmt.Errorf("UnsupportedTypeError %#v", v.Type())
  232. }
  233. return nil
  234. }
  235. func numericReflectedValueString(v reflect.Value) (string, error) {
  236. switch v.Kind() {
  237. case reflect.Bool:
  238. x := v.Bool()
  239. if x {
  240. return "1", nil
  241. } else {
  242. return "0", nil
  243. }
  244. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  245. return strconv.FormatInt(v.Int(), 10), nil
  246. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
  247. return strconv.FormatUint(v.Uint(), 10), nil
  248. case reflect.Float32, reflect.Float64:
  249. f := v.Float()
  250. if math.IsInf(f, 0) || math.IsNaN(f) {
  251. return "", fmt.Errorf("UnsupportedValueError %#v (formatted float: %s)", v, strconv.FormatFloat(f, 'g', -1, v.Type().Bits()))
  252. }
  253. return strconv.FormatFloat(f, 'g', -1, v.Type().Bits()), nil
  254. }
  255. return "", fmt.Errorf("UnsupportedNumericValueError %#v", v.Type())
  256. }
  257. // In DynamoDB we should omit empty value in some type
  258. // See http://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_PutItem.html
  259. func isEmptyValueToOmit(v reflect.Value) bool {
  260. switch v.Kind() {
  261. case reflect.Array, reflect.Map, reflect.Slice, reflect.String, reflect.Interface, reflect.Ptr:
  262. // should omit if empty value
  263. return isEmptyValue(v)
  264. }
  265. // otherwise should not omit
  266. return false
  267. }
  268. // ---------------- Below are copied handy functions from http://golang.org/src/pkg/encoding/json/encode.go --------------------------------
  269. func isEmptyValue(v reflect.Value) bool {
  270. switch v.Kind() {
  271. case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
  272. return v.Len() == 0
  273. case reflect.Bool:
  274. return !v.Bool()
  275. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  276. return v.Int() == 0
  277. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
  278. return v.Uint() == 0
  279. case reflect.Float32, reflect.Float64:
  280. return v.Float() == 0
  281. case reflect.Interface, reflect.Ptr:
  282. return v.IsNil()
  283. }
  284. return false
  285. }
  286. func fieldByIndex(v reflect.Value, index []int) reflect.Value {
  287. for _, i := range index {
  288. if v.Kind() == reflect.Ptr {
  289. if v.IsNil() {
  290. return reflect.Value{}
  291. }
  292. v = v.Elem()
  293. }
  294. v = v.Field(i)
  295. }
  296. return v
  297. }
  298. // A field represents a single field found in a struct.
  299. type field struct {
  300. name string
  301. tag bool
  302. index []int
  303. typ reflect.Type
  304. omitEmpty bool
  305. quoted bool
  306. }
  307. // byName sorts field by name, breaking ties with depth,
  308. // then breaking ties with "name came from json tag", then
  309. // breaking ties with index sequence.
  310. type byName []field
  311. func (x byName) Len() int { return len(x) }
  312. func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
  313. func (x byName) Less(i, j int) bool {
  314. if x[i].name != x[j].name {
  315. return x[i].name < x[j].name
  316. }
  317. if len(x[i].index) != len(x[j].index) {
  318. return len(x[i].index) < len(x[j].index)
  319. }
  320. if x[i].tag != x[j].tag {
  321. return x[i].tag
  322. }
  323. return byIndex(x).Less(i, j)
  324. }
  325. // byIndex sorts field by index sequence.
  326. type byIndex []field
  327. func (x byIndex) Len() int { return len(x) }
  328. func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
  329. func (x byIndex) Less(i, j int) bool {
  330. for k, xik := range x[i].index {
  331. if k >= len(x[j].index) {
  332. return false
  333. }
  334. if xik != x[j].index[k] {
  335. return xik < x[j].index[k]
  336. }
  337. }
  338. return len(x[i].index) < len(x[j].index)
  339. }
  340. func isValidTag(s string) bool {
  341. if s == "" {
  342. return false
  343. }
  344. for _, c := range s {
  345. switch {
  346. case strings.ContainsRune("!#$%&()*+-./:<=>?@[]^_{|}~ ", c):
  347. // Backslash and quote chars are reserved, but
  348. // otherwise any punctuation chars are allowed
  349. // in a tag name.
  350. default:
  351. if !unicode.IsLetter(c) && !unicode.IsDigit(c) {
  352. return false
  353. }
  354. }
  355. }
  356. return true
  357. }
  358. // tagOptions is the string following a comma in a struct field's "json"
  359. // tag, or the empty string. It does not include the leading comma.
  360. type tagOptions string
  361. // Contains returns whether checks that a comma-separated list of options
  362. // contains a particular substr flag. substr must be surrounded by a
  363. // string boundary or commas.
  364. func (o tagOptions) Contains(optionName string) bool {
  365. if len(o) == 0 {
  366. return false
  367. }
  368. s := string(o)
  369. for s != "" {
  370. var next string
  371. i := strings.Index(s, ",")
  372. if i >= 0 {
  373. s, next = s[:i], s[i+1:]
  374. }
  375. if s == optionName {
  376. return true
  377. }
  378. s = next
  379. }
  380. return false
  381. }
  382. // parseTag splits a struct field's json tag into its name and
  383. // comma-separated options.
  384. func parseTag(tag string) (string, tagOptions) {
  385. if idx := strings.Index(tag, ","); idx != -1 {
  386. return tag[:idx], tagOptions(tag[idx+1:])
  387. }
  388. return tag, tagOptions("")
  389. }
  390. // typeFields returns a list of fields that JSON should recognize for the given type.
  391. // The algorithm is breadth-first search over the set of structs to include - the top struct
  392. // and then any reachable anonymous structs.
  393. func typeFields(t reflect.Type) []field {
  394. // Anonymous fields to explore at the current level and the next.
  395. current := []field{}
  396. next := []field{{typ: t}}
  397. // Count of queued names for current level and the next.
  398. count := map[reflect.Type]int{}
  399. nextCount := map[reflect.Type]int{}
  400. // Types already visited at an earlier level.
  401. visited := map[reflect.Type]bool{}
  402. // Fields found.
  403. var fields []field
  404. for len(next) > 0 {
  405. current, next = next, current[:0]
  406. count, nextCount = nextCount, map[reflect.Type]int{}
  407. for _, f := range current {
  408. if visited[f.typ] {
  409. continue
  410. }
  411. visited[f.typ] = true
  412. // Scan f.typ for fields to include.
  413. for i := 0; i < f.typ.NumField(); i++ {
  414. sf := f.typ.Field(i)
  415. if sf.PkgPath != "" { // unexported
  416. continue
  417. }
  418. tag := sf.Tag.Get("json")
  419. if tag == "-" {
  420. continue
  421. }
  422. name, opts := parseTag(tag)
  423. if !isValidTag(name) {
  424. name = ""
  425. }
  426. index := make([]int, len(f.index)+1)
  427. copy(index, f.index)
  428. index[len(f.index)] = i
  429. ft := sf.Type
  430. if ft.Name() == "" && ft.Kind() == reflect.Ptr {
  431. // Follow pointer.
  432. ft = ft.Elem()
  433. }
  434. // Record found field and index sequence.
  435. if name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct {
  436. tagged := name != ""
  437. if name == "" {
  438. name = sf.Name
  439. }
  440. fields = append(fields, field{name, tagged, index, ft,
  441. opts.Contains("omitempty"), opts.Contains("string")})
  442. if count[f.typ] > 1 {
  443. // If there were multiple instances, add a second,
  444. // so that the annihilation code will see a duplicate.
  445. // It only cares about the distinction between 1 or 2,
  446. // so don't bother generating any more copies.
  447. fields = append(fields, fields[len(fields)-1])
  448. }
  449. continue
  450. }
  451. // Record new anonymous struct to explore in next round.
  452. nextCount[ft]++
  453. if nextCount[ft] == 1 {
  454. next = append(next, field{name: ft.Name(), index: index, typ: ft})
  455. }
  456. }
  457. }
  458. }
  459. sort.Sort(byName(fields))
  460. // Delete all fields that are hidden by the Go rules for embedded fields,
  461. // except that fields with JSON tags are promoted.
  462. // The fields are sorted in primary order of name, secondary order
  463. // of field index length. Loop over names; for each name, delete
  464. // hidden fields by choosing the one dominant field that survives.
  465. out := fields[:0]
  466. for advance, i := 0, 0; i < len(fields); i += advance {
  467. // One iteration per name.
  468. // Find the sequence of fields with the name of this first field.
  469. fi := fields[i]
  470. name := fi.name
  471. for advance = 1; i+advance < len(fields); advance++ {
  472. fj := fields[i+advance]
  473. if fj.name != name {
  474. break
  475. }
  476. }
  477. if advance == 1 { // Only one field with this name
  478. out = append(out, fi)
  479. continue
  480. }
  481. dominant, ok := dominantField(fields[i : i+advance])
  482. if ok {
  483. out = append(out, dominant)
  484. }
  485. }
  486. fields = out
  487. sort.Sort(byIndex(fields))
  488. return fields
  489. }
  490. // dominantField looks through the fields, all of which are known to
  491. // have the same name, to find the single field that dominates the
  492. // others using Go's embedding rules, modified by the presence of
  493. // JSON tags. If there are multiple top-level fields, the boolean
  494. // will be false: This condition is an error in Go and we skip all
  495. // the fields.
  496. func dominantField(fields []field) (field, bool) {
  497. // The fields are sorted in increasing index-length order. The winner
  498. // must therefore be one with the shortest index length. Drop all
  499. // longer entries, which is easy: just truncate the slice.
  500. length := len(fields[0].index)
  501. tagged := -1 // Index of first tagged field.
  502. for i, f := range fields {
  503. if len(f.index) > length {
  504. fields = fields[:i]
  505. break
  506. }
  507. if f.tag {
  508. if tagged >= 0 {
  509. // Multiple tagged fields at the same level: conflict.
  510. // Return no field.
  511. return field{}, false
  512. }
  513. tagged = i
  514. }
  515. }
  516. if tagged >= 0 {
  517. return fields[tagged], true
  518. }
  519. // All remaining fields have the same length. If there's more than one,
  520. // we have a conflict (two fields named "X" at the same level) and we
  521. // return no field.
  522. if len(fields) > 1 {
  523. return field{}, false
  524. }
  525. return fields[0], true
  526. }
  527. var fieldCache struct {
  528. sync.RWMutex
  529. m map[reflect.Type][]field
  530. }
  531. // cachedTypeFields is like typeFields but uses a cache to avoid repeated work.
  532. func cachedTypeFields(t reflect.Type) []field {
  533. fieldCache.RLock()
  534. f := fieldCache.m[t]
  535. fieldCache.RUnlock()
  536. if f != nil {
  537. return f
  538. }
  539. // Compute fields without lock.
  540. // Might duplicate effort but won't hold other computations back.
  541. f = typeFields(t)
  542. if f == nil {
  543. f = []field{}
  544. }
  545. fieldCache.Lock()
  546. if fieldCache.m == nil {
  547. fieldCache.m = map[reflect.Type][]field{}
  548. }
  549. fieldCache.m[t] = f
  550. fieldCache.Unlock()
  551. return f
  552. }