diff --git a/AGHTechDoc.md b/AGHTechDoc.md index adf483d0..d901a862 100644 --- a/AGHTechDoc.md +++ b/AGHTechDoc.md @@ -54,6 +54,7 @@ Contents: * Log-in page * API: Log in * API: Log out + * API: Get current user info ## Relations between subsystems @@ -1207,7 +1208,7 @@ YAML configuration: Session DB file: - session="..." expire=123456 + session="..." user=name expire=123456 ... Session data is SHA(random()+name+password). @@ -1270,3 +1271,20 @@ Response: 302 Found Location: /login.html Set-Cookie: session=...; Expires=Thu, 01 Jan 1970 00:00:00 GMT + + +### API: Get current user info + +Request: + + GET /control/profile + +Response: + + 200 OK + + { + "name":"..." + } + +If no client is configured then authentication is disabled and server sends an empty response. diff --git a/client/src/actions/index.js b/client/src/actions/index.js index 3087c47d..28c2a713 100644 --- a/client/src/actions/index.js +++ b/client/src/actions/index.js @@ -213,6 +213,21 @@ export const getClients = () => async (dispatch) => { } }; +export const getProfileRequest = createAction('GET_PROFILE_REQUEST'); +export const getProfileFailure = createAction('GET_PROFILE_FAILURE'); +export const getProfileSuccess = createAction('GET_PROFILE_SUCCESS'); + +export const getProfile = () => async (dispatch) => { + dispatch(getProfileRequest()); + try { + const profile = await apiClient.getProfile(); + dispatch(getProfileSuccess(profile)); + } catch (error) { + dispatch(addErrorToast({ error })); + dispatch(getProfileFailure()); + } +}; + export const dnsStatusRequest = createAction('DNS_STATUS_REQUEST'); export const dnsStatusFailure = createAction('DNS_STATUS_FAILURE'); export const dnsStatusSuccess = createAction('DNS_STATUS_SUCCESS'); @@ -224,6 +239,7 @@ export const getDnsStatus = () => async (dispatch) => { dispatch(dnsStatusSuccess(dnsStatus)); dispatch(getVersion()); dispatch(getTlsStatus()); + dispatch(getProfile()); } catch (error) { dispatch(addErrorToast({ error })); dispatch(dnsStatusFailure()); diff --git a/client/src/api/Api.js b/client/src/api/Api.js index c5ced2b8..470577a8 100644 --- a/client/src/api/Api.js +++ b/client/src/api/Api.js @@ -525,6 +525,14 @@ class Api { }; return this.makeRequest(path, method, config); } + + // Profile + GET_PROFILE = { path: 'profile', method: 'GET' }; + + getProfile() { + const { path, method } = this.GET_PROFILE; + return this.makeRequest(path, method); + } } const apiClient = new Api(); diff --git a/client/src/components/Header/index.js b/client/src/components/Header/index.js index 28fa0767..8d16e614 100644 --- a/client/src/components/Header/index.js +++ b/client/src/components/Header/index.js @@ -60,9 +60,11 @@ class Header extends Component { />
- - sign_out - + {!dashboard.processingProfile && dashboard.name && + + sign_out + + }
diff --git a/client/src/reducers/index.js b/client/src/reducers/index.js index 589da42e..0e8ff407 100644 --- a/client/src/reducers/index.js +++ b/client/src/reducers/index.js @@ -189,6 +189,14 @@ const dashboard = handleActions( processingDnsSettings: false, }; }, + + [actions.getProfileRequest]: state => ({ ...state, processingProfile: true }), + [actions.getProfileFailure]: state => ({ ...state, processingProfile: false }), + [actions.getProfileSuccess]: (state, { payload }) => ({ + ...state, + name: payload.name, + processingProfile: false, + }), }, { processing: true, @@ -198,6 +206,7 @@ const dashboard = handleActions( processingClients: true, processingUpdate: false, processingDnsSettings: true, + processingProfile: true, upstreamDns: '', bootstrapDns: '', allServers: false, @@ -209,6 +218,7 @@ const dashboard = handleActions( dnsVersion: '', clients: [], autoClients: [], + name: '', }, ); diff --git a/home/auth.go b/home/auth.go index 52b62e70..98f2ccae 100644 --- a/home/auth.go +++ b/home/auth.go @@ -20,10 +20,44 @@ import ( const cookieTTL = 365 * 24 // in hours const expireTime = 30 * 24 // in hours +type session struct { + userName string + expire uint32 // expiration time (in seconds) +} + +/* +expire byte[4] +name_len byte[2] +name byte[] +*/ +func (s *session) serialize() []byte { + var data []byte + data = make([]byte, 4+2+len(s.userName)) + binary.BigEndian.PutUint32(data[0:4], s.expire) + binary.BigEndian.PutUint16(data[4:6], uint16(len(s.userName))) + copy(data[6:], []byte(s.userName)) + return data +} + +func (s *session) deserialize(data []byte) bool { + if len(data) < 4+2 { + return false + } + s.expire = binary.BigEndian.Uint32(data[0:4]) + nameLen := binary.BigEndian.Uint16(data[4:6]) + data = data[6:] + + if len(data) < int(nameLen) { + return false + } + s.userName = string(data) + return true +} + // Auth - global object type Auth struct { db *bbolt.DB - sessions map[string]uint32 // session -> expiration time (in seconds) + sessions map[string]*session // session name -> session data lock sync.Mutex users []User } @@ -37,7 +71,7 @@ type User struct { // InitAuth - create a global object func InitAuth(dbFilename string, users []User) *Auth { a := Auth{} - a.sessions = make(map[string]uint32) + a.sessions = make(map[string]*session) rand.Seed(time.Now().UTC().Unix()) var err error a.db, err = bbolt.Open(dbFilename, 0644, nil) @@ -56,6 +90,10 @@ func (a *Auth) Close() { _ = a.db.Close() } +func bucketName() []byte { + return []byte("sessions-2") +} + // load sessions from file, remove expired sessions func (a *Auth) loadSessions() { tx, err := a.db.Begin(true) @@ -67,16 +105,22 @@ func (a *Auth) loadSessions() { _ = tx.Rollback() }() - bkt := tx.Bucket([]byte("sessions")) + bkt := tx.Bucket(bucketName()) if bkt == nil { return } removed := 0 + + if tx.Bucket([]byte("sessions")) != nil { + _ = tx.DeleteBucket([]byte("sessions")) + removed = 1 + } + now := uint32(time.Now().UTC().Unix()) forEach := func(k, v []byte) error { - i := binary.BigEndian.Uint32(v) - if i <= now { + s := session{} + if !s.deserialize(v) || s.expire <= now { err = bkt.Delete(k) if err != nil { log.Error("Auth: bbolt.Delete: %s", err) @@ -85,7 +129,8 @@ func (a *Auth) loadSessions() { } return nil } - a.sessions[hex.EncodeToString(k)] = i + + a.sessions[hex.EncodeToString(k)] = &s return nil } _ = bkt.ForEach(forEach) @@ -99,11 +144,15 @@ func (a *Auth) loadSessions() { } // store session data in file -func (a *Auth) storeSession(data []byte, expire uint32) { +func (a *Auth) addSession(data []byte, s *session) { a.lock.Lock() - a.sessions[hex.EncodeToString(data)] = expire + a.sessions[hex.EncodeToString(data)] = s a.lock.Unlock() + a.storeSession(data, s) +} +// store session data in file +func (a *Auth) storeSession(data []byte, s *session) { tx, err := a.db.Begin(true) if err != nil { log.Error("Auth: bbolt.Begin: %s", err) @@ -113,15 +162,12 @@ func (a *Auth) storeSession(data []byte, expire uint32) { _ = tx.Rollback() }() - bkt, err := tx.CreateBucketIfNotExists([]byte("sessions")) + bkt, err := tx.CreateBucketIfNotExists(bucketName()) if err != nil { log.Error("Auth: bbolt.CreateBucketIfNotExists: %s", err) return } - var val []byte - val = make([]byte, 4) - binary.BigEndian.PutUint32(val, expire) - err = bkt.Put(data, val) + err = bkt.Put(data, s.serialize()) if err != nil { log.Error("Auth: bbolt.Put: %s", err) return @@ -147,7 +193,7 @@ func (a *Auth) removeSession(sess []byte) { _ = tx.Rollback() }() - bkt := tx.Bucket([]byte("sessions")) + bkt := tx.Bucket(bucketName()) if bkt == nil { log.Error("Auth: bbolt.Bucket") return @@ -174,12 +220,12 @@ func (a *Auth) CheckSession(sess string) int { update := false a.lock.Lock() - expire, ok := a.sessions[sess] + s, ok := a.sessions[sess] if !ok { a.lock.Unlock() return -1 } - if expire <= now { + if s.expire <= now { delete(a.sessions, sess) key, _ := hex.DecodeString(sess) a.removeSession(key) @@ -188,17 +234,17 @@ func (a *Auth) CheckSession(sess string) int { } newExpire := now + expireTime*60*60 - if expire/(24*60*60) != newExpire/(24*60*60) { + if s.expire/(24*60*60) != newExpire/(24*60*60) { // update expiration time once a day update = true - a.sessions[sess] = newExpire + s.expire = newExpire } a.lock.Unlock() if update { key, _ := hex.DecodeString(sess) - a.storeSession(key, expire) + a.storeSession(key, s) } return 0 @@ -238,8 +284,10 @@ func httpCookie(req loginJSON) string { expstr = expstr[:len(expstr)-len("UTC")] // "UTC" -> "GMT" expstr += "GMT" - expireSess := uint32(now.Unix()) + expireTime*60*60 - config.auth.storeSession(sess, expireSess) + s := session{} + s.userName = u.Name + s.expire = uint32(now.Unix()) + expireTime*60*60 + config.auth.addSession(sess, &s) return fmt.Sprintf("session=%s; Path=/; HttpOnly; Expires=%s", hex.EncodeToString(sess), expstr) } @@ -402,6 +450,34 @@ func (a *Auth) UserFind(login string, password string) User { return User{} } +// GetCurrentUser - get the current user +func (a *Auth) GetCurrentUser(r *http.Request) User { + cookie, err := r.Cookie("session") + if err != nil { + // there's no Cookie, check Basic authentication + user, pass, ok := r.BasicAuth() + if ok { + u := config.auth.UserFind(user, pass) + return u + } + } + + a.lock.Lock() + s, ok := a.sessions[cookie.Value] + if !ok { + a.lock.Unlock() + return User{} + } + for _, u := range a.users { + if u.Name == s.userName { + a.lock.Unlock() + return u + } + } + a.lock.Unlock() + return User{} +} + // GetUsers - get users func (a *Auth) GetUsers() []User { a.lock.Lock() diff --git a/home/auth_test.go b/home/auth_test.go index 2ae532fd..ed2c3e6a 100644 --- a/home/auth_test.go +++ b/home/auth_test.go @@ -28,6 +28,7 @@ func TestAuth(t *testing.T) { User{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"}, } a := InitAuth(fn, nil) + s := session{} user := User{Name: "name"} a.UserAdd(&user, "password") @@ -38,12 +39,16 @@ func TestAuth(t *testing.T) { sess := getSession(&users[0]) sessStr := hex.EncodeToString(sess) + now := time.Now().UTC().Unix() // check expiration - a.storeSession(sess, uint32(time.Now().UTC().Unix())) + s.expire = uint32(now) + a.addSession(sess, &s) assert.True(t, a.CheckSession(sessStr) == 1) // add session with TTL = 2 sec - a.storeSession(sess, uint32(time.Now().UTC().Unix()+2)) + s = session{} + s.expire = uint32(now + 2) + a.addSession(sess, &s) assert.True(t, a.CheckSession(sessStr) == 0) a.Close() @@ -53,6 +58,9 @@ func TestAuth(t *testing.T) { // the session is still alive assert.True(t, a.CheckSession(sessStr) == 0) + // reset our expiration time because CheckSession() has just updated it + s.expire = uint32(now + 2) + a.storeSession(sess, &s) a.Close() u := a.UserFind("name", "password") diff --git a/home/control.go b/home/control.go index 1f2eb1fa..143f73fc 100644 --- a/home/control.go +++ b/home/control.go @@ -377,6 +377,23 @@ func checkDNS(input string, bootstrap []string) error { return nil } +type profileJSON struct { + Name string `json:"name"` +} + +func handleGetProfile(w http.ResponseWriter, r *http.Request) { + pj := profileJSON{} + u := config.auth.GetCurrentUser(r) + pj.Name = u.Name + + data, err := json.Marshal(pj) + if err != nil { + httpError(w, http.StatusInternalServerError, "json.Marshal: %s", err) + return + } + _, _ = w.Write(data) +} + // -------------- // DNS-over-HTTPS // -------------- @@ -416,6 +433,7 @@ func registerControlHandlers() { httpRegister(http.MethodGet, "/control/access/list", handleAccessList) httpRegister(http.MethodPost, "/control/access/set", handleAccessSet) + httpRegister("GET", "/control/profile", handleGetProfile) RegisterFilteringHandlers() RegisterTLSHandlers() diff --git a/openapi/CHANGELOG.md b/openapi/CHANGELOG.md index 7dc74883..281d58e5 100644 --- a/openapi/CHANGELOG.md +++ b/openapi/CHANGELOG.md @@ -1,6 +1,23 @@ # AdGuard Home API Change Log +## v0.99.1: API changes + +### API: Get current user info: GET /control/profile + +Request: + + GET /control/profile + +Response: + + 200 OK + + { + "name":"..." + } + + ## v0.99: incompatible API changes * A note about web user authentication diff --git a/openapi/openapi.yaml b/openapi/openapi.yaml index 3f1474cb..eca972b3 100644 --- a/openapi/openapi.yaml +++ b/openapi/openapi.yaml @@ -970,6 +970,18 @@ paths: 302: description: OK + /profile: + get: + tags: + - global + operationId: getProfile + summary: "" + responses: + 200: + description: OK + schema: + $ref: "#/definitions/ProfileInfo" + definitions: ServerStatus: type: "object" @@ -1559,6 +1571,14 @@ definitions: description: "Network interfaces dictionary (key is the interface name)" additionalProperties: $ref: "#/definitions/NetInterface" + + ProfileInfo: + type: "object" + description: "Information about the current user" + properties: + name: + type: "string" + Client: type: "object" description: "Client information" diff --git a/stats/stats_unit.go b/stats/stats_unit.go index 1d524bcc..3db14d5b 100644 --- a/stats/stats_unit.go +++ b/stats/stats_unit.go @@ -346,7 +346,7 @@ func (s *statsCtx) loadUnitFromDB(tx *bolt.Tx, id uint32) *unitDB { return nil } - log.Tracef("Loading unit %d", id) + // log.Tracef("Loading unit %d", id) var buf bytes.Buffer buf.Write(bkt.Get([]byte{0}))