TeamClass_MD/RNN_good.py
2025-03-19 15:56:43 +08:00

59 lines
1.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
from keras.layers import Embedding, SimpleRNN, Dense
from keras.models import Sequential
# 训练数据(包含逗号)
text = "用户:今天想吃火锅吗? 客服:我们海鲜火锅很受欢迎。用户:但朋友对海鲜过敏,推荐其他吧。客服:好的,我们有菌汤火锅。"
base_chars = ['', '', '', ''] # 确保基础标点存在
chars = sorted(list(set(text + ''.join(base_chars))))
char_to_idx = {c:i for i,c in enumerate(chars)}
idx_to_char = {i:c for c,i in char_to_idx.items()}
# 创建训练序列
max_length = 20
X, y = [], []
for i in range(len(text)-max_length):
seq = text[i:i+max_length]
target = text[i+max_length]
X.append([char_to_idx[c] for c in seq])
y.append(char_to_idx[target])
# 模型构建
model = Sequential([
Embedding(input_dim=len(chars), output_dim=32, input_length=max_length),
SimpleRNN(128),
Dense(len(chars), activation='softmax')
])
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')
# 训练
X = np.array(X)
y = np.array(y)
model.fit(X, y, epochs=50, batch_size=32)
# 增强后的生成函数
def generate_response(prompt):
generated = prompt
for _ in range(30):
# 过滤并处理未知字符
valid_chars = []
for c in generated[-max_length:]:
if c in char_to_idx:
valid_chars.append(c)
else:
valid_chars.append(' ') # 未知字符替换为空格
# 填充序列
seq = valid_chars[-max_length:]
seq = seq + [' ']*(max_length - len(seq))
# 转换为索引
seq_indices = [char_to_idx[c] for c in seq]
# 生成下一个字符
pred = model.predict(np.array([seq_indices]), verbose=0)
next_char = idx_to_char[np.argmax(pred)]
generated += next_char
return generated
print(generate_response("用户:朋友海鲜过敏,能不能推荐一些其他的?"))