provider.go raw

   1  package processcreds
   2  
   3  import (
   4  	"bytes"
   5  	"context"
   6  	"encoding/json"
   7  	"fmt"
   8  	"io"
   9  	"os"
  10  	"os/exec"
  11  	"runtime"
  12  	"time"
  13  
  14  	"github.com/aws/aws-sdk-go-v2/aws"
  15  	"github.com/aws/aws-sdk-go-v2/internal/sdkio"
  16  )
  17  
  18  const (
  19  	// ProviderName is the name this credentials provider will label any
  20  	// returned credentials Value with.
  21  	ProviderName = `ProcessProvider`
  22  
  23  	// DefaultTimeout default limit on time a process can run.
  24  	DefaultTimeout = time.Duration(1) * time.Minute
  25  )
  26  
  27  // ProviderError is an error indicating failure initializing or executing the
  28  // process credentials provider
  29  type ProviderError struct {
  30  	Err error
  31  }
  32  
  33  // Error returns the error message.
  34  func (e *ProviderError) Error() string {
  35  	return fmt.Sprintf("process provider error: %v", e.Err)
  36  }
  37  
  38  // Unwrap returns the underlying error the provider error wraps.
  39  func (e *ProviderError) Unwrap() error {
  40  	return e.Err
  41  }
  42  
  43  // Provider satisfies the credentials.Provider interface, and is a
  44  // client to retrieve credentials from a process.
  45  type Provider struct {
  46  	// Provides a constructor for exec.Cmd that are invoked by the provider for
  47  	// retrieving credentials. Use this to provide custom creation of exec.Cmd
  48  	// with things like environment variables, or other configuration.
  49  	//
  50  	// The provider defaults to the DefaultNewCommand function.
  51  	commandBuilder NewCommandBuilder
  52  
  53  	options Options
  54  }
  55  
  56  // Options is the configuration options for configuring the Provider.
  57  type Options struct {
  58  	// Timeout limits the time a process can run.
  59  	Timeout time.Duration
  60  	// The chain of providers that was used to create this provider
  61  	// These values are for reporting purposes and are not meant to be set up directly
  62  	CredentialSources []aws.CredentialSource
  63  }
  64  
  65  // NewCommandBuilder provides the interface for specifying how command will be
  66  // created that the Provider will use to retrieve credentials with.
  67  type NewCommandBuilder interface {
  68  	NewCommand(context.Context) (*exec.Cmd, error)
  69  }
  70  
  71  // NewCommandBuilderFunc provides a wrapper type around a function pointer to
  72  // satisfy the NewCommandBuilder interface.
  73  type NewCommandBuilderFunc func(context.Context) (*exec.Cmd, error)
  74  
  75  // NewCommand calls the underlying function pointer the builder was initialized with.
  76  func (fn NewCommandBuilderFunc) NewCommand(ctx context.Context) (*exec.Cmd, error) {
  77  	return fn(ctx)
  78  }
  79  
  80  // DefaultNewCommandBuilder provides the default NewCommandBuilder
  81  // implementation used by the provider. It takes a command and arguments to
  82  // invoke. The command will also be initialized with the current process
  83  // environment variables, stderr, and stdin pipes.
  84  type DefaultNewCommandBuilder struct {
  85  	Args []string
  86  }
  87  
  88  // NewCommand returns an initialized exec.Cmd with the builder's initialized
  89  // Args. The command is also initialized current process environment variables,
  90  // stderr, and stdin pipes.
  91  func (b DefaultNewCommandBuilder) NewCommand(ctx context.Context) (*exec.Cmd, error) {
  92  	var cmdArgs []string
  93  	if runtime.GOOS == "windows" {
  94  		cmdArgs = []string{"cmd.exe", "/C"}
  95  	} else {
  96  		cmdArgs = []string{"sh", "-c"}
  97  	}
  98  
  99  	if len(b.Args) == 0 {
 100  		return nil, &ProviderError{
 101  			Err: fmt.Errorf("failed to prepare command: command must not be empty"),
 102  		}
 103  	}
 104  
 105  	cmdArgs = append(cmdArgs, b.Args...)
 106  	cmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
 107  	cmd.Env = os.Environ()
 108  
 109  	cmd.Stderr = os.Stderr // display stderr on console for MFA
 110  	cmd.Stdin = os.Stdin   // enable stdin for MFA
 111  
 112  	return cmd, nil
 113  }
 114  
 115  // NewProvider returns a pointer to a new Credentials object wrapping the
 116  // Provider.
 117  //
 118  // The provider defaults to the DefaultNewCommandBuilder for creating command
 119  // the Provider will use to retrieve credentials with.
 120  func NewProvider(command string, options ...func(*Options)) *Provider {
 121  	var args []string
 122  
 123  	// Ensure that the command arguments are not set if the provided command is
 124  	// empty. This will error out when the command is executed since no
 125  	// arguments are specified.
 126  	if len(command) > 0 {
 127  		args = []string{command}
 128  	}
 129  
 130  	commanBuilder := DefaultNewCommandBuilder{
 131  		Args: args,
 132  	}
 133  	return NewProviderCommand(commanBuilder, options...)
 134  }
 135  
 136  // NewProviderCommand returns a pointer to a new Credentials object with the
 137  // specified command, and default timeout duration. Use this to provide custom
 138  // creation of exec.Cmd for options like environment variables, or other
 139  // configuration.
 140  func NewProviderCommand(builder NewCommandBuilder, options ...func(*Options)) *Provider {
 141  	p := &Provider{
 142  		commandBuilder: builder,
 143  		options: Options{
 144  			Timeout: DefaultTimeout,
 145  		},
 146  	}
 147  
 148  	for _, option := range options {
 149  		option(&p.options)
 150  	}
 151  
 152  	return p
 153  }
 154  
 155  // A CredentialProcessResponse is the AWS credentials format that must be
 156  // returned when executing an external credential_process.
 157  type CredentialProcessResponse struct {
 158  	// As of this writing, the Version key must be set to 1. This might
 159  	// increment over time as the structure evolves.
 160  	Version int
 161  
 162  	// The access key ID that identifies the temporary security credentials.
 163  	AccessKeyID string `json:"AccessKeyId"`
 164  
 165  	// The secret access key that can be used to sign requests.
 166  	SecretAccessKey string
 167  
 168  	// The token that users must pass to the service API to use the temporary credentials.
 169  	SessionToken string
 170  
 171  	// The date on which the current credentials expire.
 172  	Expiration *time.Time
 173  
 174  	// The ID of the account for credentials
 175  	AccountID string `json:"AccountId"`
 176  }
 177  
 178  // Retrieve executes the credential process command and returns the
 179  // credentials, or error if the command fails.
 180  func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
 181  	out, err := p.executeCredentialProcess(ctx)
 182  	if err != nil {
 183  		return aws.Credentials{Source: ProviderName}, err
 184  	}
 185  
 186  	// Serialize and validate response
 187  	resp := &CredentialProcessResponse{}
 188  	if err = json.Unmarshal(out, resp); err != nil {
 189  		return aws.Credentials{Source: ProviderName}, &ProviderError{
 190  			Err: fmt.Errorf("parse failed of process output: %s, error: %w", out, err),
 191  		}
 192  	}
 193  
 194  	if resp.Version != 1 {
 195  		return aws.Credentials{Source: ProviderName}, &ProviderError{
 196  			Err: fmt.Errorf("wrong version in process output (not 1)"),
 197  		}
 198  	}
 199  
 200  	if len(resp.AccessKeyID) == 0 {
 201  		return aws.Credentials{Source: ProviderName}, &ProviderError{
 202  			Err: fmt.Errorf("missing AccessKeyId in process output"),
 203  		}
 204  	}
 205  
 206  	if len(resp.SecretAccessKey) == 0 {
 207  		return aws.Credentials{Source: ProviderName}, &ProviderError{
 208  			Err: fmt.Errorf("missing SecretAccessKey in process output"),
 209  		}
 210  	}
 211  
 212  	creds := aws.Credentials{
 213  		Source:          ProviderName,
 214  		AccessKeyID:     resp.AccessKeyID,
 215  		SecretAccessKey: resp.SecretAccessKey,
 216  		SessionToken:    resp.SessionToken,
 217  		AccountID:       resp.AccountID,
 218  	}
 219  
 220  	// Handle expiration
 221  	if resp.Expiration != nil {
 222  		creds.CanExpire = true
 223  		creds.Expires = *resp.Expiration
 224  	}
 225  
 226  	return creds, nil
 227  }
 228  
 229  // executeCredentialProcess starts the credential process on the OS and
 230  // returns the results or an error.
 231  func (p *Provider) executeCredentialProcess(ctx context.Context) ([]byte, error) {
 232  	if p.options.Timeout >= 0 {
 233  		var cancelFunc func()
 234  		ctx, cancelFunc = context.WithTimeout(ctx, p.options.Timeout)
 235  		defer cancelFunc()
 236  	}
 237  
 238  	cmd, err := p.commandBuilder.NewCommand(ctx)
 239  	if err != nil {
 240  		return nil, err
 241  	}
 242  
 243  	// get creds json on process's stdout
 244  	output := bytes.NewBuffer(make([]byte, 0, int(8*sdkio.KibiByte)))
 245  	if cmd.Stdout != nil {
 246  		cmd.Stdout = io.MultiWriter(cmd.Stdout, output)
 247  	} else {
 248  		cmd.Stdout = output
 249  	}
 250  
 251  	execCh := make(chan error, 1)
 252  	go executeCommand(cmd, execCh)
 253  
 254  	select {
 255  	case execError := <-execCh:
 256  		if execError == nil {
 257  			break
 258  		}
 259  		select {
 260  		case <-ctx.Done():
 261  			return output.Bytes(), &ProviderError{
 262  				Err: fmt.Errorf("credential process timed out: %w", execError),
 263  			}
 264  		default:
 265  			return output.Bytes(), &ProviderError{
 266  				Err: fmt.Errorf("error in credential_process: %w", execError),
 267  			}
 268  		}
 269  	}
 270  
 271  	out := output.Bytes()
 272  	if runtime.GOOS == "windows" {
 273  		// windows adds slashes to quotes
 274  		out = bytes.ReplaceAll(out, []byte(`\"`), []byte(`"`))
 275  	}
 276  
 277  	return out, nil
 278  }
 279  
 280  // ProviderSources returns the credential chain that was used to construct this provider
 281  func (p *Provider) ProviderSources() []aws.CredentialSource {
 282  	if p.options.CredentialSources == nil {
 283  		return []aws.CredentialSource{aws.CredentialSourceProcess}
 284  	}
 285  	return p.options.CredentialSources
 286  }
 287  
 288  func executeCommand(cmd *exec.Cmd, exec chan error) {
 289  	// Start the command
 290  	err := cmd.Start()
 291  	if err == nil {
 292  		err = cmd.Wait()
 293  	}
 294  
 295  	exec <- err
 296  }
 297