diff options
| author | nsfisis <nsfisis@gmail.com> | 2023-10-09 08:42:33 +0900 |
|---|---|---|
| committer | nsfisis <nsfisis@gmail.com> | 2023-10-09 08:42:33 +0900 |
| commit | ceb264cb65f4a62531e11b3ce666f931074b778a (patch) | |
| tree | b727df20ca1c6ef35c4dcea2798f29e19a2035c9 /server.go | |
| parent | d137a764d050e3d5296da2830a32f6d83bdb364f (diff) | |
| download | mioproxy-ceb264cb65f4a62531e11b3ce666f931074b778a.tar.gz mioproxy-ceb264cb65f4a62531e11b3ce666f931074b778a.tar.zst mioproxy-ceb264cb65f4a62531e11b3ce666f931074b778a.zip | |
support basic authv0.2.0
Diffstat (limited to 'server.go')
| -rw-r--r-- | server.go | 67 |
1 files changed, 52 insertions, 15 deletions
@@ -8,6 +8,7 @@ import ( "net/http" "net/http/httputil" "net/url" + "os" "strings" ) @@ -19,7 +20,7 @@ type rewriteRule struct { fromHost string fromPath string toUrl *url.URL - proxy *httputil.ReverseProxy + proxy http.Handler } func (r *rewriteRule) matches(host, path string) bool { @@ -33,29 +34,62 @@ func (r *rewriteRule) matches(host, path string) bool { return ret } -func newMultipleReverseProxyServer(ps []ProxyConfig) *multipleReverseProxyServer { +func basicAuthHandler(handler http.Handler, realm, username, passwordHash string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + inputUsername, inputPassword, ok := r.BasicAuth() + if !ok || inputUsername != username || !VerifyPassword(inputPassword, passwordHash) { + w.Header().Set( + "WWW-Authenticate", + fmt.Sprintf("Basic realm=\"%s\"", realm), + ) + http.Error(w, "401 unauthorized", http.StatusUnauthorized) + return + } + handler.ServeHTTP(w, r) + }) +} + +func newMultipleReverseProxyServer(ps []ProxyConfig) (*multipleReverseProxyServer, error) { var rules []rewriteRule for _, p := range ps { targetUrl, err := url.Parse(fmt.Sprintf("http://%s:%d", p.To.Host, p.To.Port)) if err != nil { - // This setting should be validated when loading config. - panic(err) + return nil, err + } + var proxy http.Handler = &httputil.ReverseProxy{ + Rewrite: func(r *httputil.ProxyRequest) { + r.SetURL(targetUrl) + r.SetXForwarded() + }, + } + if p.BasicAuth != nil { + credentialFileContent, err := os.ReadFile(p.BasicAuth.CredentialFile) + if err != nil { + return nil, err + } + usernameAndPasswordHash := strings.Split(strings.TrimSuffix(string(credentialFileContent), "\n"), ":") + if len(usernameAndPasswordHash) != 2 { + return nil, fmt.Errorf("invalid credential file format") + } + username := usernameAndPasswordHash[0] + passwordHash := usernameAndPasswordHash[1] + proxy = basicAuthHandler( + proxy, + p.BasicAuth.Realm, + username, + passwordHash, + ) } rules = append(rules, rewriteRule{ fromHost: p.From.Host, fromPath: p.From.Path, toUrl: targetUrl, - proxy: &httputil.ReverseProxy{ - Rewrite: func(r *httputil.ProxyRequest) { - r.SetURL(targetUrl) - r.SetXForwarded() - }, - }, + proxy: proxy, }) } return &multipleReverseProxyServer{ rules: rules, - } + }, nil } func (s *multipleReverseProxyServer) tryServeHTTP( @@ -77,7 +111,7 @@ type Server struct { tlsEnabled bool } -func NewServer(cfg *ServerConfig) *Server { +func NewServer(cfg *ServerConfig) (*Server, error) { h := http.NewServeMux() if cfg.ACMEChallenge != nil { @@ -95,7 +129,10 @@ func NewServer(cfg *ServerConfig) *Server { http.Redirect(w, r, target.String(), http.StatusMovedPermanently) }) } else { - reverseProxyServer := newMultipleReverseProxyServer(cfg.Proxies) + reverseProxyServer, err := newMultipleReverseProxyServer(cfg.Proxies) + if err != nil { + return nil, err + } h.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { // r.Host may have ":port" part. hostWithoutPort, _, err := net.SplitHostPort(r.Host) @@ -114,7 +151,7 @@ func NewServer(cfg *ServerConfig) *Server { if cfg.TLSCertFile != "" && cfg.TLSKeyFile != "" { cert, err := tls.LoadX509KeyPair(cfg.TLSCertFile, cfg.TLSKeyFile) if err != nil { - panic(err) + return nil, err } tlsConfig = &tls.Config{ Certificates: []tls.Certificate{cert}, @@ -128,7 +165,7 @@ func NewServer(cfg *ServerConfig) *Server { Handler: h, TLSConfig: tlsConfig, }, - } + }, nil } func (s *Server) Label() string { |
