__init__.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. from __future__ import with_statement
  2. __version__ = '0.31'
  3. __license__ = 'MIT'
  4. import re
  5. import os
  6. import sys
  7. import finalseg
  8. import time
  9. import tempfile
  10. import marshal
  11. from math import log
  12. import random
  13. import threading
  14. from functools import wraps
  15. DICTIONARY = "dict.txt"
  16. DICT_LOCK = threading.RLock()
  17. trie = None # to be initialized
  18. FREQ = {}
  19. min_freq = 0.0
  20. total =0.0
  21. user_word_tag_tab={}
  22. initialized = False
  23. def gen_trie(f_name):
  24. lfreq = {}
  25. trie = {}
  26. ltotal = 0.0
  27. with open(f_name, 'rb') as f:
  28. lineno = 0
  29. for line in f.read().rstrip().decode('utf-8').split('\n'):
  30. lineno += 1
  31. try:
  32. word,freq,_ = line.split(' ')
  33. freq = float(freq)
  34. lfreq[word] = freq
  35. ltotal+=freq
  36. p = trie
  37. for c in word:
  38. if not c in p:
  39. p[c] ={}
  40. p = p[c]
  41. p['']='' #ending flag
  42. except ValueError, e:
  43. print >> sys.stderr, f_name, ' at line', lineno, line
  44. raise e
  45. return trie, lfreq,ltotal
  46. def initialize(*args):
  47. global trie, FREQ, total, min_freq, initialized
  48. if len(args)==0:
  49. dictionary = DICTIONARY
  50. else:
  51. dictionary = args[0]
  52. with DICT_LOCK:
  53. if initialized:
  54. return
  55. if trie:
  56. del trie
  57. trie = None
  58. _curpath=os.path.normpath( os.path.join( os.getcwd(), os.path.dirname(__file__) ) )
  59. abs_path = os.path.join(_curpath,dictionary)
  60. print >> sys.stderr, "Building Trie..., from " + abs_path
  61. t1 = time.time()
  62. if abs_path == os.path.join(_curpath,"dict.txt"): #defautl dictionary
  63. cache_file = os.path.join(tempfile.gettempdir(),"jieba.cache")
  64. else: #customer dictionary
  65. cache_file = os.path.join(tempfile.gettempdir(),"jieba.user."+str(hash(abs_path))+".cache")
  66. load_from_cache_fail = True
  67. if os.path.exists(cache_file) and os.path.getmtime(cache_file)>os.path.getmtime(abs_path):
  68. print >> sys.stderr, "loading model from cache " + cache_file
  69. try:
  70. trie,FREQ,total,min_freq = marshal.load(open(cache_file,'rb'))
  71. load_from_cache_fail = False
  72. except:
  73. load_from_cache_fail = True
  74. if load_from_cache_fail:
  75. trie,FREQ,total = gen_trie(abs_path)
  76. FREQ = dict([(k,log(float(v)/total)) for k,v in FREQ.iteritems()]) #normalize
  77. min_freq = min(FREQ.itervalues())
  78. print >> sys.stderr, "dumping model to file cache " + cache_file
  79. try:
  80. tmp_suffix = "."+str(random.random())
  81. with open(cache_file+tmp_suffix,'wb') as temp_cache_file:
  82. marshal.dump((trie,FREQ,total,min_freq),temp_cache_file)
  83. if os.name=='nt':
  84. import shutil
  85. replace_file = shutil.move
  86. else:
  87. replace_file = os.rename
  88. replace_file(cache_file+tmp_suffix,cache_file)
  89. except:
  90. print >> sys.stderr, "dump cache file failed."
  91. import traceback
  92. print >> sys.stderr, traceback.format_exc()
  93. initialized = True
  94. print >> sys.stderr, "loading model cost ", time.time() - t1, "seconds."
  95. print >> sys.stderr, "Trie has been built succesfully."
  96. def require_initialized(fn):
  97. global initialized,DICTIONARY
  98. @wraps(fn)
  99. def wrapped(*args, **kwargs):
  100. if initialized:
  101. return fn(*args, **kwargs)
  102. else:
  103. initialize(DICTIONARY)
  104. return fn(*args, **kwargs)
  105. return wrapped
  106. def __cut_all(sentence):
  107. dag = get_DAG(sentence)
  108. old_j = -1
  109. for k,L in dag.iteritems():
  110. if len(L)==1 and k>old_j:
  111. yield sentence[k:L[0]+1]
  112. old_j = L[0]
  113. else:
  114. for j in L:
  115. if j>k:
  116. yield sentence[k:j+1]
  117. old_j = j
  118. def calc(sentence,DAG,idx,route):
  119. N = len(sentence)
  120. route[N] = (0.0,'')
  121. for idx in xrange(N-1,-1,-1):
  122. candidates = [ ( FREQ.get(sentence[idx:x+1],min_freq) + route[x+1][0],x ) for x in DAG[idx] ]
  123. route[idx] = max(candidates)
  124. @require_initialized
  125. def get_DAG(sentence):
  126. N = len(sentence)
  127. i,j=0,0
  128. p = trie
  129. DAG = {}
  130. while i<N:
  131. c = sentence[j]
  132. if c in p:
  133. p = p[c]
  134. if '' in p:
  135. if not i in DAG:
  136. DAG[i]=[]
  137. DAG[i].append(j)
  138. j+=1
  139. if j>=N:
  140. i+=1
  141. j=i
  142. p=trie
  143. else:
  144. p = trie
  145. i+=1
  146. j=i
  147. for i in xrange(len(sentence)):
  148. if not i in DAG:
  149. DAG[i] =[i]
  150. return DAG
  151. def __cut_DAG(sentence):
  152. DAG = get_DAG(sentence)
  153. route ={}
  154. calc(sentence,DAG,0,route=route)
  155. x = 0
  156. buf =u''
  157. N = len(sentence)
  158. while x<N:
  159. y = route[x][1]+1
  160. l_word = sentence[x:y]
  161. if y-x==1:
  162. buf+= l_word
  163. else:
  164. if len(buf)>0:
  165. if len(buf)==1:
  166. yield buf
  167. buf=u''
  168. else:
  169. if not (buf in FREQ):
  170. regognized = finalseg.cut(buf)
  171. for t in regognized:
  172. yield t
  173. else:
  174. for elem in buf:
  175. yield elem
  176. buf=u''
  177. yield l_word
  178. x =y
  179. if len(buf)>0:
  180. if len(buf)==1:
  181. yield buf
  182. else:
  183. if not (buf in FREQ):
  184. regognized = finalseg.cut(buf)
  185. for t in regognized:
  186. yield t
  187. else:
  188. for elem in buf:
  189. yield elem
  190. def cut(sentence,cut_all=False):
  191. if not isinstance(sentence, unicode):
  192. try:
  193. sentence = sentence.decode('utf-8')
  194. except UnicodeDecodeError:
  195. sentence = sentence.decode('gbk','ignore')
  196. re_han, re_skip = re.compile(ur"([\u4E00-\u9FA5a-zA-Z0-9+#&\._]+)", re.U), re.compile(ur"(\r\n|\s)", re.U)
  197. if cut_all:
  198. re_han, re_skip = re.compile(ur"([\u4E00-\u9FA5]+)", re.U), re.compile(ur"[^a-zA-Z0-9+#\n]", re.U)
  199. blocks = re_han.split(sentence)
  200. cut_block = __cut_DAG
  201. if cut_all:
  202. cut_block = __cut_all
  203. for blk in blocks:
  204. if re_han.match(blk):
  205. #pprint.pprint(__cut_DAG(blk))
  206. for word in cut_block(blk):
  207. yield word
  208. else:
  209. tmp = re_skip.split(blk)
  210. for x in tmp:
  211. if re_skip.match(x):
  212. yield x
  213. elif not cut_all:
  214. for xx in x:
  215. yield xx
  216. else:
  217. yield x
  218. def cut_for_search(sentence):
  219. words = cut(sentence)
  220. for w in words:
  221. if len(w)>2:
  222. for i in xrange(len(w)-1):
  223. gram2 = w[i:i+2]
  224. if gram2 in FREQ:
  225. yield gram2
  226. if len(w)>3:
  227. for i in xrange(len(w)-2):
  228. gram3 = w[i:i+3]
  229. if gram3 in FREQ:
  230. yield gram3
  231. yield w
  232. @require_initialized
  233. def load_userdict(f):
  234. global trie,total,FREQ
  235. if isinstance(f, (str, unicode)):
  236. f = open(f, 'rb')
  237. content = f.read().decode('utf-8')
  238. line_no = 0
  239. for line in content.split("\n"):
  240. line_no+=1
  241. if line.rstrip()=='': continue
  242. tup =line.split(" ")
  243. word,freq = tup[0],tup[1]
  244. if line_no==1:
  245. word = word.replace(u'\ufeff',u"") #remove bom flag if it exists
  246. if len(tup)==3:
  247. add_word(word, freq, tup[2])
  248. else:
  249. add_word(word, freq)
  250. def add_word(word, freq, tag=None):
  251. global FREQ, trie, total, user_word_tag_tab
  252. freq = float(freq)
  253. FREQ[word] = log(freq / total)
  254. if tag is not None:
  255. user_word_tag_tab[word] = tag.strip()
  256. p = trie
  257. for c in word:
  258. if not c in p:
  259. p[c] = {}
  260. p = p[c]
  261. p[''] = '' # ending flag
  262. __ref_cut = cut
  263. __ref_cut_for_search = cut_for_search
  264. def __lcut(sentence):
  265. return list(__ref_cut(sentence,False))
  266. def __lcut_all(sentence):
  267. return list(__ref_cut(sentence,True))
  268. def __lcut_for_search(sentence):
  269. return list(__ref_cut_for_search(sentence))
  270. @require_initialized
  271. def enable_parallel(processnum=None):
  272. global pool,cut,cut_for_search
  273. if os.name=='nt':
  274. raise Exception("jieba: parallel mode only supports posix system")
  275. if sys.version_info[0]==2 and sys.version_info[1]<6:
  276. raise Exception("jieba: the parallel feature needs Python version>2.5 ")
  277. from multiprocessing import Pool,cpu_count
  278. if processnum==None:
  279. processnum = cpu_count()
  280. pool = Pool(processnum)
  281. def pcut(sentence,cut_all=False):
  282. parts = re.compile('([\r\n]+)').split(sentence)
  283. if cut_all:
  284. result = pool.map(__lcut_all,parts)
  285. else:
  286. result = pool.map(__lcut,parts)
  287. for r in result:
  288. for w in r:
  289. yield w
  290. def pcut_for_search(sentence):
  291. parts = re.compile('([\r\n]+)').split(sentence)
  292. result = pool.map(__lcut_for_search,parts)
  293. for r in result:
  294. for w in r:
  295. yield w
  296. cut = pcut
  297. cut_for_search = pcut_for_search
  298. def disable_parallel():
  299. global pool,cut,cut_for_search
  300. if 'pool' in globals():
  301. pool.close()
  302. pool = None
  303. cut = __ref_cut
  304. cut_for_search = __ref_cut_for_search
  305. def set_dictionary(dictionary_path):
  306. global initialized, DICTIONARY
  307. with DICT_LOCK:
  308. abs_path = os.path.normpath( os.path.join( os.getcwd(), dictionary_path ) )
  309. if not os.path.exists(abs_path):
  310. raise Exception("jieba: path does not exists:" + abs_path)
  311. DICTIONARY = abs_path
  312. initialized = False
  313. def get_abs_path_dict():
  314. _curpath=os.path.normpath( os.path.join( os.getcwd(), os.path.dirname(__file__) ) )
  315. abs_path = os.path.join(_curpath,DICTIONARY)
  316. return abs_path
  317. def tokenize(unicode_sentence,mode="default"):
  318. #mode ("default" or "search")
  319. if not isinstance(unicode_sentence, unicode):
  320. raise Exception("jieba: the input parameter should unicode.")
  321. start = 0
  322. if mode=='default':
  323. for w in cut(unicode_sentence):
  324. width = len(w)
  325. yield (w,start,start+width)
  326. start+=width
  327. else:
  328. for w in cut(unicode_sentence):
  329. width = len(w)
  330. if len(w)>2:
  331. for i in xrange(len(w)-1):
  332. gram2 = w[i:i+2]
  333. if gram2 in FREQ:
  334. yield (gram2,start+i,start+i+2)
  335. if len(w)>3:
  336. for i in xrange(len(w)-2):
  337. gram3 = w[i:i+3]
  338. if gram3 in FREQ:
  339. yield (gram3,start+i,start+i+3)
  340. yield (w,start,start+width)
  341. start+=width