diff --git a/snok_test.go b/snok_test.go index ae32f22efc18f89ad7ba6405f43d8ab6fd47bd8d..c756f96e83ab169f31c2f39d07ff6c448fb58974 100644 --- a/snok_test.go +++ b/snok_test.go @@ -1,7 +1,6 @@ package snok import ( - "context" "errors" "fmt" "io" @@ -48,8 +47,8 @@ func TestCmds(t *testing.T) { } func mockContainer(s *TestSuite, t *testing.T, args []string) { - if os.Getenv("MOCKSERVER_URL") == "" { - ctx := context.Background() + if s.GetContextValue("MOCKSERVER_URL") == nil { + ctx := s.GetContext() opts := []testcontainers.ContainerCustomizer{ testcontainers.WithImage("mockserver/mockserver:5.15.0"), } @@ -67,6 +66,7 @@ func mockContainer(s *TestSuite, t *testing.T, args []string) { require.NoError(t, mockserverContainer.Terminate(ctx)) }) s.Setenv("MOCKSERVER_URL", url) + s.SetContextValue("MOCKSERVER_URL", url) } } @@ -76,7 +76,8 @@ func patchTestInput(s *TestSuite, t *testing.T, args []string) { _, err = s.GetTest("Does not exist") require.Error(t, err) test, err := s.GetTest("Echo Input (http)") - dashboardUrl := os.Getenv("MOCKSERVER_URL") + "/mockserver/dashboard" + baseUrl := s.GetContextValue("MOCKSERVER_URL") + dashboardUrl := fmt.Sprintf("%s/mockserver/dashboard", baseUrl) t.Logf("Modifying %s", test.String()) test.Input = &dashboardUrl require.NoError(t, err) diff --git a/test_suite.go b/test_suite.go index 51f9437cb226cf52598b95d95c6d6e81b449592c..c41bd9ce4a8c90eb2ed89a9c73ff28dd951fc3d1 100644 --- a/test_suite.go +++ b/test_suite.go @@ -2,6 +2,7 @@ package snok import ( "bytes" + "context" "flag" "fmt" "io" @@ -16,8 +17,11 @@ import ( "github.com/stretchr/testify/require" ) +type ctxKey string + type TestSuite struct { t *testing.T + ctx context.Context RootCmd *cobra.Command LogLevel LogLevel Graph *TestGraph @@ -26,6 +30,7 @@ type TestSuite struct { func NewTestSuite(cmd *cobra.Command) *TestSuite { s := &TestSuite{ + ctx: context.Background(), RootCmd: cmd, LogLevel: InfoLevel, Graph: NewTestGraph(), @@ -45,6 +50,18 @@ func NewTestSuite(cmd *cobra.Command) *TestSuite { return s } +func (s *TestSuite) GetContext() context.Context { + return s.ctx +} + +func (s *TestSuite) SetContextValue(key string, value any) { + s.ctx = context.WithValue(s.ctx, ctxKey(key), value) +} + +func (s *TestSuite) GetContextValue(key string) any { + return s.ctx.Value(ctxKey(key)) +} + func (s *TestSuite) GetLogLevel() LogLevel { return s.LogLevel } @@ -159,6 +176,7 @@ func (s *TestSuite) executeCmd(args []string, input io.Reader) (string, error) { wg.Done() }() + s.RootCmd.SetContext(s.ctx) s.RootCmd.SetIn(input) s.RootCmd.SetOutput(w) s.RootCmd.SetArgs(args)