From b743b26dded574043d560627c636b2206db14b26 Mon Sep 17 00:00:00 2001
From: evilsocket <evilsocket@gmail.com>
Date: Wed, 3 Apr 2019 09:39:28 +0200
Subject: [PATCH] misc: decoupled session record reader from the modules

---
 modules/api_rest/api_rest_record.go |  2 +-
 modules/api_rest/api_rest_replay.go |  5 +++-
 modules/api_rest/record.go          | 46 ++++++++++++++---------------
 3 files changed, 27 insertions(+), 26 deletions(-)

diff --git a/modules/api_rest/api_rest_record.go b/modules/api_rest/api_rest_record.go
index fd5d7c09..1f9e0413 100644
--- a/modules/api_rest/api_rest_record.go
+++ b/modules/api_rest/api_rest_record.go
@@ -45,7 +45,7 @@ func (mod *RestAPI) recorder() {
 	mod.recTime = 0
 	mod.recording = true
 	mod.replaying = false
-	mod.record = NewRecord(mod.recordFileName, &mod.SessionModule)
+	mod.record = NewRecord(mod.recordFileName, nil)
 
 	mod.Info("started recording to %s (clock %s) ...", mod.recordFileName, clock)
 
diff --git a/modules/api_rest/api_rest_replay.go b/modules/api_rest/api_rest_replay.go
index b04c9ef4..85c73eee 100644
--- a/modules/api_rest/api_rest_replay.go
+++ b/modules/api_rest/api_rest_replay.go
@@ -38,7 +38,10 @@ func (mod *RestAPI) startReplay(filename string) (err error) {
 	mod.Info("loading %s ...", mod.recordFileName)
 
 	start := time.Now()
-	if mod.record, err = LoadRecord(mod.recordFileName, &mod.SessionModule); err != nil {
+	mod.record, err = LoadRecord(mod.recordFileName, func(progress float64) {
+		mod.State.Store("load_progress", progress)
+	})
+	if err != nil {
 		return err
 	}
 	loadedIn := time.Since(start)
diff --git a/modules/api_rest/record.go b/modules/api_rest/record.go
index cd52765f..911740ac 100644
--- a/modules/api_rest/record.go
+++ b/modules/api_rest/record.go
@@ -10,8 +10,6 @@ import (
 	"sync"
 	"time"
 
-	"github.com/bettercap/bettercap/session"
-
 	"github.com/evilsocket/islazy/fs"
 	"github.com/kr/binarydist"
 )
@@ -177,38 +175,40 @@ func (e *RecordEntry) Duration() time.Duration {
 	return e.StoppedAt().Sub(e.StartedAt())
 }
 
+type RecordLoadProgress func(p float64)
+
 // the Record object represents a recorded session
 type Record struct {
 	sync.Mutex
 
-	mod      *session.SessionModule `json:"-"`
-	fileName string                 `json:"-"`
-	done     int                    `json:"-"`
-	total    int                    `json:"-"`
-	progress float64                `json:"-"`
-	Session  *RecordEntry           `json:"session"`
-	Events   *RecordEntry           `json:"events"`
+	fileName   string             `json:"-"`
+	done       int                `json:"-"`
+	total      int                `json:"-"`
+	progress   float64            `json:"-"`
+	onProgress RecordLoadProgress `json:"-"`
+	Session    *RecordEntry       `json:"session"`
+	Events     *RecordEntry       `json:"events"`
 }
 
-func NewRecord(fileName string, mod *session.SessionModule) *Record {
+func NewRecord(fileName string, cb RecordLoadProgress) *Record {
 	r := &Record{
-		fileName: fileName,
-		mod:      mod,
+		fileName:   fileName,
+		onProgress: cb,
 	}
 
-	r.Session = NewRecordEntry(r.onProgress)
-	r.Events = NewRecordEntry(r.onProgress)
+	r.Session = NewRecordEntry(r.onPartialProgress)
+	r.Events = NewRecordEntry(r.onPartialProgress)
 
 	return r
 }
 
-func (r *Record) onProgress(done int) {
+func (r *Record) onPartialProgress(done int) {
 	r.done += done
 	r.progress = float64(r.done) / float64(r.total) * 100.0
-	r.mod.State.Store("load_progress", r.progress)
+	r.onProgress(r.progress)
 }
 
-func LoadRecord(fileName string, mod *session.SessionModule) (*Record, error) {
+func LoadRecord(fileName string, cb RecordLoadProgress) (*Record, error) {
 	if !fs.Exists(fileName) {
 		return nil, fmt.Errorf("%s does not exist", fileName)
 	}
@@ -236,17 +236,15 @@ func LoadRecord(fileName string, mod *session.SessionModule) (*Record, error) {
 		return nil, fmt.Errorf("error while parsing %s: %s", fileName, err)
 	}
 
-	rec.fileName = fileName
-	rec.mod = mod
-
 	rec.Session.NumStates = len(rec.Session.States)
-	rec.Session.progress = rec.onProgress
+	rec.Session.progress = rec.onPartialProgress
 	rec.Events.NumStates = len(rec.Events.States)
-	rec.Events.progress = rec.onProgress
-
-	rec.done = 0
+	rec.Events.progress = rec.onPartialProgress
+	rec.fileName = fileName
 	rec.total = rec.Session.NumStates + rec.Events.NumStates + 2
 	rec.progress = 0.0
+	rec.done = 0
+	rec.onProgress = cb
 
 	// reset state and precompute frames
 	if err = rec.Session.Compile(); err != nil {