From 925c5df801dc940239930034cdee3b683304b860 Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Tue, 22 Dec 2020 21:05:12 +0300 Subject: [PATCH] home: improve checkSession --- internal/home/auth.go | 41 +++++++++++++++++++++++++------------- internal/home/auth_test.go | 14 ++++++------- 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/internal/home/auth.go b/internal/home/auth.go index 941c97ab..01f89a26 100644 --- a/internal/home/auth.go +++ b/internal/home/auth.go @@ -59,10 +59,10 @@ func (s *session) deserialize(data []byte) bool { // Auth - global object type Auth struct { db *bbolt.DB - sessions map[string]*session // session name -> session data - lock sync.Mutex + sessions map[string]*session users []User - sessionTTL uint32 // in seconds + lock sync.Mutex + sessionTTL uint32 } // User object @@ -223,23 +223,35 @@ func (a *Auth) removeSession(sess []byte) { log.Debug("Auth: removed session from DB") } -// CheckSession - check if session is valid -// Return 0 if OK; -1 if session doesn't exist; 1 if session has expired -func (a *Auth) CheckSession(sess string) int { +// checkSessionResult is the result of checking a session. +type checkSessionResult int + +// checkSessionResult constants. +const ( + checkSessionOK checkSessionResult = 0 + checkSessionNotFound checkSessionResult = -1 + checkSessionExpired checkSessionResult = 1 +) + +// checkSession checks if the session is valid. +func (a *Auth) checkSession(sess string) (res checkSessionResult) { now := uint32(time.Now().UTC().Unix()) update := false a.lock.Lock() defer a.lock.Unlock() + s, ok := a.sessions[sess] if !ok { - return -1 + return checkSessionNotFound } + if s.expire <= now { delete(a.sessions, sess) key, _ := hex.DecodeString(sess) a.removeSession(key) - return 1 + + return checkSessionExpired } newExpire := now + a.sessionTTL @@ -256,7 +268,7 @@ func (a *Auth) CheckSession(sess string) int { } } - return 0 + return checkSessionOK } // RemoveSession - remove session @@ -389,8 +401,8 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (authFirst bool) ok = true } else if err == nil { - r := Context.auth.CheckSession(cookie.Value) - if r == 0 { + r := Context.auth.checkSession(cookie.Value) + if r == checkSessionOK { ok = true } else if r < 0 { log.Debug("Auth: invalid cookie value: %s", cookie) @@ -431,12 +443,13 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re authRequired := Context.auth != nil && Context.auth.AuthRequired() cookie, err := r.Cookie(sessionCookieName) if authRequired && err == nil { - r := Context.auth.CheckSession(cookie.Value) - if r == 0 { + r := Context.auth.checkSession(cookie.Value) + if r == checkSessionOK { w.Header().Set("Location", "/") w.WriteHeader(http.StatusFound) + return - } else if r < 0 { + } else if r == checkSessionNotFound { log.Debug("Auth: invalid cookie value: %s", cookie) } } diff --git a/internal/home/auth_test.go b/internal/home/auth_test.go index 25db2dd6..0998a2a6 100644 --- a/internal/home/auth_test.go +++ b/internal/home/auth_test.go @@ -38,7 +38,7 @@ func TestAuth(t *testing.T) { user := User{Name: "name"} a.UserAdd(&user, "password") - assert.True(t, a.CheckSession("notfound") == -1) + assert.Equal(t, checkSessionNotFound, a.checkSession("notfound")) a.RemoveSession("notfound") sess, err := getSession(&users[0]) @@ -49,13 +49,13 @@ func TestAuth(t *testing.T) { // check expiration s.expire = uint32(now) a.addSession(sess, &s) - assert.True(t, a.CheckSession(sessStr) == 1) + assert.Equal(t, checkSessionExpired, a.checkSession(sessStr)) // add session with TTL = 2 sec s = session{} s.expire = uint32(time.Now().UTC().Unix() + 2) a.addSession(sess, &s) - assert.True(t, a.CheckSession(sessStr) == 0) + assert.Equal(t, checkSessionOK, a.checkSession(sessStr)) a.Close() @@ -63,8 +63,8 @@ func TestAuth(t *testing.T) { a = InitAuth(fn, users, 60) // the session is still alive - assert.True(t, a.CheckSession(sessStr) == 0) - // reset our expiration time because CheckSession() has just updated it + assert.Equal(t, checkSessionOK, a.checkSession(sessStr)) + // reset our expiration time because checkSession() has just updated it s.expire = uint32(time.Now().UTC().Unix() + 2) a.storeSession(sess, &s) a.Close() @@ -76,7 +76,7 @@ func TestAuth(t *testing.T) { // load and remove expired sessions a = InitAuth(fn, users, 60) - assert.True(t, a.CheckSession(sessStr) == -1) + assert.Equal(t, checkSessionNotFound, a.checkSession(sessStr)) a.Close() os.Remove(fn) @@ -111,7 +111,7 @@ func TestAuthHTTP(t *testing.T) { Context.auth = InitAuth(fn, users, 60) handlerCalled := false - handler := func(w http.ResponseWriter, r *http.Request) { + handler := func(_ http.ResponseWriter, _ *http.Request) { handlerCalled = true } handler2 := optionalAuth(handler)