adguard-exporter/internal/adguard/client.go

279 lines
7.3 KiB
Go

package adguard
import (
"crypto/tls"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"strconv"
"strings"
"time"
"github.com/ebrianne/adguard-exporter/internal/metrics"
"github.com/mitchellh/mapstructure"
)
var (
port uint16
statusURLPattern = "%s://%s:%d/control/status"
statsURLPattern = "%s://%s:%d/control/stats"
logstatsURLPattern = "%s://%s:%d/control/querylog?limit=%s&response_status=\"all\""
resolveRDNSURLPattern = "%s://%s:%d/control/clients/find?%s"
m map[string]int
)
// Client struct is a AdGuard client to request an instance of a AdGuard ad blocker.
type Client struct {
httpClient http.Client
interval time.Duration
logLimit string
protocol string
hostname string
port uint16
username string
password string
rdnsenabled bool
}
// NewClient method initializes a new AdGuard client.
func NewClient(protocol, hostname, username, password, adport string, interval time.Duration, logLimit string, rdnsenabled bool) *Client {
temp, err := strconv.Atoi(adport)
if err != nil {
log.Fatal(err)
}
port = uint16(temp)
return &Client{
protocol: protocol,
hostname: hostname,
port: port,
username: username,
password: password,
interval: interval,
logLimit: logLimit,
httpClient: http.Client{
Transport: &http.Transport{TLSClientConfig: GetTlsConfig()},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
},
rdnsenabled: rdnsenabled,
}
}
// Scrape method authenticates and retrieves statistics from AdGuard JSON API
// and then pass them as Prometheus metrics.
func (c *Client) Scrape() {
for range time.Tick(c.interval) {
allstats := c.getStatistics()
//Set the metrics
c.setMetrics(allstats.status, allstats.stats, allstats.logStats, allstats.rdns)
log.Printf("New tick of statistics: %s", allstats.stats.ToString())
}
}
// Function to set the prometheus metrics
func (c *Client) setMetrics(status *Status, stats *Stats, logstats *LogStats, rdns map[string]string) {
//Status
var isRunning int = 0
if status.Running == true {
isRunning = 1
}
metrics.Running.WithLabelValues(c.hostname).Set(float64(isRunning))
var isProtected int = 0
if status.ProtectionEnabled == true {
isProtected = 1
}
metrics.ProtectionEnabled.WithLabelValues(c.hostname).Set(float64(isProtected))
//Stats
metrics.AvgProcessingTime.WithLabelValues(c.hostname).Set(float64(stats.AvgProcessingTime))
metrics.DnsQueries.WithLabelValues(c.hostname).Set(float64(stats.DnsQueries))
metrics.BlockedFiltering.WithLabelValues(c.hostname).Set(float64(stats.BlockedFiltering))
metrics.ParentalFiltering.WithLabelValues(c.hostname).Set(float64(stats.ParentalFiltering))
metrics.SafeBrowsingFiltering.WithLabelValues(c.hostname).Set(float64(stats.SafeBrowsingFiltering))
metrics.SafeSearchFiltering.WithLabelValues(c.hostname).Set(float64(stats.SafeSearchFiltering))
for l := range stats.TopQueries {
for domain, value := range stats.TopQueries[l] {
metrics.TopQueries.WithLabelValues(c.hostname, domain).Set(float64(value))
}
}
for l := range stats.TopBlocked {
for domain, value := range stats.TopBlocked[l] {
metrics.TopBlocked.WithLabelValues(c.hostname, domain).Set(float64(value))
}
}
for l := range stats.TopClients {
for source, value := range stats.TopClients[l] {
if c.rdnsenabled && isValidIp(source) {
hostName, exists := rdns[source]
if exists {
metrics.TopClients.WithLabelValues(c.hostname, hostName).Set(float64(value))
continue
}
}
metrics.TopClients.WithLabelValues(c.hostname, source).Set(float64(value))
}
}
//LogQuery
m = make(map[string]int)
logdata := logstats.Data
for i := range logdata {
dnsanswer := logdata[i].Answer
if dnsanswer != nil && len(dnsanswer) > 0 {
for j := range dnsanswer {
var dnsType string
//Check the type of dnsanswer[j].Value, if string leave it be, otherwise get back the object to get the correct DNS type
switch v := dnsanswer[j].Value.(type) {
case string:
dnsType = dnsanswer[j].Type
m[dnsType] += 1
case map[string]interface{}:
var dns65 Type65
mapstructure.Decode(v, &dns65)
dnsType = "TYPE" + strconv.Itoa(dns65.Hdr.Rrtype)
m[dnsType] += 1
default:
continue
}
}
}
}
for key, value := range m {
metrics.QueryTypes.WithLabelValues(c.hostname, key).Set(float64(value))
}
//clear the map
for k := range m {
delete(m, k)
}
}
// Function to get the general stats
func (c *Client) getStatistics() *AllStats {
var status Status
statusURL := fmt.Sprintf(statusURLPattern, c.protocol, c.hostname, c.port)
body := c.MakeRequest(statusURL)
err := json.Unmarshal(body, &status)
if err != nil {
log.Println("Unable to unmarshal Adguard log statistics to log statistics struct model", err)
}
var stats Stats
statsURL := fmt.Sprintf(statsURLPattern, c.protocol, c.hostname, c.port)
body = c.MakeRequest(statsURL)
err = json.Unmarshal(body, &stats)
if err != nil {
log.Println("Unable to unmarshal Adguard statistics to statistics struct model", err)
}
var logstats LogStats
logstatsURL := fmt.Sprintf(logstatsURLPattern, c.protocol, c.hostname, c.port, c.logLimit)
body = c.MakeRequest(logstatsURL)
err = json.Unmarshal(body, &logstats)
if err != nil {
log.Println("Unable to unmarshal Adguard log statistics to log statistics struct model", err)
}
var allstats AllStats
allstats.status = &status
allstats.stats = &stats
allstats.logStats = &logstats
if c.rdnsenabled {
var sb strings.Builder
for l := range stats.TopClients {
for source, _ := range stats.TopClients[l] {
sb.WriteString(fmt.Sprintf("ip%d=%s", l, source))
if l < len(stats.TopClients)-1 {
sb.WriteString("&")
}
}
}
rdnsURL := fmt.Sprintf(resolveRDNSURLPattern, c.protocol, c.hostname, c.port, sb.String())
body = c.MakeRequest(rdnsURL)
var results []map[string]interface{}
err = json.Unmarshal(body, &results)
if err != nil {
log.Println("Unable to unmarshal Reverse DNS", err)
}
rdnsData := make(map[string]string)
for _, result := range results {
for key := range result {
data := result[key].(map[string]interface{})
rdnsData[key] = data["name"].(string)
}
}
allstats.rdns = rdnsData
}
return &allstats
}
func (c *Client) MakeRequest(url string) []byte {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
log.Fatal("An error has occurred when creating HTTP statistics request", err)
}
req.Host = c.hostname
req.Header.Add("User-Agent", "Mozilla/5.0")
if c.isUsingPassword() {
c.authenticateRequest(req)
}
resp, err := c.httpClient.Do(req)
if err != nil {
log.Fatal("An error has occurred during login to Adguard", err)
}
if resp.StatusCode != 200 {
log.Fatal("An error occured in the request, Status Code ", resp.StatusCode)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Fatal("Unable to read Adguard statistics HTTP response", err)
}
return body
}
func (c *Client) isUsingPassword() bool {
return len(c.password) > 0
}
func (c *Client) authenticateRequest(req *http.Request) {
req.SetBasicAuth(c.username, c.password)
}
func GetTlsConfig() *tls.Config {
return &tls.Config{
InsecureSkipVerify: true,
}
}
func isValidIp(ip string) bool {
if net.ParseIP(ip) == nil {
return false
} else {
return true
}
}