59 lines
1.9 KiB
Python
59 lines
1.9 KiB
Python
import numpy as np
|
||
from keras.layers import Embedding, SimpleRNN, Dense
|
||
from keras.models import Sequential
|
||
|
||
# 训练数据(包含逗号)
|
||
text = "用户:今天想吃火锅吗? 客服:我们海鲜火锅很受欢迎。用户:但朋友对海鲜过敏,推荐其他吧。客服:好的,我们有菌汤火锅。"
|
||
base_chars = [',', '。', '?', ':'] # 确保基础标点存在
|
||
chars = sorted(list(set(text + ''.join(base_chars))))
|
||
char_to_idx = {c:i for i,c in enumerate(chars)}
|
||
idx_to_char = {i:c for c,i in char_to_idx.items()}
|
||
|
||
# 创建训练序列
|
||
max_length = 20
|
||
X, y = [], []
|
||
for i in range(len(text)-max_length):
|
||
seq = text[i:i+max_length]
|
||
target = text[i+max_length]
|
||
X.append([char_to_idx[c] for c in seq])
|
||
y.append(char_to_idx[target])
|
||
|
||
# 模型构建
|
||
model = Sequential([
|
||
Embedding(input_dim=len(chars), output_dim=32, input_length=max_length),
|
||
SimpleRNN(128),
|
||
Dense(len(chars), activation='softmax')
|
||
])
|
||
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')
|
||
|
||
# 训练
|
||
X = np.array(X)
|
||
y = np.array(y)
|
||
model.fit(X, y, epochs=50, batch_size=32)
|
||
|
||
# 增强后的生成函数
|
||
def generate_response(prompt):
|
||
generated = prompt
|
||
for _ in range(30):
|
||
# 过滤并处理未知字符
|
||
valid_chars = []
|
||
for c in generated[-max_length:]:
|
||
if c in char_to_idx:
|
||
valid_chars.append(c)
|
||
else:
|
||
valid_chars.append(' ') # 未知字符替换为空格
|
||
|
||
# 填充序列
|
||
seq = valid_chars[-max_length:]
|
||
seq = seq + [' ']*(max_length - len(seq))
|
||
|
||
# 转换为索引
|
||
seq_indices = [char_to_idx[c] for c in seq]
|
||
|
||
# 生成下一个字符
|
||
pred = model.predict(np.array([seq_indices]), verbose=0)
|
||
next_char = idx_to_char[np.argmax(pred)]
|
||
generated += next_char
|
||
return generated
|
||
|
||
print(generate_response("用户:朋友海鲜过敏,能不能推荐一些其他的?")) |