validation.go raw

   1  package find
   2  
   3  import (
   4  	"errors"
   5  	"fmt"
   6  	"regexp"
   7  	"strings"
   8  )
   9  
  10  var (
  11  	ErrInvalidName         = errors.New("invalid name format")
  12  	ErrNameTooLong         = errors.New("name exceeds 253 characters")
  13  	ErrLabelTooLong        = errors.New("label exceeds 63 characters")
  14  	ErrLabelEmpty          = errors.New("label is empty")
  15  	ErrInvalidCharacter    = errors.New("invalid character in name")
  16  	ErrInvalidHyphen       = errors.New("label cannot start or end with hyphen")
  17  	ErrAllNumericLabel     = errors.New("label cannot be all numeric")
  18  	ErrInvalidRecordValue  = errors.New("invalid record value")
  19  	ErrRecordLimitExceeded = errors.New("record limit exceeded")
  20  	ErrNotOwner            = errors.New("not the name owner")
  21  	ErrNameExpired         = errors.New("name registration expired")
  22  	ErrInRenewalWindow     = errors.New("name is in renewal window")
  23  	ErrNotRenewalWindow    = errors.New("not in renewal window")
  24  )
  25  
  26  // Name format validation regex
  27  var (
  28  	labelRegex = regexp.MustCompile(`^[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?$`)
  29  	allNumeric = regexp.MustCompile(`^[0-9]+$`)
  30  )
  31  
  32  // NormalizeName converts a name to lowercase
  33  func NormalizeName(name string) string {
  34  	return strings.ToLower(name)
  35  }
  36  
  37  // ValidateName validates a name according to DNS naming rules
  38  func ValidateName(name string) error {
  39  	// Normalize to lowercase
  40  	name = NormalizeName(name)
  41  
  42  	// Check total length
  43  	if len(name) > 253 {
  44  		return fmt.Errorf("%w: %d > 253", ErrNameTooLong, len(name))
  45  	}
  46  
  47  	if len(name) == 0 {
  48  		return fmt.Errorf("%w: name is empty", ErrInvalidName)
  49  	}
  50  
  51  	// Split into labels
  52  	labels := strings.Split(name, ".")
  53  
  54  	for i, label := range labels {
  55  		if err := validateLabel(label); err != nil {
  56  			return fmt.Errorf("invalid label %d (%s): %w", i, label, err)
  57  		}
  58  	}
  59  
  60  	return nil
  61  }
  62  
  63  // validateLabel validates a single label according to DNS rules
  64  func validateLabel(label string) error {
  65  	// Check length
  66  	if len(label) == 0 {
  67  		return ErrLabelEmpty
  68  	}
  69  	if len(label) > 63 {
  70  		return fmt.Errorf("%w: %d > 63", ErrLabelTooLong, len(label))
  71  	}
  72  
  73  	// Check character set and hyphen placement
  74  	if !labelRegex.MatchString(label) {
  75  		if strings.HasPrefix(label, "-") || strings.HasSuffix(label, "-") {
  76  			return ErrInvalidHyphen
  77  		}
  78  		return ErrInvalidCharacter
  79  	}
  80  
  81  	// Check not all numeric
  82  	if allNumeric.MatchString(label) {
  83  		return ErrAllNumericLabel
  84  	}
  85  
  86  	return nil
  87  }
  88  
  89  // GetParentDomain returns the parent domain of a name
  90  // e.g., "www.example.com" -> "example.com", "example.com" -> "com", "com" -> ""
  91  func GetParentDomain(name string) string {
  92  	name = NormalizeName(name)
  93  	parts := strings.Split(name, ".")
  94  	if len(parts) <= 1 {
  95  		return "" // TLD has no parent
  96  	}
  97  	return strings.Join(parts[1:], ".")
  98  }
  99  
 100  // IsTLD returns true if the name is a top-level domain (single label)
 101  func IsTLD(name string) bool {
 102  	name = NormalizeName(name)
 103  	return !strings.Contains(name, ".")
 104  }
 105  
 106  // ValidateIPv4 validates an IPv4 address format
 107  func ValidateIPv4(ip string) error {
 108  	parts := strings.Split(ip, ".")
 109  	if len(parts) != 4 {
 110  		return fmt.Errorf("%w: invalid IPv4 format", ErrInvalidRecordValue)
 111  	}
 112  
 113  	for _, part := range parts {
 114  		var octet int
 115  		if _, err := fmt.Sscanf(part, "%d", &octet); err != nil {
 116  			return fmt.Errorf("%w: invalid IPv4 octet: %v", ErrInvalidRecordValue, err)
 117  		}
 118  		if octet < 0 || octet > 255 {
 119  			return fmt.Errorf("%w: IPv4 octet out of range: %d", ErrInvalidRecordValue, octet)
 120  		}
 121  	}
 122  
 123  	return nil
 124  }
 125  
 126  // ValidateIPv6 validates an IPv6 address format (simplified check)
 127  func ValidateIPv6(ip string) error {
 128  	// Basic validation - contains colons and valid hex characters
 129  	if !strings.Contains(ip, ":") {
 130  		return fmt.Errorf("%w: invalid IPv6 format", ErrInvalidRecordValue)
 131  	}
 132  
 133  	// Split by colons
 134  	parts := strings.Split(ip, ":")
 135  	if len(parts) < 3 || len(parts) > 8 {
 136  		return fmt.Errorf("%w: invalid IPv6 segment count", ErrInvalidRecordValue)
 137  	}
 138  
 139  	// Check for valid hex characters
 140  	validHex := regexp.MustCompile(`^[0-9a-fA-F]*$`)
 141  	for _, part := range parts {
 142  		if part == "" {
 143  			continue // Allow :: notation
 144  		}
 145  		if len(part) > 4 {
 146  			return fmt.Errorf("%w: IPv6 segment too long", ErrInvalidRecordValue)
 147  		}
 148  		if !validHex.MatchString(part) {
 149  			return fmt.Errorf("%w: invalid IPv6 hex", ErrInvalidRecordValue)
 150  		}
 151  	}
 152  
 153  	return nil
 154  }
 155  
 156  // ValidateRecordValue validates a record value based on its type
 157  func ValidateRecordValue(recordType, value string) error {
 158  	switch recordType {
 159  	case RecordTypeA:
 160  		return ValidateIPv4(value)
 161  	case RecordTypeAAAA:
 162  		return ValidateIPv6(value)
 163  	case RecordTypeCNAME, RecordTypeMX, RecordTypeNS:
 164  		return ValidateName(value)
 165  	case RecordTypeTXT:
 166  		if len(value) > 1024 {
 167  			return fmt.Errorf("%w: TXT record exceeds 1024 characters", ErrInvalidRecordValue)
 168  		}
 169  		return nil
 170  	case RecordTypeSRV:
 171  		return ValidateName(value) // Hostname for SRV
 172  	default:
 173  		return fmt.Errorf("%w: unknown record type: %s", ErrInvalidRecordValue, recordType)
 174  	}
 175  }
 176  
 177  // ValidateRecordLimit checks if adding a record would exceed type limits
 178  func ValidateRecordLimit(recordType string, currentCount int) error {
 179  	limit, ok := RecordLimits[recordType]
 180  	if !ok {
 181  		return fmt.Errorf("%w: unknown record type: %s", ErrInvalidRecordValue, recordType)
 182  	}
 183  
 184  	if currentCount >= limit {
 185  		return fmt.Errorf("%w: %s records limited to %d", ErrRecordLimitExceeded, recordType, limit)
 186  	}
 187  
 188  	return nil
 189  }
 190  
 191  // ValidatePriority validates priority value (0-65535)
 192  func ValidatePriority(priority int) error {
 193  	if priority < 0 || priority > 65535 {
 194  		return fmt.Errorf("%w: priority must be 0-65535", ErrInvalidRecordValue)
 195  	}
 196  	return nil
 197  }
 198  
 199  // ValidateWeight validates weight value (0-65535)
 200  func ValidateWeight(weight int) error {
 201  	if weight < 0 || weight > 65535 {
 202  		return fmt.Errorf("%w: weight must be 0-65535", ErrInvalidRecordValue)
 203  	}
 204  	return nil
 205  }
 206  
 207  // ValidatePort validates port value (0-65535)
 208  func ValidatePort(port int) error {
 209  	if port < 0 || port > 65535 {
 210  		return fmt.Errorf("%w: port must be 0-65535", ErrInvalidRecordValue)
 211  	}
 212  	return nil
 213  }
 214  
 215  // ValidateTrustScore validates trust score (0.0-1.0)
 216  func ValidateTrustScore(score float64) error {
 217  	if score < 0.0 || score > 1.0 {
 218  		return fmt.Errorf("trust score must be between 0.0 and 1.0, got %f", score)
 219  	}
 220  	return nil
 221  }
 222