59 lines
1.9 KiB
Python
59 lines
1.9 KiB
Python
|
import numpy as np
|
|||
|
from keras.layers import Embedding, SimpleRNN, Dense
|
|||
|
from keras.models import Sequential
|
|||
|
|
|||
|
# 训练数据(包含逗号)
|
|||
|
text = "用户:今天想吃火锅吗? 客服:我们海鲜火锅很受欢迎。用户:但朋友对海鲜过敏,推荐其他吧。客服:好的,我们有菌汤火锅。"
|
|||
|
base_chars = [',', '。', '?', ':'] # 确保基础标点存在
|
|||
|
chars = sorted(list(set(text + ''.join(base_chars))))
|
|||
|
char_to_idx = {c:i for i,c in enumerate(chars)}
|
|||
|
idx_to_char = {i:c for c,i in char_to_idx.items()}
|
|||
|
|
|||
|
# 创建训练序列
|
|||
|
max_length = 20
|
|||
|
X, y = [], []
|
|||
|
for i in range(len(text)-max_length):
|
|||
|
seq = text[i:i+max_length]
|
|||
|
target = text[i+max_length]
|
|||
|
X.append([char_to_idx[c] for c in seq])
|
|||
|
y.append(char_to_idx[target])
|
|||
|
|
|||
|
# 模型构建
|
|||
|
model = Sequential([
|
|||
|
Embedding(input_dim=len(chars), output_dim=32, input_length=max_length),
|
|||
|
SimpleRNN(128),
|
|||
|
Dense(len(chars), activation='softmax')
|
|||
|
])
|
|||
|
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')
|
|||
|
|
|||
|
# 训练
|
|||
|
X = np.array(X)
|
|||
|
y = np.array(y)
|
|||
|
model.fit(X, y, epochs=50, batch_size=32)
|
|||
|
|
|||
|
# 增强后的生成函数
|
|||
|
def generate_response(prompt):
|
|||
|
generated = prompt
|
|||
|
for _ in range(30):
|
|||
|
# 过滤并处理未知字符
|
|||
|
valid_chars = []
|
|||
|
for c in generated[-max_length:]:
|
|||
|
if c in char_to_idx:
|
|||
|
valid_chars.append(c)
|
|||
|
else:
|
|||
|
valid_chars.append(' ') # 未知字符替换为空格
|
|||
|
|
|||
|
# 填充序列
|
|||
|
seq = valid_chars[-max_length:]
|
|||
|
seq = seq + [' ']*(max_length - len(seq))
|
|||
|
|
|||
|
# 转换为索引
|
|||
|
seq_indices = [char_to_idx[c] for c in seq]
|
|||
|
|
|||
|
# 生成下一个字符
|
|||
|
pred = model.predict(np.array([seq_indices]), verbose=0)
|
|||
|
next_char = idx_to_char[np.argmax(pred)]
|
|||
|
generated += next_char
|
|||
|
return generated
|
|||
|
|
|||
|
print(generate_response("用户:朋友海鲜过敏,能不能推荐一些其他的?"))
|