resolvers.go raw
1 // Copyright (c) Microsoft Corporation.
2 // Licensed under the MIT license.
3
4 // TODO(msal): Write some tests. The original code this came from didn't have tests and I'm too
5 // tired at this point to do it. It, like many other *Manager code I found was broken because
6 // they didn't have mutex protection.
7
8 package oauth
9
10 import (
11 "context"
12 "errors"
13 "fmt"
14 "strings"
15 "sync"
16
17 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops"
18 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
19 )
20
21 type cacheEntry struct {
22 Endpoints authority.Endpoints
23 ValidForDomainsInList map[string]bool
24 // Aliases stores host aliases from instance discovery for quick lookup
25 Aliases map[string]bool
26 }
27
28 func createcacheEntry(endpoints authority.Endpoints) cacheEntry {
29 return cacheEntry{endpoints, map[string]bool{}, map[string]bool{}}
30 }
31
32 // AuthorityEndpoint retrieves endpoints from an authority for auth and token acquisition.
33 type authorityEndpoint struct {
34 rest *ops.REST
35
36 mu sync.Mutex
37 cache map[string]cacheEntry
38 }
39
40 // newAuthorityEndpoint is the constructor for AuthorityEndpoint.
41 func newAuthorityEndpoint(rest *ops.REST) *authorityEndpoint {
42 m := &authorityEndpoint{rest: rest, cache: map[string]cacheEntry{}}
43 return m
44 }
45
46 // ResolveEndpoints gets the authorization and token endpoints and creates an AuthorityEndpoints instance
47 func (m *authorityEndpoint) ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error) {
48
49 if endpoints, found := m.cachedEndpoints(authorityInfo, userPrincipalName); found {
50 return endpoints, nil
51 }
52
53 endpoint, err := m.openIDConfigurationEndpoint(ctx, authorityInfo)
54 if err != nil {
55 return authority.Endpoints{}, err
56 }
57
58 resp, err := m.rest.Authority().GetTenantDiscoveryResponse(ctx, endpoint)
59 if err != nil {
60 return authority.Endpoints{}, err
61 }
62 if err := resp.Validate(); err != nil {
63 return authority.Endpoints{}, fmt.Errorf("ResolveEndpoints(): %w", err)
64 }
65
66 tenant := authorityInfo.Tenant
67
68 endpoints := authority.NewEndpoints(
69 strings.Replace(resp.AuthorizationEndpoint, "{tenant}", tenant, -1),
70 strings.Replace(resp.TokenEndpoint, "{tenant}", tenant, -1),
71 strings.Replace(resp.Issuer, "{tenant}", tenant, -1),
72 authorityInfo.Host)
73
74 m.addCachedEndpoints(authorityInfo, userPrincipalName, endpoints)
75
76 if err := resp.ValidateIssuerMatchesAuthority(authorityInfo.CanonicalAuthorityURI,
77 m.cache[authorityInfo.CanonicalAuthorityURI].Aliases); err != nil {
78 return authority.Endpoints{}, fmt.Errorf("ResolveEndpoints(): %w", err)
79 }
80
81 return endpoints, nil
82 }
83
84 // cachedEndpoints returns the cached endpoints if they exist. If not, we return false.
85 func (m *authorityEndpoint) cachedEndpoints(authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, bool) {
86 m.mu.Lock()
87 defer m.mu.Unlock()
88
89 if cacheEntry, ok := m.cache[authorityInfo.CanonicalAuthorityURI]; ok {
90 if authorityInfo.AuthorityType == authority.ADFS {
91 domain, err := adfsDomainFromUpn(userPrincipalName)
92 if err == nil {
93 if _, ok := cacheEntry.ValidForDomainsInList[domain]; ok {
94 return cacheEntry.Endpoints, true
95 }
96 }
97 }
98 return cacheEntry.Endpoints, true
99 }
100 return authority.Endpoints{}, false
101 }
102
103 func (m *authorityEndpoint) addCachedEndpoints(authorityInfo authority.Info, userPrincipalName string, endpoints authority.Endpoints) {
104 m.mu.Lock()
105 defer m.mu.Unlock()
106
107 updatedCacheEntry := createcacheEntry(endpoints)
108
109 if authorityInfo.AuthorityType == authority.ADFS {
110 // Since we're here, we've made a call to the backend. We want to ensure we're caching
111 // the latest values from the server.
112 if cacheEntry, ok := m.cache[authorityInfo.CanonicalAuthorityURI]; ok {
113 for k := range cacheEntry.ValidForDomainsInList {
114 updatedCacheEntry.ValidForDomainsInList[k] = true
115 }
116 }
117 domain, err := adfsDomainFromUpn(userPrincipalName)
118 if err == nil {
119 updatedCacheEntry.ValidForDomainsInList[domain] = true
120 }
121 }
122
123 // Extract aliases from instance discovery metadata and add to cache
124 for _, metadata := range authorityInfo.InstanceDiscoveryMetadata {
125 for _, alias := range metadata.Aliases {
126 updatedCacheEntry.Aliases[alias] = true
127 }
128 }
129
130 m.cache[authorityInfo.CanonicalAuthorityURI] = updatedCacheEntry
131 }
132
133 func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, authorityInfo authority.Info) (string, error) {
134 if authorityInfo.AuthorityType == authority.ADFS {
135 return fmt.Sprintf("https://%s/adfs/.well-known/openid-configuration", authorityInfo.Host), nil
136 } else if authorityInfo.AuthorityType == authority.DSTS {
137 return fmt.Sprintf("https://%s/dstsv2/%s/v2.0/.well-known/openid-configuration", authorityInfo.Host, authority.DSTSTenant), nil
138
139 } else if authorityInfo.ValidateAuthority && !authority.TrustedHost(authorityInfo.Host) {
140 resp, err := m.rest.Authority().AADInstanceDiscovery(ctx, authorityInfo)
141 if err != nil {
142 return "", err
143 }
144 authorityInfo.InstanceDiscoveryMetadata = resp.Metadata
145 return resp.TenantDiscoveryEndpoint, nil
146 } else if authorityInfo.Region != "" {
147 resp, err := m.rest.Authority().AADInstanceDiscovery(ctx, authorityInfo)
148 if err != nil {
149 return "", err
150 }
151 authorityInfo.InstanceDiscoveryMetadata = resp.Metadata
152 return resp.TenantDiscoveryEndpoint, nil
153 }
154
155 return authorityInfo.CanonicalAuthorityURI + "v2.0/.well-known/openid-configuration", nil
156 }
157
158 func adfsDomainFromUpn(userPrincipalName string) (string, error) {
159 parts := strings.Split(userPrincipalName, "@")
160 if len(parts) < 2 {
161 return "", errors.New("no @ present in user principal name")
162 }
163 return parts[1], nil
164 }
165