...

Source file src/github.com/cybertec-postgresql/pgwatch/v5/internal/sources/yaml.go

Documentation: github.com/cybertec-postgresql/pgwatch/v5/internal/sources

     1  package sources
     2  
     3  // This file contains the implementation of the ReaderWriter interface for the YAML file.
     4  
     5  import (
     6  	"context"
     7  	"io/fs"
     8  	"os"
     9  	"path/filepath"
    10  	"slices"
    11  	"strings"
    12  	"sync"
    13  
    14  	"gopkg.in/yaml.v3"
    15  )
    16  
    17  func NewYAMLSourcesReaderWriter(ctx context.Context, path string) (ReaderWriter, error) {
    18  	return &fileSourcesReaderWriter{
    19  		ctx:  ctx,
    20  		path: path,
    21  	}, nil
    22  }
    23  
    24  type fileSourcesReaderWriter struct {
    25  	ctx  context.Context
    26  	path string
    27  	sync.Mutex
    28  }
    29  
    30  // WriteSources writes sources to file with locking
    31  func (fcr *fileSourcesReaderWriter) WriteSources(mds Sources) error {
    32  	fcr.Lock()
    33  	defer fcr.Unlock()
    34  	return fcr.writeSources(mds)
    35  }
    36  
    37  // writeSources writes sources to file without locking (internal use only)
    38  func (fcr *fileSourcesReaderWriter) writeSources(mds Sources) error {
    39  	yamlData, _ := yaml.Marshal(mds)
    40  	return os.WriteFile(fcr.path, yamlData, 0644)
    41  }
    42  
    43  // UpdateSource updates an existing source or creates it if it doesn't exist, then writes the updated sources back to file
    44  func (fcr *fileSourcesReaderWriter) UpdateSource(md Source) error {
    45  	fcr.Lock()
    46  	defer fcr.Unlock()
    47  	dbs, err := fcr.getSources()
    48  	if err != nil {
    49  		return err
    50  	}
    51  	for i, db := range dbs {
    52  		if db.Name == md.Name {
    53  			dbs[i] = md
    54  			return fcr.writeSources(dbs)
    55  		}
    56  	}
    57  	dbs = append(dbs, md)
    58  	return fcr.writeSources(dbs)
    59  }
    60  
    61  // CreateSource creates a new source if it doesn't already exist, then writes the updated sources back to file
    62  func (fcr *fileSourcesReaderWriter) CreateSource(md Source) error {
    63  	fcr.Lock()
    64  	defer fcr.Unlock()
    65  	dbs, err := fcr.getSources()
    66  	if err != nil {
    67  		return err
    68  	}
    69  	// Check if source already exists
    70  	for _, db := range dbs {
    71  		if db.Name == md.Name {
    72  			return ErrSourceExists
    73  		}
    74  	}
    75  	dbs = append(dbs, md)
    76  	return fcr.writeSources(dbs)
    77  }
    78  
    79  // DeleteSource deletes a source by name and writes the updated sources back to file
    80  func (fcr *fileSourcesReaderWriter) DeleteSource(name string) error {
    81  	fcr.Lock()
    82  	defer fcr.Unlock()
    83  	dbs, err := fcr.getSources()
    84  	if err != nil {
    85  		return err
    86  	}
    87  	dbs = slices.DeleteFunc(dbs, func(md Source) bool { return md.Name == name })
    88  	return fcr.writeSources(dbs)
    89  }
    90  
    91  // GetSources reads sources from file with locking
    92  func (fcr *fileSourcesReaderWriter) GetSources() (dbs Sources, err error) {
    93  	fcr.Lock()
    94  	defer fcr.Unlock()
    95  	return fcr.getSources()
    96  }
    97  
    98  // getSources reads sources from file without locking (internal use only)
    99  func (fcr *fileSourcesReaderWriter) getSources() (dbs Sources, err error) {
   100  	var fi fs.FileInfo
   101  	if fi, err = os.Stat(fcr.path); err != nil {
   102  		return
   103  	}
   104  	switch mode := fi.Mode(); {
   105  	case mode.IsDir():
   106  		err = filepath.WalkDir(fcr.path, func(path string, d fs.DirEntry, err error) error {
   107  			if err != nil {
   108  				return err
   109  			}
   110  			ext := strings.ToLower(filepath.Ext(d.Name()))
   111  			if d.IsDir() || ext != ".yaml" && ext != ".yml" {
   112  				return nil
   113  			}
   114  			var mdbs Sources
   115  			if mdbs, err = fcr.loadSourcesFromFile(path); err == nil {
   116  				dbs = append(dbs, mdbs...)
   117  			}
   118  			return err
   119  		})
   120  	case mode.IsRegular():
   121  		dbs, err = fcr.loadSourcesFromFile(fcr.path)
   122  	}
   123  	if err != nil {
   124  		return nil, err
   125  	}
   126  	return dbs.Validate()
   127  }
   128  
   129  // loadSourcesFromFile reads sources from a single YAML file, expands environment variables, and returns them
   130  func (fcr *fileSourcesReaderWriter) loadSourcesFromFile(configFilePath string) (dbs Sources, err error) {
   131  	var yamlFile []byte
   132  	if yamlFile, err = os.ReadFile(configFilePath); err != nil {
   133  		return
   134  	}
   135  	c := make(Sources, 0) // there can be multiple configs in a single file
   136  	if err = yaml.Unmarshal(yamlFile, &c); err != nil {
   137  		return
   138  	}
   139  	for _, v := range c {
   140  		dbs = append(dbs, fcr.expandEnvVars(v))
   141  	}
   142  	return
   143  }
   144  
   145  func (fcr *fileSourcesReaderWriter) expandEnvVars(md Source) Source {
   146  	if strings.HasPrefix(md.ConnStr, "$") {
   147  		md.ConnStr = os.ExpandEnv(md.ConnStr)
   148  	}
   149  	if strings.HasPrefix(md.Group, "$") {
   150  		md.Group = os.ExpandEnv(md.Group)
   151  	}
   152  	if strings.HasPrefix(string(md.Kind), "$") {
   153  		md.Kind = Kind(os.ExpandEnv(string(md.Kind)))
   154  	}
   155  	if strings.HasPrefix(md.Name, "$") {
   156  		md.Name = os.ExpandEnv(md.Name)
   157  	}
   158  	if strings.HasPrefix(md.IncludePattern, "$") {
   159  		md.IncludePattern = os.ExpandEnv(md.IncludePattern)
   160  	}
   161  	if strings.HasPrefix(md.ExcludePattern, "$") {
   162  		md.ExcludePattern = os.ExpandEnv(md.ExcludePattern)
   163  	}
   164  	if strings.HasPrefix(md.PresetMetrics, "$") {
   165  		md.PresetMetrics = os.ExpandEnv(md.PresetMetrics)
   166  	}
   167  	if strings.HasPrefix(md.PresetMetricsStandby, "$") {
   168  		md.PresetMetricsStandby = os.ExpandEnv(md.PresetMetricsStandby)
   169  	}
   170  	return md
   171  }
   172