import numpy as np from keras.layers import Embedding, SimpleRNN, Dense from keras.models import Sequential # 训练数据(包含逗号) text = "用户:今天想吃火锅吗? 客服:我们海鲜火锅很受欢迎。用户:但朋友对海鲜过敏,推荐其他吧。客服:好的,我们有菌汤火锅。" base_chars = [',', '。', '?', ':'] # 确保基础标点存在 chars = sorted(list(set(text + ''.join(base_chars)))) char_to_idx = {c:i for i,c in enumerate(chars)} idx_to_char = {i:c for c,i in char_to_idx.items()} # 创建训练序列 max_length = 20 X, y = [], [] for i in range(len(text)-max_length): seq = text[i:i+max_length] target = text[i+max_length] X.append([char_to_idx[c] for c in seq]) y.append(char_to_idx[target]) # 模型构建 model = Sequential([ Embedding(input_dim=len(chars), output_dim=32, input_length=max_length), SimpleRNN(128), Dense(len(chars), activation='softmax') ]) model.compile(loss='sparse_categorical_crossentropy', optimizer='adam') # 训练 X = np.array(X) y = np.array(y) model.fit(X, y, epochs=50, batch_size=32) # 增强后的生成函数 def generate_response(prompt): generated = prompt for _ in range(30): # 过滤并处理未知字符 valid_chars = [] for c in generated[-max_length:]: if c in char_to_idx: valid_chars.append(c) else: valid_chars.append(' ') # 未知字符替换为空格 # 填充序列 seq = valid_chars[-max_length:] seq = seq + [' ']*(max_length - len(seq)) # 转换为索引 seq_indices = [char_to_idx[c] for c in seq] # 生成下一个字符 pred = model.predict(np.array([seq_indices]), verbose=0) next_char = idx_to_char[np.argmax(pred)] generated += next_char return generated print(generate_response("用户:朋友海鲜过敏,能不能推荐一些其他的?"))