sm3.go raw

   1  /*
   2  Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
   3  Licensed under the Apache License, Version 2.0 (the "License");
   4  you may not use this file except in compliance with the License.
   5  You may obtain a copy of the License at
   6  
   7                   http://www.apache.org/licenses/LICENSE-2.0
   8  
   9  Unless required by applicable law or agreed to in writing, software
  10  distributed under the License is distributed on an "AS IS" BASIS,
  11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12  See the License for the specific language governing permissions and
  13  limitations under the License.
  14  */
  15  
  16  package sm3
  17  
  18  import (
  19  	"encoding/binary"
  20  	"hash"
  21  )
  22  
  23  type SM3 struct {
  24  	digest      [8]uint32 // digest represents the partial evaluation of V
  25  	length      uint64    // length of the message
  26  	unhandleMsg []byte    // uint8  //
  27  }
  28  
  29  func (sm3 *SM3) ff0(x, y, z uint32) uint32 { return x ^ y ^ z }
  30  
  31  func (sm3 *SM3) ff1(x, y, z uint32) uint32 { return (x & y) | (x & z) | (y & z) }
  32  
  33  func (sm3 *SM3) gg0(x, y, z uint32) uint32 { return x ^ y ^ z }
  34  
  35  func (sm3 *SM3) gg1(x, y, z uint32) uint32 { return (x & y) | (^x & z) }
  36  
  37  func (sm3 *SM3) p0(x uint32) uint32 { return x ^ sm3.leftRotate(x, 9) ^ sm3.leftRotate(x, 17) }
  38  
  39  func (sm3 *SM3) p1(x uint32) uint32 { return x ^ sm3.leftRotate(x, 15) ^ sm3.leftRotate(x, 23) }
  40  
  41  func (sm3 *SM3) leftRotate(x uint32, i uint32) uint32 { return x<<(i%32) | x>>(32-i%32) }
  42  
  43  func (sm3 *SM3) pad() []byte {
  44  	msg := sm3.unhandleMsg
  45  	msg = append(msg, 0x80) // Append '1'
  46  	blockSize := 64         // Append until the resulting message length (in bits) is congruent to 448 (mod 512)
  47  	for len(msg)%blockSize != 56 {
  48  		msg = append(msg, 0x00)
  49  	}
  50  	// append message length
  51  	msg = append(msg, uint8(sm3.length>>56&0xff))
  52  	msg = append(msg, uint8(sm3.length>>48&0xff))
  53  	msg = append(msg, uint8(sm3.length>>40&0xff))
  54  	msg = append(msg, uint8(sm3.length>>32&0xff))
  55  	msg = append(msg, uint8(sm3.length>>24&0xff))
  56  	msg = append(msg, uint8(sm3.length>>16&0xff))
  57  	msg = append(msg, uint8(sm3.length>>8&0xff))
  58  	msg = append(msg, uint8(sm3.length>>0&0xff))
  59  
  60  	if len(msg)%64 != 0 {
  61  		panic("------SM3 Pad: error msgLen =")
  62  	}
  63  	return msg
  64  }
  65  
  66  func (sm3 *SM3) update(msg []byte) {
  67  	var w [68]uint32
  68  	var w1 [64]uint32
  69  
  70  	a, b, c, d, e, f, g, h := sm3.digest[0], sm3.digest[1], sm3.digest[2], sm3.digest[3], sm3.digest[4], sm3.digest[5], sm3.digest[6], sm3.digest[7]
  71  	for len(msg) >= 64 {
  72  		for i := 0; i < 16; i++ {
  73  			w[i] = binary.BigEndian.Uint32(msg[4*i : 4*(i+1)])
  74  		}
  75  		for i := 16; i < 68; i++ {
  76  			w[i] = sm3.p1(w[i-16]^w[i-9]^sm3.leftRotate(w[i-3], 15)) ^ sm3.leftRotate(w[i-13], 7) ^ w[i-6]
  77  		}
  78  		for i := 0; i < 64; i++ {
  79  			w1[i] = w[i] ^ w[i+4]
  80  		}
  81  		A, B, C, D, E, F, G, H := a, b, c, d, e, f, g, h
  82  		for i := 0; i < 16; i++ {
  83  			SS1 := sm3.leftRotate(sm3.leftRotate(A, 12)+E+sm3.leftRotate(0x79cc4519, uint32(i)), 7)
  84  			SS2 := SS1 ^ sm3.leftRotate(A, 12)
  85  			TT1 := sm3.ff0(A, B, C) + D + SS2 + w1[i]
  86  			TT2 := sm3.gg0(E, F, G) + H + SS1 + w[i]
  87  			D = C
  88  			C = sm3.leftRotate(B, 9)
  89  			B = A
  90  			A = TT1
  91  			H = G
  92  			G = sm3.leftRotate(F, 19)
  93  			F = E
  94  			E = sm3.p0(TT2)
  95  		}
  96  		for i := 16; i < 64; i++ {
  97  			SS1 := sm3.leftRotate(sm3.leftRotate(A, 12)+E+sm3.leftRotate(0x7a879d8a, uint32(i)), 7)
  98  			SS2 := SS1 ^ sm3.leftRotate(A, 12)
  99  			TT1 := sm3.ff1(A, B, C) + D + SS2 + w1[i]
 100  			TT2 := sm3.gg1(E, F, G) + H + SS1 + w[i]
 101  			D = C
 102  			C = sm3.leftRotate(B, 9)
 103  			B = A
 104  			A = TT1
 105  			H = G
 106  			G = sm3.leftRotate(F, 19)
 107  			F = E
 108  			E = sm3.p0(TT2)
 109  		}
 110  		a ^= A
 111  		b ^= B
 112  		c ^= C
 113  		d ^= D
 114  		e ^= E
 115  		f ^= F
 116  		g ^= G
 117  		h ^= H
 118  		msg = msg[64:]
 119  	}
 120  	sm3.digest[0], sm3.digest[1], sm3.digest[2], sm3.digest[3], sm3.digest[4], sm3.digest[5], sm3.digest[6], sm3.digest[7] = a, b, c, d, e, f, g, h
 121  }
 122  func (sm3 *SM3) update2(msg []byte,) [8]uint32 {
 123  	var w [68]uint32
 124  	var w1 [64]uint32
 125  
 126  	a, b, c, d, e, f, g, h := sm3.digest[0], sm3.digest[1], sm3.digest[2], sm3.digest[3], sm3.digest[4], sm3.digest[5], sm3.digest[6], sm3.digest[7]
 127  	for len(msg) >= 64 {
 128  		for i := 0; i < 16; i++ {
 129  			w[i] = binary.BigEndian.Uint32(msg[4*i : 4*(i+1)])
 130  		}
 131  		for i := 16; i < 68; i++ {
 132  			w[i] = sm3.p1(w[i-16]^w[i-9]^sm3.leftRotate(w[i-3], 15)) ^ sm3.leftRotate(w[i-13], 7) ^ w[i-6]
 133  		}
 134  		for i := 0; i < 64; i++ {
 135  			w1[i] = w[i] ^ w[i+4]
 136  		}
 137  		A, B, C, D, E, F, G, H := a, b, c, d, e, f, g, h
 138  		for i := 0; i < 16; i++ {
 139  			SS1 := sm3.leftRotate(sm3.leftRotate(A, 12)+E+sm3.leftRotate(0x79cc4519, uint32(i)), 7)
 140  			SS2 := SS1 ^ sm3.leftRotate(A, 12)
 141  			TT1 := sm3.ff0(A, B, C) + D + SS2 + w1[i]
 142  			TT2 := sm3.gg0(E, F, G) + H + SS1 + w[i]
 143  			D = C
 144  			C = sm3.leftRotate(B, 9)
 145  			B = A
 146  			A = TT1
 147  			H = G
 148  			G = sm3.leftRotate(F, 19)
 149  			F = E
 150  			E = sm3.p0(TT2)
 151  		}
 152  		for i := 16; i < 64; i++ {
 153  			SS1 := sm3.leftRotate(sm3.leftRotate(A, 12)+E+sm3.leftRotate(0x7a879d8a, uint32(i)), 7)
 154  			SS2 := SS1 ^ sm3.leftRotate(A, 12)
 155  			TT1 := sm3.ff1(A, B, C) + D + SS2 + w1[i]
 156  			TT2 := sm3.gg1(E, F, G) + H + SS1 + w[i]
 157  			D = C
 158  			C = sm3.leftRotate(B, 9)
 159  			B = A
 160  			A = TT1
 161  			H = G
 162  			G = sm3.leftRotate(F, 19)
 163  			F = E
 164  			E = sm3.p0(TT2)
 165  		}
 166  		a ^= A
 167  		b ^= B
 168  		c ^= C
 169  		d ^= D
 170  		e ^= E
 171  		f ^= F
 172  		g ^= G
 173  		h ^= H
 174  		msg = msg[64:]
 175  	}
 176  	var digest [8]uint32
 177  	digest[0], digest[1], digest[2], digest[3], digest[4], digest[5], digest[6], digest[7] = a, b, c, d, e, f, g, h
 178  	return digest
 179  }
 180  
 181  // 创建哈希计算实例
 182  func New() hash.Hash {
 183  	var sm3 SM3
 184  
 185  	sm3.Reset()
 186  	return &sm3
 187  }
 188  
 189  // BlockSize returns the hash's underlying block size.
 190  // The Write method must be able to accept any amount
 191  // of data, but it may operate more efficiently if all writes
 192  // are a multiple of the block size.
 193  func (sm3 *SM3) BlockSize() int { return 64 }
 194  
 195  // Size returns the number of bytes Sum will return.
 196  func (sm3 *SM3) Size() int { return 32 }
 197  
 198  // Reset clears the internal state by zeroing bytes in the state buffer.
 199  // This can be skipped for a newly-created hash state; the default zero-allocated state is correct.
 200  func (sm3 *SM3) Reset() {
 201  	// Reset digest
 202  	sm3.digest[0] = 0x7380166f
 203  	sm3.digest[1] = 0x4914b2b9
 204  	sm3.digest[2] = 0x172442d7
 205  	sm3.digest[3] = 0xda8a0600
 206  	sm3.digest[4] = 0xa96f30bc
 207  	sm3.digest[5] = 0x163138aa
 208  	sm3.digest[6] = 0xe38dee4d
 209  	sm3.digest[7] = 0xb0fb0e4e
 210  
 211  	sm3.length = 0 // Reset numberic states
 212  	sm3.unhandleMsg = []byte{}
 213  }
 214  
 215  // Write (via the embedded io.Writer interface) adds more data to the running hash.
 216  // It never returns an error.
 217  func (sm3 *SM3) Write(p []byte) (int, error) {
 218  	toWrite := len(p)
 219  	sm3.length += uint64(len(p) * 8)
 220  	msg := append(sm3.unhandleMsg, p...)
 221  	nblocks := len(msg) / sm3.BlockSize()
 222  	sm3.update(msg)
 223  	// Update unhandleMsg
 224  	sm3.unhandleMsg = msg[nblocks*sm3.BlockSize():]
 225  
 226  	return toWrite, nil
 227  }
 228  
 229  // 返回SM3哈希算法摘要值
 230  // Sum appends the current hash to b and returns the resulting slice.
 231  // It does not change the underlying hash state.
 232  func (sm3 *SM3) Sum(in []byte) []byte {
 233  	_, _ = sm3.Write(in)
 234  	msg := sm3.pad()
 235  	//Finalize
 236  	digest := sm3.update2(msg)
 237  
 238  	// save hash to in
 239  	needed := sm3.Size()
 240  	if cap(in)-len(in) < needed {
 241  		newIn := make([]byte, len(in), len(in)+needed)
 242  		copy(newIn, in)
 243  		in = newIn
 244  	}
 245  	out := in[len(in) : len(in)+needed]
 246  	for i := 0; i < 8; i++ {
 247  		binary.BigEndian.PutUint32(out[i*4:], digest[i])
 248  	}
 249  	return out
 250  
 251  }
 252  
 253  func Sm3Sum(data []byte) []byte {
 254  	var sm3 SM3
 255  
 256  	sm3.Reset()
 257  	_, _ = sm3.Write(data)
 258  	return sm3.Sum(nil)
 259  }
 260