1 package adal
2 3 // Copyright 2017 Microsoft Corporation
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 17 import (
18 "crypto/rsa"
19 "crypto/x509"
20 "encoding/json"
21 "errors"
22 "fmt"
23 "io/ioutil"
24 "os"
25 "path/filepath"
26 27 "golang.org/x/crypto/pkcs12"
28 )
29 30 var (
31 // ErrMissingCertificate is returned when no local certificate is found in the provided PFX data.
32 ErrMissingCertificate = errors.New("adal: certificate missing")
33 34 // ErrMissingPrivateKey is returned when no private key is found in the provided PFX data.
35 ErrMissingPrivateKey = errors.New("adal: private key missing")
36 )
37 38 // LoadToken restores a Token object from a file located at 'path'.
39 func LoadToken(path string) (*Token, error) {
40 file, err := os.Open(path)
41 if err != nil {
42 return nil, fmt.Errorf("failed to open file (%s) while loading token: %v", path, err)
43 }
44 defer file.Close()
45 46 var token Token
47 48 dec := json.NewDecoder(file)
49 if err = dec.Decode(&token); err != nil {
50 return nil, fmt.Errorf("failed to decode contents of file (%s) into Token representation: %v", path, err)
51 }
52 return &token, nil
53 }
54 55 // SaveToken persists an oauth token at the given location on disk.
56 // It moves the new file into place so it can safely be used to replace an existing file
57 // that maybe accessed by multiple processes.
58 func SaveToken(path string, mode os.FileMode, token Token) error {
59 dir := filepath.Dir(path)
60 err := os.MkdirAll(dir, os.ModePerm)
61 if err != nil {
62 return fmt.Errorf("failed to create directory (%s) to store token in: %v", dir, err)
63 }
64 65 newFile, err := ioutil.TempFile(dir, "token")
66 if err != nil {
67 return fmt.Errorf("failed to create the temp file to write the token: %v", err)
68 }
69 tempPath := newFile.Name()
70 71 if err := json.NewEncoder(newFile).Encode(token); err != nil {
72 return fmt.Errorf("failed to encode token to file (%s) while saving token: %v", tempPath, err)
73 }
74 if err := newFile.Close(); err != nil {
75 return fmt.Errorf("failed to close temp file %s: %v", tempPath, err)
76 }
77 78 // Atomic replace to avoid multi-writer file corruptions
79 if err := os.Rename(tempPath, path); err != nil {
80 return fmt.Errorf("failed to move temporary token to desired output location. src=%s dst=%s: %v", tempPath, path, err)
81 }
82 if err := os.Chmod(path, mode); err != nil {
83 return fmt.Errorf("failed to chmod the token file %s: %v", path, err)
84 }
85 return nil
86 }
87 88 // DecodePfxCertificateData extracts the x509 certificate and RSA private key from the provided PFX data.
89 // The PFX data must contain a private key along with a certificate whose public key matches that of the
90 // private key or an error is returned.
91 // If the private key is not password protected pass the empty string for password.
92 func DecodePfxCertificateData(pfxData []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) {
93 blocks, err := pkcs12.ToPEM(pfxData, password)
94 if err != nil {
95 return nil, nil, err
96 }
97 // first extract the private key
98 var priv *rsa.PrivateKey
99 for _, block := range blocks {
100 if block.Type == "PRIVATE KEY" {
101 priv, err = x509.ParsePKCS1PrivateKey(block.Bytes)
102 if err != nil {
103 return nil, nil, err
104 }
105 break
106 }
107 }
108 if priv == nil {
109 return nil, nil, ErrMissingPrivateKey
110 }
111 // now find the certificate with the matching public key of our private key
112 var cert *x509.Certificate
113 for _, block := range blocks {
114 if block.Type == "CERTIFICATE" {
115 pcert, err := x509.ParseCertificate(block.Bytes)
116 if err != nil {
117 return nil, nil, err
118 }
119 certKey, ok := pcert.PublicKey.(*rsa.PublicKey)
120 if !ok {
121 // keep looking
122 continue
123 }
124 if priv.E == certKey.E && priv.N.Cmp(certKey.N) == 0 {
125 // found a match
126 cert = pcert
127 break
128 }
129 }
130 }
131 if cert == nil {
132 return nil, nil, ErrMissingCertificate
133 }
134 return cert, priv, nil
135 }
136