route53.go raw

   1  // Package route53 implements a DNS provider for solving the DNS-01 challenge using AWS Route 53 DNS.
   2  package route53
   3  
   4  import (
   5  	"context"
   6  	"errors"
   7  	"fmt"
   8  	"math/rand"
   9  	"strings"
  10  	"time"
  11  
  12  	"github.com/aws/aws-sdk-go-v2/aws"
  13  	"github.com/aws/aws-sdk-go-v2/aws/retry"
  14  	awsconfig "github.com/aws/aws-sdk-go-v2/config"
  15  	"github.com/aws/aws-sdk-go-v2/credentials"
  16  	"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
  17  	"github.com/aws/aws-sdk-go-v2/service/route53"
  18  	awstypes "github.com/aws/aws-sdk-go-v2/service/route53/types"
  19  	"github.com/aws/aws-sdk-go-v2/service/sts"
  20  	"github.com/cenkalti/backoff/v5"
  21  	"github.com/go-acme/lego/v4/challenge"
  22  	"github.com/go-acme/lego/v4/challenge/dns01"
  23  	"github.com/go-acme/lego/v4/platform/config/env"
  24  	"github.com/go-acme/lego/v4/platform/wait"
  25  	"github.com/go-acme/lego/v4/providers/dns/internal/ptr"
  26  )
  27  
  28  // Environment variables names.
  29  const (
  30  	envNamespace = "AWS_"
  31  
  32  	EnvAccessKeyID     = envNamespace + "ACCESS_KEY_ID"
  33  	EnvSecretAccessKey = envNamespace + "SECRET_ACCESS_KEY"
  34  	EnvRegion          = envNamespace + "REGION"
  35  	EnvHostedZoneID    = envNamespace + "HOSTED_ZONE_ID"
  36  	EnvMaxRetries      = envNamespace + "MAX_RETRIES"
  37  	EnvAssumeRoleArn   = envNamespace + "ASSUME_ROLE_ARN"
  38  	EnvExternalID      = envNamespace + "EXTERNAL_ID"
  39  	EnvPrivateZone     = envNamespace + "PRIVATE_ZONE"
  40  
  41  	EnvWaitForRecordSetsChanged = envNamespace + "WAIT_FOR_RECORD_SETS_CHANGED"
  42  
  43  	EnvTTL                = envNamespace + "TTL"
  44  	EnvPropagationTimeout = envNamespace + "PROPAGATION_TIMEOUT"
  45  	EnvPollingInterval    = envNamespace + "POLLING_INTERVAL"
  46  )
  47  
  48  var _ challenge.ProviderTimeout = (*DNSProvider)(nil)
  49  
  50  // Config is used to configure the creation of the DNSProvider.
  51  type Config struct {
  52  	// Static credential chain.
  53  	// These are not set via environment for the time being and are only used if they are explicitly provided.
  54  	AccessKeyID     string
  55  	SecretAccessKey string
  56  	SessionToken    string
  57  	Region          string
  58  
  59  	HostedZoneID  string
  60  	MaxRetries    int
  61  	AssumeRoleArn string
  62  	ExternalID    string
  63  	PrivateZone   bool
  64  
  65  	WaitForRecordSetsChanged bool
  66  
  67  	TTL                int
  68  	PropagationTimeout time.Duration
  69  	PollingInterval    time.Duration
  70  
  71  	Client *route53.Client
  72  }
  73  
  74  // NewDefaultConfig returns a default configuration for the DNSProvider.
  75  func NewDefaultConfig() *Config {
  76  	return &Config{
  77  		HostedZoneID:  env.GetOrFile(EnvHostedZoneID),
  78  		MaxRetries:    env.GetOrDefaultInt(EnvMaxRetries, 5),
  79  		AssumeRoleArn: env.GetOrDefaultString(EnvAssumeRoleArn, ""),
  80  		ExternalID:    env.GetOrDefaultString(EnvExternalID, ""),
  81  		PrivateZone:   env.GetOrDefaultBool(EnvPrivateZone, false),
  82  
  83  		WaitForRecordSetsChanged: env.GetOrDefaultBool(EnvWaitForRecordSetsChanged, true),
  84  
  85  		TTL:                env.GetOrDefaultInt(EnvTTL, 10),
  86  		PropagationTimeout: env.GetOrDefaultSecond(EnvPropagationTimeout, 2*time.Minute),
  87  		PollingInterval:    env.GetOrDefaultSecond(EnvPollingInterval, 4*time.Second),
  88  	}
  89  }
  90  
  91  // DNSProvider implements the challenge.Provider interface.
  92  type DNSProvider struct {
  93  	client *route53.Client
  94  	config *Config
  95  }
  96  
  97  // NewDNSProvider returns a DNSProvider instance configured for the AWS Route 53 service.
  98  //
  99  // AWS Credentials are automatically detected in the following locations and prioritized in the following order:
 100  //  1. Environment variables: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY,
 101  //     AWS_REGION, [AWS_SESSION_TOKEN]
 102  //  2. Shared credentials file (defaults to ~/.aws/credentials)
 103  //  3. Amazon EC2 IAM role
 104  //
 105  // If AWS_HOSTED_ZONE_ID is not set, Lego tries to determine the correct public hosted zone via the FQDN.
 106  //
 107  // See also: https://github.com/aws/aws-sdk-go/wiki/configuring-sdk
 108  func NewDNSProvider() (*DNSProvider, error) {
 109  	return NewDNSProviderConfig(NewDefaultConfig())
 110  }
 111  
 112  // NewDNSProviderConfig takes a given config and returns a custom configured DNSProvider instance.
 113  func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {
 114  	if config == nil {
 115  		return nil, errors.New("route53: the configuration of the Route53 DNS provider is nil")
 116  	}
 117  
 118  	if config.Client != nil {
 119  		return &DNSProvider{client: config.Client, config: config}, nil
 120  	}
 121  
 122  	ctx := context.Background()
 123  
 124  	cfg, err := createAWSConfig(ctx, config)
 125  	if err != nil {
 126  		return nil, err
 127  	}
 128  
 129  	return &DNSProvider{
 130  		client: route53.NewFromConfig(cfg),
 131  		config: config,
 132  	}, nil
 133  }
 134  
 135  // Timeout returns the timeout and interval to use when checking for DNS propagation.
 136  func (d *DNSProvider) Timeout() (timeout, interval time.Duration) {
 137  	return d.config.PropagationTimeout, d.config.PollingInterval
 138  }
 139  
 140  // Present creates a TXT record using the specified parameters.
 141  func (d *DNSProvider) Present(domain, token, keyAuth string) error {
 142  	ctx := context.Background()
 143  	info := dns01.GetChallengeInfo(domain, keyAuth)
 144  
 145  	hostedZoneID, err := d.getHostedZoneID(ctx, info.EffectiveFQDN)
 146  	if err != nil {
 147  		return fmt.Errorf("route53: failed to determine hosted zone ID: %w", err)
 148  	}
 149  
 150  	records, err := d.getExistingRecordSets(ctx, hostedZoneID, info.EffectiveFQDN)
 151  	if err != nil {
 152  		return fmt.Errorf("route53: %w", err)
 153  	}
 154  
 155  	realValue := `"` + info.Value + `"`
 156  
 157  	var found bool
 158  
 159  	for _, record := range records {
 160  		if ptr.Deref(record.Value) == realValue {
 161  			found = true
 162  		}
 163  	}
 164  
 165  	if !found {
 166  		records = append(records, awstypes.ResourceRecord{Value: aws.String(realValue)})
 167  	}
 168  
 169  	recordSet := &awstypes.ResourceRecordSet{
 170  		Name:            aws.String(info.EffectiveFQDN),
 171  		Type:            "TXT",
 172  		TTL:             aws.Int64(int64(d.config.TTL)),
 173  		ResourceRecords: records,
 174  	}
 175  
 176  	err = d.changeRecord(ctx, awstypes.ChangeActionUpsert, hostedZoneID, recordSet)
 177  	if err != nil {
 178  		return fmt.Errorf("route53: %w", err)
 179  	}
 180  
 181  	return nil
 182  }
 183  
 184  // CleanUp removes the TXT record matching the specified parameters.
 185  func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
 186  	ctx := context.Background()
 187  	info := dns01.GetChallengeInfo(domain, keyAuth)
 188  
 189  	hostedZoneID, err := d.getHostedZoneID(ctx, info.EffectiveFQDN)
 190  	if err != nil {
 191  		return fmt.Errorf("failed to determine Route 53 hosted zone ID: %w", err)
 192  	}
 193  
 194  	existingRecords, err := d.getExistingRecordSets(ctx, hostedZoneID, info.EffectiveFQDN)
 195  	if err != nil {
 196  		return fmt.Errorf("route53: %w", err)
 197  	}
 198  
 199  	if len(existingRecords) == 0 {
 200  		return nil
 201  	}
 202  
 203  	var nonLegoRecords []awstypes.ResourceRecord
 204  
 205  	for _, record := range existingRecords {
 206  		if ptr.Deref(record.Value) != `"`+info.Value+`"` {
 207  			nonLegoRecords = append(nonLegoRecords, record)
 208  		}
 209  	}
 210  
 211  	action := awstypes.ChangeActionUpsert
 212  
 213  	recordSet := &awstypes.ResourceRecordSet{
 214  		Name:            aws.String(info.EffectiveFQDN),
 215  		Type:            "TXT",
 216  		TTL:             aws.Int64(int64(d.config.TTL)),
 217  		ResourceRecords: nonLegoRecords,
 218  	}
 219  
 220  	// If the records are only records created by lego.
 221  	if len(nonLegoRecords) == 0 {
 222  		action = awstypes.ChangeActionDelete
 223  
 224  		recordSet.ResourceRecords = existingRecords
 225  	}
 226  
 227  	err = d.changeRecord(ctx, action, hostedZoneID, recordSet)
 228  	if err != nil {
 229  		return fmt.Errorf("route53: %w", err)
 230  	}
 231  
 232  	return nil
 233  }
 234  
 235  func (d *DNSProvider) changeRecord(ctx context.Context, action awstypes.ChangeAction, hostedZoneID string, recordSet *awstypes.ResourceRecordSet) error {
 236  	recordSetInput := &route53.ChangeResourceRecordSetsInput{
 237  		HostedZoneId: aws.String(hostedZoneID),
 238  		ChangeBatch: &awstypes.ChangeBatch{
 239  			Comment: aws.String("Managed by Lego"),
 240  			Changes: []awstypes.Change{{
 241  				Action:            action,
 242  				ResourceRecordSet: recordSet,
 243  			}},
 244  		},
 245  	}
 246  
 247  	resp, err := d.client.ChangeResourceRecordSets(ctx, recordSetInput)
 248  	if err != nil {
 249  		return fmt.Errorf("failed to change record set: %w", err)
 250  	}
 251  
 252  	changeID := resp.ChangeInfo.Id
 253  
 254  	if d.config.WaitForRecordSetsChanged {
 255  		return wait.Retry(ctx,
 256  			func() error {
 257  				resp, err := d.client.GetChange(ctx, &route53.GetChangeInput{Id: changeID})
 258  				if err != nil {
 259  					return fmt.Errorf("failed to query change status: %w", err)
 260  				}
 261  
 262  				if resp.ChangeInfo.Status != awstypes.ChangeStatusInsync {
 263  					return fmt.Errorf("unable to retrieve change: ID=%s, status=%s", ptr.Deref(changeID), resp.ChangeInfo.Status)
 264  				}
 265  
 266  				return nil
 267  			},
 268  			backoff.WithBackOff(backoff.NewConstantBackOff(d.config.PollingInterval)),
 269  			backoff.WithMaxElapsedTime(d.config.PropagationTimeout),
 270  		)
 271  	}
 272  
 273  	return nil
 274  }
 275  
 276  func (d *DNSProvider) getExistingRecordSets(ctx context.Context, hostedZoneID, fqdn string) ([]awstypes.ResourceRecord, error) {
 277  	listInput := &route53.ListResourceRecordSetsInput{
 278  		HostedZoneId:    aws.String(hostedZoneID),
 279  		StartRecordName: aws.String(fqdn),
 280  		StartRecordType: "TXT",
 281  	}
 282  
 283  	recordSetsOutput, err := d.client.ListResourceRecordSets(ctx, listInput)
 284  	if err != nil {
 285  		return nil, err
 286  	}
 287  
 288  	if recordSetsOutput == nil {
 289  		return nil, nil
 290  	}
 291  
 292  	var records []awstypes.ResourceRecord
 293  
 294  	for _, recordSet := range recordSetsOutput.ResourceRecordSets {
 295  		if ptr.Deref(recordSet.Name) == fqdn {
 296  			records = append(records, recordSet.ResourceRecords...)
 297  		}
 298  	}
 299  
 300  	return records, nil
 301  }
 302  
 303  func (d *DNSProvider) getHostedZoneID(ctx context.Context, fqdn string) (string, error) {
 304  	if d.config.HostedZoneID != "" {
 305  		return d.config.HostedZoneID, nil
 306  	}
 307  
 308  	authZone, err := dns01.FindZoneByFqdn(fqdn)
 309  	if err != nil {
 310  		return "", fmt.Errorf("could not find zone for FQDN %q: %w", fqdn, err)
 311  	}
 312  
 313  	// .DNSName should not have a trailing dot
 314  	reqParams := &route53.ListHostedZonesByNameInput{
 315  		DNSName: aws.String(dns01.UnFqdn(authZone)),
 316  	}
 317  
 318  	resp, err := d.client.ListHostedZonesByName(ctx, reqParams)
 319  	if err != nil {
 320  		return "", err
 321  	}
 322  
 323  	var hostedZoneID string
 324  
 325  	for _, hostedZone := range resp.HostedZones {
 326  		// .Name has a trailing dot
 327  		if ptr.Deref(hostedZone.Name) == authZone && d.config.PrivateZone == hostedZone.Config.PrivateZone {
 328  			hostedZoneID = ptr.Deref(hostedZone.Id)
 329  			break
 330  		}
 331  	}
 332  
 333  	if hostedZoneID == "" {
 334  		return "", fmt.Errorf("zone %s not found for domain %s", authZone, fqdn)
 335  	}
 336  
 337  	hostedZoneID = strings.TrimPrefix(hostedZoneID, "/hostedzone/")
 338  
 339  	return hostedZoneID, nil
 340  }
 341  
 342  func createAWSConfig(ctx context.Context, config *Config) (aws.Config, error) {
 343  	if err := createAWSConfigCheckParams(config); err != nil {
 344  		return aws.Config{}, err
 345  	}
 346  
 347  	optFns := []func(options *awsconfig.LoadOptions) error{
 348  		awsconfig.WithRetryer(func() aws.Retryer {
 349  			return retry.NewStandard(func(options *retry.StandardOptions) {
 350  				options.MaxAttempts = config.MaxRetries
 351  
 352  				// It uses a basic exponential backoff algorithm that returns an initial
 353  				// delay of ~400ms with an upper limit of ~30 seconds which should prevent
 354  				// causing a high number of consecutive throttling errors.
 355  				// For reference: Route 53 enforces an account-wide(!) 5req/s query limit.
 356  				options.Backoff = retry.BackoffDelayerFunc(func(attempt int, err error) (time.Duration, error) {
 357  					retryCount := min(attempt, 7)
 358  
 359  					delay := (1 << uint(retryCount)) * (rand.Intn(50) + 200)
 360  
 361  					return time.Duration(delay) * time.Millisecond, nil
 362  				})
 363  			})
 364  		}),
 365  	}
 366  
 367  	if config.AccessKeyID != "" && config.SecretAccessKey != "" {
 368  		optFns = append(optFns,
 369  			awsconfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(config.AccessKeyID, config.SecretAccessKey, config.SessionToken)),
 370  		)
 371  	}
 372  
 373  	if config.Region != "" {
 374  		optFns = append(optFns, awsconfig.WithRegion(config.Region))
 375  	}
 376  
 377  	cfg, err := awsconfig.LoadDefaultConfig(ctx, optFns...)
 378  	if err != nil {
 379  		return aws.Config{}, err
 380  	}
 381  
 382  	if config.AssumeRoleArn != "" {
 383  		cfg.Credentials = stscreds.NewAssumeRoleProvider(sts.NewFromConfig(cfg), config.AssumeRoleArn, func(options *stscreds.AssumeRoleOptions) {
 384  			if config.ExternalID != "" {
 385  				options.ExternalID = &config.ExternalID
 386  			}
 387  		})
 388  	}
 389  
 390  	return cfg, nil
 391  }
 392  
 393  func createAWSConfigCheckParams(config *Config) error {
 394  	if config == nil {
 395  		return errors.New("config is nil")
 396  	}
 397  
 398  	switch {
 399  	case config.SessionToken != "" && config.AccessKeyID == "" && config.SecretAccessKey == "":
 400  		return errors.New("SessionToken must be supplied with AccessKeyID and SecretAccessKey")
 401  
 402  	case config.AccessKeyID == "" && config.SecretAccessKey != "" || config.AccessKeyID != "" && config.SecretAccessKey == "":
 403  		return errors.New("AccessKeyID and SecretAccessKey must be supplied together")
 404  	}
 405  
 406  	return nil
 407  }
 408