diff --git a/pkg/api/routes.go b/pkg/api/routes.go index 2521b72..7068dc1 100644 --- a/pkg/api/routes.go +++ b/pkg/api/routes.go @@ -9,6 +9,7 @@ import ( ) var commonMiddlewares = middleware.Middlewares{ + middleware.XFF, middleware.Cors, middleware.Logger, middleware.Tracing(nextRequestID), diff --git a/pkg/middleware/xff.go b/pkg/middleware/xff.go new file mode 100644 index 0000000..5b9afa9 --- /dev/null +++ b/pkg/middleware/xff.go @@ -0,0 +1,27 @@ +package middleware + +import ( + "net" + "net/http" +) + +const xForwardedFor = "X-Forwarded-For" + +func getIP(req *http.Request) string { + ip, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + return req.RemoteAddr + } + return ip +} + +// XFF is a middleware to identifying the originating IP address using X-Forwarded-For header +func XFF(inner http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + xff := r.Header.Get(xForwardedFor) + if xff == "" { + r.Header.Set(xForwardedFor, getIP(r)) + } + inner.ServeHTTP(w, r) + }) +}