You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

138 lines
4.1 KiB

  1. // Copyright 2014 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package main
  5. import (
  6. "fmt"
  7. "log"
  8. "net/http"
  9. "os"
  10. "path/filepath"
  11. "strings"
  12. "google.golang.org/api/googleapi"
  13. prediction "google.golang.org/api/prediction/v1.6"
  14. )
  15. func init() {
  16. scopes := []string{
  17. prediction.DevstorageFullControlScope,
  18. prediction.DevstorageReadOnlyScope,
  19. prediction.DevstorageReadWriteScope,
  20. prediction.PredictionScope,
  21. }
  22. registerDemo("prediction", strings.Join(scopes, " "), predictionMain)
  23. }
  24. type predictionType struct {
  25. api *prediction.Service
  26. projectNumber string
  27. bucketName string
  28. trainingFileName string
  29. modelName string
  30. }
  31. // This example demonstrates calling the Prediction API.
  32. // Training data is uploaded to a pre-created Google Cloud Storage Bucket and
  33. // then the Prediction API is called to train a model based on that data.
  34. // After a few minutes, the model should be completely trained and ready
  35. // for prediction. At that point, text is sent to the model and the Prediction
  36. // API attempts to classify the data, and the results are printed out.
  37. //
  38. // To get started, follow the instructions found in the "Hello Prediction!"
  39. // Getting Started Guide located here:
  40. // https://developers.google.com/prediction/docs/hello_world
  41. //
  42. // Example usage:
  43. // go-api-demo -clientid="my-clientid" -secret="my-secret" prediction
  44. // my-project-number my-bucket-name my-training-filename my-model-name
  45. //
  46. // Example output:
  47. // Predict result: language=Spanish
  48. // English Score: 0.000000
  49. // French Score: 0.000000
  50. // Spanish Score: 1.000000
  51. // analyze: output feature text=&{157 English}
  52. // analyze: output feature text=&{149 French}
  53. // analyze: output feature text=&{100 Spanish}
  54. // feature text count=406
  55. func predictionMain(client *http.Client, argv []string) {
  56. if len(argv) != 4 {
  57. fmt.Fprintln(os.Stderr,
  58. "Usage: prediction project_number bucket training_data model_name")
  59. return
  60. }
  61. api, err := prediction.New(client)
  62. if err != nil {
  63. log.Fatalf("unable to create prediction API client: %v", err)
  64. }
  65. t := &predictionType{
  66. api: api,
  67. projectNumber: argv[0],
  68. bucketName: argv[1],
  69. trainingFileName: argv[2],
  70. modelName: argv[3],
  71. }
  72. t.trainModel()
  73. t.predictModel()
  74. }
  75. func (t *predictionType) trainModel() {
  76. // First, check to see if our trained model already exists.
  77. res, err := t.api.Trainedmodels.Get(t.projectNumber, t.modelName).Do()
  78. if err != nil {
  79. if ae, ok := err.(*googleapi.Error); ok && ae.Code != http.StatusNotFound {
  80. log.Fatalf("error getting trained model: %v", err)
  81. }
  82. log.Printf("Training model not found, creating new model.")
  83. res, err = t.api.Trainedmodels.Insert(t.projectNumber, &prediction.Insert{
  84. Id: t.modelName,
  85. StorageDataLocation: filepath.Join(t.bucketName, t.trainingFileName),
  86. }).Do()
  87. if err != nil {
  88. log.Fatalf("unable to create trained model: %v", err)
  89. }
  90. }
  91. if res.TrainingStatus != "DONE" {
  92. // Wait for the trained model to finish training.
  93. fmt.Printf("Training model. Please wait and re-run program after a few minutes.")
  94. os.Exit(0)
  95. }
  96. }
  97. func (t *predictionType) predictModel() {
  98. // Model has now been trained. Predict with it.
  99. input := &prediction.Input{
  100. Input: &prediction.InputInput{
  101. CsvInstance: []interface{}{
  102. "Hola, con quien hablo",
  103. },
  104. },
  105. }
  106. res, err := t.api.Trainedmodels.Predict(t.projectNumber, t.modelName, input).Do()
  107. if err != nil {
  108. log.Fatalf("unable to get trained prediction: %v", err)
  109. }
  110. fmt.Printf("Predict result: language=%v\n", res.OutputLabel)
  111. for _, m := range res.OutputMulti {
  112. fmt.Printf("%v Score: %v\n", m.Label, m.Score)
  113. }
  114. // Now analyze the model.
  115. an, err := t.api.Trainedmodels.Analyze(t.projectNumber, t.modelName).Do()
  116. if err != nil {
  117. log.Fatalf("unable to analyze trained model: %v", err)
  118. }
  119. for _, f := range an.DataDescription.OutputFeature.Text {
  120. fmt.Printf("analyze: output feature text=%v\n", f)
  121. }
  122. for _, f := range an.DataDescription.Features {
  123. fmt.Printf("feature text count=%v\n", f.Text.Count)
  124. }
  125. }