diff options
| -rw-r--r-- | config.go | 24 | ||||
| -rw-r--r-- | example.conf.hcl | 4 | ||||
| -rw-r--r-- | main.go | 12 | ||||
| -rw-r--r-- | server.go | 37 |
4 files changed, 53 insertions, 24 deletions
@@ -2,8 +2,10 @@ package main import ( "fmt" + "net" "net/netip" "net/url" + "strconv" "strings" "github.com/hashicorp/hcl/v2/hclsimple" @@ -16,7 +18,7 @@ type Config struct { type ServerConfig struct { Protocol string - Host string + Hosts []string Port int RedirectToHTTPS bool ACMEChallenge *ACMEChallengeConfig @@ -58,7 +60,7 @@ type InternalHCLConfig struct { type InternalHCLServerConfig struct { Protocol string `hcl:"protocol,label"` - Host string `hcl:"host"` + Hosts []string `hcl:"hosts"` Port int `hcl:"port"` RedirectToHTTPS bool `hcl:"redirect_to_https,optional"` ACMEChallenge []InternalHCLACMEChallengeConfig `hcl:"acme_challenge,block"` @@ -128,7 +130,7 @@ func fromHCLConfigToConfig(hclConfig *InternalHCLConfig) *Config { } servers[i] = ServerConfig{ Protocol: s.Protocol, - Host: s.Host, + Hosts: s.Hosts, Port: s.Port, RedirectToHTTPS: s.RedirectToHTTPS, ACMEChallenge: acmeChallenge, @@ -167,9 +169,14 @@ func LoadConfig(fileName string) (*Config, error) { return nil, fmt.Errorf("Invalid protocol %s", server.Protocol) } - _, err = netip.ParseAddr(server.Host) - if err != nil { - return nil, fmt.Errorf("Invalid host %s", server.Host) + if len(server.Hosts) == 0 { + return nil, fmt.Errorf("At least one host address is required") + } + for _, host := range server.Hosts { + _, err = netip.ParseAddr(host) + if err != nil { + return nil, fmt.Errorf("Invalid host %s", host) + } } if len(server.ACMEChallenge) != 0 && len(server.ACMEChallenge) != 1 { @@ -217,9 +224,10 @@ func LoadConfig(fileName string) (*Config, error) { if p.From.Host == "" && p.From.Path == "" { return nil, fmt.Errorf("Either host or path must be specified") } - _, err := url.Parse(fmt.Sprintf("http://%s:%d", p.To.Host, p.To.Port)) + hostAndPort := net.JoinHostPort(p.To.Host, strconv.Itoa(p.To.Port)) + _, err := url.Parse("http://" + hostAndPort) if err != nil { - return nil, fmt.Errorf("Invalid host or port: %s:%d", p.To.Host, p.To.Port) + return nil, fmt.Errorf("Invalid host or port: %s", hostAndPort) } if 2 <= len(p.Auths) { return nil, fmt.Errorf("Too many auth blocks found") diff --git a/example.conf.hcl b/example.conf.hcl index 9274fa4..88f65a4 100644 --- a/example.conf.hcl +++ b/example.conf.hcl @@ -1,5 +1,7 @@ server http { - host = "127.0.0.1" + hosts = ["127.0.0.1"] + # hosts = ["::1"] # Listen on localhost (IPv6) + # hosts = ["0.0.0.0", "::"] # Listen on all interfaces (IPv4 + IPv6) port = 8000 proxy a { @@ -18,14 +18,14 @@ import ( func startServer( s *Server, - listener net.Listener, + listeners []net.Listener, wg *sync.WaitGroup, sigCtx context.Context, ) { defer wg.Done() go func() { - err := s.Serve(listener) + err := s.Serve(listeners) if err != nil && err != http.ErrServerClosed { log.Fatalf("Failed to start server (%s): %s", s.Label(), err) } @@ -113,13 +113,13 @@ func main() { configFileDir := filepath.Dir(configFileName) // Set up listeners. - var listeners []net.Listener + var listeners [][]net.Listener for _, s := range config.Servers { - l, err := NewListener(&s) + ls, err := NewListeners(&s) if err != nil { - log.Fatalf("Failed to create listener (%s:%d): %s", s.Host, s.Port, err) + log.Fatalf("Failed to create listeners (%v:%d): %s", s.Hosts, s.Port, err) } - listeners = append(listeners, l) + listeners = append(listeners, ls) } // Set up servers. @@ -9,6 +9,7 @@ import ( "net/http/httputil" "net/url" "os" + "strconv" "strings" ) @@ -55,7 +56,7 @@ func basicAuthHandler(handler http.Handler, realm, username, passwordHash string func newMultipleReverseProxyServer(cfg *ServerConfig) (*multipleReverseProxyServer, error) { var rules []rewriteRule for _, p := range cfg.Proxies { - targetUrl, err := url.Parse(fmt.Sprintf("http://%s:%d", p.To.Host, p.To.Port)) + targetUrl, err := url.Parse("http://" + net.JoinHostPort(p.To.Host, strconv.Itoa(p.To.Port))) if err != nil { return nil, err } @@ -184,7 +185,7 @@ func NewServer(cfg *ServerConfig) (*Server, error) { return &Server{ tlsEnabled: cfg.Protocol == "https", s: http.Server{ - Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), + Addr: fmt.Sprintf("%s:%d", strings.Join(cfg.Hosts, ","), cfg.Port), Handler: h, TLSConfig: tlsConfig, }, @@ -195,18 +196,36 @@ func (s *Server) Label() string { return s.s.Addr } -func (s *Server) Serve(listener net.Listener) error { - if s.tlsEnabled { - return s.s.ServeTLS(listener, "", "") - } else { - return s.s.Serve(listener) +func (s *Server) Serve(listeners []net.Listener) error { + errC := make(chan error, len(listeners)) + for _, l := range listeners { + go func() { + if s.tlsEnabled { + errC <- s.s.ServeTLS(l, "", "") + } else { + errC <- s.s.Serve(l) + } + }() } + return <-errC } func (s *Server) Shutdown(ctx context.Context) { s.s.Shutdown(ctx) } -func NewListener(cfg *ServerConfig) (net.Listener, error) { - return net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)) +func NewListeners(cfg *ServerConfig) ([]net.Listener, error) { + port := strconv.Itoa(cfg.Port) + listeners := make([]net.Listener, 0, len(cfg.Hosts)) + for _, host := range cfg.Hosts { + l, err := net.Listen("tcp", net.JoinHostPort(host, port)) + if err != nil { + for _, prev := range listeners { + prev.Close() + } + return nil, err + } + listeners = append(listeners, l) + } + return listeners, nil } |
