49bbca0c by slene

orm Improve syncdb

1 parent 6686d923
...@@ -9,7 +9,7 @@ import ( ...@@ -9,7 +9,7 @@ import (
9 9
10 type commander interface { 10 type commander interface {
11 Parse([]string) 11 Parse([]string)
12 Run() 12 Run() error
13 } 13 }
14 14
15 var ( 15 var (
...@@ -62,6 +62,8 @@ type commandSyncDb struct { ...@@ -62,6 +62,8 @@ type commandSyncDb struct {
62 al *alias 62 al *alias
63 force bool 63 force bool
64 verbose bool 64 verbose bool
65 noInfo bool
66 rtOnError bool
65 } 67 }
66 68
67 func (d *commandSyncDb) Parse(args []string) { 69 func (d *commandSyncDb) Parse(args []string) {
...@@ -76,7 +78,7 @@ func (d *commandSyncDb) Parse(args []string) { ...@@ -76,7 +78,7 @@ func (d *commandSyncDb) Parse(args []string) {
76 d.al = getDbAlias(name) 78 d.al = getDbAlias(name)
77 } 79 }
78 80
79 func (d *commandSyncDb) Run() { 81 func (d *commandSyncDb) Run() error {
80 var drops []string 82 var drops []string
81 if d.force { 83 if d.force {
82 drops = getDbDropSql(d.al) 84 drops = getDbDropSql(d.al)
...@@ -87,25 +89,103 @@ func (d *commandSyncDb) Run() { ...@@ -87,25 +89,103 @@ func (d *commandSyncDb) Run() {
87 if d.force { 89 if d.force {
88 for i, mi := range modelCache.allOrdered() { 90 for i, mi := range modelCache.allOrdered() {
89 query := drops[i] 91 query := drops[i]
90 _, err := db.Exec(query) 92 if !d.noInfo {
91 result := "" 93 fmt.Printf("drop table `%s`\n", mi.table)
92 if err != nil {
93 result = err.Error()
94 } 94 }
95 fmt.Printf("drop table `%s` %s\n", mi.table, result) 95 _, err := db.Exec(query)
96 if d.verbose { 96 if d.verbose {
97 fmt.Printf(" %s\n\n", query) 97 fmt.Printf(" %s\n\n", query)
98 } 98 }
99 if err != nil {
100 if d.rtOnError {
101 return err
102 }
103 fmt.Printf(" %s\n", err.Error())
104 }
99 } 105 }
100 } 106 }
101 107
102 sqls, indexes := getDbCreateSql(d.al) 108 sqls, indexes := getDbCreateSql(d.al)
103 109
110 tables, err := d.al.DbBaser.GetTables(db)
111 if err != nil {
112 if d.rtOnError {
113 return err
114 }
115 fmt.Printf(" %s\n", err.Error())
116 }
117
104 for i, mi := range modelCache.allOrdered() { 118 for i, mi := range modelCache.allOrdered() {
119 if tables[mi.table] {
120 if !d.noInfo {
121 fmt.Printf("table `%s` already exists, skip\n", mi.table)
122 }
123
124 var fields []*fieldInfo
125 columns, err := d.al.DbBaser.GetColumns(db, mi.table)
126 if err != nil {
127 if d.rtOnError {
128 return err
129 }
130 fmt.Printf(" %s\n", err.Error())
131 }
132
133 for _, fi := range mi.fields.fieldsDB {
134 if _, ok := columns[fi.column]; ok == false {
135 fields = append(fields, fi)
136 }
137 }
138
139 for _, fi := range fields {
140 query := getColumnAddQuery(d.al, fi)
141
142 if !d.noInfo {
143 fmt.Printf("add column `%s` for table `%s`\n", fi.fullName, mi.table)
144 }
145
146 _, err := db.Exec(query)
147 if d.verbose {
148 fmt.Printf(" %s\n", query)
149 }
150 if err != nil {
151 if d.rtOnError {
152 return err
153 }
154 fmt.Printf(" %s\n", err.Error())
155 }
156 }
157
158 for _, idx := range indexes[mi.table] {
159 if d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) == false {
160 if !d.noInfo {
161 fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
162 }
163
164 query := idx.Sql
165 _, err := db.Exec(query)
166 if d.verbose {
167 fmt.Printf(" %s\n", query)
168 }
169 if err != nil {
170 if d.rtOnError {
171 return err
172 }
173 fmt.Printf(" %s\n", err.Error())
174 }
175 }
176 }
177
178 continue
179 }
180
181 if !d.noInfo {
105 fmt.Printf("create table `%s` \n", mi.table) 182 fmt.Printf("create table `%s` \n", mi.table)
183 }
106 184
107 queries := []string{sqls[i]} 185 queries := []string{sqls[i]}
108 queries = append(queries, indexes[mi.table]...) 186 for _, idx := range indexes[mi.table] {
187 queries = append(queries, idx.Sql)
188 }
109 189
110 for _, query := range queries { 190 for _, query := range queries {
111 _, err := db.Exec(query) 191 _, err := db.Exec(query)
...@@ -114,6 +194,9 @@ func (d *commandSyncDb) Run() { ...@@ -114,6 +194,9 @@ func (d *commandSyncDb) Run() {
114 fmt.Println(query) 194 fmt.Println(query)
115 } 195 }
116 if err != nil { 196 if err != nil {
197 if d.rtOnError {
198 return err
199 }
117 fmt.Printf(" %s\n", err.Error()) 200 fmt.Printf(" %s\n", err.Error())
118 } 201 }
119 } 202 }
...@@ -121,6 +204,8 @@ func (d *commandSyncDb) Run() { ...@@ -121,6 +204,8 @@ func (d *commandSyncDb) Run() {
121 fmt.Println("") 204 fmt.Println("")
122 } 205 }
123 } 206 }
207
208 return nil
124 } 209 }
125 210
126 type commandSqlAll struct { 211 type commandSqlAll struct {
...@@ -137,19 +222,36 @@ func (d *commandSqlAll) Parse(args []string) { ...@@ -137,19 +222,36 @@ func (d *commandSqlAll) Parse(args []string) {
137 d.al = getDbAlias(name) 222 d.al = getDbAlias(name)
138 } 223 }
139 224
140 func (d *commandSqlAll) Run() { 225 func (d *commandSqlAll) Run() error {
141 sqls, indexes := getDbCreateSql(d.al) 226 sqls, indexes := getDbCreateSql(d.al)
142 var all []string 227 var all []string
143 for i, mi := range modelCache.allOrdered() { 228 for i, mi := range modelCache.allOrdered() {
144 queries := []string{sqls[i]} 229 queries := []string{sqls[i]}
145 queries = append(queries, indexes[mi.table]...) 230 for _, idx := range indexes[mi.table] {
231 queries = append(queries, idx.Sql)
232 }
146 sql := strings.Join(queries, "\n") 233 sql := strings.Join(queries, "\n")
147 all = append(all, sql) 234 all = append(all, sql)
148 } 235 }
149 fmt.Println(strings.Join(all, "\n\n")) 236 fmt.Println(strings.Join(all, "\n\n"))
237
238 return nil
150 } 239 }
151 240
152 func init() { 241 func init() {
153 commands["syncdb"] = new(commandSyncDb) 242 commands["syncdb"] = new(commandSyncDb)
154 commands["sqlall"] = new(commandSqlAll) 243 commands["sqlall"] = new(commandSqlAll)
155 } 244 }
245
246 func RunSyncdb(name string, force bool, verbose bool) error {
247 BootStrap()
248
249 al := getDbAlias(name)
250 cmd := new(commandSyncDb)
251 cmd.al = al
252 cmd.force = force
253 cmd.noInfo = !verbose
254 cmd.verbose = verbose
255 cmd.rtOnError = true
256 return cmd.Run()
257 }
......
...@@ -6,6 +6,12 @@ import ( ...@@ -6,6 +6,12 @@ import (
6 "strings" 6 "strings"
7 ) 7 )
8 8
9 type dbIndex struct {
10 Table string
11 Name string
12 Sql string
13 }
14
9 func getDbAlias(name string) *alias { 15 func getDbAlias(name string) *alias {
10 if al, ok := dataBaseCache.get(name); ok { 16 if al, ok := dataBaseCache.get(name); ok {
11 return al 17 return al
...@@ -31,36 +37,11 @@ func getDbDropSql(al *alias) (sqls []string) { ...@@ -31,36 +37,11 @@ func getDbDropSql(al *alias) (sqls []string) {
31 return sqls 37 return sqls
32 } 38 }
33 39
34 func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string) { 40 func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
35 if len(modelCache.cache) == 0 {
36 fmt.Println("no Model found, need register your model")
37 os.Exit(2)
38 }
39
40 Q := al.DbBaser.TableQuote()
41 T := al.DbBaser.DbTypes() 41 T := al.DbBaser.DbTypes()
42 sep := fmt.Sprintf("%s, %s", Q, Q)
43
44 tableIndexes = make(map[string][]string)
45
46 for _, mi := range modelCache.allOrdered() {
47 sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
48 sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName)
49 sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
50
51 sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q)
52
53 columns := make([]string, 0, len(mi.fields.fieldsDB))
54
55 sqlIndexes := [][]string{}
56
57 for _, fi := range mi.fields.fieldsDB {
58
59 fieldType := fi.fieldType 42 fieldType := fi.fieldType
60 column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q)
61 col := ""
62 43
63 checkColumn: 44 checkColumn:
64 switch fieldType { 45 switch fieldType {
65 case TypeBooleanField: 46 case TypeBooleanField:
66 col = T["bool"] 47 col = T["bool"]
...@@ -106,6 +87,48 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string) ...@@ -106,6 +87,48 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string)
106 goto checkColumn 87 goto checkColumn
107 } 88 }
108 89
90 return
91 }
92
93 func getColumnAddQuery(al *alias, fi *fieldInfo) string {
94 Q := al.DbBaser.TableQuote()
95 typ := getColumnTyp(al, fi)
96
97 if fi.null == false {
98 typ += " " + "NOT NULL"
99 }
100
101 return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s", Q, fi.mi.table, Q, Q, fi.column, Q, typ)
102 }
103
104 func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
105 if len(modelCache.cache) == 0 {
106 fmt.Println("no Model found, need register your model")
107 os.Exit(2)
108 }
109
110 Q := al.DbBaser.TableQuote()
111 T := al.DbBaser.DbTypes()
112 sep := fmt.Sprintf("%s, %s", Q, Q)
113
114 tableIndexes = make(map[string][]dbIndex)
115
116 for _, mi := range modelCache.allOrdered() {
117 sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
118 sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName)
119 sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
120
121 sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q)
122
123 columns := make([]string, 0, len(mi.fields.fieldsDB))
124
125 sqlIndexes := [][]string{}
126
127 for _, fi := range mi.fields.fieldsDB {
128
129 column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q)
130 col := getColumnTyp(al, fi)
131
109 if fi.auto { 132 if fi.auto {
110 switch al.Driver { 133 switch al.Driver {
111 case DR_Sqlite, DR_Postgres: 134 case DR_Sqlite, DR_Postgres:
...@@ -181,7 +204,13 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string) ...@@ -181,7 +204,13 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]string)
181 name := mi.table + "_" + strings.Join(names, "_") 204 name := mi.table + "_" + strings.Join(names, "_")
182 cols := strings.Join(names, sep) 205 cols := strings.Join(names, sep)
183 sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) 206 sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q)
184 tableIndexes[mi.table] = append(tableIndexes[mi.table], sql) 207
208 index := dbIndex{}
209 index.Table = mi.table
210 index.Name = name
211 index.Sql = sql
212
213 tableIndexes[mi.table] = append(tableIndexes[mi.table], index)
185 } 214 }
186 215
187 } 216 }
......
...@@ -1116,3 +1116,61 @@ func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) { ...@@ -1116,3 +1116,61 @@ func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) {
1116 func (d *dbBase) DbTypes() map[string]string { 1116 func (d *dbBase) DbTypes() map[string]string {
1117 return nil 1117 return nil
1118 } 1118 }
1119
1120 func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
1121 tables := make(map[string]bool)
1122 query := d.ins.ShowTablesQuery()
1123 rows, err := db.Query(query)
1124 if err != nil {
1125 return tables, err
1126 }
1127
1128 for rows.Next() {
1129 var table string
1130 err := rows.Scan(&table)
1131 if err != nil {
1132 return tables, err
1133 }
1134 if table != "" {
1135 tables[table] = true
1136 }
1137 }
1138
1139 return tables, nil
1140 }
1141
1142 func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
1143 columns := make(map[string][3]string)
1144 query := d.ins.ShowColumnsQuery(table)
1145 rows, err := db.Query(query)
1146 if err != nil {
1147 return columns, err
1148 }
1149
1150 for rows.Next() {
1151 var (
1152 name string
1153 typ string
1154 null string
1155 )
1156 err := rows.Scan(&name, &typ, &null)
1157 if err != nil {
1158 return columns, err
1159 }
1160 columns[name] = [3]string{name, typ, null}
1161 }
1162
1163 return columns, nil
1164 }
1165
1166 func (d *dbBase) ShowTablesQuery() string {
1167 panic(ErrNotImplement)
1168 }
1169
1170 func (d *dbBase) ShowColumnsQuery(table string) string {
1171 panic(ErrNotImplement)
1172 }
1173
1174 func (d *dbBase) IndexExists(dbQuerier, string, string) bool {
1175 panic(ErrNotImplement)
1176 }
......
1 package orm 1 package orm
2 2
3 import (
4 "fmt"
5 )
6
3 var mysqlOperators = map[string]string{ 7 var mysqlOperators = map[string]string{
4 "exact": "= ?", 8 "exact": "= ?",
5 "iexact": "LIKE ?", 9 "iexact": "LIKE ?",
...@@ -51,6 +55,23 @@ func (d *dbBaseMysql) DbTypes() map[string]string { ...@@ -51,6 +55,23 @@ func (d *dbBaseMysql) DbTypes() map[string]string {
51 return mysqlTypes 55 return mysqlTypes
52 } 56 }
53 57
58 func (d *dbBaseMysql) ShowTablesQuery() string {
59 return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
60 }
61
62 func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
63 return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
64 "WHERE table_schema = DATABASE() AND table_name = '%s'", table)
65 }
66
67 func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool {
68 row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
69 "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
70 var cnt int
71 row.Scan(&cnt)
72 return cnt > 0
73 }
74
54 func newdbBaseMysql() dbBaser { 75 func newdbBaseMysql() dbBaser {
55 b := new(dbBaseMysql) 76 b := new(dbBaseMysql)
56 b.ins = b 77 b.ins = b
......
...@@ -107,10 +107,26 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) ...@@ -107,10 +107,26 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool)
107 return 107 return
108 } 108 }
109 109
110 func (d *dbBasePostgres) ShowTablesQuery() string {
111 return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
112 }
113
114 func (d *dbBasePostgres) ShowColumnsQuery(table string) string {
115 return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table)
116 }
117
110 func (d *dbBasePostgres) DbTypes() map[string]string { 118 func (d *dbBasePostgres) DbTypes() map[string]string {
111 return postgresTypes 119 return postgresTypes
112 } 120 }
113 121
122 func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool {
123 query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
124 row := db.QueryRow(query)
125 var cnt int
126 row.Scan(&cnt)
127 return cnt > 0
128 }
129
114 func newdbBasePostgres() dbBaser { 130 func newdbBasePostgres() dbBaser {
115 b := new(dbBasePostgres) 131 b := new(dbBasePostgres)
116 b.ins = b 132 b.ins = b
......
1 package orm 1 package orm
2 2
3 import ( 3 import (
4 "database/sql"
4 "fmt" 5 "fmt"
5 ) 6 )
6 7
...@@ -67,6 +68,51 @@ func (d *dbBaseSqlite) DbTypes() map[string]string { ...@@ -67,6 +68,51 @@ func (d *dbBaseSqlite) DbTypes() map[string]string {
67 return sqliteTypes 68 return sqliteTypes
68 } 69 }
69 70
71 func (d *dbBaseSqlite) ShowTablesQuery() string {
72 return "SELECT name FROM sqlite_master WHERE type = 'table'"
73 }
74
75 func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
76 query := d.ins.ShowColumnsQuery(table)
77 rows, err := db.Query(query)
78 if err != nil {
79 return nil, err
80 }
81
82 columns := make(map[string][3]string)
83 for rows.Next() {
84 var tmp, name, typ, null sql.NullString
85 err := rows.Scan(&tmp, &name, &typ, &null, &tmp, &tmp)
86 if err != nil {
87 return nil, err
88 }
89 columns[name.String] = [3]string{name.String, typ.String, null.String}
90 }
91
92 return columns, nil
93 }
94
95 func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
96 return fmt.Sprintf("pragma table_info('%s')", table)
97 }
98
99 func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool {
100 query := fmt.Sprintf("PRAGMA index_list('%s')", table)
101 rows, err := db.Query(query)
102 if err != nil {
103 panic(err)
104 }
105 defer rows.Close()
106 for rows.Next() {
107 var tmp, index sql.NullString
108 rows.Scan(&tmp, &index, &tmp)
109 if name == index.String {
110 return true
111 }
112 }
113 return false
114 }
115
70 func newdbBaseSqlite() dbBaser { 116 func newdbBaseSqlite() dbBaser {
71 b := new(dbBaseSqlite) 117 b := new(dbBaseSqlite)
72 b.ins = b 118 b.ins = b
......
...@@ -198,28 +198,8 @@ func TestSyncDb(t *testing.T) { ...@@ -198,28 +198,8 @@ func TestSyncDb(t *testing.T) {
198 RegisterModel(new(Comment)) 198 RegisterModel(new(Comment))
199 RegisterModel(new(UserBig)) 199 RegisterModel(new(UserBig))
200 200
201 BootStrap() 201 err := RunSyncdb("default", true, false)
202 202 throwFail(t, err)
203 al := dataBaseCache.getDefault()
204 db := al.DB
205
206 drops := getDbDropSql(al)
207 for _, query := range drops {
208 _, err := db.Exec(query)
209 throwFail(t, err, query)
210 }
211
212 sqls, indexes := getDbCreateSql(al)
213
214 for i, mi := range modelCache.allOrdered() {
215 queries := []string{sqls[i]}
216 queries = append(queries, indexes[mi.table]...)
217
218 for _, query := range queries {
219 _, err := db.Exec(query)
220 throwFail(t, err, query)
221 }
222 }
223 203
224 modelCache.clean() 204 modelCache.clean()
225 } 205 }
......
...@@ -133,4 +133,9 @@ type dbBaser interface { ...@@ -133,4 +133,9 @@ type dbBaser interface {
133 TimeFromDB(*time.Time, *time.Location) 133 TimeFromDB(*time.Time, *time.Location)
134 TimeToDB(*time.Time, *time.Location) 134 TimeToDB(*time.Time, *time.Location)
135 DbTypes() map[string]string 135 DbTypes() map[string]string
136 GetTables(dbQuerier) (map[string]bool, error)
137 GetColumns(dbQuerier, string) (map[string][3]string, error)
138 ShowTablesQuery() string
139 ShowColumnsQuery(string) string
140 IndexExists(dbQuerier, string, string) bool
136 } 141 }
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!