diff options
| author | nsfisis <nsfisis@gmail.com> | 2023-10-08 00:19:24 +0900 |
|---|---|---|
| committer | nsfisis <nsfisis@gmail.com> | 2023-10-08 00:19:24 +0900 |
| commit | 25bf19b1cf99cd3ebe01c5fb96a6b6b13f528f8e (patch) | |
| tree | b2da56baea86cf4026dba2957f8d21e7417bfff3 | |
| parent | bb8632a8440839de3125989ab3ed8f66d029e95c (diff) | |
| download | mioproxy-25bf19b1cf99cd3ebe01c5fb96a6b6b13f528f8e.tar.gz mioproxy-25bf19b1cf99cd3ebe01c5fb96a6b6b13f528f8e.tar.zst mioproxy-25bf19b1cf99cd3ebe01c5fb96a6b6b13f528f8e.zip | |
add filesv0.1.0
| -rw-r--r-- | .gitignore | 1 | ||||
| -rw-r--r-- | Makefile | 14 | ||||
| -rw-r--r-- | config.go | 194 | ||||
| -rw-r--r-- | go.mod | 15 | ||||
| -rw-r--r-- | go.sum | 22 | ||||
| -rw-r--r-- | main.go | 146 | ||||
| -rw-r--r-- | server.go | 128 |
7 files changed, 520 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7b2ce56 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/mioproxy diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..e3f4a1e --- /dev/null +++ b/Makefile @@ -0,0 +1,14 @@ +.PHONY: all +all: build + +.PHONY: build +build: + go build . + +.PHONY: fmt +fmt: + go fmt . + +.PHONY: clean +clean: + rm -f ./mioproxy diff --git a/config.go b/config.go new file mode 100644 index 0000000..ed4b3e0 --- /dev/null +++ b/config.go @@ -0,0 +1,194 @@ +package main + +import ( + "fmt" + "net/netip" + "net/url" + + "github.com/hashicorp/hcl/v2/hclsimple" +) + +type Config struct { + User string + Servers []ServerConfig +} + +type ServerConfig struct { + Protocol string + Host string + Port int + RedirectToHTTPS bool + ACMEChallenge *ACMEChallengeConfig + TLSCertFile string + TLSKeyFile string + Proxies []ProxyConfig +} + +type ACMEChallengeConfig struct { + Root string +} + +type ProxyConfig struct { + Name string + From ProxyFromConfig + To ProxyToConfig +} + +type ProxyFromConfig struct { + Host string +} + +type ProxyToConfig struct { + Host string + Port int +} + +type InternalHCLConfig struct { + User string `hcl:"user,optional"` + Servers []InternalHCLServerConfig `hcl:"server,block"` +} + +type InternalHCLServerConfig struct { + Protocol string `hcl:"protocol,label"` + Host string `hcl:"host"` + Port int `hcl:"port"` + RedirectToHTTPS bool `hcl:"redirect_to_https,optional"` + ACMEChallenge []InternalHCLACMEChallengeConfig `hcl:"acme_challenge,block"` + TLSCertFile string `hcl:"tls_cert_file,optional"` + TLSKeyFile string `hcl:"tls_key_file,optional"` + Proxies []InternalHCLProxyConfig `hcl:"proxy,block"` +} + +type InternalHCLACMEChallengeConfig struct { + Root string `hcl:"root"` +} + +type InternalHCLProxyConfig struct { + Name string `hcl:"name,label"` + From InternalHCLProxyFromConfig `hcl:"from,block"` + To InternalHCLProxyToConfig `hcl:"to,block"` +} + +type InternalHCLProxyFromConfig struct { + Host string `hcl:"host"` +} + +type InternalHCLProxyToConfig struct { + Host string `hcl:"host"` + Port int `hcl:"port"` +} + +func fromHCLConfigToConfig(hclConfig *InternalHCLConfig) *Config { + servers := make([]ServerConfig, len(hclConfig.Servers)) + for i, s := range hclConfig.Servers { + var acmeChallenge *ACMEChallengeConfig + if len(s.ACMEChallenge) != 0 { + acmeChallenge = &ACMEChallengeConfig{ + Root: s.ACMEChallenge[0].Root, + } + } + proxies := make([]ProxyConfig, len(s.Proxies)) + for j, p := range s.Proxies { + proxies[j] = ProxyConfig{ + Name: p.Name, + From: ProxyFromConfig{ + Host: p.From.Host, + }, + To: ProxyToConfig{ + Host: p.To.Host, + Port: p.To.Port, + }, + } + } + servers[i] = ServerConfig{ + Protocol: s.Protocol, + Host: s.Host, + Port: s.Port, + RedirectToHTTPS: s.RedirectToHTTPS, + ACMEChallenge: acmeChallenge, + TLSCertFile: s.TLSCertFile, + TLSKeyFile: s.TLSKeyFile, + Proxies: proxies, + } + } + + return &Config{ + User: hclConfig.User, + Servers: servers, + } +} + +func LoadConfig(fileName string) (*Config, error) { + var hclConfig InternalHCLConfig + err := hclsimple.DecodeFile(fileName, nil, &hclConfig) + if err != nil { + return nil, err + } + + if len(hclConfig.Servers) == 0 { + return nil, fmt.Errorf("No server blocks found") + } + if 2 < len(hclConfig.Servers) { + return nil, fmt.Errorf("Too many server blocks found") + } + + var listenHTTPS = false + var redirectToHTTPS = false + for _, server := range hclConfig.Servers { + if server.Protocol == "https" { + listenHTTPS = true + } else if server.Protocol != "http" { + return nil, fmt.Errorf("Invalid protocol %s", server.Protocol) + } + + _, err = netip.ParseAddr(server.Host) + if err != nil { + return nil, fmt.Errorf("Invalid host %s", server.Host) + } + + if len(server.ACMEChallenge) != 0 && len(server.ACMEChallenge) != 1 { + return nil, fmt.Errorf("Only one acme_challenge block is allowed") + } + if len(server.ACMEChallenge) != 0 && server.Protocol != "http" { + return nil, fmt.Errorf("accept_acme_challenge must be on http listener") + } + + if server.RedirectToHTTPS { + redirectToHTTPS = true + if server.Protocol != "http" { + return nil, fmt.Errorf("redirect_to_https must be on http listener") + } + if len(server.Proxies) != 0 { + return nil, fmt.Errorf("redirect_to_https cannot be used with proxy") + } + } + + if server.Protocol == "https" { + if server.TLSCertFile == "" { + return nil, fmt.Errorf("tls_cert_file is required for https listener") + } + if server.TLSKeyFile == "" { + return nil, fmt.Errorf("tls_key_file is required for https listener") + } + } else { + if server.TLSCertFile != "" { + return nil, fmt.Errorf("tls_cert_file is only allowed for https listener") + } + if server.TLSKeyFile != "" { + return nil, fmt.Errorf("tls_key_file is only allowed for https listener") + } + } + + for _, p := range server.Proxies { + _, err := url.Parse(fmt.Sprintf("http://%s:%d", p.To.Host, p.To.Port)) + if err != nil { + return nil, fmt.Errorf("Invalid host or port: %s:%d", p.To.Host, p.To.Port) + } + } + } + if redirectToHTTPS && !listenHTTPS { + return nil, fmt.Errorf("redirect_to_https requires https listener") + } + + return fromHCLConfigToConfig(&hclConfig), nil +} @@ -0,0 +1,15 @@ +module github.com/nsfisis/mioproxy + +go 1.20 + +require github.com/hashicorp/hcl/v2 v2.18.0 + +require ( + github.com/agext/levenshtein v1.2.1 // indirect + github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect + github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect + github.com/google/go-cmp v0.3.1 // indirect + github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 // indirect + github.com/zclconf/go-cty v1.13.0 // indirect + golang.org/x/text v0.11.0 // indirect +) @@ -0,0 +1,22 @@ +github.com/agext/levenshtein v1.2.1 h1:QmvMAjj2aEICytGiWzmxoE0x2KZvE0fvmqMOfy2tjT8= +github.com/agext/levenshtein v1.2.1/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= +github.com/apparentlymart/go-textseg/v13 v13.0.0 h1:Y+KvPE1NYz0xl601PVImeQfFyEy6iT90AvPUL1NNfNw= +github.com/apparentlymart/go-textseg/v13 v13.0.0/go.mod h1:ZK2fH7c4NqDTLtiYLvIkEghdlcqw7yxLeM89kiTRPUo= +github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY= +github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/go-test/deep v1.0.3 h1:ZrJSEWsXzPOxaZnFteGEfooLba+ju3FYIbOrS+rQd68= +github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/hashicorp/hcl/v2 v2.18.0 h1:wYnG7Lt31t2zYkcquwgKo6MWXzRUDIeIVU5naZwHLl8= +github.com/hashicorp/hcl/v2 v2.18.0/go.mod h1:ThLC89FV4p9MPW804KVbe/cEXoQ8NZEh+JtMeeGErHE= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348 h1:MtvEpTB6LX3vkb4ax0b5D2DHbNAUsen0Gx5wZoq3lV4= +github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 h1:DpOJ2HYzCv8LZP15IdmG+YdwD2luVPHITV96TkirNBM= +github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7/go.mod h1:ZXFpozHsX6DPmq2I0TCekCxypsnAUbP2oI0UX1GXzOo= +github.com/sergi/go-diff v1.0.0 h1:Kpca3qRNrduNnOQeazBd0ysaKrUJiIuISHxogkT9RPQ= +github.com/zclconf/go-cty v1.13.0 h1:It5dfKTTZHe9aeppbNOda3mN7Ag7sg6QkBNm6TkyFa0= +github.com/zclconf/go-cty v1.13.0/go.mod h1:YKQzy/7pZ7iq2jNFzy5go57xdxdWoLLpaEp4u238AE0= +golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= +golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= @@ -0,0 +1,146 @@ +package main + +import ( + "context" + "fmt" + "log" + "net" + "net/http" + "os" + "os/signal" + "os/user" + "path/filepath" + "strconv" + "sync" + "syscall" + "time" +) + +func startServer( + s *Server, + listener net.Listener, + wg *sync.WaitGroup, + sigCtx context.Context, +) { + defer wg.Done() + + go func() { + err := s.Serve(listener) + if err != nil && err != http.ErrServerClosed { + log.Fatalf("Failed to start server (%s): %s", s.Label(), err) + } + }() + + fmt.Printf("Started server (%s)\n", s.Label()) + + // Wait until we receive a signal to stop the server. + <-sigCtx.Done() + + // Gracefully shutdown the server with timeout. + timeoutCtx, cancelTimeoutCtx := context.WithTimeout(context.Background(), time.Second*10) + defer cancelTimeoutCtx() + s.Shutdown(timeoutCtx) + + fmt.Printf("Shutdown server (%s)\n", s.Label()) +} + +func downgradeToUser(uname string) error { + // Get gid and uid. + u, err := user.Lookup(uname) + if err != nil { + return err + } + // On POSIX system, gid and uid are integers. + gid, err := strconv.Atoi(u.Gid) + if err != nil { + return err + } + uid, err := strconv.Atoi(u.Uid) + if err != nil { + return err + } + + // Set gid and uid. + err = syscall.Setgid(gid) + if err != nil { + return err + } + err = syscall.Setuid(uid) + if err != nil { + return err + } + + return nil +} + +func main() { + // Check mode + if len(os.Args) == 3 && os.Args[1] == "-check" { + configFileName := os.Args[2] + _, err := LoadConfig(configFileName) + if err != nil { + log.Fatalf("%s", err) + } + return + } + + // Load config. + if len(os.Args) < 2 { + log.Fatalf("Usage: %s <config file>", os.Args[0]) + } + configFileName := os.Args[1] + config, err := LoadConfig(configFileName) + if err != nil { + log.Fatalf("Failed to load configuration: %s", err) + } + configFileDir := filepath.Dir(configFileName) + + // Set up listeners. + var listeners []net.Listener + for _, s := range config.Servers { + l, err := NewListener(&s) + if err != nil { + log.Fatalf("Failed to create listener (%s:%d): %s", s.Host, s.Port, err) + } + listeners = append(listeners, l) + } + + // Set up servers. + var servers []*Server + for _, s := range config.Servers { + // Convert relative paths to absolute paths, based on config file location. + if s.ACMEChallenge != nil { + s.ACMEChallenge.Root = filepath.Join(configFileDir, s.ACMEChallenge.Root) + } + if s.TLSCertFile != "" { + s.TLSCertFile = filepath.Join(configFileDir, s.TLSCertFile) + } + if s.TLSKeyFile != "" { + s.TLSKeyFile = filepath.Join(configFileDir, s.TLSKeyFile) + } + servers = append(servers, NewServer(&s)) + } + + // Downgrade to non-root user. + if config.User != "" { + err := downgradeToUser(config.User) + if err != nil { + log.Fatalf("Failed to downgrade to user %s: %s", config.User, err) + } + } + + // Catch signals to stop servers. + sigCtx, cancelSigCtx := signal.NotifyContext( + context.Background(), + syscall.SIGTERM, os.Interrupt, os.Kill, + ) + defer cancelSigCtx() + + // Start servers. + var wg sync.WaitGroup + for i, s := range servers { + wg.Add(1) + go startServer(s, listeners[i], &wg, sigCtx) + } + wg.Wait() +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..ad00d26 --- /dev/null +++ b/server.go @@ -0,0 +1,128 @@ +package main + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "net/http/httputil" + "net/url" +) + +type multipleReverseProxyServer struct { + rules []rewriteRule +} + +type rewriteRule struct { + fromHost string + toUrl *url.URL + proxy *httputil.ReverseProxy +} + +func newMultipleReverseProxyServer(ps []ProxyConfig) *multipleReverseProxyServer { + var rules []rewriteRule + for _, p := range ps { + targetUrl, err := url.Parse(fmt.Sprintf("http://%s:%d", p.To.Host, p.To.Port)) + if err != nil { + // This setting should be validated when loading config. + panic(err) + } + rules = append(rules, rewriteRule{ + fromHost: p.From.Host, + toUrl: targetUrl, + proxy: &httputil.ReverseProxy{ + Rewrite: func(r *httputil.ProxyRequest) { + r.SetURL(targetUrl) + r.SetXForwarded() + }, + }, + }) + } + return &multipleReverseProxyServer{ + rules: rules, + } +} + +func (s *multipleReverseProxyServer) tryServeHTTP(w http.ResponseWriter, r *http.Request) bool { + for _, rule := range s.rules { + if r.Host == rule.fromHost { + rule.proxy.ServeHTTP(w, r) + return true + } + } + return false +} + +type Server struct { + s http.Server + tlsEnabled bool +} + +func NewServer(cfg *ServerConfig) *Server { + h := http.NewServeMux() + + if cfg.ACMEChallenge != nil { + h.Handle( + "/.well-known/acme-challenge/", + http.FileServer(http.Dir(cfg.ACMEChallenge.Root)), + ) + } + + if cfg.RedirectToHTTPS { + h.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + target := r.URL + target.Scheme = "https" + target.Host = r.Host + http.Redirect(w, r, target.String(), http.StatusMovedPermanently) + }) + } else { + reverseProxyServer := newMultipleReverseProxyServer(cfg.Proxies) + h.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + found := reverseProxyServer.tryServeHTTP(w, r) + if !found { + http.NotFound(w, r) + } + }) + } + + var tlsConfig *tls.Config + if cfg.TLSCertFile != "" && cfg.TLSKeyFile != "" { + cert, err := tls.LoadX509KeyPair(cfg.TLSCertFile, cfg.TLSKeyFile) + if err != nil { + panic(err) + } + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + } + + return &Server{ + tlsEnabled: cfg.Protocol == "https", + s: http.Server{ + Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), + Handler: h, + TLSConfig: tlsConfig, + }, + } +} + +func (s *Server) Label() string { + return s.s.Addr +} + +func (s *Server) Serve(listener net.Listener) error { + if s.tlsEnabled { + return s.s.ServeTLS(listener, "", "") + } else { + return s.s.Serve(listener) + } +} + +func (s *Server) Shutdown(ctx context.Context) { + s.s.Shutdown(ctx) +} + +func NewListener(cfg *ServerConfig) (net.Listener, error) { + return net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)) +} |
