client.go raw
1 package internal
2
3 import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "net/http"
10 "net/url"
11
12 "github.com/go-acme/lego/v4/challenge/dns01"
13 "github.com/go-acme/lego/v4/providers/dns/internal/errutils"
14 "golang.org/x/net/publicsuffix"
15 )
16
17 const defaultBaseURL = "https://gateway.stackpath.com/dns/v1/stacks/"
18
19 // Client the API client for Stackpath.
20 type Client struct {
21 stackID string
22
23 baseURL *url.URL
24 httpClient *http.Client
25 }
26
27 // NewClient creates a new Client.
28 func NewClient(stackID string, hc *http.Client) *Client {
29 baseURL, _ := url.Parse(defaultBaseURL)
30
31 return &Client{
32 baseURL: baseURL,
33 stackID: stackID,
34 httpClient: hc,
35 }
36 }
37
38 // GetZones gets all zones.
39 // https://stackpath.dev/reference/getzones
40 func (c *Client) GetZones(ctx context.Context, domain string) (*Zone, error) {
41 endpoint := c.baseURL.JoinPath(c.stackID, "zones")
42
43 req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
44 if err != nil {
45 return nil, err
46 }
47
48 tld, err := publicsuffix.EffectiveTLDPlusOne(dns01.UnFqdn(domain))
49 if err != nil {
50 return nil, err
51 }
52
53 query := req.URL.Query()
54 query.Add("page_request.filter", fmt.Sprintf("domain='%s'", tld))
55 req.URL.RawQuery = query.Encode()
56
57 var zones Zones
58
59 err = c.do(req, &zones)
60 if err != nil {
61 return nil, err
62 }
63
64 if len(zones.Zones) == 0 {
65 return nil, fmt.Errorf("did not find zone with domain %s", domain)
66 }
67
68 return &zones.Zones[0], nil
69 }
70
71 // GetZoneRecords gets all records.
72 // https://stackpath.dev/reference/getzonerecords
73 func (c *Client) GetZoneRecords(ctx context.Context, name string, zone *Zone) ([]Record, error) {
74 endpoint := c.baseURL.JoinPath(c.stackID, "zones", zone.ID, "records")
75
76 req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
77 if err != nil {
78 return nil, err
79 }
80
81 query := req.URL.Query()
82 query.Add("page_request.filter", fmt.Sprintf("name='%s' and type='TXT'", name))
83 req.URL.RawQuery = query.Encode()
84
85 var records Records
86
87 err = c.do(req, &records)
88 if err != nil {
89 return nil, err
90 }
91
92 if len(records.Records) == 0 {
93 return nil, fmt.Errorf("did not find record with name %s", name)
94 }
95
96 return records.Records, nil
97 }
98
99 // CreateZoneRecord creates a record.
100 // https://stackpath.dev/reference/createzonerecord
101 func (c *Client) CreateZoneRecord(ctx context.Context, zone *Zone, record Record) error {
102 endpoint := c.baseURL.JoinPath(c.stackID, "zones", zone.ID, "records")
103
104 req, err := newJSONRequest(ctx, http.MethodPost, endpoint, record)
105 if err != nil {
106 return err
107 }
108
109 return c.do(req, nil)
110 }
111
112 // DeleteZoneRecord deletes a record.
113 // https://stackpath.dev/reference/deletezonerecord
114 func (c *Client) DeleteZoneRecord(ctx context.Context, zone *Zone, record Record) error {
115 endpoint := c.baseURL.JoinPath(c.stackID, "zones", zone.ID, "records", record.ID)
116
117 req, err := newJSONRequest(ctx, http.MethodDelete, endpoint, nil)
118 if err != nil {
119 return err
120 }
121
122 return c.do(req, nil)
123 }
124
125 func newJSONRequest(ctx context.Context, method string, endpoint *url.URL, payload any) (*http.Request, error) {
126 buf := new(bytes.Buffer)
127
128 if payload != nil {
129 err := json.NewEncoder(buf).Encode(payload)
130 if err != nil {
131 return nil, fmt.Errorf("failed to create request JSON body: %w", err)
132 }
133 }
134
135 req, err := http.NewRequestWithContext(ctx, method, endpoint.String(), buf)
136 if err != nil {
137 return nil, fmt.Errorf("unable to create request: %w", err)
138 }
139
140 req.Header.Set("Accept", "application/json")
141
142 if payload != nil {
143 req.Header.Set("Content-Type", "application/json")
144 }
145
146 return req, nil
147 }
148
149 func (c *Client) do(req *http.Request, result any) error {
150 resp, err := c.httpClient.Do(req)
151 if err != nil {
152 return errutils.NewHTTPDoError(req, err)
153 }
154
155 defer func() { _ = resp.Body.Close() }()
156
157 if resp.StatusCode/100 != 2 {
158 return parseError(req, resp)
159 }
160
161 if result == nil {
162 return nil
163 }
164
165 raw, err := io.ReadAll(resp.Body)
166 if err != nil {
167 return errutils.NewReadResponseError(req, resp.StatusCode, err)
168 }
169
170 err = json.Unmarshal(raw, result)
171 if err != nil {
172 return errutils.NewUnmarshalError(req, resp.StatusCode, raw, err)
173 }
174
175 return nil
176 }
177
178 func parseError(req *http.Request, resp *http.Response) error {
179 raw, _ := io.ReadAll(resp.Body)
180
181 errResp := &ErrorResponse{}
182
183 err := json.Unmarshal(raw, errResp)
184 if err != nil {
185 return errutils.NewUnexpectedStatusCodeError(req, resp.StatusCode, raw)
186 }
187
188 return errResp
189 }
190