You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

322 lines
9.4 KiB

  1. package handlers
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "hash/crc32"
  6. "net/http"
  7. "strings"
  8. "time"
  9. "github.com/PuerkitoBio/ghost"
  10. "github.com/gorilla/securecookie"
  11. "github.com/nu7hatch/gouuid"
  12. )
  13. const defaultCookieName = "ghost.sid"
  14. var (
  15. ErrSessionSecretMissing = errors.New("session secret is missing")
  16. ErrNoSessionID = errors.New("session ID could not be generated")
  17. )
  18. // The Session holds the data map that persists for the duration of the session.
  19. // The information stored in this map should be marshalable for the target Session store
  20. // format (i.e. json, sql, gob, etc. depending on how the store persists the data).
  21. type Session struct {
  22. isNew bool // keep private, not saved to JSON, will be false once read from the store
  23. internalSession
  24. }
  25. // Use a separate private struct to hold the private fields of the Session,
  26. // although those fields are exposed (public). This is a trick to simplify
  27. // JSON encoding.
  28. type internalSession struct {
  29. Data map[string]interface{} // JSON cannot marshal a map[interface{}]interface{}
  30. ID string
  31. Created time.Time
  32. MaxAge time.Duration
  33. }
  34. // Create a new Session instance. It panics in the unlikely event that a new random ID cannot be generated.
  35. func newSession(maxAge int) *Session {
  36. uid, err := uuid.NewV4()
  37. if err != nil {
  38. panic(ErrNoSessionID)
  39. }
  40. return &Session{
  41. true, // is new
  42. internalSession{
  43. make(map[string]interface{}),
  44. uid.String(),
  45. time.Now(),
  46. time.Duration(maxAge) * time.Second,
  47. },
  48. }
  49. }
  50. // Gets the ID of the session.
  51. func (ø *Session) ID() string {
  52. return ø.internalSession.ID
  53. }
  54. // Get the max age duration
  55. func (ø *Session) MaxAge() time.Duration {
  56. return ø.internalSession.MaxAge
  57. }
  58. // Get the creation time of the session.
  59. func (ø *Session) Created() time.Time {
  60. return ø.internalSession.Created
  61. }
  62. // Is this a new Session (created by the current request)
  63. func (ø *Session) IsNew() bool {
  64. return ø.isNew
  65. }
  66. // TODO : Resets the max age property of the session to its original value (sliding expiration).
  67. func (ø *Session) resetMaxAge() {
  68. }
  69. // Marshal the session to JSON.
  70. func (ø *Session) MarshalJSON() ([]byte, error) {
  71. return json.Marshal(ø.internalSession)
  72. }
  73. // Unmarshal the JSON into the internal session struct.
  74. func (ø *Session) UnmarshalJSON(b []byte) error {
  75. return json.Unmarshal(b, &ø.internalSession)
  76. }
  77. // Options object for the session handler. It specified the Session store to use for
  78. // persistence, the template for the session cookie (name, path, maxage, etc.),
  79. // whether or not the proxy should be trusted to determine if the connection is secure,
  80. // and the required secret to sign the session cookie.
  81. type SessionOptions struct {
  82. Store SessionStore
  83. CookieTemplate http.Cookie
  84. TrustProxy bool
  85. Secret string
  86. }
  87. // Create a new SessionOptions struct, using default cookie and proxy values.
  88. func NewSessionOptions(store SessionStore, secret string) *SessionOptions {
  89. return &SessionOptions{
  90. Store: store,
  91. Secret: secret,
  92. }
  93. }
  94. // The augmented ResponseWriter struct for the session handler. It holds the current
  95. // Session object and Session store, as well as flags and function to send the actual
  96. // session cookie at the end of the request.
  97. type sessResponseWriter struct {
  98. http.ResponseWriter
  99. sess *Session
  100. sessStore SessionStore
  101. sessSent bool
  102. sendCookieFn func()
  103. }
  104. // Implement the WrapWriter interface.
  105. func (ø *sessResponseWriter) WrappedWriter() http.ResponseWriter {
  106. return ø.ResponseWriter
  107. }
  108. // Intercept the Write() method to add the Set-Cookie header before it's too late.
  109. func (ø *sessResponseWriter) Write(data []byte) (int, error) {
  110. if !ø.sessSent {
  111. ø.sendCookieFn()
  112. ø.sessSent = true
  113. }
  114. return ø.ResponseWriter.Write(data)
  115. }
  116. // Intercept the WriteHeader() method to add the Set-Cookie header before it's too late.
  117. func (ø *sessResponseWriter) WriteHeader(code int) {
  118. if !ø.sessSent {
  119. ø.sendCookieFn()
  120. ø.sessSent = true
  121. }
  122. ø.ResponseWriter.WriteHeader(code)
  123. }
  124. // SessionHandlerFunc is the same as SessionHandler, it is just a convenience
  125. // signature that accepts a func(http.ResponseWriter, *http.Request) instead of
  126. // a http.Handler interface. It saves the boilerplate http.HandlerFunc() cast.
  127. func SessionHandlerFunc(h http.HandlerFunc, opts *SessionOptions) http.HandlerFunc {
  128. return SessionHandler(h, opts)
  129. }
  130. // Create a Session handler to offer the Session behaviour to the specified handler.
  131. func SessionHandler(h http.Handler, opts *SessionOptions) http.HandlerFunc {
  132. // Make sure the required cookie fields are set
  133. if opts.CookieTemplate.Name == "" {
  134. opts.CookieTemplate.Name = defaultCookieName
  135. }
  136. if opts.CookieTemplate.Path == "" {
  137. opts.CookieTemplate.Path = "/"
  138. }
  139. // Secret is required
  140. if opts.Secret == "" {
  141. panic(ErrSessionSecretMissing)
  142. }
  143. // Return the actual handler
  144. return func(w http.ResponseWriter, r *http.Request) {
  145. if _, ok := getSessionWriter(w); ok {
  146. // Self-awareness
  147. h.ServeHTTP(w, r)
  148. return
  149. }
  150. if strings.Index(r.URL.Path, opts.CookieTemplate.Path) != 0 {
  151. // Session does not apply to this path
  152. h.ServeHTTP(w, r)
  153. return
  154. }
  155. // Create a new Session or retrieve the existing session based on the
  156. // session cookie received.
  157. var sess *Session
  158. var ckSessId string
  159. exCk, err := r.Cookie(opts.CookieTemplate.Name)
  160. if err != nil {
  161. sess = newSession(opts.CookieTemplate.MaxAge)
  162. ghost.LogFn("ghost.session : error getting session cookie : %s", err)
  163. } else {
  164. ckSessId, err = parseSignedCookie(exCk, opts.Secret)
  165. if err != nil {
  166. sess = newSession(opts.CookieTemplate.MaxAge)
  167. ghost.LogFn("ghost.session : error parsing signed cookie : %s", err)
  168. } else if ckSessId == "" {
  169. sess = newSession(opts.CookieTemplate.MaxAge)
  170. ghost.LogFn("ghost.session : no existing session ID")
  171. } else {
  172. // Get the session
  173. sess, err = opts.Store.Get(ckSessId)
  174. if err != nil {
  175. sess = newSession(opts.CookieTemplate.MaxAge)
  176. ghost.LogFn("ghost.session : error getting session from store : %s", err)
  177. } else if sess == nil {
  178. sess = newSession(opts.CookieTemplate.MaxAge)
  179. ghost.LogFn("ghost.session : nil session")
  180. }
  181. }
  182. }
  183. // Save the original hash of the session, used to compare if the contents
  184. // have changed during the handling of the request, so that it has to be
  185. // saved to the stored.
  186. oriHash := hash(sess)
  187. // Create the augmented ResponseWriter.
  188. srw := &sessResponseWriter{w, sess, opts.Store, false, func() {
  189. // This function is called when the header is about to be written, so that
  190. // the session cookie is correctly set.
  191. // Check if the connection is secure
  192. proto := strings.Trim(strings.ToLower(r.Header.Get("X-Forwarded-Proto")), " ")
  193. tls := r.TLS != nil || (strings.HasPrefix(proto, "https") && opts.TrustProxy)
  194. if opts.CookieTemplate.Secure && !tls {
  195. ghost.LogFn("ghost.session : secure cookie on a non-secure connection, cookie not sent")
  196. return
  197. }
  198. if !sess.IsNew() {
  199. // If this is not a new session, no need to send back the cookie
  200. // TODO : Handle expires?
  201. return
  202. }
  203. // Send the session cookie
  204. ck := opts.CookieTemplate
  205. ck.Value = sess.ID()
  206. err := signCookie(&ck, opts.Secret)
  207. if err != nil {
  208. ghost.LogFn("ghost.session : error signing cookie : %s", err)
  209. return
  210. }
  211. http.SetCookie(w, &ck)
  212. }}
  213. // Call wrapped handler
  214. h.ServeHTTP(srw, r)
  215. // TODO : Expiration management? srw.sess.resetMaxAge()
  216. // Do not save if content is the same, unless session is new (to avoid
  217. // creating a new session and sending a cookie on each successive request).
  218. if newHash := hash(sess); !sess.IsNew() && oriHash == newHash && newHash != 0 {
  219. // No changes to the session, no need to save
  220. ghost.LogFn("ghost.session : no changes to save to store")
  221. return
  222. }
  223. err = opts.Store.Set(sess)
  224. if err != nil {
  225. ghost.LogFn("ghost.session : error saving session to store : %s", err)
  226. }
  227. }
  228. }
  229. // Helper function to retrieve the session for the current request.
  230. func GetSession(w http.ResponseWriter) (*Session, bool) {
  231. ss, ok := getSessionWriter(w)
  232. if ok {
  233. return ss.sess, true
  234. }
  235. return nil, false
  236. }
  237. // Helper function to retrieve the session store
  238. func GetSessionStore(w http.ResponseWriter) (SessionStore, bool) {
  239. ss, ok := getSessionWriter(w)
  240. if ok {
  241. return ss.sessStore, true
  242. }
  243. return nil, false
  244. }
  245. // Internal helper function to retrieve the session writer object.
  246. func getSessionWriter(w http.ResponseWriter) (*sessResponseWriter, bool) {
  247. ss, ok := GetResponseWriter(w, func(tst http.ResponseWriter) bool {
  248. _, ok := tst.(*sessResponseWriter)
  249. return ok
  250. })
  251. if ok {
  252. return ss.(*sessResponseWriter), true
  253. }
  254. return nil, false
  255. }
  256. // Parse a signed cookie and return the cookie value
  257. func parseSignedCookie(ck *http.Cookie, secret string) (string, error) {
  258. var val string
  259. sck := securecookie.New([]byte(secret), nil)
  260. err := sck.Decode(ck.Name, ck.Value, &val)
  261. if err != nil {
  262. return "", err
  263. }
  264. return val, nil
  265. }
  266. // Sign the specified cookie's value
  267. func signCookie(ck *http.Cookie, secret string) error {
  268. sck := securecookie.New([]byte(secret), nil)
  269. enc, err := sck.Encode(ck.Name, ck.Value)
  270. if err != nil {
  271. return err
  272. }
  273. ck.Value = enc
  274. return nil
  275. }
  276. // Compute a CRC32 hash of the session's JSON-encoded contents.
  277. func hash(s *Session) uint32 {
  278. data, err := json.Marshal(s)
  279. if err != nil {
  280. ghost.LogFn("ghost.session : error hash : %s", err)
  281. return 0 // 0 is always treated as "modified" session content
  282. }
  283. return crc32.ChecksumIEEE(data)
  284. }