@@ -50,7 +50,6 @@ func detectNodeType(node *pg_query.Node) []*pg_query.Node {
5050 DropStmt (stmt )
5151
5252 case * pg_query.Node_AlterTableStmt :
53- fmt .Println ("Alter Type" )
5453 return AlterStmt (node )
5554
5655 case * pg_query.Node_SelectStmt :
@@ -88,10 +87,6 @@ func AlterStmt(node *pg_query.Node) []*pg_query.Node {
8887 for _ , cmd := range initialCommands {
8988 switch v := cmd .Node .(type ) {
9089 case * pg_query.Node_AlterTableCmd :
91- fmt .Printf ("%#v\n " , v )
92- fmt .Printf ("%#v\n " , v .AlterTableCmd .Def .Node )
93- fmt .Println (v .AlterTableCmd .Subtype .Enum ())
94-
9590 switch v .AlterTableCmd .Subtype {
9691 case pg_query .AlterTableType_AT_AddColumn :
9792 def := v .AlterTableCmd .Def .GetColumnDef ()
@@ -106,13 +101,13 @@ func AlterStmt(node *pg_query.Node) []*pg_query.Node {
106101 if index , ok := constraintsMap [pg_query .ConstrType_CONSTR_DEFAULT ]; ok {
107102 def .Constraints = make ([]* pg_query.Node , 0 )
108103
109- alterStmts = append (alterStmts , node )
104+ alterStmts = append (alterStmts , wrapTransaction ([] * pg_query. Node { node }) ... )
110105
111106 defaultDefinitionTemp := fmt .Sprintf (`alter table %s alter column %s set default %v;` ,
112107 alterTableStmt .GetRelation ().GetRelname (), def .Colname ,
113108 constraints [index ].GetConstraint ().GetRawExpr ().GetAConst ().GetVal ().GetInteger ().GetIval ())
114109
115- alterStmts = append (alterStmts , generateNodes (defaultDefinitionTemp )... )
110+ alterStmts = append (alterStmts , wrapTransaction ( generateNodes (defaultDefinitionTemp ) )... )
116111
117112 // TODO: Update rows
118113
@@ -125,12 +120,12 @@ func AlterStmt(node *pg_query.Node) []*pg_query.Node {
125120 constraint := v .AlterTableCmd .Def .GetConstraint ()
126121 constraint .SkipValidation = true
127122
128- alterStmts = append (alterStmts , node )
123+ alterStmts = append (alterStmts , wrapTransaction ([] * pg_query. Node { node }) ... )
129124
130- validationTemp := fmt .Sprintf (`begin; alter table %s validate constraint %s; commit ;` ,
125+ validationTemp := fmt .Sprintf (`alter table %s validate constraint %s;` ,
131126 alterTableStmt .GetRelation ().GetRelname (), constraint .GetConname ())
132127
133- alterStmts = append (alterStmts , generateNodes (validationTemp )... )
128+ alterStmts = append (alterStmts , wrapTransaction ( generateNodes (validationTemp ) )... )
134129
135130 default :
136131 alterStmts = append (alterStmts , node )
@@ -160,3 +155,14 @@ func generateNodes(nodeTemplate string) []*pg_query.Node {
160155
161156 return nodes
162157}
158+
159+ // wrapTransaction wraps nodes into transaction statements.
160+ func wrapTransaction (nodes []* pg_query.Node ) []* pg_query.Node {
161+ begin := makeBeginTransactionStmt ()
162+ commit := makeCommitTransactionStmt ()
163+
164+ nodes = append ([]* pg_query.Node {begin }, nodes ... )
165+ nodes = append (nodes , commit )
166+
167+ return nodes
168+ }
0 commit comments