diff --git a/lib/promscrape/discovery/azure/api.go b/lib/promscrape/discovery/azure/api.go index 6c17a2676e..2ff0841936 100644 --- a/lib/promscrape/discovery/azure/api.go +++ b/lib/promscrape/discovery/azure/api.go @@ -199,12 +199,17 @@ func getRefreshTokenFunc(sdc *SDConfig, ac, proxyAC *promauth.Config, env *cloud q := endpointURL.Query() msiSecret := os.Getenv("MSI_SECRET") + identityHeader := os.Getenv("IDENTITY_HEADER") clientIDParam := "client_id" apiVersion := "2018-02-01" if msiSecret != "" { clientIDParam = "clientid" apiVersion = "2017-09-01" } + if identityHeader != "" { + clientIDParam = "client_id" + apiVersion = "2019-08-01" + } q.Set("api-version", apiVersion) q.Set(clientIDParam, sdc.ClientID) q.Set("resource", env.ResourceManagerEndpoint) @@ -214,6 +219,9 @@ func getRefreshTokenFunc(sdc *SDConfig, ac, proxyAC *promauth.Config, env *cloud modifyRequest = func(request *http.Request) { if msiSecret != "" { request.Header.Set("secret", msiSecret) + if identityHeader != "" { + request.Header.Set("X-IDENTITY-HEADER", msiSecret) + } } else { request.Header.Set("Metadata", "true") } @@ -235,15 +243,38 @@ func getRefreshTokenFunc(sdc *SDConfig, ac, proxyAC *promauth.Config, env *cloud if err := json.Unmarshal(data, &tr); err != nil { return "", 0, fmt.Errorf("cannot parse token auth response %q: %w", data, err) } - expiresInSeconds, err := strconv.ParseInt(tr.ExpiresIn, 10, 64) + + expiresInSeconds, err := parseTokenExpiry(tr) if err != nil { - return "", 0, fmt.Errorf("cannot parse expiresIn param in token auth %q: %w", tr.ExpiresIn, err) + return "", 0, err } return tr.AccessToken, time.Second * time.Duration(expiresInSeconds), nil } return refreshToken, nil } +// parseTokenExpiry returns token expiry in seconds +func parseTokenExpiry(tr tokenResponse) (int64, error) { + var expiresInSeconds int64 + var err error + + if tr.ExpiresIn == "" { + var expiresOnSeconds int64 + expiresOnSeconds, err = strconv.ParseInt(tr.ExpiresOn, 10, 64) + if err != nil { + return 0, fmt.Errorf("cannot parse expiresOn=%q in auth token response: %w", tr.ExpiresOn, err) + } + expiresInSeconds = expiresOnSeconds - time.Now().Unix() + } else { + expiresInSeconds, err = strconv.ParseInt(tr.ExpiresIn, 10, 64) + if err != nil { + return 0, fmt.Errorf("cannot parse expiresIn=%q auth token response: %w", tr.ExpiresIn, err) + } + } + + return expiresInSeconds, nil +} + // mustGetAuthToken returns auth token // in case of error, logs error and return empty token func (ac *apiConfig) mustGetAuthToken() string { @@ -270,4 +301,5 @@ func (ac *apiConfig) mustGetAuthToken() string { type tokenResponse struct { AccessToken string `json:"access_token"` ExpiresIn string `json:"expires_in"` + ExpiresOn string `json:"expires_on"` }