From 3a8b4fab97f9ae021a3f7ad9da2d4f71f8222851 Mon Sep 17 00:00:00 2001
From: Boris Petersen <transacid@gmail.com>
Date: Wed, 18 May 2022 14:58:31 +0200
Subject: [PATCH] Add ability to sign requests for all AWS services (#2604)

This adds the ability to utilize sigv4 signing for all AWS services not
just "aps". When the newly introduced property "service" is not set it
will default to "aps".

Signed-off-by: Boris Petersen <boris.petersen@idealo.de>
---
 app/vmagent/remotewrite/client.go   |  7 +++++--
 lib/awsapi/config.go                | 12 +++++++++---
 lib/promscrape/discovery/ec2/api.go |  2 +-
 3 files changed, 15 insertions(+), 6 deletions(-)

diff --git a/app/vmagent/remotewrite/client.go b/app/vmagent/remotewrite/client.go
index 65bc5d12b..aa5f98c79 100644
--- a/app/vmagent/remotewrite/client.go
+++ b/app/vmagent/remotewrite/client.go
@@ -70,6 +70,8 @@ var (
 		"If multiple args are set, then they are applied independently for the corresponding -remoteWrite.url")
 	awsAccessKey = flagutil.NewArray("remoteWrite.aws.accessKey", "Optional AWS AccessKey to use for -remoteWrite.url if -remoteWrite.aws.useSigv4 is set. "+
 		"If multiple args are set, then they are applied independently for the corresponding -remoteWrite.url")
+	awsService = flagutil.NewArray("remoteWrite.aws.serice", "Optional AWS Service to use for -remoteWrite.url if -remoteWrite.aws.useSigv4 is set. "+
+		"If multiple args are set, then they are applied independently for the corresponding -remoteWrite.url. Defaults to \"aps\".")
 	awsSecretKey = flagutil.NewArray("remoteWrite.aws.secretKey", "Optional AWS SecretKey to use for -remoteWrite.url if -remoteWrite.aws.useSigv4 is set. "+
 		"If multiple args are set, then they are applied independently for the corresponding -remoteWrite.url")
 )
@@ -232,7 +234,8 @@ func getAWSAPIConfig(argIdx int) (*awsapi.Config, error) {
 	roleARN := awsRoleARN.GetOptionalArg(argIdx)
 	accessKey := awsAccessKey.GetOptionalArg(argIdx)
 	secretKey := awsSecretKey.GetOptionalArg(argIdx)
-	cfg, err := awsapi.NewConfig(region, roleARN, accessKey, secretKey)
+	service := awsService.GetOptionalArg(argIdx)
+	cfg, err := awsapi.NewConfig(region, roleARN, accessKey, secretKey, service)
 	if err != nil {
 		return nil, err
 	}
@@ -307,7 +310,7 @@ again:
 		req.Header.Set("Authorization", ah)
 	}
 	if c.awsCfg != nil {
-		if err := c.awsCfg.SignRequest(req, "aps", sigv4Hash); err != nil {
+		if err := c.awsCfg.SignRequest(req, sigv4Hash); err != nil {
 			// there is no need in retry, request will be rejected by client.Do and retried by code below
 			logger.Warnf("cannot sign remoteWrite request with AWS sigv4: %s", err)
 		}
diff --git a/lib/awsapi/config.go b/lib/awsapi/config.go
index eb6dd5f22..f183bbe4d 100644
--- a/lib/awsapi/config.go
+++ b/lib/awsapi/config.go
@@ -24,6 +24,7 @@ type Config struct {
 
 	ec2Endpoint string
 	stsEndpoint string
+	service     string
 
 	// these keys are needed for obtaining creds.
 	defaultAccessKey string
@@ -43,14 +44,19 @@ type credentials struct {
 }
 
 // NewConfig returns new AWS Config.
-func NewConfig(region, roleARN, accessKey, secretKey string) (*Config, error) {
+func NewConfig(region, roleARN, accessKey, secretKey, service string) (*Config, error) {
 	cfg := &Config{
 		client:           http.DefaultClient,
 		region:           region,
 		roleARN:          roleARN,
+		service:          service,
 		defaultAccessKey: os.Getenv("AWS_ACCESS_KEY_ID"),
 		defaultSecretKey: os.Getenv("AWS_SECRET_ACCESS_KEY"),
 	}
+	cfg.service = service
+	if cfg.service == "" {
+		cfg.service = "aps"
+	}
 	cfg.region = region
 	if cfg.region == "" {
 		r, err := getDefaultRegion(cfg.client)
@@ -115,12 +121,12 @@ func (cfg *Config) GetEC2APIResponse(action, filtersQueryString, nextPageToken s
 }
 
 // SignRequest signs request for service access and payloadHash.
-func (cfg *Config) SignRequest(req *http.Request, service string, payloadHash string) error {
+func (cfg *Config) SignRequest(req *http.Request, payloadHash string) error {
 	ac, err := cfg.getFreshAPICredentials()
 	if err != nil {
 		return err
 	}
-	return signRequestWithTime(req, service, cfg.region, payloadHash, ac, time.Now().UTC())
+	return signRequestWithTime(req, cfg.service, cfg.region, payloadHash, ac, time.Now().UTC())
 }
 
 func readResponseBody(resp *http.Response, apiURL string) ([]byte, error) {
diff --git a/lib/promscrape/discovery/ec2/api.go b/lib/promscrape/discovery/ec2/api.go
index dc3640c97..652199fb3 100644
--- a/lib/promscrape/discovery/ec2/api.go
+++ b/lib/promscrape/discovery/ec2/api.go
@@ -33,7 +33,7 @@ func newAPIConfig(sdc *SDConfig) (*apiConfig, error) {
 	if sdc.Port != nil {
 		port = *sdc.Port
 	}
-	awsCfg, err := awsapi.NewConfig(sdc.Region, sdc.RoleARN, sdc.AccessKey, sdc.SecretKey.String())
+	awsCfg, err := awsapi.NewConfig(sdc.Region, sdc.RoleARN, sdc.AccessKey, sdc.SecretKey.String(), "ec2")
 	if err != nil {
 		return nil, err
 	}