From 729bb1a0886f399d59abacd0ac80b77d6847dd84 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Erik=20Hedenstro=CC=88m?= <erik@hedenstroem.com>
Date: Tue, 21 May 2024 00:49:34 +0200
Subject: [PATCH] Added context to test suite

---
 snok_test.go  |  9 +++++----
 test_suite.go | 18 ++++++++++++++++++
 2 files changed, 23 insertions(+), 4 deletions(-)

diff --git a/snok_test.go b/snok_test.go
index ae32f22..c756f96 100644
--- a/snok_test.go
+++ b/snok_test.go
@@ -1,7 +1,6 @@
 package snok
 
 import (
-	"context"
 	"errors"
 	"fmt"
 	"io"
@@ -48,8 +47,8 @@ func TestCmds(t *testing.T) {
 }
 
 func mockContainer(s *TestSuite, t *testing.T, args []string) {
-	if os.Getenv("MOCKSERVER_URL") == "" {
-		ctx := context.Background()
+	if s.GetContextValue("MOCKSERVER_URL") == nil {
+		ctx := s.GetContext()
 		opts := []testcontainers.ContainerCustomizer{
 			testcontainers.WithImage("mockserver/mockserver:5.15.0"),
 		}
@@ -67,6 +66,7 @@ func mockContainer(s *TestSuite, t *testing.T, args []string) {
 			require.NoError(t, mockserverContainer.Terminate(ctx))
 		})
 		s.Setenv("MOCKSERVER_URL", url)
+		s.SetContextValue("MOCKSERVER_URL", url)
 	}
 }
 
@@ -76,7 +76,8 @@ func patchTestInput(s *TestSuite, t *testing.T, args []string) {
 	_, err = s.GetTest("Does not exist")
 	require.Error(t, err)
 	test, err := s.GetTest("Echo Input (http)")
-	dashboardUrl := os.Getenv("MOCKSERVER_URL") + "/mockserver/dashboard"
+	baseUrl := s.GetContextValue("MOCKSERVER_URL")
+	dashboardUrl := fmt.Sprintf("%s/mockserver/dashboard", baseUrl)
 	t.Logf("Modifying %s", test.String())
 	test.Input = &dashboardUrl
 	require.NoError(t, err)
diff --git a/test_suite.go b/test_suite.go
index 51f9437..c41bd9c 100644
--- a/test_suite.go
+++ b/test_suite.go
@@ -2,6 +2,7 @@ package snok
 
 import (
 	"bytes"
+	"context"
 	"flag"
 	"fmt"
 	"io"
@@ -16,8 +17,11 @@ import (
 	"github.com/stretchr/testify/require"
 )
 
+type ctxKey string
+
 type TestSuite struct {
 	t             *testing.T
+	ctx           context.Context
 	RootCmd       *cobra.Command
 	LogLevel      LogLevel
 	Graph         *TestGraph
@@ -26,6 +30,7 @@ type TestSuite struct {
 
 func NewTestSuite(cmd *cobra.Command) *TestSuite {
 	s := &TestSuite{
+		ctx:           context.Background(),
 		RootCmd:       cmd,
 		LogLevel:      InfoLevel,
 		Graph:         NewTestGraph(),
@@ -45,6 +50,18 @@ func NewTestSuite(cmd *cobra.Command) *TestSuite {
 	return s
 }
 
+func (s *TestSuite) GetContext() context.Context {
+	return s.ctx
+}
+
+func (s *TestSuite) SetContextValue(key string, value any) {
+	s.ctx = context.WithValue(s.ctx, ctxKey(key), value)
+}
+
+func (s *TestSuite) GetContextValue(key string) any {
+	return s.ctx.Value(ctxKey(key))
+}
+
 func (s *TestSuite) GetLogLevel() LogLevel {
 	return s.LogLevel
 }
@@ -159,6 +176,7 @@ func (s *TestSuite) executeCmd(args []string, input io.Reader) (string, error) {
 		wg.Done()
 	}()
 
+	s.RootCmd.SetContext(s.ctx)
 	s.RootCmd.SetIn(input)
 	s.RootCmd.SetOutput(w)
 	s.RootCmd.SetArgs(args)
-- 
GitLab