59 lines
1.9 KiB
Python
Raw Normal View History

2025-03-16 02:02:53 +08:00
import numpy as np
from keras.layers import Embedding, SimpleRNN, Dense
from keras.models import Sequential
# 训练数据(包含逗号)
text = "用户:今天想吃火锅吗? 客服:我们海鲜火锅很受欢迎。用户:但朋友对海鲜过敏,推荐其他吧。客服:好的,我们有菌汤火锅。"
base_chars = ['', '', '', ''] # 确保基础标点存在
chars = sorted(list(set(text + ''.join(base_chars))))
char_to_idx = {c:i for i,c in enumerate(chars)}
idx_to_char = {i:c for c,i in char_to_idx.items()}
# 创建训练序列
max_length = 20
X, y = [], []
for i in range(len(text)-max_length):
seq = text[i:i+max_length]
target = text[i+max_length]
X.append([char_to_idx[c] for c in seq])
y.append(char_to_idx[target])
# 模型构建
model = Sequential([
Embedding(input_dim=len(chars), output_dim=32, input_length=max_length),
SimpleRNN(128),
Dense(len(chars), activation='softmax')
])
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')
# 训练
X = np.array(X)
y = np.array(y)
model.fit(X, y, epochs=50, batch_size=32)
# 增强后的生成函数
def generate_response(prompt):
generated = prompt
for _ in range(30):
# 过滤并处理未知字符
valid_chars = []
for c in generated[-max_length:]:
if c in char_to_idx:
valid_chars.append(c)
else:
valid_chars.append(' ') # 未知字符替换为空格
# 填充序列
seq = valid_chars[-max_length:]
seq = seq + [' ']*(max_length - len(seq))
# 转换为索引
seq_indices = [char_to_idx[c] for c in seq]
# 生成下一个字符
pred = model.predict(np.array([seq_indices]), verbose=0)
next_char = idx_to_char[np.argmax(pred)]
generated += next_char
return generated
print(generate_response("用户:朋友海鲜过敏,能不能推荐一些其他的?"))