package aghrenameio_test import ( "io/fs" "os" "path/filepath" "testing" "github.com/AdguardTeam/AdGuardHome/internal/aghrenameio" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // testPerm is the common permission mode for tests. const testPerm fs.FileMode = 0o644 // Common file data for tests. var ( initialData = []byte("initial data\n") newData = []byte("new data\n") ) func TestPendingFile(t *testing.T) { t.Parallel() targetPath := newInitialFile(t) f, err := aghrenameio.NewPendingFile(targetPath, testPerm) require.NoError(t, err) _, err = f.Write(newData) require.NoError(t, err) err = f.CloseReplace() require.NoError(t, err) gotData, err := os.ReadFile(targetPath) require.NoError(t, err) assert.Equal(t, newData, gotData) } // newInitialFile is a test helper that returns the path to the file containing // [initialData]. func newInitialFile(t *testing.T) (targetPath string) { t.Helper() dir := t.TempDir() targetPath = filepath.Join(dir, "target") err := os.WriteFile(targetPath, initialData, 0o644) require.NoError(t, err) return targetPath } func TestWithDeferredCleanup(t *testing.T) { t.Parallel() const testError errors.Error = "test error" testCases := []struct { error error name string wantErrMsg string wantData []byte }{{ name: "success", error: nil, wantErrMsg: "", wantData: newData, }, { name: "error", error: testError, wantErrMsg: testError.Error(), wantData: initialData, }} for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() targetPath := newInitialFile(t) f, err := aghrenameio.NewPendingFile(targetPath, testPerm) require.NoError(t, err) _, err = f.Write(newData) require.NoError(t, err) err = aghrenameio.WithDeferredCleanup(tc.error, f) testutil.AssertErrorMsg(t, tc.wantErrMsg, err) gotData, err := os.ReadFile(targetPath) require.NoError(t, err) assert.Equal(t, tc.wantData, gotData) }) } }