diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..0f06797 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "files.associations": { + "stdio.h": "c" + } +} \ No newline at end of file diff --git a/Game1.py b/Game1.py new file mode 100644 index 0000000..2e025f8 --- /dev/null +++ b/Game1.py @@ -0,0 +1,80 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt +import torch.nn as nn + +import torch.optim as optim + +# 生成合成数据 +t = np.linspace(0, 24*np.pi, 1000) +data = np.sin(t) + 0.5*np.sin(3*t) + 0.05*t # 混合波形+趋势项 + +# 数据预处理 +def create_dataset(data, look_back=30): + X, y = [], [] + for i in range(len(data)-look_back): + X.append(data[i:i+look_back]) + y.append(data[i+look_back]) + return torch.FloatTensor(X).unsqueeze(-1), torch.FloatTensor(y) + +X, y = create_dataset(data) +train_size = int(0.8 * len(X)) +train_X, test_X = X[:train_size], X[train_size:] +train_y, test_y = y[:train_size], y[train_size:] + +# 模型定义 +class TimeSeriesModel(nn.Module): + def __init__(self, model_type): + super().__init__() + self.model_type = model_type + if model_type == 'LSTM': + self.rnn = nn.LSTM(1, 64, num_layers=2) + else: + self.rnn = nn.RNN(1, 64) + self.fc = nn.Linear(64, 1) + + def forward(self, x): + out, _ = self.rnn(x) + return self.fc(out[-1, :, :]) + +# 训练函数 +# 修改后的训练函数,返回预测结果和测试损失 +def train_and_predict(model_type): + model = TimeSeriesModel(model_type) + criterion = nn.MSELoss() + optimizer = optim.Adam(model.parameters(), lr=0.001) + + # 训练循环 + for epoch in range(100): + output = model(train_X.transpose(0, 1)) + loss = criterion(output.squeeze(), train_y) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if epoch % 20 == 0: + print(f"{model_type} Epoch {epoch} Loss: {loss.item():.4f}") + + # 预测阶段 + with torch.no_grad(): + test_pred = model(test_X.transpose(0, 1)) + test_loss = criterion(test_pred.squeeze(), test_y) + print(f"{model_type} Test MSE: {test_loss.item():.4f}") + + return test_pred.squeeze().numpy(), test_loss.item() + +# 同时训练两种模型并收集结果 +lstm_pred, lstm_loss = train_and_predict('LSTM') +rnn_pred, rnn_loss = train_and_predict('RNN') + +# 统一可视化比较 +plt.figure(figsize=(12,6)) +plt.plot(test_y.numpy(), label='True Values', alpha=0.7) +plt.plot(lstm_pred, label=f'LSTM (MSE: {lstm_loss:.4f})', linestyle='--') +plt.plot(rnn_pred, label=f'RNN (MSE: {rnn_loss:.4f})', linestyle='--') +plt.title('Time Series Prediction Comparison') +plt.xlabel('Time Steps') +plt.ylabel('Value') +plt.legend() +plt.show() \ No newline at end of file diff --git a/Game2.py b/Game2.py new file mode 100644 index 0000000..7dd67c2 --- /dev/null +++ b/Game2.py @@ -0,0 +1,124 @@ +import torch +import torch.nn as nn +import numpy as np + +#这个实验的目的是比较RNN和GRU在相同任务上的性能,即学习序列中两个随机位置数值的和。 +# 数据生成 +def generate_add_data(seq_len=30): + data = torch.zeros(seq_len, 2) # (seq_len, 2) + idx1, idx2 = np.random.choice(seq_len, 2, replace=False) + val1, val2 = np.random.rand()*0.5, np.random.rand()*0.5 + data[idx1, 0] = val1 + data[idx2, 0] = val2 + target = torch.tensor([val1 + val2]).view(1,1) + return data.unsqueeze(0), target # (1, seq_len, 2) + +# 模型定义 +class AdditionRNN(nn.Module): + def __init__(self): + super().__init__() + self.rnn = nn.RNN(2, 16, batch_first=True) + self.fc = nn.Linear(16, 1) + + def forward(self, x): + out, _ = self.rnn(x) + return self.fc(out[:, -1, :]) + +class AdditionGRU(nn.Module): + def __init__(self): + super().__init__() + self.gru = nn.GRU(2, 16, batch_first=True) + self.fc = nn.Linear(16, 1) + + def forward(self, x): + out, _ = self.gru(x) + return self.fc(out[:, -1, :]) + +# 训练对比 +import matplotlib.pyplot as plt + +# 修改后的训练函数,记录损失变化 +def train_addition(): + rnn = AdditionRNN() + gru = AdditionGRU() + criterion = nn.MSELoss() + optim_rnn = torch.optim.Adam(rnn.parameters(), lr=0.01) + optim_gru = torch.optim.Adam(gru.parameters(), lr=0.01) + + # 记录训练过程 + losses = {'RNN': [], 'GRU': []} + + for step in range(1000): + inputs, target = generate_add_data(seq_len=30) + + # RNN训练 + optim_rnn.zero_grad() + rnn_pred = rnn(inputs) + loss_rnn = criterion(rnn_pred, target) + loss_rnn.backward() + optim_rnn.step() + + # GRU训练 + optim_gru.zero_grad() + gru_pred = gru(inputs) + loss_gru = criterion(gru_pred, target) + loss_gru.backward() + optim_gru.step() + + # 记录损失 + losses['RNN'].append(loss_rnn.item()) + losses['GRU'].append(loss_gru.item()) + + if step % 200 == 0: + print(f"Step {step:03d} | RNN Loss: {loss_rnn.item():.4f} | GRU Loss: {loss_gru.item():.4f}") + + # 绘制损失曲线 + plt.figure(figsize=(10, 5)) + plt.plot(losses['RNN'], label='RNN', alpha=0.7) + plt.plot(losses['GRU'], label='GRU', alpha=0.7) + plt.xlabel('Training Steps') + plt.ylabel('MSE Loss') + plt.title('Training Comparison: RNN vs GRU') + plt.legend() + plt.grid(True) + plt.show() + + return rnn, gru + +# 执行训练 +rnn_model, gru_model = train_addition() +def show_test_cases(model, model_name, num_cases=5): + print(f"\n{model_name} 测试样例:") + criterion = nn.MSELoss() + total_error = 0 + + for case_idx in range(num_cases): + # 生成测试数据 + inputs, target = generate_add_data() + seq_len = inputs.shape[1] + + # 解析输入数据 + non_zero_indices = torch.nonzero(inputs[0, :, 0]) + pos1, pos2 = non_zero_indices[0].item(), non_zero_indices[1].item() + val1 = inputs[0, pos1, 0].item() + val2 = inputs[0, pos2, 0].item() + + # 模型预测 + with torch.no_grad(): + pred = model(inputs) + loss = criterion(pred, target) + + # 格式输出 + print(f"案例 {case_idx+1}:") + print(f"输入序列长度: {seq_len}") + print(f"数值位置: [{pos1:2d}]={val1:.4f}, [{pos2:2d}]={val2:.4f}") + print(f"真实值: {target.item():.4f}") + print(f"预测值: {pred.item():.4f}") + print(f"绝对误差: {abs(pred.item()-target.item()):.4f}") + print("-" * 40) + total_error += abs(pred.item()-target.item()) + + print(f"平均绝对误差: {total_error/num_cases:.4f}\n") + +show_test_cases(rnn_model, "RNN") +show_test_cases(gru_model, "GRU") diff --git a/Go入门速成/Day1/Class1 安装、初始化Go.md b/Go入门速成/Day1/Class1 安装、初始化Go.md index fe9f0d8..b4ed52e 100644 --- a/Go入门速成/Day1/Class1 安装、初始化Go.md +++ b/Go入门速成/Day1/Class1 安装、初始化Go.md @@ -32,10 +32,9 @@ func main() { } ``` - 我们写完了程序,如何把这个程序运行起来呢,分为两步: - - 第一步:执行指令`go mod init`,初始化与版本相关联的 Go 包的集合,确定了根目录、定义了项目的依赖和版本,确保项目可以重建。(也叫做Go的模块)!这一步,会在当前路径下创建`go.mod`文件 + - 第一步:执行指令`go mod init xxx` ,初始化与版本相关联的 Go 包的集合,确定了根目录、定义了项目的依赖和版本,确保项目可以重建。(也叫做Go的模块)!这一步,会在当前路径下创建`go.mod`文件 - 第二步:执行指令`go mod tidy`,拉取我们需要的go的组建(又叫做库),这个操作可以类比为`pip install -r requirements.txt`,所需要的包IDE会自动写入`go.mod`文件中。 - ![image.png](https://krust.top:5244/d/public/KrustBlogPNG/20250319090617163.png?sign=j3448Wo2gLjatAdHKplY3LSbHQv7-H445FLMbOhOFk0=:0) - 随后,通过上面这两步,go程序就初始化好了,相关的包也下载好了,接下来,我们就能重新运行了 - - 运行项目:`go run` 开始执行整个项目 - - 运行单个程序: `go run xxx.go` + - 运行项目: `go run xxx.go` - 构建go项目(将整个项目打包成为可执行文件) `go build` \ No newline at end of file diff --git a/Go入门速成/Day1/Class2 Go的基本语法.md b/Go入门速成/Day1/Class2 Go的基本语法.md index e69de29..f030795 100644 --- a/Go入门速成/Day1/Class2 Go的基本语法.md +++ b/Go入门速成/Day1/Class2 Go的基本语法.md @@ -0,0 +1,265 @@ +### **1. 基础结构** +```go +package main // 包声明(必须) +import "fmt" // 导入包 + +func main() { // 主函数(程序入口) + fmt.Println("Hello, World!") +} +``` + +--- + +### **2. 变量与常量** +- **变量声明** + ```go + var a int = 10 // 显式类型 + var b = 20 // 类型推断 + c := 30 // 短声明(函数内使用) + var d, e int = 1, 2 // 多变量声明 + ``` + +- **常量** + ```go + const Pi = 3.14 + const ( + A = 1 + B = 2 + ) + ``` +- 着重强调:**开头是大写的是Public!** +--- + +### **3. 基本数据类型** +- **基本类型** + ```go + int, int8, int16, int32, int64 + uint, uint8, uint16, uint32, uint64 + float32, float64 + bool + string + byte (等同于 uint8) + rune (等同于 int32, 表示 Unicode 字符) + ``` + +- **复合类型** + ```go + var arr [3]int // 数组(固定长度) + slice := []int{1, 2, 3} // 切片(动态数组) + m := map[string]int{"key": 1} // 映射(字典) + type Person struct { Name string } // 结构体 + ``` + +--- + +### **4. 控制结构** +- **条件语句** + ```go + if x > 0 { + // ... + } else if x == 0 { + // ... + } else { + // ... + } + ``` + +- **循环** + ```go + for i := 0; i < 10; i++ { ... } // 传统 for 循环 + for i < 10 { ... } // 类似 while + for index, value := range slice { ... } // 遍历切片/数组/map + ``` + +- **Switch** + ```go + switch x { + case 1: + // ... + case 2, 3: // 多值匹配 + // ... + default: + // ... + } + ``` +- 特点:不需要break! +--- + +### **5. 函数** +- **基本函数** + ```go + func add(a int, b int) int { + return a + b + } + ``` + +- **多返回值** + ```go + func swap(a, b int) (int, int) { + return b, a + } + ``` + +- **命名返回值** + ```go + func split(sum int) (x, y int) { + x = sum * 4 / 9 + y = sum - x + return // 隐式返回 x, y + } + ``` + +- **匿名函数与闭包**(即函数内套函数) + ```go + func() { + fmt.Println("Anonymous function") + }() + ``` + +--- + +### **6. 指针与结构体** +- **指针** + ```go + var p *int + x := 10 + p = &x + *p = 20 // 修改 x 的值 + ``` +- 特点: + - Go 的指针**不支持算术运算**,避免了内存越界和非法访问的风险,同时通过垃圾回收机制自动管理内存,**减少了内存泄漏的可能性**。 + - Go 的指针类型严格区分,空指针用 `nil` 表示,解引用空指针会触发 panic,不支持**指针算术运算和强制类型转换**。 +- **结构体与方法** + ```go + type Circle struct { + Radius float64 + } + + // 方法(值接收者) + func (c Circle) Area() float64 { + return math.Pi * c.Radius * c.Radius + } + + // 方法(指针接收者) + func (c *Circle) Scale(factor float64) { + c.Radius *= factor + } + ``` + +--- + +### **7. 接口与错误处理** +- **接口** + ```go + type Shape interface { + Area() float64 + } + + // 隐式实现接口 + func (c Circle) Area() float64 { ... } + ``` + +- **错误处理** + ```go + func readFile() ([]byte, error) { + data, err := os.ReadFile("file.txt") + if err != nil { + return nil, err + } + return data, nil + } + + // 调用 + data, err := readFile() + if err != nil { + log.Fatal(err) + } + ``` + +- **Panic & Recover** + ```go + func safeCall() { + defer func() { + if r := recover(); r != nil { + fmt.Println("Recovered:", r) + } + }() + panic("Something went wrong!") + } + ``` + +--- + +### **8. 并发编程** +- **Goroutine** + ```go + go func() { + fmt.Println("Running in goroutine") + }() + ``` + +- **Channel**(数据通道) + ```go + ch := make(chan int) + go func() { ch <- 1 }() // 发送数据 + value := <-ch // 接收数据 + ``` + +- **Select**(主要用于事件驱动) + ```go + select { + case msg1 := <-ch1: + fmt.Println(msg1) + case msg2 := <-ch2: + fmt.Println(msg2) + case <-time.After(1 * time.Second): + fmt.Println("Timeout") + } + ``` + +--- + +### **9. 包与模块** +- **创建模块** + ```bash + go mod init xx + ``` + +- **导入包** + ```go + import ( + "fmt" + "math/rand" + "github.com/user/package" + ) + ``` + +--- + +### **10. 其他特性** +- **Defer** +- `defer` 是 Go 语言中的一个关键字,用于延迟执行一个函数调用。**被 `defer` 修饰的函数调用会推迟到当前函数返回之前执行**,无论当前函数是正常返回还是由于错误(如 `panic`)提前返回。`defer` 的主要用途是确保某些操作(如资源释放、清理工作等)一定会被执行,避免遗漏。 + ```go + func readFile() { + file, _ := os.Open("file.txt") + defer file.Close() // 函数返回前执行 + // ... + } + ``` + +- **JSON 处理** + ```go + type User struct { + Name string `json:"name"` + Age int `json:"age"` + } + data, _ := json.Marshal(user) //序列化 + ``` + +--- + +### **常用内置函数** +- `len()`:获取长度 +- `cap()`:切片容量 +- `make()`:创建切片/map/channel +- `append()`:切片追加元素 diff --git a/Go入门速成/Day1/ExampleCode/BasicGrammar/BasicGrammar b/Go入门速成/Day1/ExampleCode/BasicGrammar/BasicGrammar new file mode 100755 index 0000000..a3bd77e Binary files /dev/null and b/Go入门速成/Day1/ExampleCode/BasicGrammar/BasicGrammar differ diff --git a/Go入门速成/Day1/ExampleCode/BasicGrammar/PointerPanic.go b/Go入门速成/Day1/ExampleCode/BasicGrammar/PointerPanic.go new file mode 100644 index 0000000..06ab7d0 --- /dev/null +++ b/Go入门速成/Day1/ExampleCode/BasicGrammar/PointerPanic.go @@ -0,0 +1 @@ +package main diff --git a/Go入门速成/Day1/ExampleCode/BasicGrammar/go.mod b/Go入门速成/Day1/ExampleCode/BasicGrammar/go.mod new file mode 100644 index 0000000..110fe6f --- /dev/null +++ b/Go入门速成/Day1/ExampleCode/BasicGrammar/go.mod @@ -0,0 +1,3 @@ +module BasicGrammar + +go 1.24.1 diff --git a/Go入门速成/Day1/ExampleCode/BasicGrammar/main.go b/Go入门速成/Day1/ExampleCode/BasicGrammar/main.go new file mode 100644 index 0000000..d69aa4c --- /dev/null +++ b/Go入门速成/Day1/ExampleCode/BasicGrammar/main.go @@ -0,0 +1,147 @@ +package main + +import ( + "errors" + "fmt" + "sync" + "time" +) + +// 常量 +const Pi = 3.14159 + +// 结构体 +type Circle struct { + Radius float64 +} + +// 方法(值接收者) +func (c Circle) Area() float64 { + return Pi * c.Radius * c.Radius +} + +// 方法(指针接收者) +func (c *Circle) Scale(factor float64) { + c.Radius *= factor +} + +// 接口 +type Shape interface { + Area() float64 +} + +// 函数(多返回值) +func divide(a, b float64) (float64, error) { + if b == 0 { + return 0, errors.New("division by zero") + } + return a / b, nil +} + +// 主函数 +func main() { + // 变量声明 + var a int = 10 + b := 20 + c, d := 30, 40 + fmt.Println("Variables:", a, b, c, d) + + // 数组与切片 + arr := [3]int{1, 2, 3} + slice := []int{4, 5, 6} + slice = append(slice, 7) + fmt.Println("Array:", arr, "Slice:", slice) + + // 映射 + m := map[string]int{"one": 1, "two": 2} + fmt.Println("Map:", m) + + // 控制结构 + if a > 5 { + fmt.Println("a is greater than 5") + } else { + fmt.Println("a is not greater than 5") + } + + for i := 0; i < 3; i++ { + fmt.Println("Loop:", i) + } + + switch a { + case 10: + fmt.Println("a is 10") + default: + fmt.Println("a is not 10") + } + + // 函数调用 + result, err := divide(10, 2) + if err != nil { + fmt.Println("Error:", err) + } else { + fmt.Println("Division result:", result) + } + + // 指针 + x := 10 + p := &x + *p = 20 + fmt.Println("Pointer:", x) + + // 结构体与方法 + circle := Circle{Radius: 5} + fmt.Println("Circle area:", circle.Area()) + circle.Scale(2) + fmt.Println("Scaled circle area:", circle.Area()) + + // 接口 + var shape Shape = Circle{Radius: 3} + fmt.Println("Shape area:", shape.Area()) + + // 错误处理 + _, err = divide(10, 0) + if err != nil { + fmt.Println("Error:", err) + } + + // 并发 + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + fmt.Println("Running in goroutine") + }() + wg.Wait() + + // Channel + ch := make(chan int) + go func() { + ch <- 42 + }() + value := <-ch + fmt.Println("Channel value:", value) + + // Select + ch1 := make(chan string) + ch2 := make(chan string) + go func() { + time.Sleep(1 * time.Second) + ch1 <- "from ch1" + }() + go func() { + time.Sleep(2 * time.Second) + ch2 <- "from ch2" + }() + select { + case msg := <-ch1: + fmt.Println("Select:", msg) + case msg := <-ch2: + fmt.Println("Select:", msg) + case <-time.After(3 * time.Second): + fmt.Println("Timeout") + } + + // Defer + defer fmt.Println("Defer: This runs last") + fmt.Println("Main function end") +} diff --git a/Go入门速成/Day1/ExampleCode/BasicGrammar/switch.c b/Go入门速成/Day1/ExampleCode/BasicGrammar/switch.c new file mode 100644 index 0000000..734c21b --- /dev/null +++ b/Go入门速成/Day1/ExampleCode/BasicGrammar/switch.c @@ -0,0 +1,21 @@ +//go:build ignore +// +build ignore + + +#include "stdio.h" + +int main() { + int a = 2; + switch (a) { + case 1: + printf("a = 1\n"); + break; + case 2: + printf("a = 2\n"); + break; + case 3: + printf("a = 3\n"); + break; + default: + printf("FUCK"); // default 语句是可选的 +} \ No newline at end of file diff --git a/Go入门速成/Day1/ExampleCode/Helloworld/main b/Go入门速成/Day1/ExampleCode/Helloworld/main deleted file mode 100755 index 6ce6e10..0000000 Binary files a/Go入门速成/Day1/ExampleCode/Helloworld/main and /dev/null differ diff --git a/Go入门速成/Day2/Class3 Gorm的使用.md b/Go入门速成/Day2/Class3 Gorm的使用.md new file mode 100644 index 0000000..86c5bdf --- /dev/null +++ b/Go入门速成/Day2/Class3 Gorm的使用.md @@ -0,0 +1,31 @@ +# Gorm的使用 +## Gorm是什么 +- GORM 是 Go 语言的 ORM 库,提供模型定义、关联管理、事务支持、查询构建、数据迁移、钩子回调等功能,支持主流数据库(如 MySQL/PostgreSQL/SQLite),简化数据库操作。 +## 如何使用Gorm +### 导入Gorm库 +```go +import ( +"gorm.io/driver/mysql" +) +``` +### 基本操作 +- 1、连接数据库 +```go +const ( +USERNAME = "root" +PASSWD = "oppofindx2" +DATABASENAME = "Class" +) +dsn := fmt.Sprintf("%s:%s@tcp(127.0.0.1:3306)/%s?charset=utf8mb4&parseTime=True&loc=Local", USERNAME, PASSWD, DATABASENAME) +db := mysql.Open(dsn) +````` +- 2、数据库中表的定义 + - 在Gorm中,定义一张表使用的是结构体 + ```go +type User struct { + gorm.Model + Name string + Age int +} + ``` +- 自动迁移表结构(方便我们修改表,给表添加参数) \ No newline at end of file diff --git a/Go入门速成/Day2/ExampleCode/Gorm/go.mod b/Go入门速成/Day2/ExampleCode/Gorm/go.mod new file mode 100644 index 0000000..ab8ee47 --- /dev/null +++ b/Go入门速成/Day2/ExampleCode/Gorm/go.mod @@ -0,0 +1,12 @@ +module Gorm + +go 1.24.1 + +require gorm.io/driver/mysql v1.5.7 + +require ( + github.com/go-sql-driver/mysql v1.7.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + gorm.io/gorm v1.25.7 // indirect +) diff --git a/Go入门速成/Day2/ExampleCode/Gorm/go.sum b/Go入门速成/Day2/ExampleCode/Gorm/go.sum new file mode 100644 index 0000000..f0385d3 --- /dev/null +++ b/Go入门速成/Day2/ExampleCode/Gorm/go.sum @@ -0,0 +1,10 @@ +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo= +gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= +gorm.io/gorm v1.25.7 h1:VsD6acwRjz2zFxGO50gPO6AkNs7KKnvfzUjHQhZDz/A= +gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= diff --git a/Go入门速成/Day2/ExampleCode/Gorm/main.go b/Go入门速成/Day2/ExampleCode/Gorm/main.go new file mode 100644 index 0000000..b4ead13 --- /dev/null +++ b/Go入门速成/Day2/ExampleCode/Gorm/main.go @@ -0,0 +1,26 @@ +package main + +import ( + "fmt" + + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +const ( + USERNAME = "root" + PASSWD = "oppofindx2" + DATABASENAME = "Class" +) + +func main() { + dsn := fmt.Sprintf("%s:%s@tcp(127.0.0.1:3306)/%s?charset=utf8mb4&parseTime=True&loc=Local", USERNAME, PASSWD, DATABASENAME) + db := mysql.Open(dsn) + print(db) +} + +type User struct { + gorm.Model + Name string + Age int +} diff --git a/NN_normal.py b/NN_normal.py new file mode 100644 index 0000000..93dc95f --- /dev/null +++ b/NN_normal.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torchvision import datasets, transforms +import matplotlib.pyplot as plt + +# 1. 数据准备(以MNIST手写数字识别为例) +transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) # 像素值归一化到[-1,1] +]) + +train_set = datasets.MNIST('data', download=True, train=True, transform=transform) +test_set = datasets.MNIST('data', download=True, train=False, transform=transform) + +train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True) +test_loader = torch.utils.data.DataLoader(test_set, batch_size=1000) + +# 2. 神经网络模型(演示梯度控制技巧) +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.fc1 = nn.Linear(784, 128) + self.fc2 = nn.Linear(128, 64) + self.fc3 = nn.Linear(64, 10) + + # He初始化适配ReLU + nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='relu') + nn.init.kaiming_normal_(self.fc2.weight, nonlinearity='relu') + + def forward(self, x): + x = x.view(-1, 784) # 展平图像 + x = torch.relu(self.fc1(x)) + x = torch.relu(self.fc2(x)) + x = self.fc3(x) # 输出层无需激活(CrossEntropyLoss内置Softmax) + return x + +# 3. 训练配置 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = Net().to(device) +optimizer = optim.Adam(model.parameters(), lr=0.001) +criterion = nn.CrossEntropyLoss() + +# 梯度裁剪阈值 +torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0) + +# 4. 训练过程可视化记录 +train_losses = [] +accuracies = [] + +def train(epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + + # 记录训练损失 + if batch_idx % 100 == 0: + train_losses.append(loss.item()) + +# 5. 测试函数(含准确率计算) +def test(): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += criterion(output, target).item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + + accuracy = 100. * correct / len(test_loader.dataset) + accuracies.append(accuracy) + return test_loss + +# 6. 执行训练(3个epoch演示) +for epoch in range(1, 4): + train(epoch) + loss = test() + print(f'Epoch {epoch}: Test Loss={loss:.4f}, Accuracy={accuracies[-1]:.2f}%') + +# 7. 可视化训练过程 +plt.figure(figsize=(12,5)) +plt.subplot(1,2,1) +plt.plot(train_losses, label='Training Loss') +plt.title("Loss Curve") +plt.subplot(1,2,2) +plt.plot(accuracies, label='Accuracy', color='orange') +plt.title("Accuracy Curve") +plt.show() + +# 8. 示例预测展示 +sample_data, sample_label = next(iter(test_loader)) +with torch.no_grad(): + prediction = model(sample_data.to(device)).argmax(dim=1) + +# 显示预测结果对比 +plt.figure(figsize=(10,6)) +for i in range(6): + plt.subplot(2,3,i+1) + plt.imshow(sample_data[i][0], cmap='gray') + plt.title(f"True: {sample_label[i]}\nPred: {prediction[i].item()}") + plt.axis('off') +plt.tight_layout() +plt.show() \ No newline at end of file diff --git a/RNN_good.py b/RNN_good.py new file mode 100644 index 0000000..cbd32bb --- /dev/null +++ b/RNN_good.py @@ -0,0 +1,59 @@ +import numpy as np +from keras.layers import Embedding, SimpleRNN, Dense +from keras.models import Sequential + +# 训练数据(包含逗号) +text = "用户:今天想吃火锅吗? 客服:我们海鲜火锅很受欢迎。用户:但朋友对海鲜过敏,推荐其他吧。客服:好的,我们有菌汤火锅。" +base_chars = [',', '。', '?', ':'] # 确保基础标点存在 +chars = sorted(list(set(text + ''.join(base_chars)))) +char_to_idx = {c:i for i,c in enumerate(chars)} +idx_to_char = {i:c for c,i in char_to_idx.items()} + +# 创建训练序列 +max_length = 20 +X, y = [], [] +for i in range(len(text)-max_length): + seq = text[i:i+max_length] + target = text[i+max_length] + X.append([char_to_idx[c] for c in seq]) + y.append(char_to_idx[target]) + +# 模型构建 +model = Sequential([ + Embedding(input_dim=len(chars), output_dim=32, input_length=max_length), + SimpleRNN(128), + Dense(len(chars), activation='softmax') +]) +model.compile(loss='sparse_categorical_crossentropy', optimizer='adam') + +# 训练 +X = np.array(X) +y = np.array(y) +model.fit(X, y, epochs=50, batch_size=32) + +# 增强后的生成函数 +def generate_response(prompt): + generated = prompt + for _ in range(30): + # 过滤并处理未知字符 + valid_chars = [] + for c in generated[-max_length:]: + if c in char_to_idx: + valid_chars.append(c) + else: + valid_chars.append(' ') # 未知字符替换为空格 + + # 填充序列 + seq = valid_chars[-max_length:] + seq = seq + [' ']*(max_length - len(seq)) + + # 转换为索引 + seq_indices = [char_to_idx[c] for c in seq] + + # 生成下一个字符 + pred = model.predict(np.array([seq_indices]), verbose=0) + next_char = idx_to_char[np.argmax(pred)] + generated += next_char + return generated + +print(generate_response("用户:朋友海鲜过敏,能不能推荐一些其他的?")) \ No newline at end of file