@@ -828,6 +828,15 @@ def _ensure_bytes(l):
828828 raise ValueError ('expected bytes, found %r' % l )
829829
830830
831+ def _ensure_text (l ):
832+ if isinstance (l , text_type ):
833+ return l
834+ elif isinstance (l , binary_type ):
835+ return text_type (l , 'ascii' )
836+ else :
837+ raise ValueError ('expected text, found %r' % l )
838+
839+
831840class Categorize (Codec ):
832841 """Filter encoding categorical string data as integers.
833842
@@ -862,10 +871,13 @@ class Categorize(Codec):
862871 codec_id = 'categorize'
863872
864873 def __init__ (self , labels , dtype , astype = 'u1' ):
865- self .labels = [_ensure_bytes (l ) for l in labels ]
866874 self .dtype = np .dtype (dtype )
867- if self .dtype .kind != 'S' :
868- raise ValueError ('only string data types are supported' )
875+ if self .dtype .kind == 'S' :
876+ self .labels = [_ensure_bytes (l ) for l in labels ]
877+ elif self .dtype .kind == 'U' :
878+ self .labels = [_ensure_text (l ) for l in labels ]
879+ else :
880+ raise ValueError ('data type not supported' )
869881 self .astype = np .dtype (astype )
870882
871883 def encode (self , buf ):
@@ -909,7 +921,7 @@ def decode(self, buf, out=None):
909921 def get_config (self ):
910922 config = dict ()
911923 config ['id' ] = self .codec_id
912- config ['labels' ] = [text_type ( l , 'ascii' ) for l in self .labels ]
924+ config ['labels' ] = [_ensure_text ( l ) for l in self .labels ]
913925 config ['dtype' ] = encode_dtype (self .dtype )
914926 config ['astype' ] = encode_dtype (self .astype )
915927 return config
@@ -922,8 +934,12 @@ def from_config(cls, config):
922934 return cls (labels = labels , dtype = dtype , astype = astype )
923935
924936 def __repr__ (self ):
925- r = '%s(dtype=%s, astype=%s, labels=%r)' % \
926- (type (self ).__name__ , self .dtype , self .astype , self .labels )
937+ # make sure labels part is not too long
938+ labels = repr (self .labels [:3 ])
939+ if len (self .labels ) > 3 :
940+ labels = labels [:- 1 ] + ', ...]'
941+ r = '%s(dtype=%s, astype=%s, labels=%s)' % \
942+ (type (self ).__name__ , self .dtype , self .astype , labels )
927943 return r
928944
929945
0 commit comments