Go的基本语法
This commit is contained in:
parent
171c9e9815
commit
626790f753
5
.vscode/settings.json
vendored
Normal file
5
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
{
|
||||
"files.associations": {
|
||||
"stdio.h": "c"
|
||||
}
|
||||
}
|
80
Game1.py
Normal file
80
Game1.py
Normal file
@ -0,0 +1,80 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import torch.nn as nn
|
||||
|
||||
import torch.optim as optim
|
||||
|
||||
# 生成合成数据
|
||||
t = np.linspace(0, 24*np.pi, 1000)
|
||||
data = np.sin(t) + 0.5*np.sin(3*t) + 0.05*t # 混合波形+趋势项
|
||||
|
||||
# 数据预处理
|
||||
def create_dataset(data, look_back=30):
|
||||
X, y = [], []
|
||||
for i in range(len(data)-look_back):
|
||||
X.append(data[i:i+look_back])
|
||||
y.append(data[i+look_back])
|
||||
return torch.FloatTensor(X).unsqueeze(-1), torch.FloatTensor(y)
|
||||
|
||||
X, y = create_dataset(data)
|
||||
train_size = int(0.8 * len(X))
|
||||
train_X, test_X = X[:train_size], X[train_size:]
|
||||
train_y, test_y = y[:train_size], y[train_size:]
|
||||
|
||||
# 模型定义
|
||||
class TimeSeriesModel(nn.Module):
|
||||
def __init__(self, model_type):
|
||||
super().__init__()
|
||||
self.model_type = model_type
|
||||
if model_type == 'LSTM':
|
||||
self.rnn = nn.LSTM(1, 64, num_layers=2)
|
||||
else:
|
||||
self.rnn = nn.RNN(1, 64)
|
||||
self.fc = nn.Linear(64, 1)
|
||||
|
||||
def forward(self, x):
|
||||
out, _ = self.rnn(x)
|
||||
return self.fc(out[-1, :, :])
|
||||
|
||||
# 训练函数
|
||||
# 修改后的训练函数,返回预测结果和测试损失
|
||||
def train_and_predict(model_type):
|
||||
model = TimeSeriesModel(model_type)
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
# 训练循环
|
||||
for epoch in range(100):
|
||||
output = model(train_X.transpose(0, 1))
|
||||
loss = criterion(output.squeeze(), train_y)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if epoch % 20 == 0:
|
||||
print(f"{model_type} Epoch {epoch} Loss: {loss.item():.4f}")
|
||||
|
||||
# 预测阶段
|
||||
with torch.no_grad():
|
||||
test_pred = model(test_X.transpose(0, 1))
|
||||
test_loss = criterion(test_pred.squeeze(), test_y)
|
||||
print(f"{model_type} Test MSE: {test_loss.item():.4f}")
|
||||
|
||||
return test_pred.squeeze().numpy(), test_loss.item()
|
||||
|
||||
# 同时训练两种模型并收集结果
|
||||
lstm_pred, lstm_loss = train_and_predict('LSTM')
|
||||
rnn_pred, rnn_loss = train_and_predict('RNN')
|
||||
|
||||
# 统一可视化比较
|
||||
plt.figure(figsize=(12,6))
|
||||
plt.plot(test_y.numpy(), label='True Values', alpha=0.7)
|
||||
plt.plot(lstm_pred, label=f'LSTM (MSE: {lstm_loss:.4f})', linestyle='--')
|
||||
plt.plot(rnn_pred, label=f'RNN (MSE: {rnn_loss:.4f})', linestyle='--')
|
||||
plt.title('Time Series Prediction Comparison')
|
||||
plt.xlabel('Time Steps')
|
||||
plt.ylabel('Value')
|
||||
plt.legend()
|
||||
plt.show()
|
124
Game2.py
Normal file
124
Game2.py
Normal file
@ -0,0 +1,124 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
#这个实验的目的是比较RNN和GRU在相同任务上的性能,即学习序列中两个随机位置数值的和。
|
||||
# 数据生成
|
||||
def generate_add_data(seq_len=30):
|
||||
data = torch.zeros(seq_len, 2) # (seq_len, 2)
|
||||
idx1, idx2 = np.random.choice(seq_len, 2, replace=False)
|
||||
val1, val2 = np.random.rand()*0.5, np.random.rand()*0.5
|
||||
data[idx1, 0] = val1
|
||||
data[idx2, 0] = val2
|
||||
target = torch.tensor([val1 + val2]).view(1,1)
|
||||
return data.unsqueeze(0), target # (1, seq_len, 2)
|
||||
|
||||
# 模型定义
|
||||
class AdditionRNN(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.rnn = nn.RNN(2, 16, batch_first=True)
|
||||
self.fc = nn.Linear(16, 1)
|
||||
|
||||
def forward(self, x):
|
||||
out, _ = self.rnn(x)
|
||||
return self.fc(out[:, -1, :])
|
||||
|
||||
class AdditionGRU(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gru = nn.GRU(2, 16, batch_first=True)
|
||||
self.fc = nn.Linear(16, 1)
|
||||
|
||||
def forward(self, x):
|
||||
out, _ = self.gru(x)
|
||||
return self.fc(out[:, -1, :])
|
||||
|
||||
# 训练对比
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 修改后的训练函数,记录损失变化
|
||||
def train_addition():
|
||||
rnn = AdditionRNN()
|
||||
gru = AdditionGRU()
|
||||
criterion = nn.MSELoss()
|
||||
optim_rnn = torch.optim.Adam(rnn.parameters(), lr=0.01)
|
||||
optim_gru = torch.optim.Adam(gru.parameters(), lr=0.01)
|
||||
|
||||
# 记录训练过程
|
||||
losses = {'RNN': [], 'GRU': []}
|
||||
|
||||
for step in range(1000):
|
||||
inputs, target = generate_add_data(seq_len=30)
|
||||
|
||||
# RNN训练
|
||||
optim_rnn.zero_grad()
|
||||
rnn_pred = rnn(inputs)
|
||||
loss_rnn = criterion(rnn_pred, target)
|
||||
loss_rnn.backward()
|
||||
optim_rnn.step()
|
||||
|
||||
# GRU训练
|
||||
optim_gru.zero_grad()
|
||||
gru_pred = gru(inputs)
|
||||
loss_gru = criterion(gru_pred, target)
|
||||
loss_gru.backward()
|
||||
optim_gru.step()
|
||||
|
||||
# 记录损失
|
||||
losses['RNN'].append(loss_rnn.item())
|
||||
losses['GRU'].append(loss_gru.item())
|
||||
|
||||
if step % 200 == 0:
|
||||
print(f"Step {step:03d} | RNN Loss: {loss_rnn.item():.4f} | GRU Loss: {loss_gru.item():.4f}")
|
||||
|
||||
# 绘制损失曲线
|
||||
plt.figure(figsize=(10, 5))
|
||||
plt.plot(losses['RNN'], label='RNN', alpha=0.7)
|
||||
plt.plot(losses['GRU'], label='GRU', alpha=0.7)
|
||||
plt.xlabel('Training Steps')
|
||||
plt.ylabel('MSE Loss')
|
||||
plt.title('Training Comparison: RNN vs GRU')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.show()
|
||||
|
||||
return rnn, gru
|
||||
|
||||
# 执行训练
|
||||
rnn_model, gru_model = train_addition()
|
||||
def show_test_cases(model, model_name, num_cases=5):
|
||||
print(f"\n{model_name} 测试样例:")
|
||||
criterion = nn.MSELoss()
|
||||
total_error = 0
|
||||
|
||||
for case_idx in range(num_cases):
|
||||
# 生成测试数据
|
||||
inputs, target = generate_add_data()
|
||||
seq_len = inputs.shape[1]
|
||||
|
||||
# 解析输入数据
|
||||
non_zero_indices = torch.nonzero(inputs[0, :, 0])
|
||||
pos1, pos2 = non_zero_indices[0].item(), non_zero_indices[1].item()
|
||||
val1 = inputs[0, pos1, 0].item()
|
||||
val2 = inputs[0, pos2, 0].item()
|
||||
|
||||
# 模型预测
|
||||
with torch.no_grad():
|
||||
pred = model(inputs)
|
||||
loss = criterion(pred, target)
|
||||
|
||||
# 格式输出
|
||||
print(f"案例 {case_idx+1}:")
|
||||
print(f"输入序列长度: {seq_len}")
|
||||
print(f"数值位置: [{pos1:2d}]={val1:.4f}, [{pos2:2d}]={val2:.4f}")
|
||||
print(f"真实值: {target.item():.4f}")
|
||||
print(f"预测值: {pred.item():.4f}")
|
||||
print(f"绝对误差: {abs(pred.item()-target.item()):.4f}")
|
||||
print("-" * 40)
|
||||
total_error += abs(pred.item()-target.item())
|
||||
|
||||
print(f"平均绝对误差: {total_error/num_cases:.4f}\n")
|
||||
|
||||
show_test_cases(rnn_model, "RNN")
|
||||
show_test_cases(gru_model, "GRU")
|
@ -32,10 +32,9 @@ func main() {
|
||||
}
|
||||
```
|
||||
- 我们写完了程序,如何把这个程序运行起来呢,分为两步:
|
||||
- 第一步:执行指令`go mod init`,初始化与版本相关联的 Go 包的集合,确定了根目录、定义了项目的依赖和版本,确保项目可以重建。(也叫做Go的模块)!这一步,会在当前路径下创建`go.mod`文件
|
||||
- 第一步:执行指令`go mod init xxx` ,初始化与版本相关联的 Go 包的集合,确定了根目录、定义了项目的依赖和版本,确保项目可以重建。(也叫做Go的模块)!这一步,会在当前路径下创建`go.mod`文件
|
||||
- 第二步:执行指令`go mod tidy`,拉取我们需要的go的组建(又叫做库),这个操作可以类比为`pip install -r requirements.txt`,所需要的包IDE会自动写入`go.mod`文件中。
|
||||
- 
|
||||
- 随后,通过上面这两步,go程序就初始化好了,相关的包也下载好了,接下来,我们就能重新运行了
|
||||
- 运行项目:`go run` 开始执行整个项目
|
||||
- 运行单个程序: `go run xxx.go`
|
||||
- 运行项目: `go run xxx.go`
|
||||
- 构建go项目(将整个项目打包成为可执行文件) `go build`
|
@ -0,0 +1,265 @@
|
||||
### **1. 基础结构**
|
||||
```go
|
||||
package main // 包声明(必须)
|
||||
import "fmt" // 导入包
|
||||
|
||||
func main() { // 主函数(程序入口)
|
||||
fmt.Println("Hello, World!")
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### **2. 变量与常量**
|
||||
- **变量声明**
|
||||
```go
|
||||
var a int = 10 // 显式类型
|
||||
var b = 20 // 类型推断
|
||||
c := 30 // 短声明(函数内使用)
|
||||
var d, e int = 1, 2 // 多变量声明
|
||||
```
|
||||
|
||||
- **常量**
|
||||
```go
|
||||
const Pi = 3.14
|
||||
const (
|
||||
A = 1
|
||||
B = 2
|
||||
)
|
||||
```
|
||||
- 着重强调:**开头是大写的是Public!**
|
||||
---
|
||||
|
||||
### **3. 基本数据类型**
|
||||
- **基本类型**
|
||||
```go
|
||||
int, int8, int16, int32, int64
|
||||
uint, uint8, uint16, uint32, uint64
|
||||
float32, float64
|
||||
bool
|
||||
string
|
||||
byte (等同于 uint8)
|
||||
rune (等同于 int32, 表示 Unicode 字符)
|
||||
```
|
||||
|
||||
- **复合类型**
|
||||
```go
|
||||
var arr [3]int // 数组(固定长度)
|
||||
slice := []int{1, 2, 3} // 切片(动态数组)
|
||||
m := map[string]int{"key": 1} // 映射(字典)
|
||||
type Person struct { Name string } // 结构体
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### **4. 控制结构**
|
||||
- **条件语句**
|
||||
```go
|
||||
if x > 0 {
|
||||
// ...
|
||||
} else if x == 0 {
|
||||
// ...
|
||||
} else {
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
- **循环**
|
||||
```go
|
||||
for i := 0; i < 10; i++ { ... } // 传统 for 循环
|
||||
for i < 10 { ... } // 类似 while
|
||||
for index, value := range slice { ... } // 遍历切片/数组/map
|
||||
```
|
||||
|
||||
- **Switch**
|
||||
```go
|
||||
switch x {
|
||||
case 1:
|
||||
// ...
|
||||
case 2, 3: // 多值匹配
|
||||
// ...
|
||||
default:
|
||||
// ...
|
||||
}
|
||||
```
|
||||
- 特点:不需要break!
|
||||
---
|
||||
|
||||
### **5. 函数**
|
||||
- **基本函数**
|
||||
```go
|
||||
func add(a int, b int) int {
|
||||
return a + b
|
||||
}
|
||||
```
|
||||
|
||||
- **多返回值**
|
||||
```go
|
||||
func swap(a, b int) (int, int) {
|
||||
return b, a
|
||||
}
|
||||
```
|
||||
|
||||
- **命名返回值**
|
||||
```go
|
||||
func split(sum int) (x, y int) {
|
||||
x = sum * 4 / 9
|
||||
y = sum - x
|
||||
return // 隐式返回 x, y
|
||||
}
|
||||
```
|
||||
|
||||
- **匿名函数与闭包**(即函数内套函数)
|
||||
```go
|
||||
func() {
|
||||
fmt.Println("Anonymous function")
|
||||
}()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### **6. 指针与结构体**
|
||||
- **指针**
|
||||
```go
|
||||
var p *int
|
||||
x := 10
|
||||
p = &x
|
||||
*p = 20 // 修改 x 的值
|
||||
```
|
||||
- 特点:
|
||||
- Go 的指针**不支持算术运算**,避免了内存越界和非法访问的风险,同时通过垃圾回收机制自动管理内存,**减少了内存泄漏的可能性**。
|
||||
- Go 的指针类型严格区分,空指针用 `nil` 表示,解引用空指针会触发 panic,不支持**指针算术运算和强制类型转换**。
|
||||
- **结构体与方法**
|
||||
```go
|
||||
type Circle struct {
|
||||
Radius float64
|
||||
}
|
||||
|
||||
// 方法(值接收者)
|
||||
func (c Circle) Area() float64 {
|
||||
return math.Pi * c.Radius * c.Radius
|
||||
}
|
||||
|
||||
// 方法(指针接收者)
|
||||
func (c *Circle) Scale(factor float64) {
|
||||
c.Radius *= factor
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### **7. 接口与错误处理**
|
||||
- **接口**
|
||||
```go
|
||||
type Shape interface {
|
||||
Area() float64
|
||||
}
|
||||
|
||||
// 隐式实现接口
|
||||
func (c Circle) Area() float64 { ... }
|
||||
```
|
||||
|
||||
- **错误处理**
|
||||
```go
|
||||
func readFile() ([]byte, error) {
|
||||
data, err := os.ReadFile("file.txt")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// 调用
|
||||
data, err := readFile()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
```
|
||||
|
||||
- **Panic & Recover**
|
||||
```go
|
||||
func safeCall() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
fmt.Println("Recovered:", r)
|
||||
}
|
||||
}()
|
||||
panic("Something went wrong!")
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### **8. 并发编程**
|
||||
- **Goroutine**
|
||||
```go
|
||||
go func() {
|
||||
fmt.Println("Running in goroutine")
|
||||
}()
|
||||
```
|
||||
|
||||
- **Channel**(数据通道)
|
||||
```go
|
||||
ch := make(chan int)
|
||||
go func() { ch <- 1 }() // 发送数据
|
||||
value := <-ch // 接收数据
|
||||
```
|
||||
|
||||
- **Select**(主要用于事件驱动)
|
||||
```go
|
||||
select {
|
||||
case msg1 := <-ch1:
|
||||
fmt.Println(msg1)
|
||||
case msg2 := <-ch2:
|
||||
fmt.Println(msg2)
|
||||
case <-time.After(1 * time.Second):
|
||||
fmt.Println("Timeout")
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### **9. 包与模块**
|
||||
- **创建模块**
|
||||
```bash
|
||||
go mod init xx
|
||||
```
|
||||
|
||||
- **导入包**
|
||||
```go
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"github.com/user/package"
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### **10. 其他特性**
|
||||
- **Defer**
|
||||
- `defer` 是 Go 语言中的一个关键字,用于延迟执行一个函数调用。**被 `defer` 修饰的函数调用会推迟到当前函数返回之前执行**,无论当前函数是正常返回还是由于错误(如 `panic`)提前返回。`defer` 的主要用途是确保某些操作(如资源释放、清理工作等)一定会被执行,避免遗漏。
|
||||
```go
|
||||
func readFile() {
|
||||
file, _ := os.Open("file.txt")
|
||||
defer file.Close() // 函数返回前执行
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
- **JSON 处理**
|
||||
```go
|
||||
type User struct {
|
||||
Name string `json:"name"`
|
||||
Age int `json:"age"`
|
||||
}
|
||||
data, _ := json.Marshal(user) //序列化
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### **常用内置函数**
|
||||
- `len()`:获取长度
|
||||
- `cap()`:切片容量
|
||||
- `make()`:创建切片/map/channel
|
||||
- `append()`:切片追加元素
|
BIN
Go入门速成/Day1/ExampleCode/BasicGrammar/BasicGrammar
Executable file
BIN
Go入门速成/Day1/ExampleCode/BasicGrammar/BasicGrammar
Executable file
Binary file not shown.
1
Go入门速成/Day1/ExampleCode/BasicGrammar/PointerPanic.go
Normal file
1
Go入门速成/Day1/ExampleCode/BasicGrammar/PointerPanic.go
Normal file
@ -0,0 +1 @@
|
||||
package main
|
3
Go入门速成/Day1/ExampleCode/BasicGrammar/go.mod
Normal file
3
Go入门速成/Day1/ExampleCode/BasicGrammar/go.mod
Normal file
@ -0,0 +1,3 @@
|
||||
module BasicGrammar
|
||||
|
||||
go 1.24.1
|
147
Go入门速成/Day1/ExampleCode/BasicGrammar/main.go
Normal file
147
Go入门速成/Day1/ExampleCode/BasicGrammar/main.go
Normal file
@ -0,0 +1,147 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 常量
|
||||
const Pi = 3.14159
|
||||
|
||||
// 结构体
|
||||
type Circle struct {
|
||||
Radius float64
|
||||
}
|
||||
|
||||
// 方法(值接收者)
|
||||
func (c Circle) Area() float64 {
|
||||
return Pi * c.Radius * c.Radius
|
||||
}
|
||||
|
||||
// 方法(指针接收者)
|
||||
func (c *Circle) Scale(factor float64) {
|
||||
c.Radius *= factor
|
||||
}
|
||||
|
||||
// 接口
|
||||
type Shape interface {
|
||||
Area() float64
|
||||
}
|
||||
|
||||
// 函数(多返回值)
|
||||
func divide(a, b float64) (float64, error) {
|
||||
if b == 0 {
|
||||
return 0, errors.New("division by zero")
|
||||
}
|
||||
return a / b, nil
|
||||
}
|
||||
|
||||
// 主函数
|
||||
func main() {
|
||||
// 变量声明
|
||||
var a int = 10
|
||||
b := 20
|
||||
c, d := 30, 40
|
||||
fmt.Println("Variables:", a, b, c, d)
|
||||
|
||||
// 数组与切片
|
||||
arr := [3]int{1, 2, 3}
|
||||
slice := []int{4, 5, 6}
|
||||
slice = append(slice, 7)
|
||||
fmt.Println("Array:", arr, "Slice:", slice)
|
||||
|
||||
// 映射
|
||||
m := map[string]int{"one": 1, "two": 2}
|
||||
fmt.Println("Map:", m)
|
||||
|
||||
// 控制结构
|
||||
if a > 5 {
|
||||
fmt.Println("a is greater than 5")
|
||||
} else {
|
||||
fmt.Println("a is not greater than 5")
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
fmt.Println("Loop:", i)
|
||||
}
|
||||
|
||||
switch a {
|
||||
case 10:
|
||||
fmt.Println("a is 10")
|
||||
default:
|
||||
fmt.Println("a is not 10")
|
||||
}
|
||||
|
||||
// 函数调用
|
||||
result, err := divide(10, 2)
|
||||
if err != nil {
|
||||
fmt.Println("Error:", err)
|
||||
} else {
|
||||
fmt.Println("Division result:", result)
|
||||
}
|
||||
|
||||
// 指针
|
||||
x := 10
|
||||
p := &x
|
||||
*p = 20
|
||||
fmt.Println("Pointer:", x)
|
||||
|
||||
// 结构体与方法
|
||||
circle := Circle{Radius: 5}
|
||||
fmt.Println("Circle area:", circle.Area())
|
||||
circle.Scale(2)
|
||||
fmt.Println("Scaled circle area:", circle.Area())
|
||||
|
||||
// 接口
|
||||
var shape Shape = Circle{Radius: 3}
|
||||
fmt.Println("Shape area:", shape.Area())
|
||||
|
||||
// 错误处理
|
||||
_, err = divide(10, 0)
|
||||
if err != nil {
|
||||
fmt.Println("Error:", err)
|
||||
}
|
||||
|
||||
// 并发
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
fmt.Println("Running in goroutine")
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
// Channel
|
||||
ch := make(chan int)
|
||||
go func() {
|
||||
ch <- 42
|
||||
}()
|
||||
value := <-ch
|
||||
fmt.Println("Channel value:", value)
|
||||
|
||||
// Select
|
||||
ch1 := make(chan string)
|
||||
ch2 := make(chan string)
|
||||
go func() {
|
||||
time.Sleep(1 * time.Second)
|
||||
ch1 <- "from ch1"
|
||||
}()
|
||||
go func() {
|
||||
time.Sleep(2 * time.Second)
|
||||
ch2 <- "from ch2"
|
||||
}()
|
||||
select {
|
||||
case msg := <-ch1:
|
||||
fmt.Println("Select:", msg)
|
||||
case msg := <-ch2:
|
||||
fmt.Println("Select:", msg)
|
||||
case <-time.After(3 * time.Second):
|
||||
fmt.Println("Timeout")
|
||||
}
|
||||
|
||||
// Defer
|
||||
defer fmt.Println("Defer: This runs last")
|
||||
fmt.Println("Main function end")
|
||||
}
|
21
Go入门速成/Day1/ExampleCode/BasicGrammar/switch.c
Normal file
21
Go入门速成/Day1/ExampleCode/BasicGrammar/switch.c
Normal file
@ -0,0 +1,21 @@
|
||||
//go:build ignore
|
||||
// +build ignore
|
||||
|
||||
|
||||
#include "stdio.h"
|
||||
|
||||
int main() {
|
||||
int a = 2;
|
||||
switch (a) {
|
||||
case 1:
|
||||
printf("a = 1\n");
|
||||
break;
|
||||
case 2:
|
||||
printf("a = 2\n");
|
||||
break;
|
||||
case 3:
|
||||
printf("a = 3\n");
|
||||
break;
|
||||
default:
|
||||
printf("FUCK"); // default 语句是可选的
|
||||
}
|
Binary file not shown.
31
Go入门速成/Day2/Class3 Gorm的使用.md
Normal file
31
Go入门速成/Day2/Class3 Gorm的使用.md
Normal file
@ -0,0 +1,31 @@
|
||||
# Gorm的使用
|
||||
## Gorm是什么
|
||||
- GORM 是 Go 语言的 ORM 库,提供模型定义、关联管理、事务支持、查询构建、数据迁移、钩子回调等功能,支持主流数据库(如 MySQL/PostgreSQL/SQLite),简化数据库操作。
|
||||
## 如何使用Gorm
|
||||
### 导入Gorm库
|
||||
```go
|
||||
import (
|
||||
"gorm.io/driver/mysql"
|
||||
)
|
||||
```
|
||||
### 基本操作
|
||||
- 1、连接数据库
|
||||
```go
|
||||
const (
|
||||
USERNAME = "root"
|
||||
PASSWD = "oppofindx2"
|
||||
DATABASENAME = "Class"
|
||||
)
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(127.0.0.1:3306)/%s?charset=utf8mb4&parseTime=True&loc=Local", USERNAME, PASSWD, DATABASENAME)
|
||||
db := mysql.Open(dsn)
|
||||
`````
|
||||
- 2、数据库中表的定义
|
||||
- 在Gorm中,定义一张表使用的是结构体
|
||||
```go
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
```
|
||||
- 自动迁移表结构(方便我们修改表,给表添加参数)
|
12
Go入门速成/Day2/ExampleCode/Gorm/go.mod
Normal file
12
Go入门速成/Day2/ExampleCode/Gorm/go.mod
Normal file
@ -0,0 +1,12 @@
|
||||
module Gorm
|
||||
|
||||
go 1.24.1
|
||||
|
||||
require gorm.io/driver/mysql v1.5.7
|
||||
|
||||
require (
|
||||
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
gorm.io/gorm v1.25.7 // indirect
|
||||
)
|
10
Go入门速成/Day2/ExampleCode/Gorm/go.sum
Normal file
10
Go入门速成/Day2/ExampleCode/Gorm/go.sum
Normal file
@ -0,0 +1,10 @@
|
||||
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
|
||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo=
|
||||
gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
|
||||
gorm.io/gorm v1.25.7 h1:VsD6acwRjz2zFxGO50gPO6AkNs7KKnvfzUjHQhZDz/A=
|
||||
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
26
Go入门速成/Day2/ExampleCode/Gorm/main.go
Normal file
26
Go入门速成/Day2/ExampleCode/Gorm/main.go
Normal file
@ -0,0 +1,26 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
USERNAME = "root"
|
||||
PASSWD = "oppofindx2"
|
||||
DATABASENAME = "Class"
|
||||
)
|
||||
|
||||
func main() {
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(127.0.0.1:3306)/%s?charset=utf8mb4&parseTime=True&loc=Local", USERNAME, PASSWD, DATABASENAME)
|
||||
db := mysql.Open(dsn)
|
||||
print(db)
|
||||
}
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
Age int
|
||||
}
|
111
NN_normal.py
Normal file
111
NN_normal.py
Normal file
@ -0,0 +1,111 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 1. 数据准备(以MNIST手写数字识别为例)
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (0.5,)) # 像素值归一化到[-1,1]
|
||||
])
|
||||
|
||||
train_set = datasets.MNIST('data', download=True, train=True, transform=transform)
|
||||
test_set = datasets.MNIST('data', download=True, train=False, transform=transform)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
|
||||
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1000)
|
||||
|
||||
# 2. 神经网络模型(演示梯度控制技巧)
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.fc1 = nn.Linear(784, 128)
|
||||
self.fc2 = nn.Linear(128, 64)
|
||||
self.fc3 = nn.Linear(64, 10)
|
||||
|
||||
# He初始化适配ReLU
|
||||
nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='relu')
|
||||
nn.init.kaiming_normal_(self.fc2.weight, nonlinearity='relu')
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(-1, 784) # 展平图像
|
||||
x = torch.relu(self.fc1(x))
|
||||
x = torch.relu(self.fc2(x))
|
||||
x = self.fc3(x) # 输出层无需激活(CrossEntropyLoss内置Softmax)
|
||||
return x
|
||||
|
||||
# 3. 训练配置
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = Net().to(device)
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# 梯度裁剪阈值
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
|
||||
|
||||
# 4. 训练过程可视化记录
|
||||
train_losses = []
|
||||
accuracies = []
|
||||
|
||||
def train(epoch):
|
||||
model.train()
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
data, target = data.to(device), target.to(device)
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 记录训练损失
|
||||
if batch_idx % 100 == 0:
|
||||
train_losses.append(loss.item())
|
||||
|
||||
# 5. 测试函数(含准确率计算)
|
||||
def test():
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for data, target in test_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
output = model(data)
|
||||
test_loss += criterion(output, target).item()
|
||||
pred = output.argmax(dim=1, keepdim=True)
|
||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||
|
||||
accuracy = 100. * correct / len(test_loader.dataset)
|
||||
accuracies.append(accuracy)
|
||||
return test_loss
|
||||
|
||||
# 6. 执行训练(3个epoch演示)
|
||||
for epoch in range(1, 4):
|
||||
train(epoch)
|
||||
loss = test()
|
||||
print(f'Epoch {epoch}: Test Loss={loss:.4f}, Accuracy={accuracies[-1]:.2f}%')
|
||||
|
||||
# 7. 可视化训练过程
|
||||
plt.figure(figsize=(12,5))
|
||||
plt.subplot(1,2,1)
|
||||
plt.plot(train_losses, label='Training Loss')
|
||||
plt.title("Loss Curve")
|
||||
plt.subplot(1,2,2)
|
||||
plt.plot(accuracies, label='Accuracy', color='orange')
|
||||
plt.title("Accuracy Curve")
|
||||
plt.show()
|
||||
|
||||
# 8. 示例预测展示
|
||||
sample_data, sample_label = next(iter(test_loader))
|
||||
with torch.no_grad():
|
||||
prediction = model(sample_data.to(device)).argmax(dim=1)
|
||||
|
||||
# 显示预测结果对比
|
||||
plt.figure(figsize=(10,6))
|
||||
for i in range(6):
|
||||
plt.subplot(2,3,i+1)
|
||||
plt.imshow(sample_data[i][0], cmap='gray')
|
||||
plt.title(f"True: {sample_label[i]}\nPred: {prediction[i].item()}")
|
||||
plt.axis('off')
|
||||
plt.tight_layout()
|
||||
plt.show()
|
59
RNN_good.py
Normal file
59
RNN_good.py
Normal file
@ -0,0 +1,59 @@
|
||||
import numpy as np
|
||||
from keras.layers import Embedding, SimpleRNN, Dense
|
||||
from keras.models import Sequential
|
||||
|
||||
# 训练数据(包含逗号)
|
||||
text = "用户:今天想吃火锅吗? 客服:我们海鲜火锅很受欢迎。用户:但朋友对海鲜过敏,推荐其他吧。客服:好的,我们有菌汤火锅。"
|
||||
base_chars = [',', '。', '?', ':'] # 确保基础标点存在
|
||||
chars = sorted(list(set(text + ''.join(base_chars))))
|
||||
char_to_idx = {c:i for i,c in enumerate(chars)}
|
||||
idx_to_char = {i:c for c,i in char_to_idx.items()}
|
||||
|
||||
# 创建训练序列
|
||||
max_length = 20
|
||||
X, y = [], []
|
||||
for i in range(len(text)-max_length):
|
||||
seq = text[i:i+max_length]
|
||||
target = text[i+max_length]
|
||||
X.append([char_to_idx[c] for c in seq])
|
||||
y.append(char_to_idx[target])
|
||||
|
||||
# 模型构建
|
||||
model = Sequential([
|
||||
Embedding(input_dim=len(chars), output_dim=32, input_length=max_length),
|
||||
SimpleRNN(128),
|
||||
Dense(len(chars), activation='softmax')
|
||||
])
|
||||
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')
|
||||
|
||||
# 训练
|
||||
X = np.array(X)
|
||||
y = np.array(y)
|
||||
model.fit(X, y, epochs=50, batch_size=32)
|
||||
|
||||
# 增强后的生成函数
|
||||
def generate_response(prompt):
|
||||
generated = prompt
|
||||
for _ in range(30):
|
||||
# 过滤并处理未知字符
|
||||
valid_chars = []
|
||||
for c in generated[-max_length:]:
|
||||
if c in char_to_idx:
|
||||
valid_chars.append(c)
|
||||
else:
|
||||
valid_chars.append(' ') # 未知字符替换为空格
|
||||
|
||||
# 填充序列
|
||||
seq = valid_chars[-max_length:]
|
||||
seq = seq + [' ']*(max_length - len(seq))
|
||||
|
||||
# 转换为索引
|
||||
seq_indices = [char_to_idx[c] for c in seq]
|
||||
|
||||
# 生成下一个字符
|
||||
pred = model.predict(np.array([seq_indices]), verbose=0)
|
||||
next_char = idx_to_char[np.argmax(pred)]
|
||||
generated += next_char
|
||||
return generated
|
||||
|
||||
print(generate_response("用户:朋友海鲜过敏,能不能推荐一些其他的?"))
|
Loading…
x
Reference in New Issue
Block a user