go 反向代理

  • Post author:
  • Post category:其他


对go语言的反向代理还不是很熟悉,先把相关代码记录下来。

proxy.go

package main

import (
	"context"
	"errors"
	"fmt"
	"net"
	"net/http"
	"net/http/httputil"
	"strings"
	"sync"
	"sync/atomic"
	"time"
)

type Transport struct {
	*http.Transport
	lastAccess atomic.Value
}

type HTTPProxyRoundTripper struct {
	transprots map[string]*Transport

	rwmutex sync.RWMutex
}

func newHTTPProxyRoundTripper() *HTTPProxyRoundTripper {
	rt := &HTTPProxyRoundTripper{
		transprots: make(map[string]*Transport),
	}

	return rt
}

func (rt *HTTPProxyRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
	clientTag := req.Header.Get("client_tag")

	if clientTag == "" {
		return nil, errors.New("no client tag src")
	}

	req.Header.Del("client_tag")

	rt.rwmutex.Lock()
	if transport, ok := rt.transprots[clientTag]; ok {
		rt.rwmutex.Unlock()
		return transport.RoundTrip(req)
	}

	transport := &Transport{
		Transport: rt.createNewTransport(),
	}

	rt.transprots[clientTag] = transport
	rt.rwmutex.Unlock()

	return transport.RoundTrip(req)
}

func (rt *HTTPProxyRoundTripper) createNewTransport() *http.Transport {
	netDialer := &net.Dialer{
		Timeout:   30 * time.Second,
		KeepAlive: 30 * time.Second,
	}

	dialer := func(ctx context.Context, network string, addr string) (net.Conn, error) {

		s, ok := ctx.Value("backaddrs").(string)
		if !ok {
			fmt.Println("no backend address find")
			return nil, errors.New("no backend address found")
		}

		serverAddrArr := strings.Split(s, ",")

		for _, addr := range serverAddrArr {
			addr := strings.TrimSpace(addr)
			conn, err := netDialer.DialContext(ctx, network, addr)
			if err == nil {
				fmt.Println("connect to backend server success, addr is ", addr)
				return conn, nil
			}

			select {
			case <-ctx.Done():
				break
			default:

			}
		}

		fmt.Println("connetc to backend server failed, address: %v", s)
		return nil, fmt.Errorf("connect to backend server failed. addrs:%v", s)
	}

	transport := &http.Transport{
		Proxy:        nil,
		DialContext:  dialer,
		MaxIdleConns: 100,
	}

	return transport
}

type HTTPProxy struct {
	rp *httputil.ReverseProxy
}

func NewHTTPProxy() *HTTPProxy {
	transport := newHTTPProxyRoundTripper()

	proxy := &HTTPProxy{
		rp: &httputil.ReverseProxy{
			Director: func(req *http.Request) {

			},
			Transport: transport,
		},
	}

	return proxy
}

func (proxy *HTTPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	serverAddrs := r.Header.Get("backaddrs")

	fmt.Println("url: ", r.URL)
	ctx := context.WithValue(r.Context(), "backaddrs", serverAddrs)
	r = r.WithContext(ctx)
	r.URL.Scheme = "http"
	r.URL.Host = r.Host

	fmt.Println("after url: ", r.URL)

	proxy.rp.ServeHTTP(w, r)
}

func main() {
	httpProxyHandler := NewHTTPProxy()
	server := &http.Server{
		Addr:    ":8000",
		Handler: httpProxyHandler,
	}

	server.ListenAndServe()
}

test_proxy.go

package main

import (
	"fmt"
	"io/ioutil"
	"net/http"
)

func main() {
	client := &http.Client{}

	req, err := http.NewRequest("GET", "http://127.0.0.1:8000/hello", nil)
	if err != nil {
		fmt.Println("new request failed, error is ", err)
		return
	}

	req.Header.Set("client_tag", "127.0.0.1")
	req.Header.Set("backaddrs", "127.0.0.1:8001")

	resp, err := client.Do(req)
	defer resp.Body.Close()

	body, err := ioutil.ReadAll(resp.Body)
	if err != nil {
		fmt.Println("read failed", err)
		return
	}

	fmt.Println(string(body))
}



版权声明:本文为qq_35728402原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。