tls_config_provider.go raw

   1  // Copyright (c) 2016, 2018, 2025, Oracle and/or its affiliates.  All rights reserved.
   2  // This software is dual-licensed to you under the Universal Permissive License (UPL) 1.0 as shown at https://oss.oracle.com/licenses/upl or Apache License 2.0 as shown at http://www.apache.org/licenses/LICENSE-2.0. You may choose either license.
   3  
   4  package common
   5  
   6  import (
   7  	"crypto/tls"
   8  	"crypto/x509"
   9  	"fmt"
  10  	"os"
  11  	"sync"
  12  )
  13  
  14  // GetTLSConfigTemplateForTransport returns the TLSConfigTemplate to used depending on whether any additional
  15  // CA Bundle or client side certs have been configured
  16  func GetTLSConfigTemplateForTransport() TLSConfigProvider {
  17  	certPath := os.Getenv(ociDefaultClientCertsPath)
  18  	keyPath := os.Getenv(ociDefaultClientCertsPrivateKeyPath)
  19  	caBundlePath := os.Getenv(ociDefaultCertsPath)
  20  	if certPath != "" && keyPath != "" {
  21  		return &DefaultMTLSConfigProvider{
  22  			caBundlePath:         caBundlePath,
  23  			clientCertPath:       certPath,
  24  			clientKeyPath:        keyPath,
  25  			watchedFilesStatsMap: make(map[string]os.FileInfo),
  26  		}
  27  	}
  28  	return &DefaultTLSConfigProvider{
  29  		caBundlePath: caBundlePath,
  30  	}
  31  }
  32  
  33  // TLSConfigProvider is an interface the defines a function that creates a new *tls.Config.
  34  type TLSConfigProvider interface {
  35  	NewOrDefault() (*tls.Config, error)
  36  	WatchedFilesModified() bool
  37  }
  38  
  39  // DefaultTLSConfigProvider is a provider that provides a TLS tls.config for the HTTPTransport
  40  type DefaultTLSConfigProvider struct {
  41  	caBundlePath string
  42  	mux          sync.Mutex
  43  	currentStat  os.FileInfo
  44  }
  45  
  46  // NewOrDefault returns a default tls.Config which
  47  // sets its RootCAs to be a *x509.CertPool from caBundlePath.
  48  func (t *DefaultTLSConfigProvider) NewOrDefault() (*tls.Config, error) {
  49  	if t.caBundlePath == "" {
  50  		return &tls.Config{}, nil
  51  	}
  52  
  53  	// Keep the current Stat info from the ca bundle in a map
  54  	Debugf("Getting Initial Stats for file: %s", t.caBundlePath)
  55  	caBundleStat, err := os.Stat(t.caBundlePath)
  56  	if err != nil {
  57  		return nil, err
  58  	}
  59  	t.mux.Lock()
  60  	defer t.mux.Unlock()
  61  	t.currentStat = caBundleStat
  62  
  63  	rootCAs, err := CertPoolFrom(t.caBundlePath)
  64  	if err != nil {
  65  		return nil, err
  66  	}
  67  	return &tls.Config{
  68  		RootCAs: rootCAs,
  69  	}, nil
  70  }
  71  
  72  // WatchedFilesModified returns true if any files in the watchedFilesStatsMap has been modified else returns false
  73  func (t *DefaultTLSConfigProvider) WatchedFilesModified() bool {
  74  	modified := false
  75  	if t.caBundlePath != "" {
  76  		newStat, err := os.Stat(t.caBundlePath)
  77  		if err == nil && (t.currentStat.Size() != newStat.Size() || t.currentStat.ModTime() != newStat.ModTime()) {
  78  			Logf("Modification detected in cert/ca-bundle file: %s", t.caBundlePath)
  79  			modified = true
  80  			t.mux.Lock()
  81  			defer t.mux.Unlock()
  82  			t.currentStat = newStat
  83  		}
  84  	}
  85  	return modified
  86  }
  87  
  88  // DefaultMTLSConfigProvider is a provider that provides a MTLS tls.config for the HTTPTransport
  89  type DefaultMTLSConfigProvider struct {
  90  	caBundlePath         string
  91  	clientCertPath       string
  92  	clientKeyPath        string
  93  	mux                  sync.Mutex
  94  	watchedFilesStatsMap map[string]os.FileInfo
  95  }
  96  
  97  // NewOrDefault returns a default tls.Config which sets its RootCAs
  98  // to be a *x509.CertPool from caBundlePath and calls
  99  // tls.LoadX509KeyPair(clientCertPath, clientKeyPath) to set mtls client certs.
 100  func (t *DefaultMTLSConfigProvider) NewOrDefault() (*tls.Config, error) {
 101  	rootCAs, err := CertPoolFrom(t.caBundlePath)
 102  	if err != nil {
 103  		return nil, err
 104  	}
 105  	cert, err := tls.LoadX509KeyPair(t.clientCertPath, t.clientKeyPath)
 106  	if err != nil {
 107  		return nil, err
 108  	}
 109  
 110  	// Configure the initial certs file stats, error skipped because we error out before this if the files don't exist
 111  	t.mux.Lock()
 112  	defer t.mux.Unlock()
 113  	t.watchedFilesStatsMap[t.caBundlePath], _ = os.Stat(t.caBundlePath)
 114  	t.watchedFilesStatsMap[t.clientCertPath], _ = os.Stat(t.clientCertPath)
 115  	t.watchedFilesStatsMap[t.clientKeyPath], _ = os.Stat(t.clientKeyPath)
 116  
 117  	return &tls.Config{
 118  		RootCAs:      rootCAs,
 119  		Certificates: []tls.Certificate{cert},
 120  	}, nil
 121  }
 122  
 123  // WatchedFilesModified returns true if any files in the watchedFilesStatsMap has been modified else returns false
 124  func (t *DefaultMTLSConfigProvider) WatchedFilesModified() bool {
 125  	modified := false
 126  
 127  	t.mux.Lock()
 128  	defer t.mux.Unlock()
 129  	for k, v := range t.watchedFilesStatsMap {
 130  		if k != "" {
 131  			currentStat, err := os.Stat(k)
 132  			if err == nil && (v.Size() != currentStat.Size() || v.ModTime() != currentStat.ModTime()) {
 133  				modified = true
 134  				Logf("Modification detected in cert/ca-bundle file: %s", k)
 135  				t.watchedFilesStatsMap[k] = currentStat
 136  			}
 137  		}
 138  	}
 139  
 140  	return modified
 141  }
 142  
 143  // CertPoolFrom creates a new x509.CertPool from a given file.
 144  func CertPoolFrom(caBundleFile string) (*x509.CertPool, error) {
 145  	pemCerts, err := os.ReadFile(caBundleFile)
 146  	if err != nil {
 147  		return nil, err
 148  	}
 149  
 150  	trust := x509.NewCertPool()
 151  	if !trust.AppendCertsFromPEM(pemCerts) {
 152  		return nil, fmt.Errorf("creating a new x509.CertPool from %s: no certs added", caBundleFile)
 153  	}
 154  
 155  	return trust, nil
 156  }
 157