client.go raw
1 package internal
2
3 import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "net/http"
11 "net/url"
12 "strconv"
13 "strings"
14 "sync"
15 "time"
16
17 "github.com/go-acme/lego/v4/providers/dns/internal/errutils"
18 "golang.org/x/oauth2"
19 )
20
21 const (
22 ns1 = "ns.checkdomain.de"
23 ns2 = "ns2.checkdomain.de"
24 )
25
26 // DefaultEndpoint the default API endpoint.
27 const DefaultEndpoint = "https://api.checkdomain.de"
28
29 const domainNotFound = -1
30
31 // max page limit that the checkdomain api allows.
32 const maxLimit = 100
33
34 // max integer value.
35 const maxInt = int((^uint(0)) >> 1)
36
37 // Client the Autodns API client.
38 type Client struct {
39 BaseURL *url.URL
40 httpClient *http.Client
41
42 domainIDMapping map[string]int
43 domainIDMu sync.Mutex
44 }
45
46 // NewClient creates a new Client.
47 func NewClient(hc *http.Client) *Client {
48 baseURL, _ := url.Parse(DefaultEndpoint)
49
50 if hc == nil {
51 hc = &http.Client{Timeout: 10 * time.Second}
52 }
53
54 return &Client{
55 BaseURL: baseURL,
56 httpClient: hc,
57 domainIDMapping: make(map[string]int),
58 }
59 }
60
61 func (c *Client) GetDomainIDByName(ctx context.Context, name string) (int, error) {
62 // Load from cache if exists
63 c.domainIDMu.Lock()
64 id, ok := c.domainIDMapping[name]
65 c.domainIDMu.Unlock()
66
67 if ok {
68 return id, nil
69 }
70
71 // Find out by querying API
72 domains, err := c.listDomains(ctx)
73 if err != nil {
74 return domainNotFound, err
75 }
76
77 // Linear search over all registered domains
78 for _, domain := range domains {
79 if domain.Name == name || strings.HasSuffix(name, "."+domain.Name) {
80 c.domainIDMu.Lock()
81 c.domainIDMapping[name] = domain.ID
82 c.domainIDMu.Unlock()
83
84 return domain.ID, nil
85 }
86 }
87
88 return domainNotFound, errors.New("domain not found")
89 }
90
91 func (c *Client) listDomains(ctx context.Context) ([]*Domain, error) {
92 endpoint := c.BaseURL.JoinPath("v1", "domains")
93
94 // Checkdomain also provides a query param 'query' which allows filtering domains for a string.
95 // But that functionality is kinda broken,
96 // so we scan through the whole list of registered domains to later find the one that is of interest to us.
97 q := endpoint.Query()
98 q.Set("limit", strconv.Itoa(maxLimit))
99
100 currentPage := 1
101 totalPages := maxInt
102
103 var domainList []*Domain
104
105 for currentPage <= totalPages {
106 q.Set("page", strconv.Itoa(currentPage))
107 endpoint.RawQuery = q.Encode()
108
109 req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
110 if err != nil {
111 return nil, fmt.Errorf("failed to make request: %w", err)
112 }
113
114 var res DomainListingResponse
115 if err := c.do(req, &res); err != nil {
116 return nil, fmt.Errorf("failed to send domain listing request: %w", err)
117 }
118
119 // This is the first response,
120 // so we update totalPages and allocate the slice memory.
121 if totalPages == maxInt {
122 totalPages = res.Pages
123 domainList = make([]*Domain, 0, res.Total)
124 }
125
126 domainList = append(domainList, res.Embedded.Domains...)
127 currentPage++
128 }
129
130 return domainList, nil
131 }
132
133 func (c *Client) getNameserverInfo(ctx context.Context, domainID int) (*NameserverResponse, error) {
134 endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers")
135
136 req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
137 if err != nil {
138 return nil, err
139 }
140
141 res := &NameserverResponse{}
142 if err := c.do(req, res); err != nil {
143 return nil, err
144 }
145
146 return res, nil
147 }
148
149 func (c *Client) CheckNameservers(ctx context.Context, domainID int) error {
150 info, err := c.getNameserverInfo(ctx, domainID)
151 if err != nil {
152 return err
153 }
154
155 var found1, found2 bool
156
157 for _, item := range info.Nameservers {
158 switch item.Name {
159 case ns1:
160 found1 = true
161 case ns2:
162 found2 = true
163 }
164 }
165
166 if !found1 || !found2 {
167 return errors.New("not using checkdomain nameservers, can not update records")
168 }
169
170 return nil
171 }
172
173 func (c *Client) CreateRecord(ctx context.Context, domainID int, record *Record) error {
174 endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records")
175
176 req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record)
177 if err != nil {
178 return err
179 }
180
181 return c.do(req, nil)
182 }
183
184 // DeleteTXTRecord Checkdomain doesn't seem provide a way to delete records but one can replace all records at once.
185 // The current solution is to fetch all records and then use that list minus the record deleted as the new record list.
186 // TODO: Simplify this function once Checkdomain do provide the functionality.
187 func (c *Client) DeleteTXTRecord(ctx context.Context, domainID int, recordName, recordValue string) error {
188 domainInfo, err := c.getDomainInfo(ctx, domainID)
189 if err != nil {
190 return err
191 }
192
193 nsInfo, err := c.getNameserverInfo(ctx, domainID)
194 if err != nil {
195 return err
196 }
197
198 allRecords, err := c.listRecords(ctx, domainID, "")
199 if err != nil {
200 return err
201 }
202
203 recordName = strings.TrimSuffix(recordName, "."+domainInfo.Name+".")
204
205 var recordsToKeep []*Record
206
207 // Find and delete matching records
208 for _, record := range allRecords {
209 if skipRecord(recordName, recordValue, record, nsInfo) {
210 continue
211 }
212
213 // Checkdomain API can return records without any TTL set (indicated by the value of 0).
214 // The API Call to replace the records would fail if we wouldn't specify a value.
215 // Thus, we use the default TTL queried beforehand
216 if record.TTL == 0 {
217 record.TTL = nsInfo.SOA.TTL
218 }
219
220 recordsToKeep = append(recordsToKeep, record)
221 }
222
223 return c.replaceRecords(ctx, domainID, recordsToKeep)
224 }
225
226 func (c *Client) getDomainInfo(ctx context.Context, domainID int) (*DomainResponse, error) {
227 endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID))
228
229 req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
230 if err != nil {
231 return nil, err
232 }
233
234 var res DomainResponse
235
236 err = c.do(req, &res)
237 if err != nil {
238 return nil, err
239 }
240
241 return &res, nil
242 }
243
244 func (c *Client) listRecords(ctx context.Context, domainID int, recordType string) ([]*Record, error) {
245 endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records")
246
247 q := endpoint.Query()
248 q.Set("limit", strconv.Itoa(maxLimit))
249
250 if recordType != "" {
251 q.Set("type", recordType)
252 }
253
254 currentPage := 1
255 totalPages := maxInt
256
257 var recordList []*Record
258
259 for currentPage <= totalPages {
260 q.Set("page", strconv.Itoa(currentPage))
261 endpoint.RawQuery = q.Encode()
262
263 req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
264 if err != nil {
265 return nil, fmt.Errorf("failed to create request: %w", err)
266 }
267
268 var res RecordListingResponse
269 if err := c.do(req, &res); err != nil {
270 return nil, fmt.Errorf("failed to send record listing request: %w", err)
271 }
272
273 // This is the first response, so we update totalPages and allocate the slice memory.
274 if totalPages == maxInt {
275 totalPages = res.Pages
276 recordList = make([]*Record, 0, res.Total)
277 }
278
279 recordList = append(recordList, res.Embedded.Records...)
280 currentPage++
281 }
282
283 return recordList, nil
284 }
285
286 func (c *Client) replaceRecords(ctx context.Context, domainID int, records []*Record) error {
287 endpoint := c.BaseURL.JoinPath("v1", "domains", strconv.Itoa(domainID), "nameservers", "records")
288
289 req, err := newJSONRequest(ctx, http.MethodPut, endpoint, records)
290 if err != nil {
291 return err
292 }
293
294 return c.do(req, nil)
295 }
296
297 func (c *Client) do(req *http.Request, result any) error {
298 resp, err := c.httpClient.Do(req)
299 if err != nil {
300 return errutils.NewHTTPDoError(req, err)
301 }
302
303 defer func() { _ = resp.Body.Close() }()
304
305 if resp.StatusCode/100 != 2 {
306 return errutils.NewUnexpectedResponseStatusCodeError(req, resp)
307 }
308
309 if result == nil {
310 return nil
311 }
312
313 raw, err := io.ReadAll(resp.Body)
314 if err != nil {
315 return errutils.NewReadResponseError(req, resp.StatusCode, err)
316 }
317
318 err = json.Unmarshal(raw, result)
319 if err != nil {
320 return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
321 }
322
323 return nil
324 }
325
326 func (c *Client) CleanCache(fqdn string) {
327 c.domainIDMu.Lock()
328 delete(c.domainIDMapping, fqdn)
329 c.domainIDMu.Unlock()
330 }
331
332 func skipRecord(recordName, recordValue string, record *Record, nsInfo *NameserverResponse) bool {
333 // Skip empty records
334 if record.Value == "" {
335 return true
336 }
337
338 // Skip some special records, otherwise we would get a "Nameserver update failed"
339 if record.Type == "SOA" || record.Type == "NS" || record.Name == "@" || (nsInfo.General.IncludeWWW && record.Name == "www") {
340 return true
341 }
342
343 nameMatch := recordName == "" || record.Name == recordName
344 valueMatch := recordValue == "" || record.Value == recordValue
345
346 // Skip our matching record
347 if record.Type == "TXT" && nameMatch && valueMatch {
348 return true
349 }
350
351 return false
352 }
353
354 func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
355 buf := new(bytes.Buffer)
356
357 if payload != nil {
358 err := json.NewEncoder(buf).Encode(payload)
359 if err != nil {
360 return nil, fmt.Errorf("failed to create request JSON body: %w", err)
361 }
362 }
363
364 req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf)
365 if err != nil {
366 return nil, fmt.Errorf("unable to create request: %w", err)
367 }
368
369 req.Header.Set("Accept", "application/json")
370
371 if payload != nil {
372 req.Header.Set("Content-Type", "application/json")
373 }
374
375 return req, nil
376 }
377
378 func OAuthStaticAccessToken(client *http.Client, accessToken string) *http.Client {
379 if client == nil {
380 client = &http.Client{Timeout: 5 * time.Second}
381 }
382
383 client.Transport = &oauth2.Transport{
384 Source: oauth2.StaticTokenSource(&oauth2.Token{AccessToken: accessToken}),
385 Base: client.Transport,
386 }
387
388 return client
389 }
390