--- /sys/src/libsec/port/tlshand.c +++ /sys/src/libsec/port/tlshand.c @@ -97,6 +97,7 @@ typedef struct Msg{ Bytes* sid; Ints* ciphers; Bytes* compressors; + Ints* sigAlgs; } clientHello; struct { int version; @@ -245,6 +246,19 @@ enum { CompressionMax }; +// extensions +enum { + ExtSigalgs = 0xd, +}; + +// signature algorithms +enum { + RSA_PKCS1_SHA1 = 0x0201, + RSA_PKCS1_SHA256 = 0x0401, + RSA_PKCS1_SHA384 = 0x0501, + RSA_PKCS1_SHA512 = 0x0601 +}; + static Algs cipherAlgs[] = { {"rc4_128", "md5", 2*(16+MD5dlen), TLS_RSA_WITH_RC4_128_MD5}, {"rc4_128", "sha1", 2*(16+SHA1dlen), TLS_RSA_WITH_RC4_128_SHA}, @@ -257,6 +271,11 @@ static uchar compressors[] = { CompressionNull, }; +static int sigAlgs[] = { + RSA_PKCS1_SHA256, + RSA_PKCS1_SHA1, +}; + static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chain); static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...)); @@ -644,6 +663,8 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ... m.u.clientHello.sid = makebytes(csid, ncsid); m.u.clientHello.ciphers = makeciphers(); m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors)); + if(c->clientVersion >= TLS12Version) + m.u.clientHello.sigAlgs = makeints(sigAlgs, nelem(sigAlgs)); if(!msgSend(c, &m, AFlush)) goto Err; msgClear(&m); @@ -874,6 +895,19 @@ msgSend(TlsConnection *c, Msg *m, int act) p[0] = n; memmove(p+1, m->u.clientHello.compressors->data, n); p += n+1; + + if(m->u.clientHello.sigAlgs != nil) { + n = m->u.clientHello.sigAlgs->len; + put16(p, 6 + 2*n); /* length of extensions */ + put16(p+2, ExtSigalgs); + put16(p+4, 2 + 2*n); /* length of extension content */ + put16(p+6, 2*n); /* length of algorithm list */ + p += 8; + for(i = 0; i < n; i++) { + put16(p, m->u.clientHello.sigAlgs->data[i]); + p += 2; + } + } break; case HServerHello: put16(p, m->u.serverHello.version); @@ -983,7 +1017,7 @@ static int msgRecv(TlsConnection *c, Msg *m) { uchar *p; - int type, n, nn, i, nsid, nrandom, nciph; + int type, n, nn, nx, i, nsid, nrandom, nciph; for(;;) { p = tlsReadN(c, 4); @@ -1101,7 +1135,40 @@ msgRecv(TlsConnection *c, Msg *m) nn = p[0]; m->u.clientHello.compressors = newbytes(nn); memmove(m->u.clientHello.compressors->data, p+1, nn); + p += nn + 1; n -= nn + 1; + + /* extensions */ + if(n == 0) + break; + if(n < 2) + goto Short; + nx = get16(p); + p += 2; + n -= 2; + while(nx > 0){ + if(n < nx || nx < 4) + goto Short; + i = get16(p); + nn = get16(p+2); + if(nx < nn+4) + goto Short; + nx -= nn+4; + p += 4; + n -= 4; + if(i == ExtSigalgs){ + if(get16(p) != nn-2) + goto Short; + p += 2; + n -= 2; + nn -= 2; + m->u.clientHello.sigAlgs = newints(nn/2); + for(i = 0; i < nn; i += 2) + m->u.clientHello.sigAlgs->data[i >> 1] = get16(&p[i]); + } + p += nn; + n -= nn; + } break; case HServerHello: if(n < 2) @@ -1254,6 +1321,7 @@ msgClear(Msg *m) freeints(m->u.clientHello.ciphers); m->u.clientHello.ciphers = nil; freebytes(m->u.clientHello.compressors); + freeints(m->u.clientHello.sigAlgs); break; case HServerHello: freebytes(m->u.clientHello.sid); @@ -1338,6 +1406,8 @@ msgPrint(char *buf, int n, Msg *m) bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n"); bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n"); bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n"); + if(m->u.clientHello.sigAlgs != nil) + bs = intsPrint(bs, be, "\tsigAlgs: ", m->u.clientHello.sigAlgs, "\n"); break; case HServerHello: bs = seprint(bs, be, "ServerHello\n");