From 4a078200486572b1d9cb7cd71113ff63e79c73e6 Mon Sep 17 00:00:00 2001
From: Aliaksandr Valialkin <valyala@gmail.com>
Date: Wed, 17 Feb 2021 14:59:04 +0200
Subject: [PATCH] lib/storage: make sure that nobody uses partitions when
 closing the table

---
 lib/storage/storage.go      |  2 ++
 lib/storage/storage_test.go |  3 ++-
 lib/storage/table.go        | 32 ++++++++++++++++++--------------
 3 files changed, 22 insertions(+), 15 deletions(-)

diff --git a/lib/storage/storage.go b/lib/storage/storage.go
index 0cdfc22dfa..d1f8dd57c9 100644
--- a/lib/storage/storage.go
+++ b/lib/storage/storage.go
@@ -588,6 +588,8 @@ func (s *Storage) mustRotateIndexDB() {
 }
 
 // MustClose closes the storage.
+//
+// It is expected that the s is no longer used during the close.
 func (s *Storage) MustClose() {
 	close(s.stop)
 
diff --git a/lib/storage/storage_test.go b/lib/storage/storage_test.go
index 9c8bbd60ca..f2891841ee 100644
--- a/lib/storage/storage_test.go
+++ b/lib/storage/storage_test.go
@@ -967,15 +967,16 @@ func testStorageAddRows(s *Storage) error {
 		return fmt.Errorf("error when force merging partitions: %w", err)
 	}
 	ptws := s1.tb.GetPartitions(nil)
-	defer s1.tb.PutPartitions(ptws)
 	for _, ptw := range ptws {
 		pws := ptw.pt.GetParts(nil)
 		numParts := len(pws)
 		ptw.pt.PutParts(pws)
 		if numParts != 1 {
+			s1.tb.PutPartitions(ptws)
 			return fmt.Errorf("unexpected number of parts for partition %q after force merge; got %d; want 1", ptw.pt.name, numParts)
 		}
 	}
+	s1.tb.PutPartitions(ptws)
 
 	s1.MustClose()
 
diff --git a/lib/storage/table.go b/lib/storage/table.go
index faea8c1a08..646d330742 100644
--- a/lib/storage/table.go
+++ b/lib/storage/table.go
@@ -189,6 +189,7 @@ func (tb *table) addPartitionNolock(pt *partition) {
 }
 
 // MustClose closes the table.
+// It is expected that all the pending searches on the table are finished before calling MustClose.
 func (tb *table) MustClose() {
 	close(tb.stop)
 	tb.retentionWatcherWG.Wait()
@@ -198,10 +199,11 @@ func (tb *table) MustClose() {
 	tb.ptws = nil
 	tb.ptwsLock.Unlock()
 
-	// Decrement references to partitions, so they may be eventually closed after
-	// pending searches are done.
 	for _, ptw := range ptws {
-		ptw.decRef()
+		if n := atomic.LoadUint64(&ptw.refCount); n != 1 {
+			logger.Panicf("BUG: unexpected refCount=%d when closing the partition; probably there are pending searches", n)
+		}
+		ptw.pt.MustClose()
 	}
 
 	// Release exclusive lock on the table.
@@ -271,10 +273,10 @@ func (tb *table) AddRows(rows []rawRow) error {
 
 	ptwsX.a = tb.GetPartitions(ptwsX.a[:0])
 	ptws := ptwsX.a
-	for _, ptw := range ptws {
+	for i, ptw := range ptws {
 		singlePt := true
-		for i := range rows {
-			if !ptw.pt.HasTimestamp(rows[i].Timestamp) {
+		for j := range rows {
+			if !ptw.pt.HasTimestamp(rows[j].Timestamp) {
 				singlePt = false
 				break
 			}
@@ -283,16 +285,18 @@ func (tb *table) AddRows(rows []rawRow) error {
 			continue
 		}
 
-		// Move the partition with the matching rows to the front of tb.ptws,
-		// so it will be detected faster next time.
-		tb.ptwsLock.Lock()
-		for i := range tb.ptws {
-			if ptw == tb.ptws[i] {
-				tb.ptws[0], tb.ptws[i] = tb.ptws[i], tb.ptws[0]
-				break
+		if i != 0 {
+			// Move the partition with the matching rows to the front of tb.ptws,
+			// so it will be detected faster next time.
+			tb.ptwsLock.Lock()
+			for j := range tb.ptws {
+				if ptw == tb.ptws[j] {
+					tb.ptws[0], tb.ptws[j] = tb.ptws[j], tb.ptws[0]
+					break
+				}
 			}
+			tb.ptwsLock.Unlock()
 		}
-		tb.ptwsLock.Unlock()
 
 		// Fast path - add all the rows into the ptw.
 		ptw.pt.AddRows(rows)