client.go raw

   1  package internal
   2  
   3  import (
   4  	"bytes"
   5  	"context"
   6  	"encoding/json"
   7  	"errors"
   8  	"fmt"
   9  	"io"
  10  	"net/http"
  11  	"net/url"
  12  	"strconv"
  13  	"strings"
  14  	"sync"
  15  	"time"
  16  
  17  	"github.com/go-acme/lego/v4/providers/dns/internal/errutils"
  18  	"golang.org/x/oauth2"
  19  )
  20  
  21  const (
  22  	ns1 = "ns.checkdomain.de"
  23  	ns2 = "ns2.checkdomain.de"
  24  )
  25  
  26  // DefaultEndpoint the default API endpoint.
  27  const DefaultEndpoint = "https://api.checkdomain.de"
  28  
  29  const domainNotFound = -1
  30  
  31  // max page limit that the checkdomain api allows.
  32  const maxLimit = 100
  33  
  34  // max integer value.
  35  const maxInt = int((^uint(0)) >> 1)
  36  
  37  // Client the Autodns API client.
  38  type Client struct {
  39  	BaseURL    *url.URL
  40  	httpClient *http.Client
  41  
  42  	domainIDMapping map[string]int
  43  	domainIDMu      sync.Mutex
  44  }
  45  
  46  // NewClient creates a new Client.
  47  func NewClient(hc *http.Client) *Client {
  48  	baseURL, _ := url.Parse(DefaultEndpoint)
  49  
  50  	if hc == nil {
  51  		hc = &http.Client{Timeout: 10 * time.Second}
  52  	}
  53  
  54  	return &Client{
  55  		BaseURL:         baseURL,
  56  		httpClient:      hc,
  57  		domainIDMapping: make(map[string]int),
  58  	}
  59  }
  60  
  61  func (c *Client) GetDomainIDByName(ctx context.Context, name string) (int, error) {
  62  	// Load from cache if exists
  63  	c.domainIDMu.Lock()
  64  	id, ok := c.domainIDMapping[name]
  65  	c.domainIDMu.Unlock()
  66  
  67  	if ok {
  68  		return id, nil
  69  	}
  70  
  71  	// Find out by querying API
  72  	domains, err := c.listDomains(ctx)
  73  	if err != nil {
  74  		return domainNotFound, err
  75  	}
  76  
  77  	// Linear search over all registered domains
  78  	for _, domain := range domains {
  79  		if domain.Name == name || strings.HasSuffix(name, "."+domain.Name) {
  80  			c.domainIDMu.Lock()
  81  			c.domainIDMapping[name] = domain.ID
  82  			c.domainIDMu.Unlock()
  83  
  84  			return domain.ID, nil
  85  		}
  86  	}
  87  
  88  	return domainNotFound, errors.New("domain not found")
  89  }
  90  
  91  func (c *Client) listDomains(ctx context.Context) ([]*Domain, error) {
  92  	endpoint := c.BaseURL.JoinPath("v1", "domains")
  93  
  94  	// Checkdomain also provides a query param 'query' which allows filtering domains for a string.
  95  	// But that functionality is kinda broken,
  96  	// so we scan through the whole list of registered domains to later find the one that is of interest to us.
  97  	q := endpoint.Query()
  98  	q.Set("limit", strconv.Itoa(maxLimit))
  99  
 100  	currentPage := 1
 101  	totalPages := maxInt
 102  
 103  	var domainList []*Domain
 104  
 105  	for currentPage <= totalPages {
 106  		q.Set("page", strconv.Itoa(currentPage))
 107  		endpoint.RawQuery = q.Encode()
 108  
 109  		req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
 110  		if err != nil {
 111  			return nil, fmt.Errorf("failed to make request: %w", err)
 112  		}
 113  
 114  		var res DomainListingResponse
 115  		if err := c.do(req, &res); err != nil {
 116  			return nil, fmt.Errorf("failed to send domain listing request: %w", err)
 117  		}
 118  
 119  		// This is the first response,
 120  		// so we update totalPages and allocate the slice memory.
 121  		if totalPages == maxInt {
 122  			totalPages = res.Pages
 123  			domainList = make([]*Domain, 0, res.Total)
 124  		}
 125  
 126  		domainList = append(domainList, res.Embedded.Domains...)
 127  		currentPage++
 128  	}
 129  
 130  	return domainList, nil
 131  }
 132  
 133  func (c *Client) getNameserverInfo(ctx context.Context, domainID int) (*NameserverResponse, error) {
 134  	endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers")
 135  
 136  	req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
 137  	if err != nil {
 138  		return nil, err
 139  	}
 140  
 141  	res := &NameserverResponse{}
 142  	if err := c.do(req, res); err != nil {
 143  		return nil, err
 144  	}
 145  
 146  	return res, nil
 147  }
 148  
 149  func (c *Client) CheckNameservers(ctx context.Context, domainID int) error {
 150  	info, err := c.getNameserverInfo(ctx, domainID)
 151  	if err != nil {
 152  		return err
 153  	}
 154  
 155  	var found1, found2 bool
 156  
 157  	for _, item := range info.Nameservers {
 158  		switch item.Name {
 159  		case ns1:
 160  			found1 = true
 161  		case ns2:
 162  			found2 = true
 163  		}
 164  	}
 165  
 166  	if !found1 || !found2 {
 167  		return errors.New("not using checkdomain nameservers, can not update records")
 168  	}
 169  
 170  	return nil
 171  }
 172  
 173  func (c *Client) CreateRecord(ctx context.Context, domainID int, record *Record) error {
 174  	endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records")
 175  
 176  	req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record)
 177  	if err != nil {
 178  		return err
 179  	}
 180  
 181  	return c.do(req, nil)
 182  }
 183  
 184  // DeleteTXTRecord Checkdomain doesn't seem provide a way to delete records but one can replace all records at once.
 185  // The current solution is to fetch all records and then use that list minus the record deleted as the new record list.
 186  // TODO: Simplify this function once Checkdomain do provide the functionality.
 187  func (c *Client) DeleteTXTRecord(ctx context.Context, domainID int, recordName, recordValue string) error {
 188  	domainInfo, err := c.getDomainInfo(ctx, domainID)
 189  	if err != nil {
 190  		return err
 191  	}
 192  
 193  	nsInfo, err := c.getNameserverInfo(ctx, domainID)
 194  	if err != nil {
 195  		return err
 196  	}
 197  
 198  	allRecords, err := c.listRecords(ctx, domainID, "")
 199  	if err != nil {
 200  		return err
 201  	}
 202  
 203  	recordName = strings.TrimSuffix(recordName, "."+domainInfo.Name+".")
 204  
 205  	var recordsToKeep []*Record
 206  
 207  	// Find and delete matching records
 208  	for _, record := range allRecords {
 209  		if skipRecord(recordName, recordValue, record, nsInfo) {
 210  			continue
 211  		}
 212  
 213  		// Checkdomain API can return records without any TTL set (indicated by the value of 0).
 214  		// The API Call to replace the records would fail if we wouldn't specify a value.
 215  		// Thus, we use the default TTL queried beforehand
 216  		if record.TTL == 0 {
 217  			record.TTL = nsInfo.SOA.TTL
 218  		}
 219  
 220  		recordsToKeep = append(recordsToKeep, record)
 221  	}
 222  
 223  	return c.replaceRecords(ctx, domainID, recordsToKeep)
 224  }
 225  
 226  func (c *Client) getDomainInfo(ctx context.Context, domainID int) (*DomainResponse, error) {
 227  	endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID))
 228  
 229  	req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
 230  	if err != nil {
 231  		return nil, err
 232  	}
 233  
 234  	var res DomainResponse
 235  
 236  	err = c.do(req, &res)
 237  	if err != nil {
 238  		return nil, err
 239  	}
 240  
 241  	return &res, nil
 242  }
 243  
 244  func (c *Client) listRecords(ctx context.Context, domainID int, recordType string) ([]*Record, error) {
 245  	endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records")
 246  
 247  	q := endpoint.Query()
 248  	q.Set("limit", strconv.Itoa(maxLimit))
 249  
 250  	if recordType != "" {
 251  		q.Set("type", recordType)
 252  	}
 253  
 254  	currentPage := 1
 255  	totalPages := maxInt
 256  
 257  	var recordList []*Record
 258  
 259  	for currentPage <= totalPages {
 260  		q.Set("page", strconv.Itoa(currentPage))
 261  		endpoint.RawQuery = q.Encode()
 262  
 263  		req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
 264  		if err != nil {
 265  			return nil, fmt.Errorf("failed to create request: %w", err)
 266  		}
 267  
 268  		var res RecordListingResponse
 269  		if err := c.do(req, &res); err != nil {
 270  			return nil, fmt.Errorf("failed to send record listing request: %w", err)
 271  		}
 272  
 273  		// This is the first response, so we update totalPages and allocate the slice memory.
 274  		if totalPages == maxInt {
 275  			totalPages = res.Pages
 276  			recordList = make([]*Record, 0, res.Total)
 277  		}
 278  
 279  		recordList = append(recordList, res.Embedded.Records...)
 280  		currentPage++
 281  	}
 282  
 283  	return recordList, nil
 284  }
 285  
 286  func (c *Client) replaceRecords(ctx context.Context, domainID int, records []*Record) error {
 287  	endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records")
 288  
 289  	req, err := newJSONRequest(ctx, http.MethodPut, endpoint, records)
 290  	if err != nil {
 291  		return err
 292  	}
 293  
 294  	return c.do(req, nil)
 295  }
 296  
 297  func (c *Client) do(req *http.Request, result any) error {
 298  	resp, err := c.httpClient.Do(req)
 299  	if err != nil {
 300  		return errutils.NewHTTPDoError(req, err)
 301  	}
 302  
 303  	defer func() { _ = resp.Body.Close() }()
 304  
 305  	if resp.StatusCode/100 != 2 {
 306  		return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
 307  	}
 308  
 309  	if result == nil {
 310  		return nil
 311  	}
 312  
 313  	raw, err := io.ReadAll(resp.Body)
 314  	if err != nil {
 315  		return errutils.NewReadResponseError(req, resp.StatusCode, err)
 316  	}
 317  
 318  	err = json.Unmarshal(raw, result)
 319  	if err != nil {
 320  		return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
 321  	}
 322  
 323  	return nil
 324  }
 325  
 326  func (c *Client) CleanCache(fqdn string) {
 327  	c.domainIDMu.Lock()
 328  	delete(c.domainIDMapping, fqdn)
 329  	c.domainIDMu.Unlock()
 330  }
 331  
 332  func skipRecord(recordName, recordValue string, record *Record, nsInfo *NameserverResponse) bool {
 333  	// Skip empty records
 334  	if record.Value == "" {
 335  		return true
 336  	}
 337  
 338  	// Skip some special records, otherwise we would get a "Nameserver update failed"
 339  	if record.Type == "SOA" || record.Type == "NS" || record.Name == "@" || (nsInfo.General.IncludeWWW && record.Name == "www") {
 340  		return true
 341  	}
 342  
 343  	nameMatch := recordName == "" || record.Name == recordName
 344  	valueMatch := recordValue == "" || record.Value == recordValue
 345  
 346  	// Skip our matching record
 347  	if record.Type == "TXT" && nameMatch && valueMatch {
 348  		return true
 349  	}
 350  
 351  	return false
 352  }
 353  
 354  func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
 355  	buf := new(bytes.Buffer)
 356  
 357  	if payload != nil {
 358  		err := json.NewEncoder(buf).Encode(payload)
 359  		if err != nil {
 360  			return nil, fmt.Errorf("failed to create request JSON body: %w", err)
 361  		}
 362  	}
 363  
 364  	req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf)
 365  	if err != nil {
 366  		return nil, fmt.Errorf("unable to create request: %w", err)
 367  	}
 368  
 369  	req.Header.Set("Accept", "application/json")
 370  
 371  	if payload != nil {
 372  		req.Header.Set("Content-Type", "application/json")
 373  	}
 374  
 375  	return req, nil
 376  }
 377  
 378  func OAuthStaticAccessToken(client *http.Client, accessToken string) *http.Client {
 379  	if client == nil {
 380  		client = &http.Client{Timeout: 5 * time.Second}
 381  	}
 382  
 383  	client.Transport = &oauth2.Transport{
 384  		Source: oauth2.StaticTokenSource(&oauth2.Token{AccessToken: accessToken}),
 385  		Base:   client.Transport,
 386  	}
 387  
 388  	return client
 389  }
 390