1616import sys
1717import time
1818
19- from .script import CScript
19+ from .script import CScript , CScriptWitness
2020
2121from .serialize import *
2222
@@ -304,11 +304,44 @@ def from_txout(cls, txout):
304304 """Create a fullly mutable copy of an existing TxOut"""
305305 return cls (txout .nValue , txout .scriptPubKey )
306306
307+
308+ class CTxInWitness (ImmutableSerializable ):
309+ """Witness data for a transaction input. """
310+ __slots__ = ['scriptWitness' ]
311+
312+ def __init__ (self , scriptWitness = CScriptWitness ()):
313+ object .__setattr__ (self , 'scriptWitness' , scriptWitness )
314+
315+ @classmethod
316+ def stream_deserialize (cls , f ):
317+ scriptWitness = CScriptWitness .stream_deserialize (f )
318+ return cls (scriptWitness )
319+
320+ def stream_serialize (self , f ):
321+ self .scriptWitness .stream_serialize (f )
322+
323+ def __repr__ (self ):
324+ return "CTxInWitness(%s)" % (repr (self .scriptWitness ))
325+
326+ @classmethod
327+ def from_txinwitness (cls , txinwitness ):
328+ """Create an immutable copy of an existing TxInWitness
329+
330+ If txin is already immutable (txin.__class__ is CTxIn) it is returned
331+ directly.
332+ """
333+ if txinwitness .__class__ is CTxInWitness :
334+ return txinwitness
335+
336+ else :
337+ return cls (txinwitness .scriptWitness )
338+
339+
307340class CTransaction (ImmutableSerializable ):
308341 """A transaction"""
309- __slots__ = ['nVersion' , 'vin' , 'vout' , 'nLockTime' ]
342+ __slots__ = ['nVersion' , 'vin' , 'vout' , 'nLockTime' , 'wit' ]
310343
311- def __init__ (self , vin = (), vout = (), nLockTime = 0 , nVersion = 1 ):
344+ def __init__ (self , vin = (), vout = (), nLockTime = 0 , nVersion = 1 , witness = () ):
312345 """Create a new transaction
313346
314347 vin and vout are iterables of transaction inputs and outputs
@@ -322,26 +355,55 @@ def __init__(self, vin=(), vout=(), nLockTime=0, nVersion=1):
322355 object .__setattr__ (self , 'nVersion' , nVersion )
323356 object .__setattr__ (self , 'vin' , tuple (CTxIn .from_txin (txin ) for txin in vin ))
324357 object .__setattr__ (self , 'vout' , tuple (CTxOut .from_txout (txout ) for txout in vout ))
358+ object .__setattr__ (self , 'wit' ,
359+ tuple (CTxInWitness .from_txinwitness (witness ) for txinwitness in
360+ witness ))
325361
326362 @classmethod
327363 def stream_deserialize (cls , f ):
328364 nVersion = struct .unpack (b"<i" , ser_read (f ,4 ))[0 ]
329- vin = VectorSerializer .stream_deserialize (CTxIn , f )
330- vout = VectorSerializer .stream_deserialize (CTxOut , f )
331- nLockTime = struct .unpack (b"<I" , ser_read (f ,4 ))[0 ]
332- return cls (vin , vout , nLockTime , nVersion )
365+ pos = f .tell ()
366+ markerbyte = struct .unpack (b'B' , ser_read (f , 1 ))[0 ]
367+ if markerbyte == 0 :
368+ flagbyte = struct .unpack (b'B' , ser_read (f , 1 ))[0 ]
369+ if flagbyte != 1 :
370+ raise DeserializationFormatError
371+ vin = VectorSerializer .stream_deserialize (CTxIn , f )
372+ vout = VectorSerializer .stream_deserialize (CTxOut , f )
373+ wit = VectorSerializer .stream_deserialize (CTxInWitness , f )
374+ nLockTime = struct .unpack (b"<I" , ser_read (f ,4 ))[0 ]
375+ return cls (vin , vout , nLockTime , nVersion , wit )
376+ else :
377+ f .seek (pos ) # put marker byte back, since we don't have peek
378+ vin = VectorSerializer .stream_deserialize (CTxIn , f )
379+ vout = VectorSerializer .stream_deserialize (CTxOut , f )
380+ nLockTime = struct .unpack (b"<I" , ser_read (f ,4 ))[0 ]
381+ return cls (vin , vout , nLockTime , nVersion )
382+
333383
334384 def stream_serialize (self , f ):
335- f .write (struct .pack (b"<i" , self .nVersion ))
336- VectorSerializer .stream_serialize (CTxIn , self .vin , f )
337- VectorSerializer .stream_serialize (CTxOut , self .vout , f )
338- f .write (struct .pack (b"<I" , self .nLockTime ))
385+ if self .wit :
386+ if len (self .wit ) != len (self .vin ):
387+ raise SerializationMissingWitnessError
388+ f .write (struct .pack (b"<i" , self .nVersion ))
389+ f .write (b'\x00 ' ) # Marker
390+ f .write (b'\x01 ' ) # Flag
391+ VectorSerializer .stream_serialize (CTxIn , self .vin , f )
392+ VectorSerializer .stream_serialize (CTxOut , self .vout , f )
393+ for w in self .wit : w .stream_serialize (f )
394+ f .write (struct .pack (b"<I" , self .nLockTime ))
395+ else :
396+ f .write (struct .pack (b"<i" , self .nVersion ))
397+ VectorSerializer .stream_serialize (CTxIn , self .vin , f )
398+ VectorSerializer .stream_serialize (CTxOut , self .vout , f )
399+ f .write (struct .pack (b"<I" , self .nLockTime ))
339400
340401 def is_coinbase (self ):
341402 return len (self .vin ) == 1 and self .vin [0 ].prevout .is_null ()
342403
343404 def __repr__ (self ):
344- return "CTransaction(%r, %r, %i, %i)" % (self .vin , self .vout , self .nLockTime , self .nVersion )
405+ return "CTransaction(%r, %r, %i, %i, %r)" % (self .vin , self .vout ,
406+ self .nLockTime , self .nVersion , self .wit )
345407
346408 @classmethod
347409 def from_tx (cls , tx ):
@@ -354,15 +416,30 @@ def from_tx(cls, tx):
354416 return tx
355417
356418 else :
357- return cls (tx .vin , tx .vout , tx .nLockTime , tx .nVersion )
419+ return cls (tx .vin , tx .vout , tx .nLockTime , tx .nVersion , tx .wit )
420+
421+ def GetTxid (self ):
422+ """Get the transaction ID. This differs from the transactions hash as
423+ given by GetHash. GetTxid excludes witness data, while GetHash
424+ includes it. """
425+ if self .wit :
426+ wit = self .wit
427+ self .wit = b''
428+ txid = Hash (self .serialize ())
429+ self .wit = wit
430+ else :
431+ txid = Hash (self .serialize ())
432+ return txid
433+
434+
358435
359436
360437@__make_mutable
361438class CMutableTransaction (CTransaction ):
362439 """A mutable transaction"""
363440 __slots__ = []
364441
365- def __init__ (self , vin = None , vout = None , nLockTime = 0 , nVersion = 1 ):
442+ def __init__ (self , vin = None , vout = None , nLockTime = 0 , nVersion = 1 , witness = CScriptWitness ([]) ):
366443 if not (0 <= nLockTime <= 0xffffffff ):
367444 raise ValueError ('CTransaction: nLockTime must be in range 0x0 to 0xffffffff; got %x' % nLockTime )
368445 self .nLockTime = nLockTime
@@ -375,14 +452,15 @@ def __init__(self, vin=None, vout=None, nLockTime=0, nVersion=1):
375452 vout = []
376453 self .vout = vout
377454 self .nVersion = nVersion
455+ self .wit = witness
378456
379457 @classmethod
380458 def from_tx (cls , tx ):
381459 """Create a fully mutable copy of a pre-existing transaction"""
382460 vin = [CMutableTxIn .from_txin (txin ) for txin in tx .vin ]
383461 vout = [CMutableTxOut .from_txout (txout ) for txout in tx .vout ]
384462
385- return cls (vin , vout , tx .nLockTime , tx .nVersion )
463+ return cls (vin , vout , tx .nLockTime , tx .nVersion , tx . wit )
386464
387465
388466
@@ -478,7 +556,7 @@ def build_merkle_tree_from_txids(txids):
478556 @staticmethod
479557 def build_merkle_tree_from_txs (txs ):
480558 """Build a full merkle tree from transactions"""
481- txids = [tx .GetHash () for tx in txs ]
559+ txids = [tx .GetTxid () for tx in txs ]
482560 return CBlock .build_merkle_tree_from_txids (txids )
483561
484562 def calc_merkle_root (self ):
@@ -730,7 +808,7 @@ def CheckBlock(block, fCheckPoW = True, fCheckMerkleRoot = True, cur_time=None):
730808
731809 CheckTransaction (tx )
732810
733- txid = tx .GetHash ()
811+ txid = tx .GetTxid ()
734812 if txid in unique_txids :
735813 raise CheckBlockError ("CheckBlock() : duplicate transaction" )
736814 unique_txids .add (txid )
0 commit comments