diff --git a/AGHTechDoc.md b/AGHTechDoc.md index adf483d0..d901a862 100644 --- a/AGHTechDoc.md +++ b/AGHTechDoc.md @@ -54,6 +54,7 @@ Contents: * Log-in page * API: Log in * API: Log out + * API: Get current user info ## Relations between subsystems @@ -1207,7 +1208,7 @@ YAML configuration: Session DB file: - session="..." expire=123456 + session="..." user=name expire=123456 ... Session data is SHA(random()+name+password). @@ -1270,3 +1271,20 @@ Response: 302 Found Location: /login.html Set-Cookie: session=...; Expires=Thu, 01 Jan 1970 00:00:00 GMT + + +### API: Get current user info + +Request: + + GET /control/profile + +Response: + + 200 OK + + { + "name":"..." + } + +If no client is configured then authentication is disabled and server sends an empty response. diff --git a/home/auth.go b/home/auth.go index 52b62e70..98f2ccae 100644 --- a/home/auth.go +++ b/home/auth.go @@ -20,10 +20,44 @@ import ( const cookieTTL = 365 * 24 // in hours const expireTime = 30 * 24 // in hours +type session struct { + userName string + expire uint32 // expiration time (in seconds) +} + +/* +expire byte[4] +name_len byte[2] +name byte[] +*/ +func (s *session) serialize() []byte { + var data []byte + data = make([]byte, 4+2+len(s.userName)) + binary.BigEndian.PutUint32(data[0:4], s.expire) + binary.BigEndian.PutUint16(data[4:6], uint16(len(s.userName))) + copy(data[6:], []byte(s.userName)) + return data +} + +func (s *session) deserialize(data []byte) bool { + if len(data) < 4+2 { + return false + } + s.expire = binary.BigEndian.Uint32(data[0:4]) + nameLen := binary.BigEndian.Uint16(data[4:6]) + data = data[6:] + + if len(data) < int(nameLen) { + return false + } + s.userName = string(data) + return true +} + // Auth - global object type Auth struct { db *bbolt.DB - sessions map[string]uint32 // session -> expiration time (in seconds) + sessions map[string]*session // session name -> session data lock sync.Mutex users []User } @@ -37,7 +71,7 @@ type User struct { // InitAuth - create a global object func InitAuth(dbFilename string, users []User) *Auth { a := Auth{} - a.sessions = make(map[string]uint32) + a.sessions = make(map[string]*session) rand.Seed(time.Now().UTC().Unix()) var err error a.db, err = bbolt.Open(dbFilename, 0644, nil) @@ -56,6 +90,10 @@ func (a *Auth) Close() { _ = a.db.Close() } +func bucketName() []byte { + return []byte("sessions-2") +} + // load sessions from file, remove expired sessions func (a *Auth) loadSessions() { tx, err := a.db.Begin(true) @@ -67,16 +105,22 @@ func (a *Auth) loadSessions() { _ = tx.Rollback() }() - bkt := tx.Bucket([]byte("sessions")) + bkt := tx.Bucket(bucketName()) if bkt == nil { return } removed := 0 + + if tx.Bucket([]byte("sessions")) != nil { + _ = tx.DeleteBucket([]byte("sessions")) + removed = 1 + } + now := uint32(time.Now().UTC().Unix()) forEach := func(k, v []byte) error { - i := binary.BigEndian.Uint32(v) - if i <= now { + s := session{} + if !s.deserialize(v) || s.expire <= now { err = bkt.Delete(k) if err != nil { log.Error("Auth: bbolt.Delete: %s", err) @@ -85,7 +129,8 @@ func (a *Auth) loadSessions() { } return nil } - a.sessions[hex.EncodeToString(k)] = i + + a.sessions[hex.EncodeToString(k)] = &s return nil } _ = bkt.ForEach(forEach) @@ -99,11 +144,15 @@ func (a *Auth) loadSessions() { } // store session data in file -func (a *Auth) storeSession(data []byte, expire uint32) { +func (a *Auth) addSession(data []byte, s *session) { a.lock.Lock() - a.sessions[hex.EncodeToString(data)] = expire + a.sessions[hex.EncodeToString(data)] = s a.lock.Unlock() + a.storeSession(data, s) +} +// store session data in file +func (a *Auth) storeSession(data []byte, s *session) { tx, err := a.db.Begin(true) if err != nil { log.Error("Auth: bbolt.Begin: %s", err) @@ -113,15 +162,12 @@ func (a *Auth) storeSession(data []byte, expire uint32) { _ = tx.Rollback() }() - bkt, err := tx.CreateBucketIfNotExists([]byte("sessions")) + bkt, err := tx.CreateBucketIfNotExists(bucketName()) if err != nil { log.Error("Auth: bbolt.CreateBucketIfNotExists: %s", err) return } - var val []byte - val = make([]byte, 4) - binary.BigEndian.PutUint32(val, expire) - err = bkt.Put(data, val) + err = bkt.Put(data, s.serialize()) if err != nil { log.Error("Auth: bbolt.Put: %s", err) return @@ -147,7 +193,7 @@ func (a *Auth) removeSession(sess []byte) { _ = tx.Rollback() }() - bkt := tx.Bucket([]byte("sessions")) + bkt := tx.Bucket(bucketName()) if bkt == nil { log.Error("Auth: bbolt.Bucket") return @@ -174,12 +220,12 @@ func (a *Auth) CheckSession(sess string) int { update := false a.lock.Lock() - expire, ok := a.sessions[sess] + s, ok := a.sessions[sess] if !ok { a.lock.Unlock() return -1 } - if expire <= now { + if s.expire <= now { delete(a.sessions, sess) key, _ := hex.DecodeString(sess) a.removeSession(key) @@ -188,17 +234,17 @@ func (a *Auth) CheckSession(sess string) int { } newExpire := now + expireTime*60*60 - if expire/(24*60*60) != newExpire/(24*60*60) { + if s.expire/(24*60*60) != newExpire/(24*60*60) { // update expiration time once a day update = true - a.sessions[sess] = newExpire + s.expire = newExpire } a.lock.Unlock() if update { key, _ := hex.DecodeString(sess) - a.storeSession(key, expire) + a.storeSession(key, s) } return 0 @@ -238,8 +284,10 @@ func httpCookie(req loginJSON) string { expstr = expstr[:len(expstr)-len("UTC")] // "UTC" -> "GMT" expstr += "GMT" - expireSess := uint32(now.Unix()) + expireTime*60*60 - config.auth.storeSession(sess, expireSess) + s := session{} + s.userName = u.Name + s.expire = uint32(now.Unix()) + expireTime*60*60 + config.auth.addSession(sess, &s) return fmt.Sprintf("session=%s; Path=/; HttpOnly; Expires=%s", hex.EncodeToString(sess), expstr) } @@ -402,6 +450,34 @@ func (a *Auth) UserFind(login string, password string) User { return User{} } +// GetCurrentUser - get the current user +func (a *Auth) GetCurrentUser(r *http.Request) User { + cookie, err := r.Cookie("session") + if err != nil { + // there's no Cookie, check Basic authentication + user, pass, ok := r.BasicAuth() + if ok { + u := config.auth.UserFind(user, pass) + return u + } + } + + a.lock.Lock() + s, ok := a.sessions[cookie.Value] + if !ok { + a.lock.Unlock() + return User{} + } + for _, u := range a.users { + if u.Name == s.userName { + a.lock.Unlock() + return u + } + } + a.lock.Unlock() + return User{} +} + // GetUsers - get users func (a *Auth) GetUsers() []User { a.lock.Lock() diff --git a/home/auth_test.go b/home/auth_test.go index 2ae532fd..ed2c3e6a 100644 --- a/home/auth_test.go +++ b/home/auth_test.go @@ -28,6 +28,7 @@ func TestAuth(t *testing.T) { User{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"}, } a := InitAuth(fn, nil) + s := session{} user := User{Name: "name"} a.UserAdd(&user, "password") @@ -38,12 +39,16 @@ func TestAuth(t *testing.T) { sess := getSession(&users[0]) sessStr := hex.EncodeToString(sess) + now := time.Now().UTC().Unix() // check expiration - a.storeSession(sess, uint32(time.Now().UTC().Unix())) + s.expire = uint32(now) + a.addSession(sess, &s) assert.True(t, a.CheckSession(sessStr) == 1) // add session with TTL = 2 sec - a.storeSession(sess, uint32(time.Now().UTC().Unix()+2)) + s = session{} + s.expire = uint32(now + 2) + a.addSession(sess, &s) assert.True(t, a.CheckSession(sessStr) == 0) a.Close() @@ -53,6 +58,9 @@ func TestAuth(t *testing.T) { // the session is still alive assert.True(t, a.CheckSession(sessStr) == 0) + // reset our expiration time because CheckSession() has just updated it + s.expire = uint32(now + 2) + a.storeSession(sess, &s) a.Close() u := a.UserFind("name", "password") diff --git a/home/control.go b/home/control.go index 1f2eb1fa..143f73fc 100644 --- a/home/control.go +++ b/home/control.go @@ -377,6 +377,23 @@ func checkDNS(input string, bootstrap []string) error { return nil } +type profileJSON struct { + Name string `json:"name"` +} + +func handleGetProfile(w http.ResponseWriter, r *http.Request) { + pj := profileJSON{} + u := config.auth.GetCurrentUser(r) + pj.Name = u.Name + + data, err := json.Marshal(pj) + if err != nil { + httpError(w, http.StatusInternalServerError, "json.Marshal: %s", err) + return + } + _, _ = w.Write(data) +} + // -------------- // DNS-over-HTTPS // -------------- @@ -416,6 +433,7 @@ func registerControlHandlers() { httpRegister(http.MethodGet, "/control/access/list", handleAccessList) httpRegister(http.MethodPost, "/control/access/set", handleAccessSet) + httpRegister("GET", "/control/profile", handleGetProfile) RegisterFilteringHandlers() RegisterTLSHandlers() diff --git a/openapi/openapi.yaml b/openapi/openapi.yaml index 3f1474cb..eca972b3 100644 --- a/openapi/openapi.yaml +++ b/openapi/openapi.yaml @@ -970,6 +970,18 @@ paths: 302: description: OK + /profile: + get: + tags: + - global + operationId: getProfile + summary: "" + responses: + 200: + description: OK + schema: + $ref: "#/definitions/ProfileInfo" + definitions: ServerStatus: type: "object" @@ -1559,6 +1571,14 @@ definitions: description: "Network interfaces dictionary (key is the interface name)" additionalProperties: $ref: "#/definitions/NetInterface" + + ProfileInfo: + type: "object" + description: "Information about the current user" + properties: + name: + type: "string" + Client: type: "object" description: "Client information"