From b801b299f051055e3f66b2c78426a195152086e5 Mon Sep 17 00:00:00 2001
From: faceair <git@faceair.me>
Date: Thu, 27 May 2021 19:52:44 +0800
Subject: [PATCH] lib/promscrape: apply body size & sample limit to stream
 parse (#1331)

* lib/promscrape: apply body size limit to stream parse

Signed-off-by: faceair <git@faceair.me>

* lib/promscrape: apply sample limit to stream parse

Signed-off-by: faceair <git@faceair.me>
---
 lib/promscrape/client.go     | 18 +++++++++++++-----
 lib/promscrape/scrapework.go |  8 ++++++++
 2 files changed, 21 insertions(+), 5 deletions(-)

diff --git a/lib/promscrape/client.go b/lib/promscrape/client.go
index 5baa7cbab4..fc21f967f7 100644
--- a/lib/promscrape/client.go
+++ b/lib/promscrape/client.go
@@ -192,8 +192,10 @@ func (c *client) GetStreamReader() (*streamReader, error) {
 	}
 	scrapesOK.Inc()
 	return &streamReader{
-		r:      resp.Body,
-		cancel: cancel,
+		r:           resp.Body,
+		cancel:      cancel,
+		scrapeURL:   c.scrapeURL,
+		maxBodySize: int64(c.hc.MaxResponseBodySize),
 	}, nil
 }
 
@@ -328,14 +330,20 @@ func doRequestWithPossibleRetry(hc *fasthttp.HostClient, req *fasthttp.Request,
 }
 
 type streamReader struct {
-	r         io.ReadCloser
-	cancel    context.CancelFunc
-	bytesRead int64
+	r           io.ReadCloser
+	cancel      context.CancelFunc
+	bytesRead   int64
+	scrapeURL   string
+	maxBodySize int64
 }
 
 func (sr *streamReader) Read(p []byte) (int, error) {
 	n, err := sr.r.Read(p)
 	sr.bytesRead += int64(n)
+	if sr.bytesRead > sr.maxBodySize {
+		return 0, fmt.Errorf("the response from %q exceeds -promscrape.maxScrapeSize=%d; "+
+			"either reduce the response size for the target or increase -promscrape.maxScrapeSize", sr.scrapeURL, sr.maxBodySize)
+	}
 	return n, err
 }
 
diff --git a/lib/promscrape/scrapework.go b/lib/promscrape/scrapework.go
index 3a09482988..7922344296 100644
--- a/lib/promscrape/scrapework.go
+++ b/lib/promscrape/scrapework.go
@@ -305,6 +305,8 @@ func (sw *scrapeWork) scrapeInternal(scrapeTimestamp, realTimestamp int64) error
 		wc.resetNoRows()
 		up = 0
 		scrapesSkippedBySampleLimit.Inc()
+		err = fmt.Errorf("the response from %q exceeds sample_limit=%d; "+
+			"either reduce the sample count for the target or increase sample_limit", sw.Config.ScrapeURL, sw.Config.SampleLimit)
 	}
 	sw.updateSeriesAdded(wc)
 	seriesAdded := sw.finalizeSeriesAdded(samplesPostRelabeling)
@@ -348,6 +350,12 @@ func (sw *scrapeWork) scrapeStream(scrapeTimestamp, realTimestamp int64) error {
 			// after returning from the callback - this will result in data race.
 			// See https://github.com/VictoriaMetrics/VictoriaMetrics/issues/825#issuecomment-723198247
 			samplesPostRelabeling += len(wc.writeRequest.Timeseries)
+			if sw.Config.SampleLimit > 0 && samplesPostRelabeling > sw.Config.SampleLimit {
+				wc.resetNoRows()
+				scrapesSkippedBySampleLimit.Inc()
+				return fmt.Errorf("the response from %q exceeds sample_limit=%d; "+
+					"either reduce the sample count for the target or increase sample_limit", sw.Config.ScrapeURL, sw.Config.SampleLimit)
+			}
 			sw.updateSeriesAdded(wc)
 			startTime := time.Now()
 			sw.PushData(&wc.writeRequest)