aboutsummaryrefslogtreecommitdiff
path: root/rediscache/main.go
blob: e7564c82d7f0d1d94ab91b9046b59af4503b67fb (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
package rediscache

import (
	"bytes"
	"log"
	"net/http"
	"os"

	"github.com/gomodule/redigo/redis"
	"github.com/vmihailenco/msgpack"
)

// Logger is the default logger used for this package, feel free to
// override.
var Logger = log.New(os.Stderr, "REDIS: ", log.LstdFlags)

// redisHTTPResponseWriter is essentially a fake http.ResponseWriter that
// is going to let us suck out information and chuck it into redis
// implements the interface as defined in net/http.
type redisHTTPResponseWriter struct {
	Headers    http.Header
	StatusCode int
	Data       []byte

	buf *bytes.Buffer
}

// Simply for satisfying the http.ResponseWriter interface.
func (rw *redisHTTPResponseWriter) Header() http.Header {
	if rw.Headers == nil {
		rw.Headers = http.Header{}
	}

	return rw.Headers
}

// Writes to the internal buffer.
func (rw *redisHTTPResponseWriter) Write(msg []byte) (int, error) {
	if rw.buf == nil {
		rw.buf = &bytes.Buffer{}
	}

	return rw.buf.Write(msg)
}

// Simply for satisfying the http.ResponseWriter interface.
func (rw *redisHTTPResponseWriter) WriteHeader(code int) {
	rw.StatusCode = code
}

// WriteData takes the internal buffer and writes out the entire
// contents to the 'Data' field for storage in Redis.
func (rw *redisHTTPResponseWriter) WriteData() {
	rw.Data = rw.buf.Bytes()
}

// Simple function that will cache the response for given handler in redis
// and instead of responding with the result from the handler it will
// simply dump the contents of the redis key if it exists.
func Handle(pool *redis.Pool, next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		client := pool.Get()
		defer client.Close()

	content:
		data, err := client.Do("GET", r.URL.Path)
		if err != nil {
			// Assume something bad has happened with redis, we're
			// just going to log this and then pass through the
			// request as normal.
			Logger.Println("ERROR: ", err)
			next.ServeHTTP(w, r)

			return
		} else if data == nil {
			rw := &redisHTTPResponseWriter{}
			next.ServeHTTP(rw, r)

			rw.WriteData()
			b, err := msgpack.Marshal(rw)
			if err != nil {
				Logger.Println("ERROR: marshaling: ", err)

				return
			}
			_, err = client.Do("SET", r.URL.Path, b)
			if err != nil {
				Logger.Println("ERROR: during set: ", err)

				return
			}

			// We got the content, let's go back around again and dump
			// it out from redis
			goto content
		}

		rw := &redisHTTPResponseWriter{}

		err = msgpack.Unmarshal(data.([]byte), rw)
		if err != nil {
			Logger.Println("ERROR: unmarshaling: ", err)

			return
		}

		if rw.Headers != nil {
			for k, v := range rw.Headers {
				w.Header()[k] = v
			}
		}

		if rw.StatusCode != 0 {
			w.WriteHeader(rw.StatusCode)
		}
		_, _ = w.Write(rw.Data)
	})
}