aboutsummaryrefslogtreecommitdiff
path: root/rediscache/main.go
blob: 04ca622a8aaac7639fc11950b41f69f409587048 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
package rediscache

import (
	"bytes"
	"log"
	"net/http"
	"os"

	"github.com/gomodule/redigo/redis"
	"github.com/vmihailenco/msgpack"
)

// Logger is the default logger used for this package, feel free to
// override.
var Logger = log.New(os.Stderr, "REDIS: ", log.LstdFlags)

// redisHTTPResponseWriter is essentially a fake http.ResponseWriter that
// is going to let us suck out information and chuck it into redis
// implements the interface as defined in net/http.
type redisHTTPResponseWriter struct {
	Headers    http.Header
	StatusCode int
	Data       []byte

	buf *bytes.Buffer
}

// Simply for satisfying the http.ResponseWriter interface.
func (rw *redisHTTPResponseWriter) Header() http.Header {
	if rw.Headers == nil {
		rw.Headers = http.Header{}
	}

	return rw.Headers
}

// Writes to the internal buffer.
func (rw *redisHTTPResponseWriter) Write(msg []byte) (int, error) {
	if rw.buf == nil {
		rw.buf = &bytes.Buffer{}
	}

	return rw.buf.Write(msg)
}

// Simply for satisfying the http.ResponseWriter interface.
func (rw *redisHTTPResponseWriter) WriteHeader(code int) {
	rw.StatusCode = code
}

// WriteData takes the internal buffer and writes out the entire
// contents to the 'Data' field for storage in Redis.
func (rw *redisHTTPResponseWriter) WriteData() {
	rw.Data = rw.buf.Bytes()
}

// HandleWithParams is the same as Handle but caches for the GET params
// rather than discarding them.
func HandleWithParams(pool *redis.Pool, key string, next http.Handler) http.Handler {
	return handle(pool, key, true, next)
}

// Handle is a Simple function that will cache the response for given handler in
// redis and instead of responding with the result from the handler it will
// simply dump the contents of the redis key if it exists.
func Handle(pool *redis.Pool, key string, next http.Handler) http.Handler {
	return handle(pool, key, false, next)
}

func handle(pool *redis.Pool, key string, params bool, next http.Handler) http.Handler { //nolint:funlen
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		client := pool.Get()
		defer client.Close()

		subkey := r.URL.Path
		if params {
			subkey = r.URL.Path + "?" + r.URL.RawQuery
		}

	content:
		data, err := client.Do("HGET", key, subkey)
		if err != nil {
			// Assume something bad has happened with redis, we're
			// just going to log this and then pass through the
			// request as normal.
			Logger.Println("ERROR: ", err)
			next.ServeHTTP(w, r)

			return
		} else if data == nil {
			rw := &redisHTTPResponseWriter{}
			next.ServeHTTP(rw, r)

			rw.WriteData()
			b, err := msgpack.Marshal(rw)
			if err != nil {
				Logger.Println("ERROR: marshaling: ", err)

				return
			}
			_, err = client.Do("HSET", key, subkey, b)
			if err != nil {
				Logger.Println("ERROR: during set: ", err)

				return
			}

			// We got the content, let's go back around again and dump
			// it out from redis
			goto content
		}

		rw := &redisHTTPResponseWriter{}

		err = msgpack.Unmarshal(data.([]byte), rw)
		if err != nil {
			Logger.Println("ERROR: unmarshaling: ", err)

			return
		}

		if rw.Headers != nil {
			for k, v := range rw.Headers {
				w.Header()[k] = v
			}
		}

		if rw.StatusCode != 0 {
			w.WriteHeader(rw.StatusCode)
		}
		_, _ = w.Write(rw.Data)
	})
}