resolvers.go raw

   1  // Copyright (c) Microsoft Corporation.
   2  // Licensed under the MIT license.
   3  
   4  // TODO(msal): Write some tests. The original code this came from didn't have tests and I'm too
   5  // tired at this point to do it. It, like many other *Manager code I found was broken because
   6  // they didn't have mutex protection.
   7  
   8  package oauth
   9  
  10  import (
  11  	"context"
  12  	"errors"
  13  	"fmt"
  14  	"strings"
  15  	"sync"
  16  
  17  	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops"
  18  	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
  19  )
  20  
  21  type cacheEntry struct {
  22  	Endpoints             authority.Endpoints
  23  	ValidForDomainsInList map[string]bool
  24  	// Aliases stores host aliases from instance discovery for quick lookup
  25  	Aliases map[string]bool
  26  }
  27  
  28  func createcacheEntry(endpoints authority.Endpoints) cacheEntry {
  29  	return cacheEntry{endpoints, map[string]bool{}, map[string]bool{}}
  30  }
  31  
  32  // AuthorityEndpoint retrieves endpoints from an authority for auth and token acquisition.
  33  type authorityEndpoint struct {
  34  	rest *ops.REST
  35  
  36  	mu    sync.Mutex
  37  	cache map[string]cacheEntry
  38  }
  39  
  40  // newAuthorityEndpoint is the constructor for AuthorityEndpoint.
  41  func newAuthorityEndpoint(rest *ops.REST) *authorityEndpoint {
  42  	m := &authorityEndpoint{rest: rest, cache: map[string]cacheEntry{}}
  43  	return m
  44  }
  45  
  46  // ResolveEndpoints gets the authorization and token endpoints and creates an AuthorityEndpoints instance
  47  func (m *authorityEndpoint) ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error) {
  48  
  49  	if endpoints, found := m.cachedEndpoints(authorityInfo, userPrincipalName); found {
  50  		return endpoints, nil
  51  	}
  52  
  53  	endpoint, err := m.openIDConfigurationEndpoint(ctx, authorityInfo)
  54  	if err != nil {
  55  		return authority.Endpoints{}, err
  56  	}
  57  
  58  	resp, err := m.rest.Authority().GetTenantDiscoveryResponse(ctx, endpoint)
  59  	if err != nil {
  60  		return authority.Endpoints{}, err
  61  	}
  62  	if err := resp.Validate(); err != nil {
  63  		return authority.Endpoints{}, fmt.Errorf("ResolveEndpoints(): %w", err)
  64  	}
  65  
  66  	tenant := authorityInfo.Tenant
  67  
  68  	endpoints := authority.NewEndpoints(
  69  		strings.Replace(resp.AuthorizationEndpoint, "{tenant}", tenant, -1),
  70  		strings.Replace(resp.TokenEndpoint, "{tenant}", tenant, -1),
  71  		strings.Replace(resp.Issuer, "{tenant}", tenant, -1),
  72  		authorityInfo.Host)
  73  
  74  	m.addCachedEndpoints(authorityInfo, userPrincipalName, endpoints)
  75  
  76  	if err := resp.ValidateIssuerMatchesAuthority(authorityInfo.CanonicalAuthorityURI,
  77  		m.cache[authorityInfo.CanonicalAuthorityURI].Aliases); err != nil {
  78  		return authority.Endpoints{}, fmt.Errorf("ResolveEndpoints(): %w", err)
  79  	}
  80  
  81  	return endpoints, nil
  82  }
  83  
  84  // cachedEndpoints returns the cached endpoints if they exist. If not, we return false.
  85  func (m *authorityEndpoint) cachedEndpoints(authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, bool) {
  86  	m.mu.Lock()
  87  	defer m.mu.Unlock()
  88  
  89  	if cacheEntry, ok := m.cache[authorityInfo.CanonicalAuthorityURI]; ok {
  90  		if authorityInfo.AuthorityType == authority.ADFS {
  91  			domain, err := adfsDomainFromUpn(userPrincipalName)
  92  			if err == nil {
  93  				if _, ok := cacheEntry.ValidForDomainsInList[domain]; ok {
  94  					return cacheEntry.Endpoints, true
  95  				}
  96  			}
  97  		}
  98  		return cacheEntry.Endpoints, true
  99  	}
 100  	return authority.Endpoints{}, false
 101  }
 102  
 103  func (m *authorityEndpoint) addCachedEndpoints(authorityInfo authority.Info, userPrincipalName string, endpoints authority.Endpoints) {
 104  	m.mu.Lock()
 105  	defer m.mu.Unlock()
 106  
 107  	updatedCacheEntry := createcacheEntry(endpoints)
 108  
 109  	if authorityInfo.AuthorityType == authority.ADFS {
 110  		// Since we're here, we've made a call to the backend.  We want to ensure we're caching
 111  		// the latest values from the server.
 112  		if cacheEntry, ok := m.cache[authorityInfo.CanonicalAuthorityURI]; ok {
 113  			for k := range cacheEntry.ValidForDomainsInList {
 114  				updatedCacheEntry.ValidForDomainsInList[k] = true
 115  			}
 116  		}
 117  		domain, err := adfsDomainFromUpn(userPrincipalName)
 118  		if err == nil {
 119  			updatedCacheEntry.ValidForDomainsInList[domain] = true
 120  		}
 121  	}
 122  
 123  	// Extract aliases from instance discovery metadata and add to cache
 124  	for _, metadata := range authorityInfo.InstanceDiscoveryMetadata {
 125  		for _, alias := range metadata.Aliases {
 126  			updatedCacheEntry.Aliases[alias] = true
 127  		}
 128  	}
 129  
 130  	m.cache[authorityInfo.CanonicalAuthorityURI] = updatedCacheEntry
 131  }
 132  
 133  func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, authorityInfo authority.Info) (string, error) {
 134  	if authorityInfo.AuthorityType == authority.ADFS {
 135  		return fmt.Sprintf("https://%s/adfs/.well-known/openid-configuration", authorityInfo.Host), nil
 136  	} else if authorityInfo.AuthorityType == authority.DSTS {
 137  		return fmt.Sprintf("https://%s/dstsv2/%s/v2.0/.well-known/openid-configuration", authorityInfo.Host, authority.DSTSTenant), nil
 138  
 139  	} else if authorityInfo.ValidateAuthority && !authority.TrustedHost(authorityInfo.Host) {
 140  		resp, err := m.rest.Authority().AADInstanceDiscovery(ctx, authorityInfo)
 141  		if err != nil {
 142  			return "", err
 143  		}
 144  		authorityInfo.InstanceDiscoveryMetadata = resp.Metadata
 145  		return resp.TenantDiscoveryEndpoint, nil
 146  	} else if authorityInfo.Region != "" {
 147  		resp, err := m.rest.Authority().AADInstanceDiscovery(ctx, authorityInfo)
 148  		if err != nil {
 149  			return "", err
 150  		}
 151  		authorityInfo.InstanceDiscoveryMetadata = resp.Metadata
 152  		return resp.TenantDiscoveryEndpoint, nil
 153  	}
 154  
 155  	return authorityInfo.CanonicalAuthorityURI + "v2.0/.well-known/openid-configuration", nil
 156  }
 157  
 158  func adfsDomainFromUpn(userPrincipalName string) (string, error) {
 159  	parts := strings.Split(userPrincipalName, "@")
 160  	if len(parts) < 2 {
 161  		return "", errors.New("no @ present in user principal name")
 162  	}
 163  	return parts[1], nil
 164  }
 165