oasis/fastcgi/client.go

211 lines
5.3 KiB
Go

package fastcgi
import (
"bufio"
"bytes"
"encoding/binary"
"fmt"
"io"
"net/http"
"net/http/httputil"
"net/textproto"
"strconv"
"strings"
"sync"
)
type FCGIClient struct {
mutex sync.Mutex
rwc io.ReadWriteCloser
h Header
buf bytes.Buffer
keepAlive bool
reqId uint16
}
// Close fcgi connnection
func (client *FCGIClient) Close() {
client.rwc.Close()
}
func (client *FCGIClient) writeRecord(recType FCGIRequestType, content []byte) (err error) {
client.mutex.Lock()
defer client.mutex.Unlock()
client.buf.Reset()
// Write the record to the connection
rec := NewRecord(recType, content)
b, err := rec.toBytes()
_, err = client.rwc.Write(b)
return err
}
func (client *FCGIClient) writeEndRequest(appStatus int, protocolStatus uint8) error {
b := make([]byte, 8)
binary.BigEndian.PutUint32(b, uint32(appStatus))
b[4] = protocolStatus
return client.writeRecord(FCGI_END_REQUEST, b)
}
// Spec: https://www.mit.edu/~yandros/doc/specs/fcgi-spec.html#S3
// Name value pairs such as: SCRIPT_PATH = /some/path
// Should be encoded as such:
// Name size
// Value size
// Name
// Value
type NameValuePair struct {
// Making the length values 32 bit for ease.
// However, when encoding, the rules for
// how many bytes are used will apply.
NameLength uint32
ValueLength uint32
// Data
NameData string
ValueData string
}
func (client *FCGIClient) writePairs(recType FCGIRequestType, pairs map[string]string) error {
// Get ourselves a nice slice to work with
nvpairs := []NameValuePair{}
for k, v := range pairs {
nvpairs = append(nvpairs, NameValuePair{
NameLength: uint32(len(k)),
ValueLength: uint32(len(v)),
NameData: k,
ValueData: v,
})
}
// We'll use this to put together
// the packet
var buf bytes.Buffer
for _, p := range nvpairs {
// Let's see how many bytes we have in total.
// Since we have to leave 8 bytes for encoding
// the sizes, we'll add it to the calculation.
// If the value is larger than what we can
// handle, we'll truncate it
if (8 + p.NameLength + p.ValueLength) > maxWrite {
fmt.Println("We should not have hit this")
p.ValueLength = maxWrite - 8 - p.NameLength
p.ValueData = p.ValueData[:p.ValueLength]
}
// The high bit of name size and value size is used for signaling
// how many bytes are used to store the length/size.
// If the size is > 127, we can just use one byte,
// and the high bit will be 0, otherwise, we use
// four bytes and the high bit will be 1
// So if length is encoded in 4 bytes it would look
// something like:
// 10000000000000000000010000100000
if p.NameLength > 127 {
p.NameLength |= 1 << 31
b := make([]byte, 4)
binary.BigEndian.PutUint32(b, p.NameLength)
buf.Write(b)
} else {
buf.Write([]byte{byte(p.NameLength)})
}
if p.ValueLength > 127 {
p.ValueLength |= 1 << 31
b := make([]byte, 4)
binary.BigEndian.PutUint32(b, p.ValueLength)
buf.Write(b)
} else {
buf.Write([]byte{byte(p.ValueLength)})
}
// Now we just write our values to the buffer
buf.WriteString(p.NameData)
buf.WriteString(p.ValueData)
}
w := newWriter(client, recType)
defer w.Close()
// Send the data
w.Write(buf.Bytes())
w.Flush()
return nil
}
// Do made the request and returns a io.Reader that translates the data read
// from fcgi responder out of fcgi packet before returning it.
func (client *FCGIClient) Do(req *FCGIRequest) (http.Response, error) {
beginRequestRecord := NewBeginRequestRecord()
err := client.writeRecord(beginRequestRecord.Header.Type, beginRequestRecord.Content)
if err != nil {
return http.Response{}, err
}
err = client.writePairs(FCGI_PARAMS, req.Context)
if err != nil {
return http.Response{}, err
}
// body := newWriter(client, FCGI_STDIN)
// if req != nil {
// io.Copy(body, req)
// }
// body.Close()
r := &streamReader{c: client}
rb := bufio.NewReader(r)
tp := textproto.NewReader(rb)
resp := new(http.Response)
// Parse the first line of the response.
line, err := tp.ReadLine()
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return http.Response{}, err
}
if i := strings.IndexByte(line, ' '); i == -1 {
err = &badStringError{"malformed HTTP response", line}
} else {
resp.Proto = line[:i]
resp.Status = strings.TrimLeft(line[i+1:], " ")
}
statusCode := resp.Status
if i := strings.IndexByte(resp.Status, ' '); i != -1 {
statusCode = resp.Status[:i]
}
if len(statusCode) != 3 {
err = &badStringError{"malformed HTTP status code", statusCode}
}
resp.StatusCode, err = strconv.Atoi(statusCode)
if err != nil || resp.StatusCode < 0 {
err = &badStringError{"malformed HTTP status code", statusCode}
}
var ok bool
if resp.ProtoMajor, resp.ProtoMinor, ok = http.ParseHTTPVersion(resp.Proto); !ok {
err = &badStringError{"malformed HTTP version", resp.Proto}
}
// Parse the response headers.
mimeHeader, err := tp.ReadMIMEHeader()
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return http.Response{}, err
}
resp.Header = http.Header(mimeHeader)
// TODO: fixTransferEncoding ?
resp.TransferEncoding = resp.Header["Transfer-Encoding"]
resp.ContentLength, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
if chunked(resp.TransferEncoding) {
resp.Body = io.NopCloser(httputil.NewChunkedReader(rb))
} else {
resp.Body = io.NopCloser(rb)
}
return *resp, nil
}