Go的基本语法

This commit is contained in:
krust 2025-03-19 15:56:43 +08:00
parent 171c9e9815
commit 626790f753
17 changed files with 897 additions and 3 deletions

5
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,5 @@
{
"files.associations": {
"stdio.h": "c"
}
}

80
Game1.py Normal file
View File

@ -0,0 +1,80 @@
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
# 生成合成数据
t = np.linspace(0, 24*np.pi, 1000)
data = np.sin(t) + 0.5*np.sin(3*t) + 0.05*t # 混合波形+趋势项
# 数据预处理
def create_dataset(data, look_back=30):
X, y = [], []
for i in range(len(data)-look_back):
X.append(data[i:i+look_back])
y.append(data[i+look_back])
return torch.FloatTensor(X).unsqueeze(-1), torch.FloatTensor(y)
X, y = create_dataset(data)
train_size = int(0.8 * len(X))
train_X, test_X = X[:train_size], X[train_size:]
train_y, test_y = y[:train_size], y[train_size:]
# 模型定义
class TimeSeriesModel(nn.Module):
def __init__(self, model_type):
super().__init__()
self.model_type = model_type
if model_type == 'LSTM':
self.rnn = nn.LSTM(1, 64, num_layers=2)
else:
self.rnn = nn.RNN(1, 64)
self.fc = nn.Linear(64, 1)
def forward(self, x):
out, _ = self.rnn(x)
return self.fc(out[-1, :, :])
# 训练函数
# 修改后的训练函数,返回预测结果和测试损失
def train_and_predict(model_type):
model = TimeSeriesModel(model_type)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练循环
for epoch in range(100):
output = model(train_X.transpose(0, 1))
loss = criterion(output.squeeze(), train_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 20 == 0:
print(f"{model_type} Epoch {epoch} Loss: {loss.item():.4f}")
# 预测阶段
with torch.no_grad():
test_pred = model(test_X.transpose(0, 1))
test_loss = criterion(test_pred.squeeze(), test_y)
print(f"{model_type} Test MSE: {test_loss.item():.4f}")
return test_pred.squeeze().numpy(), test_loss.item()
# 同时训练两种模型并收集结果
lstm_pred, lstm_loss = train_and_predict('LSTM')
rnn_pred, rnn_loss = train_and_predict('RNN')
# 统一可视化比较
plt.figure(figsize=(12,6))
plt.plot(test_y.numpy(), label='True Values', alpha=0.7)
plt.plot(lstm_pred, label=f'LSTM (MSE: {lstm_loss:.4f})', linestyle='--')
plt.plot(rnn_pred, label=f'RNN (MSE: {rnn_loss:.4f})', linestyle='--')
plt.title('Time Series Prediction Comparison')
plt.xlabel('Time Steps')
plt.ylabel('Value')
plt.legend()
plt.show()

124
Game2.py Normal file
View File

@ -0,0 +1,124 @@
import torch
import torch.nn as nn
import numpy as np
#这个实验的目的是比较RNN和GRU在相同任务上的性能即学习序列中两个随机位置数值的和。
# 数据生成
def generate_add_data(seq_len=30):
data = torch.zeros(seq_len, 2) # (seq_len, 2)
idx1, idx2 = np.random.choice(seq_len, 2, replace=False)
val1, val2 = np.random.rand()*0.5, np.random.rand()*0.5
data[idx1, 0] = val1
data[idx2, 0] = val2
target = torch.tensor([val1 + val2]).view(1,1)
return data.unsqueeze(0), target # (1, seq_len, 2)
# 模型定义
class AdditionRNN(nn.Module):
def __init__(self):
super().__init__()
self.rnn = nn.RNN(2, 16, batch_first=True)
self.fc = nn.Linear(16, 1)
def forward(self, x):
out, _ = self.rnn(x)
return self.fc(out[:, -1, :])
class AdditionGRU(nn.Module):
def __init__(self):
super().__init__()
self.gru = nn.GRU(2, 16, batch_first=True)
self.fc = nn.Linear(16, 1)
def forward(self, x):
out, _ = self.gru(x)
return self.fc(out[:, -1, :])
# 训练对比
import matplotlib.pyplot as plt
# 修改后的训练函数,记录损失变化
def train_addition():
rnn = AdditionRNN()
gru = AdditionGRU()
criterion = nn.MSELoss()
optim_rnn = torch.optim.Adam(rnn.parameters(), lr=0.01)
optim_gru = torch.optim.Adam(gru.parameters(), lr=0.01)
# 记录训练过程
losses = {'RNN': [], 'GRU': []}
for step in range(1000):
inputs, target = generate_add_data(seq_len=30)
# RNN训练
optim_rnn.zero_grad()
rnn_pred = rnn(inputs)
loss_rnn = criterion(rnn_pred, target)
loss_rnn.backward()
optim_rnn.step()
# GRU训练
optim_gru.zero_grad()
gru_pred = gru(inputs)
loss_gru = criterion(gru_pred, target)
loss_gru.backward()
optim_gru.step()
# 记录损失
losses['RNN'].append(loss_rnn.item())
losses['GRU'].append(loss_gru.item())
if step % 200 == 0:
print(f"Step {step:03d} | RNN Loss: {loss_rnn.item():.4f} | GRU Loss: {loss_gru.item():.4f}")
# 绘制损失曲线
plt.figure(figsize=(10, 5))
plt.plot(losses['RNN'], label='RNN', alpha=0.7)
plt.plot(losses['GRU'], label='GRU', alpha=0.7)
plt.xlabel('Training Steps')
plt.ylabel('MSE Loss')
plt.title('Training Comparison: RNN vs GRU')
plt.legend()
plt.grid(True)
plt.show()
return rnn, gru
# 执行训练
rnn_model, gru_model = train_addition()
def show_test_cases(model, model_name, num_cases=5):
print(f"\n{model_name} 测试样例:")
criterion = nn.MSELoss()
total_error = 0
for case_idx in range(num_cases):
# 生成测试数据
inputs, target = generate_add_data()
seq_len = inputs.shape[1]
# 解析输入数据
non_zero_indices = torch.nonzero(inputs[0, :, 0])
pos1, pos2 = non_zero_indices[0].item(), non_zero_indices[1].item()
val1 = inputs[0, pos1, 0].item()
val2 = inputs[0, pos2, 0].item()
# 模型预测
with torch.no_grad():
pred = model(inputs)
loss = criterion(pred, target)
# 格式输出
print(f"案例 {case_idx+1}:")
print(f"输入序列长度: {seq_len}")
print(f"数值位置: [{pos1:2d}]={val1:.4f}, [{pos2:2d}]={val2:.4f}")
print(f"真实值: {target.item():.4f}")
print(f"预测值: {pred.item():.4f}")
print(f"绝对误差: {abs(pred.item()-target.item()):.4f}")
print("-" * 40)
total_error += abs(pred.item()-target.item())
print(f"平均绝对误差: {total_error/num_cases:.4f}\n")
show_test_cases(rnn_model, "RNN")
show_test_cases(gru_model, "GRU")

View File

@ -32,10 +32,9 @@ func main() {
}
```
- 我们写完了程序,如何把这个程序运行起来呢,分为两步:
- 第一步:执行指令`go mod init`,初始化与版本相关联的 Go 包的集合确定了根目录、定义了项目的依赖和版本确保项目可以重建。也叫做Go的模块这一步会在当前路径下创建`go.mod`文件
- 第一步:执行指令`go mod init xxx` ,初始化与版本相关联的 Go 包的集合确定了根目录、定义了项目的依赖和版本确保项目可以重建。也叫做Go的模块这一步会在当前路径下创建`go.mod`文件
- 第二步:执行指令`go mod tidy`拉取我们需要的go的组建又叫做库这个操作可以类比为`pip install -r requirements.txt`所需要的包IDE会自动写入`go.mod`文件中。
- ![image.png](https://krust.top:5244/d/public/KrustBlogPNG/20250319090617163.png?sign=j3448Wo2gLjatAdHKplY3LSbHQv7-H445FLMbOhOFk0=:0)
- 随后通过上面这两步go程序就初始化好了相关的包也下载好了接下来我们就能重新运行了
- 运行项目:`go run` 开始执行整个项目
- 运行单个程序: `go run xxx.go`
- 运行项目: `go run xxx.go`
- 构建go项目将整个项目打包成为可执行文件 `go build`

View File

@ -0,0 +1,265 @@
### **1. 基础结构**
```go
package main // 包声明(必须)
import "fmt" // 导入包
func main() { // 主函数(程序入口)
fmt.Println("Hello, World!")
}
```
---
### **2. 变量与常量**
- **变量声明**
```go
var a int = 10 // 显式类型
var b = 20 // 类型推断
c := 30 // 短声明(函数内使用)
var d, e int = 1, 2 // 多变量声明
```
- **常量**
```go
const Pi = 3.14
const (
A = 1
B = 2
)
```
- 着重强调:**开头是大写的是Public**
---
### **3. 基本数据类型**
- **基本类型**
```go
int, int8, int16, int32, int64
uint, uint8, uint16, uint32, uint64
float32, float64
bool
string
byte (等同于 uint8)
rune (等同于 int32, 表示 Unicode 字符)
```
- **复合类型**
```go
var arr [3]int // 数组(固定长度)
slice := []int{1, 2, 3} // 切片(动态数组)
m := map[string]int{"key": 1} // 映射(字典)
type Person struct { Name string } // 结构体
```
---
### **4. 控制结构**
- **条件语句**
```go
if x > 0 {
// ...
} else if x == 0 {
// ...
} else {
// ...
}
```
- **循环**
```go
for i := 0; i < 10; i++ { ... } // 传统 for 循环
for i < 10 { ... } // 类似 while
for index, value := range slice { ... } // 遍历切片/数组/map
```
- **Switch**
```go
switch x {
case 1:
// ...
case 2, 3: // 多值匹配
// ...
default:
// ...
}
```
- 特点:不需要break
---
### **5. 函数**
- **基本函数**
```go
func add(a int, b int) int {
return a + b
}
```
- **多返回值**
```go
func swap(a, b int) (int, int) {
return b, a
}
```
- **命名返回值**
```go
func split(sum int) (x, y int) {
x = sum * 4 / 9
y = sum - x
return // 隐式返回 x, y
}
```
- **匿名函数与闭包**(即函数内套函数)
```go
func() {
fmt.Println("Anonymous function")
}()
```
---
### **6. 指针与结构体**
- **指针**
```go
var p *int
x := 10
p = &x
*p = 20 // 修改 x 的值
```
- 特点:
- Go 的指针**不支持算术运算**,避免了内存越界和非法访问的风险,同时通过垃圾回收机制自动管理内存,**减少了内存泄漏的可能性**。
- Go 的指针类型严格区分,空指针用 `nil` 表示,解引用空指针会触发 panic不支持**指针算术运算和强制类型转换**。
- **结构体与方法**
```go
type Circle struct {
Radius float64
}
// 方法(值接收者)
func (c Circle) Area() float64 {
return math.Pi * c.Radius * c.Radius
}
// 方法(指针接收者)
func (c *Circle) Scale(factor float64) {
c.Radius *= factor
}
```
---
### **7. 接口与错误处理**
- **接口**
```go
type Shape interface {
Area() float64
}
// 隐式实现接口
func (c Circle) Area() float64 { ... }
```
- **错误处理**
```go
func readFile() ([]byte, error) {
data, err := os.ReadFile("file.txt")
if err != nil {
return nil, err
}
return data, nil
}
// 调用
data, err := readFile()
if err != nil {
log.Fatal(err)
}
```
- **Panic & Recover**
```go
func safeCall() {
defer func() {
if r := recover(); r != nil {
fmt.Println("Recovered:", r)
}
}()
panic("Something went wrong!")
}
```
---
### **8. 并发编程**
- **Goroutine**
```go
go func() {
fmt.Println("Running in goroutine")
}()
```
- **Channel**(数据通道)
```go
ch := make(chan int)
go func() { ch <- 1 }() // 发送数据
value := <-ch // 接收数据
```
- **Select**(主要用于事件驱动)
```go
select {
case msg1 := <-ch1:
fmt.Println(msg1)
case msg2 := <-ch2:
fmt.Println(msg2)
case <-time.After(1 * time.Second):
fmt.Println("Timeout")
}
```
---
### **9. 包与模块**
- **创建模块**
```bash
go mod init xx
```
- **导入包**
```go
import (
"fmt"
"math/rand"
"github.com/user/package"
)
```
---
### **10. 其他特性**
- **Defer**
- `defer` 是 Go 语言中的一个关键字,用于延迟执行一个函数调用。**被 `defer` 修饰的函数调用会推迟到当前函数返回之前执行**,无论当前函数是正常返回还是由于错误(如 `panic`)提前返回。`defer` 的主要用途是确保某些操作(如资源释放、清理工作等)一定会被执行,避免遗漏。
```go
func readFile() {
file, _ := os.Open("file.txt")
defer file.Close() // 函数返回前执行
// ...
}
```
- **JSON 处理**
```go
type User struct {
Name string `json:"name"`
Age int `json:"age"`
}
data, _ := json.Marshal(user) //序列化
```
---
### **常用内置函数**
- `len()`:获取长度
- `cap()`:切片容量
- `make()`:创建切片/map/channel
- `append()`:切片追加元素

Binary file not shown.

View File

@ -0,0 +1 @@
package main

View File

@ -0,0 +1,3 @@
module BasicGrammar
go 1.24.1

View File

@ -0,0 +1,147 @@
package main
import (
"errors"
"fmt"
"sync"
"time"
)
// 常量
const Pi = 3.14159
// 结构体
type Circle struct {
Radius float64
}
// 方法(值接收者)
func (c Circle) Area() float64 {
return Pi * c.Radius * c.Radius
}
// 方法(指针接收者)
func (c *Circle) Scale(factor float64) {
c.Radius *= factor
}
// 接口
type Shape interface {
Area() float64
}
// 函数(多返回值)
func divide(a, b float64) (float64, error) {
if b == 0 {
return 0, errors.New("division by zero")
}
return a / b, nil
}
// 主函数
func main() {
// 变量声明
var a int = 10
b := 20
c, d := 30, 40
fmt.Println("Variables:", a, b, c, d)
// 数组与切片
arr := [3]int{1, 2, 3}
slice := []int{4, 5, 6}
slice = append(slice, 7)
fmt.Println("Array:", arr, "Slice:", slice)
// 映射
m := map[string]int{"one": 1, "two": 2}
fmt.Println("Map:", m)
// 控制结构
if a > 5 {
fmt.Println("a is greater than 5")
} else {
fmt.Println("a is not greater than 5")
}
for i := 0; i < 3; i++ {
fmt.Println("Loop:", i)
}
switch a {
case 10:
fmt.Println("a is 10")
default:
fmt.Println("a is not 10")
}
// 函数调用
result, err := divide(10, 2)
if err != nil {
fmt.Println("Error:", err)
} else {
fmt.Println("Division result:", result)
}
// 指针
x := 10
p := &x
*p = 20
fmt.Println("Pointer:", x)
// 结构体与方法
circle := Circle{Radius: 5}
fmt.Println("Circle area:", circle.Area())
circle.Scale(2)
fmt.Println("Scaled circle area:", circle.Area())
// 接口
var shape Shape = Circle{Radius: 3}
fmt.Println("Shape area:", shape.Area())
// 错误处理
_, err = divide(10, 0)
if err != nil {
fmt.Println("Error:", err)
}
// 并发
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
fmt.Println("Running in goroutine")
}()
wg.Wait()
// Channel
ch := make(chan int)
go func() {
ch <- 42
}()
value := <-ch
fmt.Println("Channel value:", value)
// Select
ch1 := make(chan string)
ch2 := make(chan string)
go func() {
time.Sleep(1 * time.Second)
ch1 <- "from ch1"
}()
go func() {
time.Sleep(2 * time.Second)
ch2 <- "from ch2"
}()
select {
case msg := <-ch1:
fmt.Println("Select:", msg)
case msg := <-ch2:
fmt.Println("Select:", msg)
case <-time.After(3 * time.Second):
fmt.Println("Timeout")
}
// Defer
defer fmt.Println("Defer: This runs last")
fmt.Println("Main function end")
}

View File

@ -0,0 +1,21 @@
//go:build ignore
// +build ignore
#include "stdio.h"
int main() {
int a = 2;
switch (a) {
case 1:
printf("a = 1\n");
break;
case 2:
printf("a = 2\n");
break;
case 3:
printf("a = 3\n");
break;
default:
printf("FUCK"); // default 语句是可选的
}

View File

@ -0,0 +1,31 @@
# Gorm的使用
## Gorm是什么
- GORM 是 Go 语言的 ORM 库,提供模型定义、关联管理、事务支持、查询构建、数据迁移、钩子回调等功能,支持主流数据库(如 MySQL/PostgreSQL/SQLite简化数据库操作。
## 如何使用Gorm
### 导入Gorm库
```go
import (
"gorm.io/driver/mysql"
)
```
### 基本操作
- 1、连接数据库
```go
const (
USERNAME = "root"
PASSWD = "oppofindx2"
DATABASENAME = "Class"
)
dsn := fmt.Sprintf("%s:%s@tcp(127.0.0.1:3306)/%s?charset=utf8mb4&parseTime=True&loc=Local", USERNAME, PASSWD, DATABASENAME)
db := mysql.Open(dsn)
`````
- 2、数据库中表的定义
- 在Gorm中定义一张表使用的是结构体
```go
type User struct {
gorm.Model
Name string
Age int
}
```
- 自动迁移表结构(方便我们修改表,给表添加参数)

View File

@ -0,0 +1,12 @@
module Gorm
go 1.24.1
require gorm.io/driver/mysql v1.5.7
require (
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
gorm.io/gorm v1.25.7 // indirect
)

View File

@ -0,0 +1,10 @@
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo=
gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
gorm.io/gorm v1.25.7 h1:VsD6acwRjz2zFxGO50gPO6AkNs7KKnvfzUjHQhZDz/A=
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=

View File

@ -0,0 +1,26 @@
package main
import (
"fmt"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
const (
USERNAME = "root"
PASSWD = "oppofindx2"
DATABASENAME = "Class"
)
func main() {
dsn := fmt.Sprintf("%s:%s@tcp(127.0.0.1:3306)/%s?charset=utf8mb4&parseTime=True&loc=Local", USERNAME, PASSWD, DATABASENAME)
db := mysql.Open(dsn)
print(db)
}
type User struct {
gorm.Model
Name string
Age int
}

111
NN_normal.py Normal file
View File

@ -0,0 +1,111 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
# 1. 数据准备以MNIST手写数字识别为例
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # 像素值归一化到[-1,1]
])
train_set = datasets.MNIST('data', download=True, train=True, transform=transform)
test_set = datasets.MNIST('data', download=True, train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1000)
# 2. 神经网络模型(演示梯度控制技巧)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
# He初始化适配ReLU
nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='relu')
nn.init.kaiming_normal_(self.fc2.weight, nonlinearity='relu')
def forward(self, x):
x = x.view(-1, 784) # 展平图像
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x) # 输出层无需激活CrossEntropyLoss内置Softmax
return x
# 3. 训练配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# 梯度裁剪阈值
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
# 4. 训练过程可视化记录
train_losses = []
accuracies = []
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# 记录训练损失
if batch_idx % 100 == 0:
train_losses.append(loss.item())
# 5. 测试函数(含准确率计算)
def test():
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
accuracy = 100. * correct / len(test_loader.dataset)
accuracies.append(accuracy)
return test_loss
# 6. 执行训练3个epoch演示
for epoch in range(1, 4):
train(epoch)
loss = test()
print(f'Epoch {epoch}: Test Loss={loss:.4f}, Accuracy={accuracies[-1]:.2f}%')
# 7. 可视化训练过程
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(train_losses, label='Training Loss')
plt.title("Loss Curve")
plt.subplot(1,2,2)
plt.plot(accuracies, label='Accuracy', color='orange')
plt.title("Accuracy Curve")
plt.show()
# 8. 示例预测展示
sample_data, sample_label = next(iter(test_loader))
with torch.no_grad():
prediction = model(sample_data.to(device)).argmax(dim=1)
# 显示预测结果对比
plt.figure(figsize=(10,6))
for i in range(6):
plt.subplot(2,3,i+1)
plt.imshow(sample_data[i][0], cmap='gray')
plt.title(f"True: {sample_label[i]}\nPred: {prediction[i].item()}")
plt.axis('off')
plt.tight_layout()
plt.show()

59
RNN_good.py Normal file
View File

@ -0,0 +1,59 @@
import numpy as np
from keras.layers import Embedding, SimpleRNN, Dense
from keras.models import Sequential
# 训练数据(包含逗号)
text = "用户:今天想吃火锅吗? 客服:我们海鲜火锅很受欢迎。用户:但朋友对海鲜过敏,推荐其他吧。客服:好的,我们有菌汤火锅。"
base_chars = ['', '', '', ''] # 确保基础标点存在
chars = sorted(list(set(text + ''.join(base_chars))))
char_to_idx = {c:i for i,c in enumerate(chars)}
idx_to_char = {i:c for c,i in char_to_idx.items()}
# 创建训练序列
max_length = 20
X, y = [], []
for i in range(len(text)-max_length):
seq = text[i:i+max_length]
target = text[i+max_length]
X.append([char_to_idx[c] for c in seq])
y.append(char_to_idx[target])
# 模型构建
model = Sequential([
Embedding(input_dim=len(chars), output_dim=32, input_length=max_length),
SimpleRNN(128),
Dense(len(chars), activation='softmax')
])
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')
# 训练
X = np.array(X)
y = np.array(y)
model.fit(X, y, epochs=50, batch_size=32)
# 增强后的生成函数
def generate_response(prompt):
generated = prompt
for _ in range(30):
# 过滤并处理未知字符
valid_chars = []
for c in generated[-max_length:]:
if c in char_to_idx:
valid_chars.append(c)
else:
valid_chars.append(' ') # 未知字符替换为空格
# 填充序列
seq = valid_chars[-max_length:]
seq = seq + [' ']*(max_length - len(seq))
# 转换为索引
seq_indices = [char_to_idx[c] for c in seq]
# 生成下一个字符
pred = model.predict(np.array([seq_indices]), verbose=0)
next_char = idx_to_char[np.argmax(pred)]
generated += next_char
return generated
print(generate_response("用户:朋友海鲜过敏,能不能推荐一些其他的?"))