cpanel.go raw

   1  // Package cpanel implements a DNS provider for solving the DNS-01 challenge using CPanel.
   2  package cpanel
   3  
   4  import (
   5  	"context"
   6  	"encoding/base64"
   7  	"errors"
   8  	"fmt"
   9  	"net/http"
  10  	"slices"
  11  	"strings"
  12  	"time"
  13  
  14  	"github.com/go-acme/lego/v4/challenge"
  15  	"github.com/go-acme/lego/v4/challenge/dns01"
  16  	"github.com/go-acme/lego/v4/platform/config/env"
  17  	"github.com/go-acme/lego/v4/providers/dns/cpanel/internal/cpanel"
  18  	"github.com/go-acme/lego/v4/providers/dns/cpanel/internal/shared"
  19  	"github.com/go-acme/lego/v4/providers/dns/cpanel/internal/whm"
  20  	"github.com/go-acme/lego/v4/providers/dns/internal/clientdebug"
  21  )
  22  
  23  // Environment variables names.
  24  const (
  25  	envNamespace = "CPANEL_"
  26  
  27  	EnvMode     = envNamespace + "MODE"
  28  	EnvUsername = envNamespace + "USERNAME"
  29  	EnvToken    = envNamespace + "TOKEN"
  30  	EnvBaseURL  = envNamespace + "BASE_URL"
  31  
  32  	EnvTTL                = envNamespace + "TTL"
  33  	EnvPropagationTimeout = envNamespace + "PROPAGATION_TIMEOUT"
  34  	EnvPollingInterval    = envNamespace + "POLLING_INTERVAL"
  35  	EnvHTTPTimeout        = envNamespace + "HTTP_TIMEOUT"
  36  )
  37  
  38  var _ challenge.ProviderTimeout = (*DNSProvider)(nil)
  39  
  40  type apiClient interface {
  41  	FetchZoneInformation(ctx context.Context, domain string) ([]shared.ZoneRecord, error)
  42  	AddRecord(ctx context.Context, serial uint32, domain string, record shared.Record) (*shared.ZoneSerial, error)
  43  	EditRecord(ctx context.Context, serial uint32, domain string, record shared.Record) (*shared.ZoneSerial, error)
  44  	DeleteRecord(ctx context.Context, serial uint32, domain string, lineIndex int) (*shared.ZoneSerial, error)
  45  }
  46  
  47  // Config is used to configure the creation of the DNSProvider.
  48  type Config struct {
  49  	Mode               string
  50  	Username           string
  51  	Token              string
  52  	BaseURL            string
  53  	TTL                int
  54  	PropagationTimeout time.Duration
  55  	PollingInterval    time.Duration
  56  	HTTPClient         *http.Client
  57  }
  58  
  59  // NewDefaultConfig returns a default configuration for the DNSProvider.
  60  func NewDefaultConfig() *Config {
  61  	return &Config{
  62  		Mode:               env.GetOrDefaultString(EnvMode, "cpanel"),
  63  		TTL:                env.GetOrDefaultInt(EnvTTL, 300),
  64  		PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 2*time.Minute),
  65  		PollingInterval:    env.GetOrDefaultSecond(EnvPollingInterval, dns01.DefaultPollingInterval),
  66  		HTTPClient: &http.Client{
  67  			Timeout: env.GetOrDefaultSecond(EnvHTTPTimeout, 30*time.Second),
  68  		},
  69  	}
  70  }
  71  
  72  // DNSProvider implements the challenge.Provider interface.
  73  type DNSProvider struct {
  74  	config *Config
  75  	client apiClient
  76  }
  77  
  78  // NewDNSProvider returns a DNSProvider instance configured for CPanel.
  79  // Credentials must be passed in the environment variables:
  80  // CPANEL_USERNAME, CPANEL_TOKEN, CPANEL_BASE_URL, CPANEL_NAMESERVER.
  81  func NewDNSProvider() (*DNSProvider, error) {
  82  	values, err := env.Get(EnvUsername, EnvToken, EnvBaseURL)
  83  	if err != nil {
  84  		return nil, fmt.Errorf("cpanel: %w", err)
  85  	}
  86  
  87  	config := NewDefaultConfig()
  88  	config.Username = values[EnvUsername]
  89  	config.Token = values[EnvToken]
  90  	config.BaseURL = values[EnvBaseURL]
  91  
  92  	return NewDNSProviderConfig(config)
  93  }
  94  
  95  // NewDNSProviderConfig return a DNSProvider instance configured for CPanel.
  96  func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
  97  	if config == nil {
  98  		return nil, errors.New("cpanel: the configuration of the DNS provider is nil")
  99  	}
 100  
 101  	if config.Username == "" || config.Token == "" {
 102  		return nil, errors.New("cpanel: some credentials information are missing")
 103  	}
 104  
 105  	if config.BaseURL == "" {
 106  		return nil, errors.New("cpanel: server information are missing")
 107  	}
 108  
 109  	client, err := createClient(config)
 110  	if err != nil {
 111  		return nil, fmt.Errorf("cpanel: create client error: %w", err)
 112  	}
 113  
 114  	return &DNSProvider{
 115  		config: config,
 116  		client: client,
 117  	}, nil
 118  }
 119  
 120  // Timeout returns the timeout and interval to use when checking for DNS propagation.
 121  // Adjusting here to cope with spikes in propagation times.
 122  func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
 123  	return d.config.PropagationTimeout, d.config.PollingInterval
 124  }
 125  
 126  // Present creates a TXT record to fulfill the dns-01 challenge.
 127  func (d *DNSProvider) Present(domain, _, keyAuth string) error {
 128  	ctx := context.Background()
 129  	info := dns01.GetChallengeInfo(domain, keyAuth)
 130  
 131  	authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
 132  	if err != nil {
 133  		return fmt.Errorf("arvancloud: could not find zone for domain %q: %w", domain, err)
 134  	}
 135  
 136  	zone := dns01.UnFqdn(authZone)
 137  
 138  	zoneInfo, err := d.client.FetchZoneInformation(ctx, zone)
 139  	if err != nil {
 140  		return fmt.Errorf("cpanel[mode=%s]: fetch zone information: %w", d.config.Mode, err)
 141  	}
 142  
 143  	serial, err := getZoneSerial(authZone, zoneInfo)
 144  	if err != nil {
 145  		return fmt.Errorf("cpanel[mode=%s]: get zone serial: %w", d.config.Mode, err)
 146  	}
 147  
 148  	valueB64 := base64.StdEncoding.EncodeToString([]byte(info.Value))
 149  
 150  	var (
 151  		found          bool
 152  		existingRecord shared.ZoneRecord
 153  	)
 154  
 155  	for _, record := range zoneInfo {
 156  		if slices.Contains(record.DataB64, valueB64) {
 157  			existingRecord = record
 158  			found = true
 159  
 160  			break
 161  		}
 162  	}
 163  
 164  	record := shared.Record{
 165  		DName:      info.EffectiveFQDN,
 166  		TTL:        d.config.TTL,
 167  		RecordType: "TXT",
 168  	}
 169  
 170  	// New record.
 171  	if !found {
 172  		record.Data = []string{info.Value}
 173  
 174  		_, err = d.client.AddRecord(ctx, serial, zone, record)
 175  		if err != nil {
 176  			return fmt.Errorf("cpanel[mode=%s]: add record: %w", d.config.Mode, err)
 177  		}
 178  
 179  		return nil
 180  	}
 181  
 182  	// Update existing record.
 183  	record.LineIndex = existingRecord.LineIndex
 184  
 185  	for _, dataB64 := range existingRecord.DataB64 {
 186  		data, errD := base64.StdEncoding.DecodeString(dataB64)
 187  		if errD != nil {
 188  			return fmt.Errorf("cpanel[mode=%s]: decode base64 record value: %w", d.config.Mode, errD)
 189  		}
 190  
 191  		record.Data = append(record.Data, string(data))
 192  	}
 193  
 194  	record.Data = append(record.Data, info.Value)
 195  
 196  	_, err = d.client.EditRecord(ctx, serial, zone, record)
 197  	if err != nil {
 198  		return fmt.Errorf("cpanel[mode=%s]: edit record: %w", d.config.Mode, err)
 199  	}
 200  
 201  	return nil
 202  }
 203  
 204  // CleanUp removes the TXT record matching the specified parameters.
 205  func (d *DNSProvider) CleanUp(domain, _, keyAuth string) error {
 206  	ctx := context.Background()
 207  	info := dns01.GetChallengeInfo(domain, keyAuth)
 208  
 209  	authZone, err := dns01.FindZoneByFqdn(info.EffectiveFQDN)
 210  	if err != nil {
 211  		return fmt.Errorf("arvancloud: could not find zone for domain %q: %w", domain, err)
 212  	}
 213  
 214  	zone := dns01.UnFqdn(authZone)
 215  
 216  	zoneInfo, err := d.client.FetchZoneInformation(ctx, zone)
 217  	if err != nil {
 218  		return fmt.Errorf("cpanel[mode=%s]: fetch zone information: %w", d.config.Mode, err)
 219  	}
 220  
 221  	serial, err := getZoneSerial(authZone, zoneInfo)
 222  	if err != nil {
 223  		return fmt.Errorf("cpanel[mode=%s]: get zone serial: %w", d.config.Mode, err)
 224  	}
 225  
 226  	valueB64 := base64.StdEncoding.EncodeToString([]byte(info.Value))
 227  
 228  	var (
 229  		found          bool
 230  		existingRecord shared.ZoneRecord
 231  	)
 232  
 233  	for _, record := range zoneInfo {
 234  		if slices.Contains(record.DataB64, valueB64) {
 235  			existingRecord = record
 236  			found = true
 237  
 238  			break
 239  		}
 240  	}
 241  
 242  	if !found {
 243  		return nil
 244  	}
 245  
 246  	var newData []string
 247  
 248  	for _, dataB64 := range existingRecord.DataB64 {
 249  		if dataB64 == valueB64 {
 250  			continue
 251  		}
 252  
 253  		data, errD := base64.StdEncoding.DecodeString(dataB64)
 254  		if errD != nil {
 255  			return fmt.Errorf("cpanel[mode=%s]: decode base64 record value: %w", d.config.Mode, errD)
 256  		}
 257  
 258  		newData = append(newData, string(data))
 259  	}
 260  
 261  	// Delete record.
 262  	if len(newData) == 0 {
 263  		_, err = d.client.DeleteRecord(ctx, serial, zone, existingRecord.LineIndex)
 264  		if err != nil {
 265  			return fmt.Errorf("cpanel[mode=%s]: delete record: %w", d.config.Mode, err)
 266  		}
 267  
 268  		return nil
 269  	}
 270  
 271  	// Remove one value.
 272  	record := shared.Record{
 273  		DName:      info.EffectiveFQDN,
 274  		TTL:        d.config.TTL,
 275  		RecordType: "TXT",
 276  		Data:       newData,
 277  		LineIndex:  existingRecord.LineIndex,
 278  	}
 279  
 280  	_, err = d.client.EditRecord(ctx, serial, zone, record)
 281  	if err != nil {
 282  		return fmt.Errorf("cpanel[mode=%s]: edit record: %w", d.config.Mode, err)
 283  	}
 284  
 285  	return nil
 286  }
 287  
 288  func getZoneSerial(zoneFqdn string, zoneInfo []shared.ZoneRecord) (uint32, error) {
 289  	nameB64 := base64.StdEncoding.EncodeToString([]byte(zoneFqdn))
 290  
 291  	for _, record := range zoneInfo {
 292  		if record.Type != "record" || record.RecordType != "SOA" || record.DNameB64 != nameB64 {
 293  			continue
 294  		}
 295  
 296  		// https://github.com/go-acme/lego/issues/1060#issuecomment-1925572386
 297  		// https://github.com/go-acme/lego/issues/1060#issuecomment-1925581832
 298  		data, err := base64.StdEncoding.DecodeString(record.DataB64[2])
 299  		if err != nil {
 300  			return 0, fmt.Errorf("decode serial DNameB64: %w", err)
 301  		}
 302  
 303  		var newSerial uint32
 304  
 305  		_, err = fmt.Sscan(string(data), &newSerial)
 306  		if err != nil {
 307  			return 0, fmt.Errorf("decode serial DNameB64, invalid serial value %q: %w", string(data), err)
 308  		}
 309  
 310  		return newSerial, nil
 311  	}
 312  
 313  	return 0, errors.New("zone serial not found")
 314  }
 315  
 316  func createClient(config *Config) (apiClient, error) {
 317  	switch strings.ToLower(config.Mode) {
 318  	case "cpanel":
 319  		client, err := cpanel.NewClient(config.BaseURL, config.Username, config.Token)
 320  		if err != nil {
 321  			return nil, fmt.Errorf("failed to create cPanel API client: %w", err)
 322  		}
 323  
 324  		if config.HTTPClient != nil {
 325  			client.HTTPClient = config.HTTPClient
 326  		}
 327  
 328  		client.HTTPClient = clientdebug.Wrap(client.HTTPClient)
 329  
 330  		return client, nil
 331  
 332  	case "whm":
 333  		client, err := whm.NewClient(config.BaseURL, config.Username, config.Token)
 334  		if err != nil {
 335  			return nil, fmt.Errorf("failed to create WHM API client: %w", err)
 336  		}
 337  
 338  		if config.HTTPClient != nil {
 339  			client.HTTPClient = config.HTTPClient
 340  		}
 341  
 342  		client.HTTPClient = clientdebug.Wrap(client.HTTPClient)
 343  
 344  		return client, nil
 345  
 346  	default:
 347  		return nil, fmt.Errorf("unsupported mode: %q", config.Mode)
 348  	}
 349  }
 350