tools.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # coding:utf-8
  2. import jieba.posseg as psg
  3. from sklearn.feature_extraction.text import TfidfVectorizer
  4. from sklearn.preprocessing import LabelEncoder, OneHotEncoder
  5. from sklearn.preprocessing import MultiLabelBinarizer
  6. from util.htmlutil.htmltag import Clean
  7. import jieba
  8. import multiprocessing
  9. jieba.enable_parallel(multiprocessing.cpu_count())
  10. def chinese2vectors(chinese: list, remove_word: list, stop_words: list) -> list:
  11. """
  12. 中文转向量(多文本)
  13. :param chinese:
  14. :param remove_word: 去除词性 x ,n , eng
  15. :param stop_words: 停用词
  16. :return:
  17. """
  18. if not remove_word:
  19. remove_word = ["x"]
  20. if not stop_words:
  21. stop_words = []
  22. space_words = []
  23. for row in chinese:
  24. cut_ret = [word for word, x in psg.lcut(Clean(row)) if x not in remove_word and word not in stop_words]
  25. space_words.append(" ".join(cut_ret))
  26. return space_words
  27. def chinese2vector(chinese: str, remove_word: list, stopwords: list) -> str:
  28. """
  29. 中文转向量(但文本)
  30. :param chinese:
  31. :param remove_word: 去除词性 x ,n , eng
  32. :param stopwords: 停用词
  33. :return:
  34. """
  35. if not stopwords:
  36. stopwords = []
  37. if not remove_word:
  38. remove_word = ["x"]
  39. cut_ret = [word for word, x in psg.lcut(Clean(chinese)) if x not in remove_word and word not in stopwords]
  40. cut_ret = " ".join(cut_ret)
  41. return cut_ret
  42. def tfidf(analyzer, space_words) -> tuple:
  43. '''
  44. tf-idf编码
  45. :param analyzer:
  46. :param space_words:
  47. :return:
  48. '''
  49. tfidf_vec = TfidfVectorizer(analyzer=analyzer)
  50. tfidf_ret = tfidf_vec.fit_transform(space_words)
  51. return tfidf_vec, tfidf_ret
  52. def one2hot(space_words) -> tuple:
  53. '''
  54. onehot编码
  55. :param space_words:
  56. :return:
  57. '''
  58. oht = OneHotEncoder()
  59. oht_ret = oht.fit_transform(space_words)
  60. return oht, oht_ret
  61. def combine_row(target_one: [list], target_two: [list]) -> list:
  62. """
  63. 二维元组
  64. :param target_one:
  65. :param target_two:
  66. :return:
  67. """
  68. if len(target_one) != len(target_two):
  69. raise ValueError("两个列表维度不同")
  70. try:
  71. for ind, row in enumerate(target_two):
  72. target_one[ind] += row
  73. except Exception as e:
  74. raise e
  75. return target_one
  76. def label2encode(labels: []) -> tuple:
  77. """
  78. labelEncode 标签向量化
  79. :param labels:
  80. :return:
  81. """
  82. le = LabelEncoder()
  83. train_labels = []
  84. for row in labels:
  85. train_labels += row
  86. le.fit_transform(train_labels)
  87. le_ret = [le.transform(row) for row in labels]
  88. le_ret = MultiLabelBinarizer().fit_transform(le_ret)
  89. return le, le_ret
  90. def encode2label(le, predict_results, target_label: list) -> list:
  91. """
  92. 向量转标签
  93. :param le: 标签词典对象
  94. :param predict_results: 预测结果
  95. :param target_label: 需要的分类
  96. :return:
  97. """
  98. detail_labels = []
  99. for i, label in enumerate(predict_results):
  100. if label.sum() > 0:
  101. label = [i for (i, x) in enumerate(label) if x > 0]
  102. label_str = ','.join([label for label in le.inverse_transform(label) if label in target_label])
  103. detail_labels.append(label_str)
  104. return detail_labels