sm2.go raw
1 /*
2 Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 */
15
16 package sm2
17
18 // reference to ecdsa
19 import (
20 "bytes"
21 "crypto"
22 "crypto/elliptic"
23 "crypto/rand"
24 "encoding/asn1"
25 "encoding/binary"
26 "errors"
27 "io"
28 "math/big"
29
30 "github.com/tjfoc/gmsm/sm3"
31 )
32
33 var (
34 default_uid = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38}
35 C1C3C2=0
36 C1C2C3=1
37 )
38
39 type PublicKey struct {
40 elliptic.Curve
41 X, Y *big.Int
42 }
43
44 type PrivateKey struct {
45 PublicKey
46 D *big.Int
47 }
48
49 type sm2Signature struct {
50 R, S *big.Int
51 }
52 type sm2Cipher struct {
53 XCoordinate *big.Int
54 YCoordinate *big.Int
55 HASH []byte
56 CipherText []byte
57 }
58
59 // The SM2's private key contains the public key
60 func (priv *PrivateKey) Public() crypto.PublicKey {
61 return &priv.PublicKey
62 }
63
64 var errZeroParam = errors.New("zero parameter")
65 var one = new(big.Int).SetInt64(1)
66 var two = new(big.Int).SetInt64(2)
67
68 // sign format = 30 + len(z) + 02 + len(r) + r + 02 + len(s) + s, z being what follows its size, ie 02+len(r)+r+02+len(s)+s
69 func (priv *PrivateKey) Sign(random io.Reader, msg []byte, signer crypto.SignerOpts) ([]byte, error) {
70 r, s, err := Sm2Sign(priv, msg, nil, random)
71 if err != nil {
72 return nil, err
73 }
74 return asn1.Marshal(sm2Signature{r, s})
75 }
76
77 func (pub *PublicKey) Verify(msg []byte, sign []byte) bool {
78 var sm2Sign sm2Signature
79 _, err := asn1.Unmarshal(sign, &sm2Sign)
80 if err != nil {
81 return false
82 }
83 return Sm2Verify(pub, msg, default_uid, sm2Sign.R, sm2Sign.S)
84 }
85
86 func (pub *PublicKey) Sm3Digest(msg, uid []byte) ([]byte, error) {
87 if len(uid) == 0 {
88 uid = default_uid
89 }
90
91 za, err := ZA(pub, uid)
92 if err != nil {
93 return nil, err
94 }
95
96 e, err := msgHash(za, msg)
97 if err != nil {
98 return nil, err
99 }
100
101 return e.Bytes(), nil
102 }
103
104 //****************************Encryption algorithm****************************//
105 func (pub *PublicKey) EncryptAsn1(data []byte, random io.Reader) ([]byte, error) {
106 return EncryptAsn1(pub, data, random)
107 }
108
109 func (priv *PrivateKey) DecryptAsn1(data []byte) ([]byte, error) {
110 return DecryptAsn1(priv, data)
111 }
112
113 //**************************Key agreement algorithm**************************//
114 // KeyExchangeB 协商第二部,用户B调用, 返回共享密钥k
115 func KeyExchangeB(klen int, ida, idb []byte, priB *PrivateKey, pubA *PublicKey, rpri *PrivateKey, rpubA *PublicKey) (k, s1, s2 []byte, err error) {
116 return keyExchange(klen, ida, idb, priB, pubA, rpri, rpubA, false)
117 }
118
119 // KeyExchangeA 协商第二部,用户A调用,返回共享密钥k
120 func KeyExchangeA(klen int, ida, idb []byte, priA *PrivateKey, pubB *PublicKey, rpri *PrivateKey, rpubB *PublicKey) (k, s1, s2 []byte, err error) {
121 return keyExchange(klen, ida, idb, priA, pubB, rpri, rpubB, true)
122 }
123
124 //****************************************************************************//
125
126 func Sm2Sign(priv *PrivateKey, msg, uid []byte, random io.Reader) (r, s *big.Int, err error) {
127 digest, err := priv.PublicKey.Sm3Digest(msg, uid)
128 if err != nil {
129 return nil, nil, err
130 }
131 e := new(big.Int).SetBytes(digest)
132 c := priv.PublicKey.Curve
133 N := c.Params().N
134 if N.Sign() == 0 {
135 return nil, nil, errZeroParam
136 }
137 var k *big.Int
138 for { // 调整算法细节以实现SM2
139 for {
140 k, err = randFieldElement(c, random)
141 if err != nil {
142 r = nil
143 return
144 }
145 r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
146 r.Add(r, e)
147 r.Mod(r, N)
148 if r.Sign() != 0 {
149 if t := new(big.Int).Add(r, k); t.Cmp(N) != 0 {
150 break
151 }
152 }
153
154 }
155 rD := new(big.Int).Mul(priv.D, r)
156 s = new(big.Int).Sub(k, rD)
157 d1 := new(big.Int).Add(priv.D, one)
158 d1Inv := new(big.Int).ModInverse(d1, N)
159 s.Mul(s, d1Inv)
160 s.Mod(s, N)
161 if s.Sign() != 0 {
162 break
163 }
164 }
165 return
166 }
167 func Sm2Verify(pub *PublicKey, msg, uid []byte, r, s *big.Int) bool {
168 c := pub.Curve
169 N := c.Params().N
170 one := new(big.Int).SetInt64(1)
171 if r.Cmp(one) < 0 || s.Cmp(one) < 0 {
172 return false
173 }
174 if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
175 return false
176 }
177 if len(uid) == 0 {
178 uid = default_uid
179 }
180 za, err := ZA(pub, uid)
181 if err != nil {
182 return false
183 }
184 e, err := msgHash(za, msg)
185 if err != nil {
186 return false
187 }
188 t := new(big.Int).Add(r, s)
189 t.Mod(t, N)
190 if t.Sign() == 0 {
191 return false
192 }
193 var x *big.Int
194 x1, y1 := c.ScalarBaseMult(s.Bytes())
195 x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
196 x, _ = c.Add(x1, y1, x2, y2)
197
198 x.Add(x, e)
199 x.Mod(x, N)
200 return x.Cmp(r) == 0
201 }
202
203 /*
204 za, err := ZA(pub, uid)
205 if err != nil {
206 return
207 }
208 e, err := msgHash(za, msg)
209 hash=e.getBytes()
210 */
211 func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
212 c := pub.Curve
213 N := c.Params().N
214
215 if r.Sign() <= 0 || s.Sign() <= 0 {
216 return false
217 }
218 if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
219 return false
220 }
221
222 // 调整算法细节以实现SM2
223 t := new(big.Int).Add(r, s)
224 t.Mod(t, N)
225 if t.Sign() == 0 {
226 return false
227 }
228
229 var x *big.Int
230 x1, y1 := c.ScalarBaseMult(s.Bytes())
231 x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
232 x, _ = c.Add(x1, y1, x2, y2)
233
234 e := new(big.Int).SetBytes(hash)
235 x.Add(x, e)
236 x.Mod(x, N)
237 return x.Cmp(r) == 0
238 }
239
240 /*
241 * sm2密文结构如下:
242 * x
243 * y
244 * hash
245 * CipherText
246 */
247 func Encrypt(pub *PublicKey, data []byte, random io.Reader,mode int) ([]byte, error) {
248 length := len(data)
249 for {
250 c := []byte{}
251 curve := pub.Curve
252 k, err := randFieldElement(curve, random)
253 if err != nil {
254 return nil, err
255 }
256 x1, y1 := curve.ScalarBaseMult(k.Bytes())
257 x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
258 x1Buf := x1.Bytes()
259 y1Buf := y1.Bytes()
260 x2Buf := x2.Bytes()
261 y2Buf := y2.Bytes()
262 if n := len(x1Buf); n < 32 {
263 x1Buf = append(zeroByteSlice()[:32-n], x1Buf...)
264 }
265 if n := len(y1Buf); n < 32 {
266 y1Buf = append(zeroByteSlice()[:32-n], y1Buf...)
267 }
268 if n := len(x2Buf); n < 32 {
269 x2Buf = append(zeroByteSlice()[:32-n], x2Buf...)
270 }
271 if n := len(y2Buf); n < 32 {
272 y2Buf = append(zeroByteSlice()[:32-n], y2Buf...)
273 }
274 c = append(c, x1Buf...) // x分量
275 c = append(c, y1Buf...) // y分量
276 tm := []byte{}
277 tm = append(tm, x2Buf...)
278 tm = append(tm, data...)
279 tm = append(tm, y2Buf...)
280 h := sm3.Sm3Sum(tm)
281 c = append(c, h...)
282 ct, ok := kdf(length, x2Buf, y2Buf) // 密文
283 if !ok {
284 continue
285 }
286 c = append(c, ct...)
287 for i := 0; i < length; i++ {
288 c[96+i] ^= data[i]
289 }
290 switch mode{
291
292 case C1C3C2:
293 return append([]byte{0x04}, c...), nil
294 case C1C2C3:
295 c1 := make([]byte, 64)
296 c2 := make([]byte, len(c) - 96)
297 c3 := make([]byte, 32)
298 copy(c1, c[:64])//x1,y1
299 copy(c3, c[64:96])//hash
300 copy(c2, c[96:])//密文
301 ciphertext := []byte{}
302 ciphertext = append(ciphertext, c1...)
303 ciphertext = append(ciphertext, c2...)
304 ciphertext = append(ciphertext, c3...)
305 return append([]byte{0x04}, ciphertext...), nil
306 default:
307 return append([]byte{0x04}, c...), nil
308 }
309 }
310 }
311
312
313
314 func Decrypt(priv *PrivateKey, data []byte,mode int) ([]byte, error) {
315 switch mode {
316 case C1C3C2:
317 data = data[1:]
318 case C1C2C3:
319 data = data[1:]
320 c1 := make([]byte, 64)
321 c2 := make([]byte, len(data) - 96)
322 c3 := make([]byte, 32)
323 copy(c1, data[:64])//x1,y1
324 copy(c2, data[64:len(data) - 32])//密文
325 copy(c3, data[len(data) - 32:])//hash
326 c := []byte{}
327 c = append(c, c1...)
328 c = append(c, c3...)
329 c = append(c, c2...)
330 data = c
331 default:
332 data = data[1:]
333 }
334 length := len(data) - 96
335 curve := priv.Curve
336 x := new(big.Int).SetBytes(data[:32])
337 y := new(big.Int).SetBytes(data[32:64])
338 x2, y2 := curve.ScalarMult(x, y, priv.D.Bytes())
339 x2Buf := x2.Bytes()
340 y2Buf := y2.Bytes()
341 if n := len(x2Buf); n < 32 {
342 x2Buf = append(zeroByteSlice()[:32-n], x2Buf...)
343 }
344 if n := len(y2Buf); n < 32 {
345 y2Buf = append(zeroByteSlice()[:32-n], y2Buf...)
346 }
347 c, ok := kdf(length, x2Buf, y2Buf)
348 if !ok {
349 return nil, errors.New("Decrypt: failed to decrypt")
350 }
351 for i := 0; i < length; i++ {
352 c[i] ^= data[i+96]
353 }
354 tm := []byte{}
355 tm = append(tm, x2Buf...)
356 tm = append(tm, c...)
357 tm = append(tm, y2Buf...)
358 h := sm3.Sm3Sum(tm)
359 if bytes.Compare(h, data[64:96]) != 0 {
360 return c, errors.New("Decrypt: failed to decrypt")
361 }
362 return c, nil
363 }
364
365
366
367 // keyExchange 为SM2密钥交换算法的第二部和第三步复用部分,协商的双方均调用此函数计算共同的字节串
368 // klen: 密钥长度
369 // ida, idb: 协商双方的标识,ida为密钥协商算法发起方标识,idb为响应方标识
370 // pri: 函数调用者的密钥
371 // pub: 对方的公钥
372 // rpri: 函数调用者生成的临时SM2密钥
373 // rpub: 对方发来的临时SM2公钥
374 // thisIsA: 如果是A调用,文档中的协商第三步,设置为true,否则设置为false
375 // 返回 k 为klen长度的字节串
376 func keyExchange(klen int, ida, idb []byte, pri *PrivateKey, pub *PublicKey, rpri *PrivateKey, rpub *PublicKey, thisISA bool) (k, s1, s2 []byte, err error) {
377 curve := P256Sm2()
378 N := curve.Params().N
379 x2hat := keXHat(rpri.PublicKey.X)
380 x2rb := new(big.Int).Mul(x2hat, rpri.D)
381 tbt := new(big.Int).Add(pri.D, x2rb)
382 tb := new(big.Int).Mod(tbt, N)
383 if !curve.IsOnCurve(rpub.X, rpub.Y) {
384 err = errors.New("Ra not on curve")
385 return
386 }
387 x1hat := keXHat(rpub.X)
388 ramx1, ramy1 := curve.ScalarMult(rpub.X, rpub.Y, x1hat.Bytes())
389 vxt, vyt := curve.Add(pub.X, pub.Y, ramx1, ramy1)
390
391 vx, vy := curve.ScalarMult(vxt, vyt, tb.Bytes())
392 pza := pub
393 if thisISA {
394 pza = &pri.PublicKey
395 }
396 za, err := ZA(pza, ida)
397 if err != nil {
398 return
399 }
400 zero := new(big.Int)
401 if vx.Cmp(zero) == 0 || vy.Cmp(zero) == 0 {
402 err = errors.New("V is infinite")
403 }
404 pzb := pub
405 if !thisISA {
406 pzb = &pri.PublicKey
407 }
408 zb, err := ZA(pzb, idb)
409 k, ok := kdf(klen, vx.Bytes(), vy.Bytes(), za, zb)
410 if !ok {
411 err = errors.New("kdf: zero key")
412 return
413 }
414 h1 := BytesCombine(vx.Bytes(), za, zb, rpub.X.Bytes(), rpub.Y.Bytes(), rpri.X.Bytes(), rpri.Y.Bytes())
415 if !thisISA {
416 h1 = BytesCombine(vx.Bytes(), za, zb, rpri.X.Bytes(), rpri.Y.Bytes(), rpub.X.Bytes(), rpub.Y.Bytes())
417 }
418 hash := sm3.Sm3Sum(h1)
419 h2 := BytesCombine([]byte{0x02}, vy.Bytes(), hash)
420 S1 := sm3.Sm3Sum(h2)
421 h3 := BytesCombine([]byte{0x03}, vy.Bytes(), hash)
422 S2 := sm3.Sm3Sum(h3)
423 return k, S1, S2, nil
424 }
425
426 func msgHash(za, msg []byte) (*big.Int, error) {
427 e := sm3.New()
428 e.Write(za)
429 e.Write(msg)
430 return new(big.Int).SetBytes(e.Sum(nil)[:32]), nil
431 }
432
433 // ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA)
434 func ZA(pub *PublicKey, uid []byte) ([]byte, error) {
435 za := sm3.New()
436 uidLen := len(uid)
437 if uidLen >= 8192 {
438 return []byte{}, errors.New("SM2: uid too large")
439 }
440 Entla := uint16(8 * uidLen)
441 za.Write([]byte{byte((Entla >> 8) & 0xFF)})
442 za.Write([]byte{byte(Entla & 0xFF)})
443 if uidLen > 0 {
444 za.Write(uid)
445 }
446 za.Write(sm2P256ToBig(&sm2P256.a).Bytes())
447 za.Write(sm2P256.B.Bytes())
448 za.Write(sm2P256.Gx.Bytes())
449 za.Write(sm2P256.Gy.Bytes())
450
451 xBuf := pub.X.Bytes()
452 yBuf := pub.Y.Bytes()
453 if n := len(xBuf); n < 32 {
454 xBuf = append(zeroByteSlice()[:32-n], xBuf...)
455 }
456 if n := len(yBuf); n < 32 {
457 yBuf = append(zeroByteSlice()[:32-n], yBuf...)
458 }
459 za.Write(xBuf)
460 za.Write(yBuf)
461 return za.Sum(nil)[:32], nil
462 }
463
464 // 32byte
465 func zeroByteSlice() []byte {
466 return []byte{
467 0, 0, 0, 0,
468 0, 0, 0, 0,
469 0, 0, 0, 0,
470 0, 0, 0, 0,
471 0, 0, 0, 0,
472 0, 0, 0, 0,
473 0, 0, 0, 0,
474 0, 0, 0, 0,
475 }
476 }
477
478 /*
479 sm2加密,返回asn.1编码格式的密文内容
480 */
481 func EncryptAsn1(pub *PublicKey, data []byte, rand io.Reader) ([]byte, error) {
482 cipher, err := Encrypt(pub, data, rand,C1C3C2)
483 if err != nil {
484 return nil, err
485 }
486 return CipherMarshal(cipher)
487 }
488
489 /*
490 sm2解密,解析asn.1编码格式的密文内容
491 */
492 func DecryptAsn1(pub *PrivateKey, data []byte) ([]byte, error) {
493 cipher, err := CipherUnmarshal(data)
494 if err != nil {
495 return nil, err
496 }
497 return Decrypt(pub, cipher,C1C3C2)
498 }
499
500 /*
501 *sm2密文转asn.1编码格式
502 *sm2密文结构如下:
503 * x
504 * y
505 * hash
506 * CipherText
507 */
508 func CipherMarshal(data []byte) ([]byte, error) {
509 data = data[1:]
510 x := new(big.Int).SetBytes(data[:32])
511 y := new(big.Int).SetBytes(data[32:64])
512 hash := data[64:96]
513 cipherText := data[96:]
514 return asn1.Marshal(sm2Cipher{x, y, hash, cipherText})
515 }
516
517 /*
518 sm2密文asn.1编码格式转C1|C3|C2拼接格式
519 */
520 func CipherUnmarshal(data []byte) ([]byte, error) {
521 var cipher sm2Cipher
522 _, err := asn1.Unmarshal(data, &cipher)
523 if err != nil {
524 return nil, err
525 }
526 x := cipher.XCoordinate.Bytes()
527 y := cipher.YCoordinate.Bytes()
528 hash := cipher.HASH
529 if err != nil {
530 return nil, err
531 }
532 cipherText := cipher.CipherText
533 if err != nil {
534 return nil, err
535 }
536 if n := len(x); n < 32 {
537 x = append(zeroByteSlice()[:32-n], x...)
538 }
539 if n := len(y); n < 32 {
540 y = append(zeroByteSlice()[:32-n], y...)
541 }
542 c := []byte{}
543 c = append(c, x...) // x分量
544 c = append(c, y...) // y分
545 c = append(c, hash...) // x分量
546 c = append(c, cipherText...) // y分
547 return append([]byte{0x04}, c...), nil
548 }
549
550 // keXHat 计算 x = 2^w + (x & (2^w-1))
551 // 密钥协商算法辅助函数
552 func keXHat(x *big.Int) (xul *big.Int) {
553 buf := x.Bytes()
554 for i := 0; i < len(buf)-16; i++ {
555 buf[i] = 0
556 }
557 if len(buf) >= 16 {
558 c := buf[len(buf)-16]
559 buf[len(buf)-16] = c & 0x7f
560 }
561
562 r := new(big.Int).SetBytes(buf)
563 _2w := new(big.Int).SetBytes([]byte{
564 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
565 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
566 return r.Add(r, _2w)
567 }
568
569 func BytesCombine(pBytes ...[]byte) []byte {
570 len := len(pBytes)
571 s := make([][]byte, len)
572 for index := 0; index < len; index++ {
573 s[index] = pBytes[index]
574 }
575 sep := []byte("")
576 return bytes.Join(s, sep)
577 }
578
579 func intToBytes(x int) []byte {
580 var buf = make([]byte, 4)
581
582 binary.BigEndian.PutUint32(buf, uint32(x))
583 return buf
584 }
585
586 func kdf(length int, x ...[]byte) ([]byte, bool) {
587 var c []byte
588
589 ct := 1
590 h := sm3.New()
591 for i, j := 0, (length+31)/32; i < j; i++ {
592 h.Reset()
593 for _, xx := range x {
594 h.Write(xx)
595 }
596 h.Write(intToBytes(ct))
597 hash := h.Sum(nil)
598 if i+1 == j && length%32 != 0 {
599 c = append(c, hash[:length%32]...)
600 } else {
601 c = append(c, hash...)
602 }
603 ct++
604 }
605 for i := 0; i < length; i++ {
606 if c[i] != 0 {
607 return c, true
608 }
609 }
610 return c, false
611 }
612
613 func randFieldElement(c elliptic.Curve, random io.Reader) (k *big.Int, err error) {
614 if random == nil {
615 random = rand.Reader //If there is no external trusted random source,please use rand.Reader to instead of it.
616 }
617 params := c.Params()
618 b := make([]byte, params.BitSize/8+8)
619 _, err = io.ReadFull(random, b)
620 if err != nil {
621 return
622 }
623 k = new(big.Int).SetBytes(b)
624 n := new(big.Int).Sub(params.N, one)
625 k.Mod(k, n)
626 k.Add(k, one)
627 return
628 }
629
630 func GenerateKey(random io.Reader) (*PrivateKey, error) {
631 c := P256Sm2()
632 if random == nil {
633 random = rand.Reader //If there is no external trusted random source,please use rand.Reader to instead of it.
634 }
635 params := c.Params()
636 b := make([]byte, params.BitSize/8+8)
637 _, err := io.ReadFull(random, b)
638 if err != nil {
639 return nil, err
640 }
641
642 k := new(big.Int).SetBytes(b)
643 n := new(big.Int).Sub(params.N, two)
644 k.Mod(k, n)
645 k.Add(k, one)
646 priv := new(PrivateKey)
647 priv.PublicKey.Curve = c
648 priv.D = k
649 priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
650
651 return priv, nil
652 }
653
654 type zr struct {
655 io.Reader
656 }
657
658 func (z *zr) Read(dst []byte) (n int, err error) {
659 for i := range dst {
660 dst[i] = 0
661 }
662 return len(dst), nil
663 }
664
665 var zeroReader = &zr{}
666
667 func getLastBit(a *big.Int) uint {
668 return a.Bit(0)
669 }
670
671 // crypto.Decrypter
672 func (priv *PrivateKey) Decrypt(_ io.Reader, msg []byte, _ crypto.DecrypterOpts) (plaintext []byte, err error) {
673 return Decrypt(priv, msg,C1C3C2)
674 }
675