...

Source file src/github.com/cybertec-postgresql/pgwatch/v3/internal/sources/yaml_test.go

Documentation: github.com/cybertec-postgresql/pgwatch/v3/internal/sources

     1  package sources_test
     2  
     3  import (
     4  	"os"
     5  	"path/filepath"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  
    10  	"github.com/cybertec-postgresql/pgwatch/v3/internal/sources"
    11  )
    12  
    13  // the number of entries in the sample.sources.yaml file
    14  const sampleEntriesNumber = 4
    15  
    16  const (
    17  	contribDir = "../../contrib/"
    18  	sampleFile = "../../contrib/sample.sources.yaml"
    19  )
    20  
    21  func TestNewYAMLSourcesReaderWriter(t *testing.T) {
    22  	a := assert.New(t)
    23  	yamlrw, err := sources.NewYAMLSourcesReaderWriter(ctx, sampleFile)
    24  	a.NoError(err)
    25  	a.NotNil(t, yamlrw)
    26  }
    27  
    28  func TestYAMLGetMonitoredDatabases(t *testing.T) {
    29  	a := assert.New(t)
    30  
    31  	t.Run("single file", func(*testing.T) {
    32  		yamlrw, err := sources.NewYAMLSourcesReaderWriter(ctx, sampleFile)
    33  		a.NoError(err)
    34  
    35  		dbs, err := yamlrw.GetSources()
    36  		a.NoError(err)
    37  		a.Len(dbs, sampleEntriesNumber)
    38  	})
    39  
    40  	t.Run("folder with yaml files", func(*testing.T) {
    41  		yamlrw, err := sources.NewYAMLSourcesReaderWriter(ctx, contribDir)
    42  		a.NoError(err)
    43  
    44  		dbs, err := yamlrw.GetSources()
    45  		a.NoError(err)
    46  		a.Len(dbs, sampleEntriesNumber)
    47  	})
    48  
    49  	t.Run("nonexistent file", func(*testing.T) {
    50  		yamlrw, err := sources.NewYAMLSourcesReaderWriter(ctx, "nonexistent.yaml")
    51  		a.NoError(err)
    52  		dbs, err := yamlrw.GetSources()
    53  		a.Error(err)
    54  		a.Nil(dbs)
    55  	})
    56  
    57  	t.Run("garbage file", func(*testing.T) {
    58  		yamlrw, err := sources.NewYAMLSourcesReaderWriter(ctx, filepath.Join(contribDir, "yaml.go"))
    59  		a.NoError(err)
    60  		dbs, err := yamlrw.GetSources()
    61  		a.Error(err)
    62  		a.Nil(dbs)
    63  	})
    64  
    65  	t.Run("duplicate in single file", func(t *testing.T) {
    66  		tmpFile := filepath.Join(t.TempDir(), "duplicate.yaml")
    67  		yamlContent := `
    68  - name: test1
    69    conn_str: postgresql://localhost/test1
    70  - name: test2
    71    conn_str: postgresql://localhost/test2
    72  - name: test1
    73    conn_str: postgresql://localhost/test1_duplicate
    74  `
    75  		err := os.WriteFile(tmpFile, []byte(yamlContent), 0644)
    76  		a.NoError(err)
    77  		yamlrw, err := sources.NewYAMLSourcesReaderWriter(ctx, tmpFile)
    78  		a.NoError(err)
    79  
    80  		dbs, err := yamlrw.GetSources()
    81  		a.Error(err)
    82  		a.Nil(dbs)
    83  	})
    84  
    85  	t.Run("duplicates across files", func(t *testing.T) {
    86  		tmpDir := t.TempDir()
    87  		yamlContent1 := `
    88  - name: test1
    89    conn_str: postgresql://localhost/test1
    90  - name: test2
    91    conn_str: postgresql://localhost/test2
    92  `
    93  		err := os.WriteFile(filepath.Join(tmpDir, "sources1.yaml"), []byte(yamlContent1), 0644)
    94  		a.NoError(err)
    95  
    96  		yamlContent2 := `
    97  - name: test1
    98    conn_str: postgresql://localhost/test1_duplicate
    99  `
   100  		err = os.WriteFile(filepath.Join(tmpDir, "sources2.yaml"), []byte(yamlContent2), 0644)
   101  		a.NoError(err)
   102  		yamlrw, err := sources.NewYAMLSourcesReaderWriter(ctx, tmpDir)
   103  		a.NoError(err)
   104  
   105  		dbs, err := yamlrw.GetSources()
   106  		a.Error(err)
   107  		a.Nil(dbs)
   108  	})
   109  }
   110  
   111  func TestYAMLDeleteDatabase(t *testing.T) {
   112  	a := assert.New(t)
   113  
   114  	t.Run("happy path", func(*testing.T) {
   115  		data, err := os.ReadFile(sampleFile)
   116  		a.NoError(err)
   117  		tmpSampleFile := filepath.Join(t.TempDir(), "sample.sources.yaml")
   118  		err = os.WriteFile(tmpSampleFile, data, 0644)
   119  		a.NoError(err)
   120  		defer os.Remove(tmpSampleFile)
   121  
   122  		yamlrw, err := sources.NewYAMLSourcesReaderWriter(ctx, tmpSampleFile)
   123  		a.NoError(err)
   124  
   125  		err = yamlrw.DeleteSource("test1")
   126  		a.NoError(err)
   127  
   128  		dbs, err := yamlrw.GetSources()
   129  		a.NoError(err)
   130  		a.Len(dbs, sampleEntriesNumber-1)
   131  	})
   132  
   133  	t.Run("nonexistent file", func(*testing.T) {
   134  		yamlrw, err := sources.NewYAMLSourcesReaderWriter(ctx, "nonexistent.yaml")
   135  		a.NoError(err)
   136  		err = yamlrw.DeleteSource("test1")
   137  		a.Error(err)
   138  	})
   139  }
   140  
   141  func TestYAMLUpdateDatabase(t *testing.T) {
   142  	a := assert.New(t)
   143  
   144  	t.Run("happy path", func(*testing.T) {
   145  		data, err := os.ReadFile(sampleFile)
   146  		a.NoError(err)
   147  		tmpSampleFile := filepath.Join(t.TempDir(), "sample.sources.yaml")
   148  		err = os.WriteFile(tmpSampleFile, data, 0644)
   149  		a.NoError(err)
   150  		defer os.Remove(tmpSampleFile)
   151  
   152  		yamlrw, err := sources.NewYAMLSourcesReaderWriter(ctx, tmpSampleFile)
   153  		a.NoError(err)
   154  
   155  		// change the connection string of the first database
   156  		md := sources.Source{}
   157  		md.Name = "test1"
   158  		md.ConnStr = "postgresql://localhost/test1"
   159  		err = yamlrw.UpdateSource(md)
   160  		a.NoError(err)
   161  
   162  		// add a new database
   163  		md = sources.Source{}
   164  		md.Name = "test5"
   165  		md.ConnStr = "postgresql://localhost/test5"
   166  		err = yamlrw.UpdateSource(md)
   167  		a.NoError(err)
   168  
   169  		dbs, err := yamlrw.GetSources()
   170  		a.NoError(err)
   171  		a.Len(dbs, sampleEntriesNumber+1)
   172  		dbs[0].ConnStr = "postgresql://localhost/test1"
   173  		dbs[sampleEntriesNumber].ConnStr = "postgresql://localhost/test5"
   174  	})
   175  
   176  	t.Run("nonexistent file", func(*testing.T) {
   177  		yamlrw, err := sources.NewYAMLSourcesReaderWriter(ctx, "")
   178  		a.NoError(err)
   179  		err = yamlrw.UpdateSource(sources.Source{})
   180  		a.Error(err)
   181  	})
   182  }
   183