--- /sys/src/libsec/port/tlshand.c +++ /sys/src/libsec/port/tlshand.c @@ -51,6 +51,12 @@ typedef struct Finished{ int n; } Finished; +typedef struct HandHash{ + MD5state md5; + SHAstate sha1; + SHA2_256state sha2_256; +} HandHash; + typedef struct TlsConnection{ TlsSec *sec; // security management goo int hand, ctl; // record layer file descriptors @@ -78,8 +84,7 @@ typedef struct TlsConnection{ int nsecret; // amount of secret data to init keys // for finished messages - MD5state hsmd5; // handshake hash - SHAstate hssha1; // handshake hash + HandHash hs; // handshake hash Finished finished; } TlsConnection; @@ -128,15 +133,17 @@ typedef struct TlsSec{ int vers; // final version // byte generation and handshake checksum void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int); - void (*setFinished)(TlsSec*, MD5state, SHAstate, uchar*, int); + void (*setFinished)(TlsSec*, HandHash, uchar*, int); int nfin; } TlsSec; enum { - TLSVersion = 0x0301, - SSL3Version = 0x0300, - ProtocolVersion = 0x0301, // maximum version we speak + SSL3Version = 0x0300, + TLS10Version = 0x0301, + TLS11Version = 0x0302, + TLS12Version = 0x0303, + ProtocolVersion = TLS12Version, // maximum version we speak MinProtoVersion = 0x0300, // limits on version we accept MaxProtoVersion = 0x03ff, }; @@ -273,7 +280,7 @@ static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uc static int tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd); static TlsSec* tlsSecInitc(int cvers, uchar *crandom); static int tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd); -static int tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient); +static int tlsSecFinished(TlsSec *sec, HandHash hs, uchar *fin, int nfin, int isclient); static void tlsSecOk(TlsSec *sec); static void tlsSecKill(TlsSec *sec); static void tlsSecClose(TlsSec *sec); @@ -283,8 +290,9 @@ static void setSecrets(TlsSec *sec, uchar *kd, int nkd); static int clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm); static Bytes *pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype); static Bytes *pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm); -static void tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient); -static void sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient); +static void tlsSetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient); +static void tls12SetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient); +static void sslSetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient); static void sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1); static int setVers(TlsSec *sec, int version); @@ -556,7 +564,7 @@ tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ... msgClear(&m); /* no CertificateVerify; skip to Finished */ - if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){ + if(tlsSecFinished(c->sec, c->hs, c->finished.verify, c->finished.n, 1) < 0){ tlsError(c, EInternalError, "can't set finished: %r"); goto Err; } @@ -578,7 +586,7 @@ tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ... goto Err; } - if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){ + if(tlsSecFinished(c->sec, c->hs, c->finished.verify, c->finished.n, 0) < 0){ tlsError(c, EInternalError, "can't set finished: %r"); goto Err; } @@ -747,7 +755,7 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ... // Cipherchange must occur immediately before Finished to avoid // potential hole; see section 4.3 of Wagner Schneier 1996. - if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){ + if(tlsSecFinished(c->sec, c->hs, c->finished.verify, c->finished.n, 1) < 0){ tlsError(c, EInternalError, "can't set finished 1: %r"); goto Err; } @@ -761,7 +769,7 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ... } msgClear(&m); - if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){ + if(tlsSecFinished(c->sec, c->hs, c->finished.verify, c->finished.n, 0) < 0){ fprint(2, "tlsClient nepm=%d\n", nepm); tlsError(c, EInternalError, "can't set finished 0: %r"); goto Err; @@ -803,6 +811,17 @@ Err: static uchar sendbuf[9000], *sendp; +static void +msgHash(TlsConnection *c, uchar *p, int n) +{ + md5(p, n, 0, &c->hs.md5); + sha1(p, n, 0, &c->hs.sha1); + if(c->version >= TLS12Version) + sha2_256(p, n, 0, &c->hs.sha2_256); + else + memset(&c->hs.sha2_256, 0, sizeof c->hs.sha2_256); +} + static int msgSend(TlsConnection *c, Msg *m, int act) { @@ -914,8 +933,7 @@ msgSend(TlsConnection *c, Msg *m, int act) // remember hash of Handshake messages if(m->tag != HHelloRequest) { - md5(sendp, n, 0, &c->hsmd5); - sha1(sendp, n, 0, &c->hssha1); + msgHash(c, sendp, n); } sendp = p; @@ -991,8 +1009,7 @@ msgRecv(TlsConnection *c, Msg *m) p = tlsReadN(c, n); if(p == nil) return 0; - md5(p, n, 0, &c->hsmd5); - sha1(p, n, 0, &c->hssha1); + msgHash(c, p, n); m->tag = HClientHello; if(n < 22) goto Short; @@ -1030,15 +1047,13 @@ msgRecv(TlsConnection *c, Msg *m) goto Ok; } - md5(p, 4, 0, &c->hsmd5); - sha1(p, 4, 0, &c->hssha1); + msgHash(c, p, 4); p = tlsReadN(c, n); if(p == nil) return 0; - md5(p, n, 0, &c->hsmd5); - sha1(p, n, 0, &c->hssha1); + msgHash(c, p, n); m->tag = type; @@ -1388,14 +1403,19 @@ setVersion(TlsConnection *c, int version) return -1; if(version > c->version) version = c->version; - if(version == SSL3Version) { - c->version = version; + switch(version) { + case SSL3Version: c->finished.n = SSL3FinishedLen; - }else if(version == TLSVersion){ - c->version = version; + break; + case TLS10Version: + case TLS11Version: + case TLS12Version: c->finished.n = TLSFinishedLen; - }else + break; + default: return -1; + } + c->version = version; c->verset = 1; return fprint(c->ctl, "version 0x%x", version); } @@ -1721,6 +1741,32 @@ tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, u } } +static void +tlsPsha2_256(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed, int nseed) +{ + uchar ai[SHA2_256dlen], tmp[SHA2_256dlen]; + int n; + SHAstate *s; + + // generate a1 + s = hmac_sha2_256(label, nlabel, key, nkey, nil, nil); + hmac_sha2_256(seed, nseed, key, nkey, ai, s); + + while(nbuf > 0) { + s = hmac_sha2_256(ai, SHA2_256dlen, key, nkey, nil, nil); + s = hmac_sha2_256(label, nlabel, key, nkey, nil, s); + hmac_sha2_256(seed, nseed, key, nkey, tmp, s); + n = SHA2_256dlen; + if(n > nbuf) + n = nbuf; + memmove(buf, tmp, n); + buf += n; + nbuf -= n; + hmac_sha2_256(ai, SHA2_256dlen, key, nkey, tmp, nil); + memmove(ai, tmp, SHA2_256dlen); + } +} + // fill buf with md5(args)^sha1(args) static void tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1) @@ -1735,6 +1781,17 @@ tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, in tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1); } +void +tls12PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1) +{ + uchar seed[2*RandomSize]; + int nlabel = strlen(label); + + memmove(seed, seed0, nseed0); + memmove(seed+nseed0, seed1, nseed1); + tlsPsha2_256(buf, nbuf, key, nkey, (uchar*)label, nlabel, seed, nseed0+nseed1); +} + /* * for setting server session id's */ @@ -1833,16 +1890,16 @@ Err: } static int -tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient) +tlsSecFinished(TlsSec *sec, HandHash hs, uchar *fin, int nfin, int isclient) { if(sec->nfin != nfin){ sec->ok = -1; werrstr("invalid finished exchange"); return -1; } - md5.malloced = 0; - sha1.malloced = 0; - (*sec->setFinished)(sec, md5, sha1, fin, isclient); + hs.md5.malloced = 0; + hs.sha1.malloced = 0; + (*sec->setFinished)(sec, hs, fin, isclient); return 1; } @@ -1875,15 +1932,24 @@ tlsSecClose(TlsSec *sec) static int setVers(TlsSec *sec, int v) { - if(v == SSL3Version){ + switch(v){ + case SSL3Version: sec->setFinished = sslSetFinished; sec->nfin = SSL3FinishedLen; sec->prf = sslPRF; - }else if(v == TLSVersion){ + break; + case TLS10Version: + case TLS11Version: sec->setFinished = tlsSetFinished; sec->nfin = TLSFinishedLen; sec->prf = tlsPRF; - }else{ + break; + case TLS12Version: + sec->setFinished = tls12SetFinished; + sec->nfin = TLSFinishedLen; + sec->prf = tls12PRF; + break; + default: werrstr("invalid version"); return -1; } @@ -1973,7 +2039,7 @@ clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm) } static void -sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient) +sslSetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient) { DigestState *s; uchar h0[MD5dlen], h1[SHA1dlen], pad[48]; @@ -1984,21 +2050,21 @@ sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, in else label = "SRVR"; - md5((uchar*)label, 4, nil, &hsmd5); - md5(sec->sec, MasterSecretSize, nil, &hsmd5); + md5((uchar*)label, 4, nil, &hs.md5); + md5(sec->sec, MasterSecretSize, nil, &hs.md5); memset(pad, 0x36, 48); - md5(pad, 48, nil, &hsmd5); - md5(nil, 0, h0, &hsmd5); + md5(pad, 48, nil, &hs.md5); + md5(nil, 0, h0, &hs.md5); memset(pad, 0x5C, 48); s = md5(sec->sec, MasterSecretSize, nil, nil); s = md5(pad, 48, nil, s); md5(h0, MD5dlen, finished, s); - sha1((uchar*)label, 4, nil, &hssha1); - sha1(sec->sec, MasterSecretSize, nil, &hssha1); + sha1((uchar*)label, 4, nil, &hs.sha1); + sha1(sec->sec, MasterSecretSize, nil, &hs.sha1); memset(pad, 0x36, 40); - sha1(pad, 40, nil, &hssha1); - sha1(nil, 0, h1, &hssha1); + sha1(pad, 40, nil, &hs.sha1); + sha1(nil, 0, h1, &hs.sha1); memset(pad, 0x5C, 40); s = sha1(sec->sec, MasterSecretSize, nil, nil); s = sha1(pad, 40, nil, s); @@ -2007,20 +2073,37 @@ sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, in // fill "finished" arg with md5(args)^sha1(args) static void -tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient) +tlsSetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient) { uchar h0[MD5dlen], h1[SHA1dlen]; char *label; // get current hash value, but allow further messages to be hashed in - md5(nil, 0, h0, &hsmd5); - sha1(nil, 0, h1, &hssha1); + md5(nil, 0, h0, &hs.md5); + sha1(nil, 0, h1, &hs.sha1); + + if(isClient) + label = "client finished"; + else + label = "server finished"; + (*sec->prf)(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen); +} + +// fill "finished" arg with sha256(args) +static void +tls12SetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient) +{ + uchar h[SHA2_256dlen]; + char *label; + + // get current hash value, but allow further messages to be hashed in + sha2_256(nil, 0, h, &hs.sha2_256); if(isClient) label = "client finished"; else label = "server finished"; - tlsPRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen); + tlsPsha2_256(finished, TLSFinishedLen, sec->sec, MasterSecretSize, (uchar*)label, strlen(label), h, SHA2_256dlen); } static void